ai-edge-torch-nightly 0.3.0.dev20241226__py3-none-any.whl → 0.3.0.dev20250105__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +22 -18
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +1 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +3 -0
- ai_edge_torch/odml_torch/export.py +1 -6
- ai_edge_torch/odml_torch/tf_integration.py +12 -50
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import logging
|
17
|
-
import os
|
18
17
|
from typing import Any, Literal, Optional, Union
|
19
18
|
|
19
|
+
import ai_edge_torch
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
21
|
from ai_edge_torch import lowertools
|
22
22
|
from ai_edge_torch import model
|
@@ -26,8 +26,6 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
|
|
26
26
|
from ai_edge_torch.quantize import quant_config as qcfg
|
27
27
|
import torch
|
28
28
|
|
29
|
-
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
30
|
-
|
31
29
|
|
32
30
|
def _run_convert_passes(
|
33
31
|
exported_program: torch.export.ExportedProgram,
|
@@ -35,21 +33,27 @@ def _run_convert_passes(
|
|
35
33
|
exported_program = generative_fx_passes.run_generative_passes(
|
36
34
|
exported_program
|
37
35
|
)
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
36
|
+
|
37
|
+
passes = [
|
38
|
+
fx_passes.BuildInterpolateCompositePass(),
|
39
|
+
fx_passes.CanonicalizePass(),
|
40
|
+
fx_passes.OptimizeLayoutTransposesPass(),
|
41
|
+
fx_passes.CanonicalizePass(),
|
42
|
+
fx_passes.BuildAtenCompositePass(),
|
43
|
+
fx_passes.CanonicalizePass(),
|
44
|
+
fx_passes.RemoveNonUserOutputsPass(),
|
45
|
+
fx_passes.CanonicalizePass(),
|
46
|
+
]
|
47
|
+
|
48
|
+
# Debuginfo is not injected automatically by odml_torch. Only inject
|
49
|
+
# debuginfo via fx pass when using torch_xla.
|
50
|
+
if ai_edge_torch.config.use_torch_xla:
|
51
|
+
passes += [
|
52
|
+
fx_passes.InjectMlirDebuginfoPass(),
|
53
|
+
fx_passes.CanonicalizePass(),
|
54
|
+
]
|
55
|
+
|
56
|
+
exported_program = fx_pass_base.run_passes(exported_program, passes)
|
53
57
|
return exported_program
|
54
58
|
|
55
59
|
|
@@ -62,6 +62,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
62
62
|
|
63
63
|
|
64
64
|
class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
|
65
|
+
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""
|
65
66
|
|
66
67
|
def call(self, graph_module: torch.fx.GraphModule):
|
67
68
|
for node in graph_module.graph.nodes:
|
@@ -27,6 +27,9 @@ if "PJRT_DEVICE" not in os.environ:
|
|
27
27
|
# https://github.com/google-ai-edge/ai-edge-torch/issues/326
|
28
28
|
os.environ["PJRT_DEVICE"] = "CPU"
|
29
29
|
|
30
|
+
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
31
|
+
|
32
|
+
|
30
33
|
from ai_edge_torch import model
|
31
34
|
from ai_edge_torch._convert import conversion_utils
|
32
35
|
from ai_edge_torch._convert import signature as signature_module
|
@@ -202,7 +202,7 @@ class MlirLowered:
|
|
202
202
|
target_version = stablehlo.get_minimum_version()
|
203
203
|
else:
|
204
204
|
target_version = stablehlo.get_version_from_compatibility_requirement(
|
205
|
-
stablehlo.StablehloCompatibilityRequirement.
|
205
|
+
stablehlo.StablehloCompatibilityRequirement.WEEK_12
|
206
206
|
)
|
207
207
|
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
|
208
208
|
self.module_bytecode, target_version
|
@@ -222,11 +222,6 @@ class MlirLowered:
|
|
222
222
|
# Lazy importing TF when execution is needed.
|
223
223
|
return self.tf_function(*args)
|
224
224
|
|
225
|
-
def to_flatbuffer(self):
|
226
|
-
from . import tf_integration
|
227
|
-
|
228
|
-
return tf_integration.mlir_to_flatbuffer(self)
|
229
|
-
|
230
225
|
|
231
226
|
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
232
227
|
def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
@@ -12,10 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""APIs to convert lowered MLIR from PyTorch to TensorFlow
|
15
|
+
"""APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts."""
|
16
16
|
|
17
17
|
import re
|
18
|
-
import tempfile
|
19
18
|
|
20
19
|
import tensorflow as tf
|
21
20
|
import torch
|
@@ -104,20 +103,26 @@ def _extract_call_args(
|
|
104
103
|
def _wrap_as_tf_func(lowered, tf_state_dict):
|
105
104
|
"""Build tf.function from lowered and tf_state_dict."""
|
106
105
|
|
107
|
-
|
106
|
+
version = 6
|
107
|
+
if hasattr(tfxla, "call_module_maximum_supported_version"):
|
108
|
+
version = tfxla.call_module_maximum_supported_version()
|
109
|
+
|
110
|
+
def tf_func(*args):
|
108
111
|
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
|
109
112
|
s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
|
110
113
|
call_args = _extract_call_args(lowered, args, tf_state_dict)
|
111
114
|
return tfxla.call_module(
|
112
115
|
tuple(call_args),
|
113
|
-
version=
|
116
|
+
version=version,
|
114
117
|
Tout=t_outs, # dtype information
|
115
|
-
Sout=s_outs, #
|
118
|
+
Sout=s_outs, # shape information
|
116
119
|
function_list=[],
|
117
|
-
module=lowered.
|
120
|
+
module=lowered.module_bytecode_vhlo,
|
121
|
+
has_token_input_output=False,
|
122
|
+
platforms=["CPU"],
|
118
123
|
)
|
119
124
|
|
120
|
-
return
|
125
|
+
return tf_func
|
121
126
|
|
122
127
|
|
123
128
|
def _make_input_signatures(
|
@@ -149,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered):
|
|
149
154
|
_wrap_as_tf_func(lowered, tf_state_dict),
|
150
155
|
input_signature=_make_input_signatures(lowered),
|
151
156
|
)
|
152
|
-
|
153
|
-
|
154
|
-
def mlir_to_flatbuffer(lowered: export.MlirLowered):
|
155
|
-
"""Convert the MLIR lowered to a TFLite flatbuffer binary."""
|
156
|
-
tf_state_dict = _build_tf_state_dict(lowered)
|
157
|
-
signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
158
|
-
tf_signatures = [_make_input_signatures(lowered)]
|
159
|
-
tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
|
160
|
-
|
161
|
-
tf_module = tf.Module()
|
162
|
-
tf_module.f = []
|
163
|
-
|
164
|
-
for tf_sig, func in zip(tf_signatures, tf_functions):
|
165
|
-
tf_module.f.append(
|
166
|
-
tf.function(
|
167
|
-
func,
|
168
|
-
input_signature=tf_sig,
|
169
|
-
)
|
170
|
-
)
|
171
|
-
|
172
|
-
tf_module._variables = list(tf_state_dict.values())
|
173
|
-
|
174
|
-
tf_concrete_funcs = [
|
175
|
-
func.get_concrete_function(*tf_sig)
|
176
|
-
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
177
|
-
]
|
178
|
-
|
179
|
-
# We need to temporarily save since TFLite's from_concrete_functions does not
|
180
|
-
# allow providing names for each of the concrete functions.
|
181
|
-
with tempfile.TemporaryDirectory() as temp_dir_path:
|
182
|
-
tf.saved_model.save(
|
183
|
-
tf_module,
|
184
|
-
temp_dir_path,
|
185
|
-
signatures={
|
186
|
-
sig_name: tf_concrete_funcs[idx]
|
187
|
-
for idx, sig_name in enumerate(signature_names)
|
188
|
-
},
|
189
|
-
)
|
190
|
-
|
191
|
-
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
192
|
-
tflite_model = converter.convert()
|
193
|
-
|
194
|
-
return tflite_model
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250105
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,9 +3,9 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=rEruohWdKGtxlBLh9SF_NnC4pbAqrOU4MKG598yJRHY,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
10
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
@@ -13,7 +13,7 @@ ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDiu
|
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
15
|
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
|
16
|
-
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=
|
16
|
+
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
19
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
|
@@ -166,14 +166,14 @@ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5S
|
|
166
166
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
167
167
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
|
168
168
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
169
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=
|
169
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
170
170
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
171
171
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
172
172
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
173
173
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
174
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
174
|
+
ai_edge_torch/odml_torch/export.py,sha256=sqIMXmxK_qIuVC-_DNJ6wKlIWiXq4_WOCKbSqMRFudg,13293
|
175
175
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
176
|
-
ai_edge_torch/odml_torch/tf_integration.py,sha256=
|
176
|
+
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
177
177
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
178
178
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
179
179
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
203
203
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
204
204
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
205
205
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/METADATA,sha256=d8fPEhT1HG6ZlbX2joNTeIpEQNqth8LduM_W6aQZQn8,1966
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/RECORD,,
|
File without changes
|
File without changes
|