ai-edge-torch-nightly 0.3.0.dev20241226__py3-none-any.whl → 0.3.0.dev20250105__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,9 +14,9 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import logging
17
- import os
18
17
  from typing import Any, Literal, Optional, Union
19
18
 
19
+ import ai_edge_torch
20
20
  from ai_edge_torch import fx_pass_base
21
21
  from ai_edge_torch import lowertools
22
22
  from ai_edge_torch import model
@@ -26,8 +26,6 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
26
26
  from ai_edge_torch.quantize import quant_config as qcfg
27
27
  import torch
28
28
 
29
- os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
30
-
31
29
 
32
30
  def _run_convert_passes(
33
31
  exported_program: torch.export.ExportedProgram,
@@ -35,21 +33,27 @@ def _run_convert_passes(
35
33
  exported_program = generative_fx_passes.run_generative_passes(
36
34
  exported_program
37
35
  )
38
- exported_program = fx_pass_base.run_passes(
39
- exported_program,
40
- [
41
- fx_passes.BuildInterpolateCompositePass(),
42
- fx_passes.CanonicalizePass(),
43
- fx_passes.OptimizeLayoutTransposesPass(),
44
- fx_passes.CanonicalizePass(),
45
- fx_passes.BuildAtenCompositePass(),
46
- fx_passes.CanonicalizePass(),
47
- fx_passes.RemoveNonUserOutputsPass(),
48
- fx_passes.CanonicalizePass(),
49
- fx_passes.InjectMlirDebuginfoPass(),
50
- fx_passes.CanonicalizePass(),
51
- ],
52
- )
36
+
37
+ passes = [
38
+ fx_passes.BuildInterpolateCompositePass(),
39
+ fx_passes.CanonicalizePass(),
40
+ fx_passes.OptimizeLayoutTransposesPass(),
41
+ fx_passes.CanonicalizePass(),
42
+ fx_passes.BuildAtenCompositePass(),
43
+ fx_passes.CanonicalizePass(),
44
+ fx_passes.RemoveNonUserOutputsPass(),
45
+ fx_passes.CanonicalizePass(),
46
+ ]
47
+
48
+ # Debuginfo is not injected automatically by odml_torch. Only inject
49
+ # debuginfo via fx pass when using torch_xla.
50
+ if ai_edge_torch.config.use_torch_xla:
51
+ passes += [
52
+ fx_passes.InjectMlirDebuginfoPass(),
53
+ fx_passes.CanonicalizePass(),
54
+ ]
55
+
56
+ exported_program = fx_pass_base.run_passes(exported_program, passes)
53
57
  return exported_program
54
58
 
55
59
 
@@ -62,6 +62,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
62
62
 
63
63
 
64
64
  class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
65
+ """DEPRECATED: Debuginfo is injected automatically by odml_torch."""
65
66
 
66
67
  def call(self, graph_module: torch.fx.GraphModule):
67
68
  for node in graph_module.graph.nodes:
@@ -27,6 +27,9 @@ if "PJRT_DEVICE" not in os.environ:
27
27
  # https://github.com/google-ai-edge/ai-edge-torch/issues/326
28
28
  os.environ["PJRT_DEVICE"] = "CPU"
29
29
 
30
+ os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
31
+
32
+
30
33
  from ai_edge_torch import model
31
34
  from ai_edge_torch._convert import conversion_utils
32
35
  from ai_edge_torch._convert import signature as signature_module
@@ -202,7 +202,7 @@ class MlirLowered:
202
202
  target_version = stablehlo.get_minimum_version()
203
203
  else:
204
204
  target_version = stablehlo.get_version_from_compatibility_requirement(
205
- stablehlo.StablehloCompatibilityRequirement.WEEK_4
205
+ stablehlo.StablehloCompatibilityRequirement.WEEK_12
206
206
  )
207
207
  module_bytecode = xla_extension.mlir.serialize_portable_artifact(
208
208
  self.module_bytecode, target_version
@@ -222,11 +222,6 @@ class MlirLowered:
222
222
  # Lazy importing TF when execution is needed.
223
223
  return self.tf_function(*args)
224
224
 
225
- def to_flatbuffer(self):
226
- from . import tf_integration
227
-
228
- return tf_integration.mlir_to_flatbuffer(self)
229
-
230
225
 
231
226
  # TODO(b/331481564) Make this a ai_edge_torch FX pass.
232
227
  def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
@@ -12,10 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """APIs to convert lowered MLIR from PyTorch to TensorFlow and TFLite artifacts."""
15
+ """APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts."""
16
16
 
17
17
  import re
18
- import tempfile
19
18
 
20
19
  import tensorflow as tf
21
20
  import torch
@@ -104,20 +103,26 @@ def _extract_call_args(
104
103
  def _wrap_as_tf_func(lowered, tf_state_dict):
105
104
  """Build tf.function from lowered and tf_state_dict."""
106
105
 
107
- def inner(*args):
106
+ version = 6
107
+ if hasattr(tfxla, "call_module_maximum_supported_version"):
108
+ version = tfxla.call_module_maximum_supported_version()
109
+
110
+ def tf_func(*args):
108
111
  t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
109
112
  s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
110
113
  call_args = _extract_call_args(lowered, args, tf_state_dict)
111
114
  return tfxla.call_module(
112
115
  tuple(call_args),
113
- version=5,
116
+ version=version,
114
117
  Tout=t_outs, # dtype information
115
- Sout=s_outs, # Shape information
118
+ Sout=s_outs, # shape information
116
119
  function_list=[],
117
- module=lowered.module_bytecode,
120
+ module=lowered.module_bytecode_vhlo,
121
+ has_token_input_output=False,
122
+ platforms=["CPU"],
118
123
  )
119
124
 
120
- return inner
125
+ return tf_func
121
126
 
122
127
 
123
128
  def _make_input_signatures(
@@ -149,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered):
149
154
  _wrap_as_tf_func(lowered, tf_state_dict),
150
155
  input_signature=_make_input_signatures(lowered),
151
156
  )
152
-
153
-
154
- def mlir_to_flatbuffer(lowered: export.MlirLowered):
155
- """Convert the MLIR lowered to a TFLite flatbuffer binary."""
156
- tf_state_dict = _build_tf_state_dict(lowered)
157
- signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
158
- tf_signatures = [_make_input_signatures(lowered)]
159
- tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
160
-
161
- tf_module = tf.Module()
162
- tf_module.f = []
163
-
164
- for tf_sig, func in zip(tf_signatures, tf_functions):
165
- tf_module.f.append(
166
- tf.function(
167
- func,
168
- input_signature=tf_sig,
169
- )
170
- )
171
-
172
- tf_module._variables = list(tf_state_dict.values())
173
-
174
- tf_concrete_funcs = [
175
- func.get_concrete_function(*tf_sig)
176
- for func, tf_sig in zip(tf_module.f, tf_signatures)
177
- ]
178
-
179
- # We need to temporarily save since TFLite's from_concrete_functions does not
180
- # allow providing names for each of the concrete functions.
181
- with tempfile.TemporaryDirectory() as temp_dir_path:
182
- tf.saved_model.save(
183
- tf_module,
184
- temp_dir_path,
185
- signatures={
186
- sig_name: tf_concrete_funcs[idx]
187
- for idx, sig_name in enumerate(signature_names)
188
- },
189
- )
190
-
191
- converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
192
- tflite_model = converter.convert()
193
-
194
- return tflite_model
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241226"
16
+ __version__ = "0.3.0.dev20250105"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241226
3
+ Version: 0.3.0.dev20250105
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,9 +3,9 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=E_WDoV5Y1AG1Kq9M0_73bQYoSSnhDCJ7dxLCdKpkJJE,706
6
+ ai_edge_torch/version.py,sha256=rEruohWdKGtxlBLh9SF_NnC4pbAqrOU4MKG598yJRHY,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
8
+ ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
10
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
@@ -13,7 +13,7 @@ ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDiu
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
15
  ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
16
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
16
+ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
17
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
19
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
@@ -166,14 +166,14 @@ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5S
166
166
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
167
167
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
168
168
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
169
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
169
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
170
170
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
171
171
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
172
172
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
173
173
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
174
- ai_edge_torch/odml_torch/export.py,sha256=QzOPmcNPB7R-KhhPEP0oGVbDRgGPptIxRSoz3S8py9I,13405
174
+ ai_edge_torch/odml_torch/export.py,sha256=sqIMXmxK_qIuVC-_DNJ6wKlIWiXq4_WOCKbSqMRFudg,13293
175
175
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
176
- ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
176
+ ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
177
177
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
178
178
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
179
179
  ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
203
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
204
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
205
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
206
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/METADATA,sha256=khQQVRgopWndD2IbqOblhMzGAlOzSS6f0SbpP1oZ5xw,1966
208
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/METADATA,sha256=d8fPEhT1HG6ZlbX2joNTeIpEQNqth8LduM_W6aQZQn8,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/RECORD,,