ai-edge-torch-nightly 0.3.0.dev20241224__py3-none-any.whl → 0.3.0.dev20250105__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/conversion.py +22 -18
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +1 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +3 -0
- ai_edge_torch/odml_torch/export.py +1 -6
- ai_edge_torch/odml_torch/tf_integration.py +12 -50
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241224.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241224.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20241224.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241224.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241224.dist-info → ai_edge_torch_nightly-0.3.0.dev20250105.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import logging
|
17
|
-
import os
|
18
17
|
from typing import Any, Literal, Optional, Union
|
19
18
|
|
19
|
+
import ai_edge_torch
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
21
|
from ai_edge_torch import lowertools
|
22
22
|
from ai_edge_torch import model
|
@@ -26,8 +26,6 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
|
|
26
26
|
from ai_edge_torch.quantize import quant_config as qcfg
|
27
27
|
import torch
|
28
28
|
|
29
|
-
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
30
|
-
|
31
29
|
|
32
30
|
def _run_convert_passes(
|
33
31
|
exported_program: torch.export.ExportedProgram,
|
@@ -35,21 +33,27 @@ def _run_convert_passes(
|
|
35
33
|
exported_program = generative_fx_passes.run_generative_passes(
|
36
34
|
exported_program
|
37
35
|
)
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
36
|
+
|
37
|
+
passes = [
|
38
|
+
fx_passes.BuildInterpolateCompositePass(),
|
39
|
+
fx_passes.CanonicalizePass(),
|
40
|
+
fx_passes.OptimizeLayoutTransposesPass(),
|
41
|
+
fx_passes.CanonicalizePass(),
|
42
|
+
fx_passes.BuildAtenCompositePass(),
|
43
|
+
fx_passes.CanonicalizePass(),
|
44
|
+
fx_passes.RemoveNonUserOutputsPass(),
|
45
|
+
fx_passes.CanonicalizePass(),
|
46
|
+
]
|
47
|
+
|
48
|
+
# Debuginfo is not injected automatically by odml_torch. Only inject
|
49
|
+
# debuginfo via fx pass when using torch_xla.
|
50
|
+
if ai_edge_torch.config.use_torch_xla:
|
51
|
+
passes += [
|
52
|
+
fx_passes.InjectMlirDebuginfoPass(),
|
53
|
+
fx_passes.CanonicalizePass(),
|
54
|
+
]
|
55
|
+
|
56
|
+
exported_program = fx_pass_base.run_passes(exported_program, passes)
|
53
57
|
return exported_program
|
54
58
|
|
55
59
|
|
@@ -62,6 +62,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
62
62
|
|
63
63
|
|
64
64
|
class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
|
65
|
+
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""
|
65
66
|
|
66
67
|
def call(self, graph_module: torch.fx.GraphModule):
|
67
68
|
for node in graph_module.graph.nodes:
|
@@ -27,6 +27,9 @@ if "PJRT_DEVICE" not in os.environ:
|
|
27
27
|
# https://github.com/google-ai-edge/ai-edge-torch/issues/326
|
28
28
|
os.environ["PJRT_DEVICE"] = "CPU"
|
29
29
|
|
30
|
+
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
31
|
+
|
32
|
+
|
30
33
|
from ai_edge_torch import model
|
31
34
|
from ai_edge_torch._convert import conversion_utils
|
32
35
|
from ai_edge_torch._convert import signature as signature_module
|
@@ -202,7 +202,7 @@ class MlirLowered:
|
|
202
202
|
target_version = stablehlo.get_minimum_version()
|
203
203
|
else:
|
204
204
|
target_version = stablehlo.get_version_from_compatibility_requirement(
|
205
|
-
stablehlo.StablehloCompatibilityRequirement.
|
205
|
+
stablehlo.StablehloCompatibilityRequirement.WEEK_12
|
206
206
|
)
|
207
207
|
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
|
208
208
|
self.module_bytecode, target_version
|
@@ -222,11 +222,6 @@ class MlirLowered:
|
|
222
222
|
# Lazy importing TF when execution is needed.
|
223
223
|
return self.tf_function(*args)
|
224
224
|
|
225
|
-
def to_flatbuffer(self):
|
226
|
-
from . import tf_integration
|
227
|
-
|
228
|
-
return tf_integration.mlir_to_flatbuffer(self)
|
229
|
-
|
230
225
|
|
231
226
|
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
232
227
|
def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
@@ -12,10 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""APIs to convert lowered MLIR from PyTorch to TensorFlow
|
15
|
+
"""APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts."""
|
16
16
|
|
17
17
|
import re
|
18
|
-
import tempfile
|
19
18
|
|
20
19
|
import tensorflow as tf
|
21
20
|
import torch
|
@@ -104,20 +103,26 @@ def _extract_call_args(
|
|
104
103
|
def _wrap_as_tf_func(lowered, tf_state_dict):
|
105
104
|
"""Build tf.function from lowered and tf_state_dict."""
|
106
105
|
|
107
|
-
|
106
|
+
version = 6
|
107
|
+
if hasattr(tfxla, "call_module_maximum_supported_version"):
|
108
|
+
version = tfxla.call_module_maximum_supported_version()
|
109
|
+
|
110
|
+
def tf_func(*args):
|
108
111
|
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
|
109
112
|
s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
|
110
113
|
call_args = _extract_call_args(lowered, args, tf_state_dict)
|
111
114
|
return tfxla.call_module(
|
112
115
|
tuple(call_args),
|
113
|
-
version=
|
116
|
+
version=version,
|
114
117
|
Tout=t_outs, # dtype information
|
115
|
-
Sout=s_outs, #
|
118
|
+
Sout=s_outs, # shape information
|
116
119
|
function_list=[],
|
117
|
-
module=lowered.
|
120
|
+
module=lowered.module_bytecode_vhlo,
|
121
|
+
has_token_input_output=False,
|
122
|
+
platforms=["CPU"],
|
118
123
|
)
|
119
124
|
|
120
|
-
return
|
125
|
+
return tf_func
|
121
126
|
|
122
127
|
|
123
128
|
def _make_input_signatures(
|
@@ -149,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered):
|
|
149
154
|
_wrap_as_tf_func(lowered, tf_state_dict),
|
150
155
|
input_signature=_make_input_signatures(lowered),
|
151
156
|
)
|
152
|
-
|
153
|
-
|
154
|
-
def mlir_to_flatbuffer(lowered: export.MlirLowered):
|
155
|
-
"""Convert the MLIR lowered to a TFLite flatbuffer binary."""
|
156
|
-
tf_state_dict = _build_tf_state_dict(lowered)
|
157
|
-
signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
158
|
-
tf_signatures = [_make_input_signatures(lowered)]
|
159
|
-
tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
|
160
|
-
|
161
|
-
tf_module = tf.Module()
|
162
|
-
tf_module.f = []
|
163
|
-
|
164
|
-
for tf_sig, func in zip(tf_signatures, tf_functions):
|
165
|
-
tf_module.f.append(
|
166
|
-
tf.function(
|
167
|
-
func,
|
168
|
-
input_signature=tf_sig,
|
169
|
-
)
|
170
|
-
)
|
171
|
-
|
172
|
-
tf_module._variables = list(tf_state_dict.values())
|
173
|
-
|
174
|
-
tf_concrete_funcs = [
|
175
|
-
func.get_concrete_function(*tf_sig)
|
176
|
-
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
177
|
-
]
|
178
|
-
|
179
|
-
# We need to temporarily save since TFLite's from_concrete_functions does not
|
180
|
-
# allow providing names for each of the concrete functions.
|
181
|
-
with tempfile.TemporaryDirectory() as temp_dir_path:
|
182
|
-
tf.saved_model.save(
|
183
|
-
tf_module,
|
184
|
-
temp_dir_path,
|
185
|
-
signatures={
|
186
|
-
sig_name: tf_concrete_funcs[idx]
|
187
|
-
for idx, sig_name in enumerate(signature_names)
|
188
|
-
},
|
189
|
-
)
|
190
|
-
|
191
|
-
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
192
|
-
tflite_model = converter.convert()
|
193
|
-
|
194
|
-
return tflite_model
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250105
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,9 +3,9 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=rEruohWdKGtxlBLh9SF_NnC4pbAqrOU4MKG598yJRHY,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
10
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
@@ -13,7 +13,7 @@ ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDiu
|
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
15
|
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
|
16
|
-
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=
|
16
|
+
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
19
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
|
@@ -166,14 +166,14 @@ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5S
|
|
166
166
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
167
167
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
|
168
168
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
169
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=
|
169
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
170
170
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
171
171
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
172
172
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
173
173
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
174
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
174
|
+
ai_edge_torch/odml_torch/export.py,sha256=sqIMXmxK_qIuVC-_DNJ6wKlIWiXq4_WOCKbSqMRFudg,13293
|
175
175
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
176
|
-
ai_edge_torch/odml_torch/tf_integration.py,sha256=
|
176
|
+
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
177
177
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
178
178
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
179
179
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
203
203
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
204
204
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
205
205
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/METADATA,sha256=d8fPEhT1HG6ZlbX2joNTeIpEQNqth8LduM_W6aQZQn8,1966
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/RECORD,,
|
File without changes
|
File without changes
|