ai-edge-torch-nightly 0.5.0.dev20250424__py3-none-any.whl → 0.5.0.dev20250426__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -3
  2. ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
  5. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -3
  6. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
  7. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -38
  8. ai_edge_torch/generative/examples/hammer/__init__.py +14 -0
  9. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +92 -0
  10. ai_edge_torch/generative/examples/hammer/hammer.py +107 -0
  11. ai_edge_torch/generative/examples/hammer/verify.py +86 -0
  12. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -3
  13. ai_edge_torch/generative/examples/llama/llama.py +3 -1
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -2
  15. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -2
  16. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -2
  17. ai_edge_torch/generative/examples/phi/phi2.py +1 -1
  18. ai_edge_torch/generative/examples/phi/phi3.py +3 -1
  19. ai_edge_torch/generative/examples/phi/phi4.py +3 -1
  20. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -3
  21. ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
  22. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +5 -3
  23. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +4 -3
  24. ai_edge_torch/generative/examples/smollm/smollm.py +3 -1
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -2
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +3 -1
  27. ai_edge_torch/generative/layers/kv_cache.py +2 -4
  28. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
  29. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +4 -6
  30. ai_edge_torch/generative/test/test_model_conversion.py +3 -33
  31. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -75
  32. ai_edge_torch/generative/utilities/converter.py +11 -1
  33. ai_edge_torch/generative/utilities/export_config.py +30 -0
  34. ai_edge_torch/model.py +2 -0
  35. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/RECORD +41 -39
  39. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
  40. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
  41. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/LICENSE +0 -0
  42. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/WHEEL +0 -0
  43. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/top_level.txt +0 -0
@@ -53,6 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  norm_config = cfg.NormalizationConfig(
54
54
  type=cfg.NormalizationType.RMS_NORM,
55
55
  epsilon=1e-06,
56
+ enable_hlfb=True,
56
57
  )
57
58
  block_config = cfg.TransformerBlockConfig(
58
59
  attn_config=attn_config,
@@ -35,6 +35,10 @@ def main(_):
35
35
  pytorch_model = smollm.build_model(
36
36
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
37
37
  )
38
+
39
+ export_config = export_cfg.get_from_flags()
40
+ export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
41
+
38
42
  converter.convert_to_tflite(
39
43
  pytorch_model,
40
44
  output_path=flags.FLAGS.output_path,
@@ -42,9 +46,7 @@ def main(_):
42
46
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43
47
  quantize=flags.FLAGS.quantize,
44
48
  lora_ranks=flags.FLAGS.lora_ranks,
45
- export_config=export_cfg.ExportConfig(
46
- decode_batch_size=_DECODE_BATCH_SIZE.value
47
- ),
49
+ export_config=export_config,
48
50
  )
49
51
 
50
52
 
@@ -34,6 +34,9 @@ def main(_):
34
34
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
35
35
  )
36
36
 
37
+ export_config = export_cfg.get_from_flags()
38
+ export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
39
+
37
40
  converter.convert_to_tflite(
38
41
  pytorch_model,
39
42
  output_path=flags.FLAGS.output_path,
@@ -41,9 +44,7 @@ def main(_):
41
44
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
42
45
  quantize=flags.FLAGS.quantize,
43
46
  lora_ranks=flags.FLAGS.lora_ranks,
44
- export_config=export_cfg.ExportConfig(
45
- decode_batch_size=_DECODE_BATCH_SIZE.value
46
- ),
47
+ export_config=export_config,
47
48
  )
48
49
 
49
50
 
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
49
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
50
50
  intermediate_size=1536,
51
51
  )
52
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
52
+ norm_config = cfg.NormalizationConfig(
53
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54
+ )
53
55
  block_config = cfg.TransformerBlockConfig(
54
56
  attn_config=attn_config,
55
57
  ff_config=ff_config,
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("tiny_llama")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
49
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
50
50
  intermediate_size=5632,
51
51
  )
52
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
52
+ norm_config = cfg.NormalizationConfig(
53
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54
+ )
53
55
  block_config = cfg.TransformerBlockConfig(
54
56
  attn_config=attn_config,
55
57
  ff_config=ff_config,
@@ -51,10 +51,7 @@ class KVCacheEntry:
51
51
  config: model_config.AttentionConfig,
52
52
  batch_size: int,
53
53
  ) -> List[int]:
54
- """Constructs the shape of the key or value cache entry based on
55
-
56
- the specified layout.
57
- """
54
+ """Construct the shape of KV cache entry based on the specified layout."""
58
55
  output_shape = []
59
56
  for dim_spec in shape_spec:
60
57
  if dim_spec is types.TensorDims.BATCH:
@@ -213,6 +210,7 @@ pytree.register_pytree_node(
213
210
  serialized_type_name="",
214
211
  )
215
212
 
213
+
216
214
  def update(
217
215
  cache: KVCacheEntry,
218
216
  input_pos: torch.Tensor,
@@ -17,6 +17,8 @@
17
17
  import math
18
18
  from typing import Optional
19
19
 
20
+ from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
22
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
21
23
  import torch
22
24
  import torch.nn.functional as F
@@ -142,3 +144,52 @@ def scaled_dot_product_attention_with_hlfb(
142
144
  result = y.transpose(1, 2)
143
145
  result = builder.mark_outputs(result)
144
146
  return result
147
+
148
+
149
+ def scaled_dot_product_attention_transposed(
150
+ query: torch.Tensor,
151
+ key: torch.Tensor,
152
+ value: torch.Tensor,
153
+ head_size: int,
154
+ mask: Optional[torch.Tensor] = None,
155
+ scale: Optional[float] = None,
156
+ softcap: Optional[float] = None,
157
+ ):
158
+ """Scaled dot product attention with transposed key and value.
159
+
160
+ Args:
161
+ query: Query tensor, with shape [B, T, N, H].
162
+ key: Key tensor, with shape [B, T, KV_LEN, H].
163
+ value: Value tensor, with shape [B, T, KV_LEN, H].
164
+ head_size (int): head dimension.
165
+ mask (torch.Tensor): the optional mask tensor.
166
+ scale (float): the optional scale factor.
167
+ softcap (float): the optional softcap for the logits.
168
+
169
+ Returns:
170
+ The output tensor of scaled_dot_product_attention_transposed.
171
+ """
172
+
173
+ if scale is None:
174
+ scale = 1.0 / math.sqrt(head_size)
175
+
176
+ query = query * scale
177
+
178
+ assert mask is not None, "Mask should not be None!"
179
+ t = mask.shape[2]
180
+
181
+ logits = bmm_lib.bmm_4d(query, key)
182
+
183
+ _, bk, gt, s = logits.shape
184
+ g = gt // t
185
+ logits = logits.reshape((bk, g, t, s))
186
+ if softcap is not None:
187
+ logits = torch.tanh(logits / softcap)
188
+ logits = logits * softcap
189
+
190
+ padded_logits = logits + mask
191
+ padded_logits = padded_logits.reshape(1, bk, gt, s)
192
+ probs = F.softmax(padded_logits, dim=-1).type_as(key)
193
+ encoded = bmm_lib.bmm_4d(probs, value)
194
+
195
+ return encoded # 1, bk, gt, h
@@ -18,9 +18,8 @@
18
18
  from typing import Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
21
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
22
22
  from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
23
- from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
24
23
  import ai_edge_torch.generative.layers.model_config as cfg
25
24
  import torch
26
25
 
@@ -72,8 +71,7 @@ def _sdpa_with_kv_update_transposed(
72
71
  kv = kv_utils_experimental.update(kv, input_pos, key, value)
73
72
  key, value = kv.k_cache, kv.v_cache
74
73
 
75
- sdpa_out = sdpa.scaled_dot_product_attention(
76
- kv,
74
+ sdpa_out = sdpa.scaled_dot_product_attention_transposed(
77
75
  query,
78
76
  key,
79
77
  value,
@@ -105,9 +103,9 @@ def _sdpa_with_kv_update_default(
105
103
  key, value = kv.k_cache, kv.v_cache
106
104
 
107
105
  if enable_hlfb:
108
- sdpa_func = sdpa_default.scaled_dot_product_attention_with_hlfb
106
+ sdpa_func = sdpa.scaled_dot_product_attention_with_hlfb
109
107
  else:
110
- sdpa_func = sdpa_default.scaled_dot_product_attention
108
+ sdpa_func = sdpa.scaled_dot_product_attention
111
109
  sdpa_out = sdpa_func(
112
110
  query,
113
111
  key,
@@ -32,10 +32,8 @@ class TestModelConversion(googletest.TestCase):
32
32
 
33
33
  def setUp(self):
34
34
  super().setUp()
35
- # Builder function for an Interpreter that supports custom ops.
36
35
  self._interpreter_builder = (
37
- lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
38
- custom_op_registerers=["GenAIOpsRegisterer"],
36
+ lambda tflite_model: lambda: interpreter.Interpreter(
39
37
  model_content=tflite_model,
40
38
  experimental_default_delegate_latest_features=True,
41
39
  )
@@ -85,44 +83,24 @@ class TestModelConversion(googletest.TestCase):
85
83
  )
86
84
  )
87
85
 
88
- @googletest.skipIf(
89
- ai_edge_torch.config.in_oss,
90
- reason="tests with custom ops are not supported in oss",
91
- )
92
86
  def test_toy_model_with_kv_cache(self):
93
87
  self._test_model_with_kv_cache(enable_hlfb=False)
94
88
 
95
- @googletest.skipIf(
96
- ai_edge_torch.config.in_oss,
97
- reason="tests with custom ops are not supported in oss",
98
- )
99
89
  def test_toy_model_with_kv_cache_with_hlfb(self):
100
90
  self._test_model_with_kv_cache(enable_hlfb=True)
101
91
 
102
- @googletest.skipIf(
103
- ai_edge_torch.config.in_oss,
104
- reason="tests with custom ops are not supported in oss",
105
- )
106
92
  def test_toy_model_with_kv_cache_transposed(self):
107
93
  self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
108
94
 
109
- @googletest.skipIf(
110
- ai_edge_torch.config.in_oss,
111
- reason="tests with custom ops are not supported in oss",
112
- )
113
95
  def test_toy_model_has_dus_op(self):
114
96
  """Tests that the model has the dynamic update slice op."""
115
97
  _, edge_model, _ = self._get_params(
116
98
  enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
117
99
  )
118
- interpreter_ = interpreter.InterpreterWithCustomOps(
119
- custom_op_registerers=["GenAIOpsRegisterer"],
120
- model_content=edge_model.tflite_model(),
121
- experimental_default_delegate_latest_features=True,
122
- )
100
+ interpreter = self._interpreter_builder(edge_model.tflite_model())()
123
101
 
124
102
  # pylint: disable=protected-access
125
- op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
103
+ op_names = [op["op_name"] for op in interpreter._get_ops_details()]
126
104
  self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
127
105
 
128
106
  def _test_multisig_model(
@@ -197,19 +175,11 @@ class TestModelConversion(googletest.TestCase):
197
175
  )
198
176
  )
199
177
 
200
- @googletest.skipIf(
201
- ai_edge_torch.config.in_oss,
202
- reason="tests with custom ops are not supported in oss",
203
- )
204
178
  def test_tiny_llama_multisig(self):
205
179
  config = tiny_llama.get_fake_model_config()
206
180
  pytorch_model = tiny_llama.TinyLlama(config).eval()
207
181
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
208
182
 
209
- @googletest.skipIf(
210
- ai_edge_torch.config.in_oss,
211
- reason="tests with custom ops are not supported in oss",
212
- )
213
183
  def test_tiny_llama_multisig_kv_layout_transposed(self):
214
184
  config = tiny_llama.get_fake_model_config()
215
185
  pytorch_model = tiny_llama.TinyLlama(config).eval()
@@ -20,6 +20,7 @@ from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
20
  from ai_edge_torch.generative.examples.deepseek import deepseek
21
21
  from ai_edge_torch.generative.examples.gemma import gemma1
22
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.examples.hammer import hammer
23
24
  from ai_edge_torch.generative.examples.llama import llama
24
25
  from ai_edge_torch.generative.examples.openelm import openelm
25
26
  from ai_edge_torch.generative.examples.paligemma import decoder
@@ -48,10 +49,8 @@ class TestModelConversion(googletest.TestCase):
48
49
 
49
50
  def setUp(self):
50
51
  super().setUp()
51
- # Builder function for an Interpreter that supports custom ops.
52
52
  self._interpreter_builder = (
53
- lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
54
- custom_op_registerers=["GenAIOpsRegisterer"],
53
+ lambda tflite_model: lambda: interpreter.Interpreter(
55
54
  model_content=tflite_model,
56
55
  experimental_default_delegate_latest_features=True,
57
56
  )
@@ -94,110 +93,68 @@ class TestModelConversion(googletest.TestCase):
94
93
  )
95
94
  )
96
95
 
97
- @googletest.skipIf(
98
- ai_edge_torch.config.in_oss,
99
- reason="tests with custom ops are not supported in oss",
100
- )
101
96
  def test_gemma1(self):
102
97
  config = gemma1.get_fake_model_config()
103
98
  pytorch_model = gemma1.Gemma1(config).eval()
104
99
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
105
100
 
106
- @googletest.skipIf(
107
- ai_edge_torch.config.in_oss,
108
- reason="tests with custom ops are not supported in oss",
109
- )
110
101
  def test_gemma2(self):
111
102
  config = gemma2.get_fake_model_config()
112
103
  pytorch_model = gemma2.Gemma2(config).eval()
113
104
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
114
105
 
115
- @googletest.skipIf(
116
- ai_edge_torch.config.in_oss,
117
- reason="tests with custom ops are not supported in oss",
118
- )
119
106
  def test_llama(self):
120
107
  config = llama.get_fake_model_config()
121
108
  pytorch_model = llama.Llama(config).eval()
122
109
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
123
110
 
124
- @googletest.skipIf(
125
- ai_edge_torch.config.in_oss,
126
- reason="tests with custom ops are not supported in oss",
127
- )
128
111
  def test_phi2(self):
129
112
  config = phi2.get_fake_model_config()
130
113
  pytorch_model = phi2.Phi2(config).eval()
131
114
  # Phi-2 logits are very big, so we need a larger absolute tolerance.
132
115
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
133
116
 
134
- @googletest.skipIf(
135
- ai_edge_torch.config.in_oss,
136
- reason="tests with custom ops are not supported in oss",
137
- )
138
117
  def test_phi3(self):
139
118
  config = phi3.get_fake_model_config()
140
119
  pytorch_model = phi3.Phi3_5Mini(config).eval()
141
120
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
142
121
 
143
- @googletest.skipIf(
144
- ai_edge_torch.config.in_oss,
145
- reason="tests with custom ops are not supported in oss",
146
- )
147
122
  def test_phi4(self):
148
123
  config = phi4.get_fake_model_config()
149
124
  pytorch_model = phi4.Phi4Mini(config).eval()
150
125
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
151
126
 
152
- @googletest.skipIf(
153
- ai_edge_torch.config.in_oss,
154
- reason="tests with custom ops are not supported in oss",
155
- )
156
127
  def test_smollm(self):
157
128
  config = smollm.get_fake_model_config()
158
129
  pytorch_model = smollm.SmolLM(config).eval()
159
130
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
160
131
 
161
- @googletest.skipIf(
162
- ai_edge_torch.config.in_oss,
163
- reason="tests with custom ops are not supported in oss",
164
- )
165
132
  def test_smollm2(self):
166
133
  config = smollm.get_fake_model_config_v2()
167
134
  pytorch_model = smollm.SmolLM2(config).eval()
168
135
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
169
136
 
170
- @googletest.skipIf(
171
- ai_edge_torch.config.in_oss,
172
- reason="tests with custom ops are not supported in oss",
173
- )
174
137
  def test_openelm(self):
175
138
  config = openelm.get_fake_model_config()
176
139
  pytorch_model = openelm.OpenELM(config).eval()
177
140
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
178
141
 
179
- @googletest.skipIf(
180
- ai_edge_torch.config.in_oss,
181
- reason="tests with custom ops are not supported in oss",
182
- )
183
142
  def test_qwen(self):
184
143
  config = qwen.get_fake_model_config()
185
144
  pytorch_model = qwen.Qwen(config).eval()
186
145
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
187
146
 
188
- @googletest.skipIf(
189
- ai_edge_torch.config.in_oss,
190
- reason="tests with custom ops are not supported in oss",
191
- )
192
147
  def test_deepseek(self):
193
148
  config = deepseek.get_fake_model_config()
194
149
  pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
195
150
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
196
151
 
197
- @googletest.skipIf(
198
- ai_edge_torch.config.in_oss,
199
- reason="tests with custom ops are not supported in oss",
200
- )
152
+ def test_hammer(self):
153
+ config = hammer.get_fake_model_config()
154
+ pytorch_model = hammer.Hammer(config).eval()
155
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
156
+
157
+
201
158
  def test_amd_llama_135m(self):
202
159
  config = amd_llama_135m.get_fake_model_config()
203
160
  pytorch_model = amd_llama_135m.AmdLlama(config).eval()
@@ -246,19 +203,11 @@ class TestModelConversion(googletest.TestCase):
246
203
  )
247
204
  )
248
205
 
249
- @googletest.skipIf(
250
- ai_edge_torch.config.in_oss,
251
- reason="tests with custom ops are not supported in oss",
252
- )
253
206
  def test_paligemma1(self):
254
207
  self._test_paligemma_model(
255
208
  decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
256
209
  )
257
210
 
258
- @googletest.skipIf(
259
- ai_edge_torch.config.in_oss,
260
- reason="tests with custom ops are not supported in oss",
261
- )
262
211
  def test_paligemma2(self):
263
212
  self._test_paligemma_model(
264
213
  decoder2.Decoder2,
@@ -267,10 +216,6 @@ class TestModelConversion(googletest.TestCase):
267
216
  rtol=1e-5,
268
217
  )
269
218
 
270
- @googletest.skipIf(
271
- ai_edge_torch.config.in_oss,
272
- reason="tests with custom ops are not supported in oss",
273
- )
274
219
  def test_qwen_vl_model(self):
275
220
  config = qwen_vl.get_fake_model_config()
276
221
  pytorch_model = qwen_vl.QwenVL(config).eval()
@@ -316,10 +261,7 @@ class TestModelConversion(googletest.TestCase):
316
261
  )
317
262
  )
318
263
 
319
- @googletest.skipIf(
320
- ai_edge_torch.config.in_oss,
321
- reason="tests with custom ops are not supported in oss",
322
- )
264
+ @googletest.skipIf(ai_edge_torch.config.in_oss, reason="flaky")
323
265
  def test_stable_diffusion_clip(self):
324
266
  config = sd_clip.get_fake_model_config()
325
267
  prompt_tokens = torch.from_numpy(
@@ -348,10 +290,7 @@ class TestModelConversion(googletest.TestCase):
348
290
  )
349
291
  )
350
292
 
351
- @googletest.skipIf(
352
- ai_edge_torch.config.in_oss,
353
- reason="tests with custom ops are not supported in oss",
354
- )
293
+ @googletest.skipIf(ai_edge_torch.config.in_oss, reason="b/413106901")
355
294
  def test_stable_diffusion_diffusion(self):
356
295
  config = sd_diffusion.get_fake_model_config(2)
357
296
  # Reduce stddev(scale) of input values to avoid too big output logits which
@@ -390,10 +329,6 @@ class TestModelConversion(googletest.TestCase):
390
329
  )
391
330
  )
392
331
 
393
- @googletest.skipIf(
394
- ai_edge_torch.config.in_oss,
395
- reason="tests with custom ops are not supported in oss",
396
- )
397
332
  def test_stable_diffusion_decoder(self):
398
333
  config = sd_decoder.get_fake_model_config()
399
334
  # Reduce stddev(scale) of input values to avoid too big output logits which
@@ -81,7 +81,17 @@ def define_conversion_flags(model_name: str):
81
81
  'If set, the model will be converted with the provided list of LoRA'
82
82
  ' ranks.',
83
83
  )
84
-
84
+ flags.DEFINE_bool(
85
+ 'mask_as_input',
86
+ False,
87
+ 'If true, the mask will be passed in as input. Otherwise, mask will be '
88
+ 'built by the model internally.',
89
+ )
90
+ flags.DEFINE_bool(
91
+ 'transpose_kv_cache',
92
+ False,
93
+ 'If true, the model will be converted with transposed KV cache.',
94
+ )
85
95
  return flags
86
96
 
87
97
 
@@ -14,8 +14,11 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """Config for customizing model export process."""
17
+
17
18
  import dataclasses
18
19
  from typing import List, Optional
20
+
21
+ from absl import flags
19
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
23
  import torch
21
24
 
@@ -38,3 +41,30 @@ class ExportConfig:
38
41
  kvcache_cls: type = kv_utils.KVCache
39
42
  # The batch size of the decode signature.
40
43
  decode_batch_size: int = 1
44
+
45
+
46
+ def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
47
+ if isinstance(mask_len, list):
48
+ return [_build_mask(i, kv_cache_max_len) for i in mask_len]
49
+
50
+ mask = torch.full(
51
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
52
+ )
53
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
54
+ return mask
55
+
56
+
57
+ def get_from_flags() -> ExportConfig:
58
+ """Builds an export config according to the commandline flags."""
59
+ export_config = ExportConfig()
60
+
61
+ if flags.FLAGS.mask_as_input:
62
+ export_config.prefill_mask = _build_mask(
63
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
64
+ )
65
+ export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
66
+
67
+ if flags.FLAGS.transpose_kv_cache:
68
+ export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
69
+
70
+ return export_config
ai_edge_torch/model.py CHANGED
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
 
23
23
  import abc
24
24
  import re
25
+ import os
25
26
  from typing import Callable
26
27
 
27
28
  import numpy.typing as npt
@@ -154,6 +155,7 @@ class TfLiteModel(Model):
154
155
  Args:
155
156
  path: The path to file to which the model is serialized.
156
157
  """
158
+ os.makedirs(os.path.dirname(path), exist_ok=True)
157
159
  with open(path, 'wb') as file_handle:
158
160
  file_handle.write(self._tflite_model)
159
161
 
@@ -34,6 +34,8 @@ fx_infra.decomp.update_pre_lower_decomp(
34
34
  torch.ops.aten.replication_pad1d,
35
35
  torch.ops.aten.replication_pad2d,
36
36
  torch.ops.aten.replication_pad3d,
37
+ torch.ops.aten.upsample_bilinear2d.vec,
38
+ torch.ops.aten.upsample_nearest2d.vec,
37
39
  torch.ops.aten.addmm,
38
40
  ])
39
41
  )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250424"
16
+ __version__ = "0.5.0.dev20250426"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250424
3
+ Version: 0.5.0.dev20250426
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI