ai-edge-torch-nightly 0.3.0.dev20241226__py3-none-any.whl → 0.3.0.dev20250105__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -14,9 +14,9 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import logging
17
- import os
18
17
  from typing import Any, Literal, Optional, Union
19
18
 
19
+ import ai_edge_torch
20
20
  from ai_edge_torch import fx_pass_base
21
21
  from ai_edge_torch import lowertools
22
22
  from ai_edge_torch import model
@@ -26,8 +26,6 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
26
26
  from ai_edge_torch.quantize import quant_config as qcfg
27
27
  import torch
28
28
 
29
- os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
30
-
31
29
 
32
30
  def _run_convert_passes(
33
31
  exported_program: torch.export.ExportedProgram,
@@ -35,21 +33,27 @@ def _run_convert_passes(
35
33
  exported_program = generative_fx_passes.run_generative_passes(
36
34
  exported_program
37
35
  )
38
- exported_program = fx_pass_base.run_passes(
39
- exported_program,
40
- [
41
- fx_passes.BuildInterpolateCompositePass(),
42
- fx_passes.CanonicalizePass(),
43
- fx_passes.OptimizeLayoutTransposesPass(),
44
- fx_passes.CanonicalizePass(),
45
- fx_passes.BuildAtenCompositePass(),
46
- fx_passes.CanonicalizePass(),
47
- fx_passes.RemoveNonUserOutputsPass(),
48
- fx_passes.CanonicalizePass(),
49
- fx_passes.InjectMlirDebuginfoPass(),
50
- fx_passes.CanonicalizePass(),
51
- ],
52
- )
36
+
37
+ passes = [
38
+ fx_passes.BuildInterpolateCompositePass(),
39
+ fx_passes.CanonicalizePass(),
40
+ fx_passes.OptimizeLayoutTransposesPass(),
41
+ fx_passes.CanonicalizePass(),
42
+ fx_passes.BuildAtenCompositePass(),
43
+ fx_passes.CanonicalizePass(),
44
+ fx_passes.RemoveNonUserOutputsPass(),
45
+ fx_passes.CanonicalizePass(),
46
+ ]
47
+
48
+ # Debuginfo is not injected automatically by odml_torch. Only inject
49
+ # debuginfo via fx pass when using torch_xla.
50
+ if ai_edge_torch.config.use_torch_xla:
51
+ passes += [
52
+ fx_passes.InjectMlirDebuginfoPass(),
53
+ fx_passes.CanonicalizePass(),
54
+ ]
55
+
56
+ exported_program = fx_pass_base.run_passes(exported_program, passes)
53
57
  return exported_program
54
58
 
55
59
 
@@ -62,6 +62,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
62
62
 
63
63
 
64
64
  class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
65
+ """DEPRECATED: Debuginfo is injected automatically by odml_torch."""
65
66
 
66
67
  def call(self, graph_module: torch.fx.GraphModule):
67
68
  for node in graph_module.graph.nodes:
@@ -27,6 +27,9 @@ if "PJRT_DEVICE" not in os.environ:
27
27
  # https://github.com/google-ai-edge/ai-edge-torch/issues/326
28
28
  os.environ["PJRT_DEVICE"] = "CPU"
29
29
 
30
+ os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
31
+
32
+
30
33
  from ai_edge_torch import model
31
34
  from ai_edge_torch._convert import conversion_utils
32
35
  from ai_edge_torch._convert import signature as signature_module
@@ -202,7 +202,7 @@ class MlirLowered:
202
202
  target_version = stablehlo.get_minimum_version()
203
203
  else:
204
204
  target_version = stablehlo.get_version_from_compatibility_requirement(
205
- stablehlo.StablehloCompatibilityRequirement.WEEK_4
205
+ stablehlo.StablehloCompatibilityRequirement.WEEK_12
206
206
  )
207
207
  module_bytecode = xla_extension.mlir.serialize_portable_artifact(
208
208
  self.module_bytecode, target_version
@@ -222,11 +222,6 @@ class MlirLowered:
222
222
  # Lazy importing TF when execution is needed.
223
223
  return self.tf_function(*args)
224
224
 
225
- def to_flatbuffer(self):
226
- from . import tf_integration
227
-
228
- return tf_integration.mlir_to_flatbuffer(self)
229
-
230
225
 
231
226
  # TODO(b/331481564) Make this a ai_edge_torch FX pass.
232
227
  def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
@@ -12,10 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """APIs to convert lowered MLIR from PyTorch to TensorFlow and TFLite artifacts."""
15
+ """APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts."""
16
16
 
17
17
  import re
18
- import tempfile
19
18
 
20
19
  import tensorflow as tf
21
20
  import torch
@@ -104,20 +103,26 @@ def _extract_call_args(
104
103
  def _wrap_as_tf_func(lowered, tf_state_dict):
105
104
  """Build tf.function from lowered and tf_state_dict."""
106
105
 
107
- def inner(*args):
106
+ version = 6
107
+ if hasattr(tfxla, "call_module_maximum_supported_version"):
108
+ version = tfxla.call_module_maximum_supported_version()
109
+
110
+ def tf_func(*args):
108
111
  t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
109
112
  s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
110
113
  call_args = _extract_call_args(lowered, args, tf_state_dict)
111
114
  return tfxla.call_module(
112
115
  tuple(call_args),
113
- version=5,
116
+ version=version,
114
117
  Tout=t_outs, # dtype information
115
- Sout=s_outs, # Shape information
118
+ Sout=s_outs, # shape information
116
119
  function_list=[],
117
- module=lowered.module_bytecode,
120
+ module=lowered.module_bytecode_vhlo,
121
+ has_token_input_output=False,
122
+ platforms=["CPU"],
118
123
  )
119
124
 
120
- return inner
125
+ return tf_func
121
126
 
122
127
 
123
128
  def _make_input_signatures(
@@ -149,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered):
149
154
  _wrap_as_tf_func(lowered, tf_state_dict),
150
155
  input_signature=_make_input_signatures(lowered),
151
156
  )
152
-
153
-
154
- def mlir_to_flatbuffer(lowered: export.MlirLowered):
155
- """Convert the MLIR lowered to a TFLite flatbuffer binary."""
156
- tf_state_dict = _build_tf_state_dict(lowered)
157
- signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
158
- tf_signatures = [_make_input_signatures(lowered)]
159
- tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
160
-
161
- tf_module = tf.Module()
162
- tf_module.f = []
163
-
164
- for tf_sig, func in zip(tf_signatures, tf_functions):
165
- tf_module.f.append(
166
- tf.function(
167
- func,
168
- input_signature=tf_sig,
169
- )
170
- )
171
-
172
- tf_module._variables = list(tf_state_dict.values())
173
-
174
- tf_concrete_funcs = [
175
- func.get_concrete_function(*tf_sig)
176
- for func, tf_sig in zip(tf_module.f, tf_signatures)
177
- ]
178
-
179
- # We need to temporarily save since TFLite's from_concrete_functions does not
180
- # allow providing names for each of the concrete functions.
181
- with tempfile.TemporaryDirectory() as temp_dir_path:
182
- tf.saved_model.save(
183
- tf_module,
184
- temp_dir_path,
185
- signatures={
186
- sig_name: tf_concrete_funcs[idx]
187
- for idx, sig_name in enumerate(signature_names)
188
- },
189
- )
190
-
191
- converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
192
- tflite_model = converter.convert()
193
-
194
- return tflite_model
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241226"
16
+ __version__ = "0.3.0.dev20250105"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241226
3
+ Version: 0.3.0.dev20250105
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,9 +3,9 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=E_WDoV5Y1AG1Kq9M0_73bQYoSSnhDCJ7dxLCdKpkJJE,706
6
+ ai_edge_torch/version.py,sha256=rEruohWdKGtxlBLh9SF_NnC4pbAqrOU4MKG598yJRHY,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
8
+ ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
10
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
@@ -13,7 +13,7 @@ ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDiu
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
15
  ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
16
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
16
+ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
17
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
19
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
@@ -166,14 +166,14 @@ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5S
166
166
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
167
167
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
168
168
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
169
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
169
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
170
170
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
171
171
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
172
172
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
173
173
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
174
- ai_edge_torch/odml_torch/export.py,sha256=QzOPmcNPB7R-KhhPEP0oGVbDRgGPptIxRSoz3S8py9I,13405
174
+ ai_edge_torch/odml_torch/export.py,sha256=sqIMXmxK_qIuVC-_DNJ6wKlIWiXq4_WOCKbSqMRFudg,13293
175
175
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
176
- ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
176
+ ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
177
177
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
178
178
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
179
179
  ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
203
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
204
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
205
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
206
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/METADATA,sha256=khQQVRgopWndD2IbqOblhMzGAlOzSS6f0SbpP1oZ5xw,1966
208
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/METADATA,sha256=d8fPEhT1HG6ZlbX2joNTeIpEQNqth8LduM_W6aQZQn8,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/RECORD,,