ai-edge-torch-nightly 0.3.0.dev20241014__py3-none-any.whl → 0.3.0.dev20241017__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/conversion.py +18 -5
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -0
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +16 -0
- ai_edge_torch/fx_pass_base.py +1 -1
- ai_edge_torch/generative/layers/attention_utils.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/pattern.py +5 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241017.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241017.dist-info}/RECORD +14 -13
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241017.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241017.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241017.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,7 @@ def _run_convert_passes(
|
|
35
35
|
exported_program = generative_fx_passes.run_generative_passes(
|
36
36
|
exported_program
|
37
37
|
)
|
38
|
-
|
38
|
+
exported_program = fx_pass_base.run_passes(
|
39
39
|
exported_program,
|
40
40
|
[
|
41
41
|
fx_passes.BuildInterpolateCompositePass(),
|
@@ -44,10 +44,13 @@ def _run_convert_passes(
|
|
44
44
|
fx_passes.CanonicalizePass(),
|
45
45
|
fx_passes.BuildAtenCompositePass(),
|
46
46
|
fx_passes.CanonicalizePass(),
|
47
|
+
fx_passes.RemoveNonUserOutputsPass(),
|
48
|
+
fx_passes.CanonicalizePass(),
|
47
49
|
fx_passes.InjectMlirDebuginfoPass(),
|
48
50
|
fx_passes.CanonicalizePass(),
|
49
51
|
],
|
50
52
|
)
|
53
|
+
return exported_program
|
51
54
|
|
52
55
|
|
53
56
|
def _warn_training_modules(signatures: list[signature.Signature]):
|
@@ -103,17 +106,27 @@ def convert_signatures(
|
|
103
106
|
nonlocal strict_export
|
104
107
|
if strict_export == "auto":
|
105
108
|
try:
|
106
|
-
|
109
|
+
exported_program = torch.export.export(*args, **kwargs, strict=True)
|
107
110
|
except Exception:
|
108
111
|
logging.warning(
|
109
112
|
"torch.export.export(..., strict=True) failed. Retrying with"
|
110
113
|
" strict=False"
|
111
114
|
)
|
112
|
-
|
115
|
+
exported_program = torch.export.export(*args, **kwargs, strict=False)
|
113
116
|
elif not strict_export:
|
114
|
-
|
117
|
+
exported_program = torch.export.export(*args, **kwargs, strict=False)
|
115
118
|
else:
|
116
|
-
|
119
|
+
exported_program = torch.export.export(*args, **kwargs, strict=True)
|
120
|
+
|
121
|
+
if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
|
122
|
+
# Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
|
123
|
+
# stop-gap table which replicates the old behaviour of post-dispatch IR.
|
124
|
+
# This could help ensure the collection of aten ops remaining still as the
|
125
|
+
# implementation of torch.export changes.
|
126
|
+
exported_program = exported_program.run_decompositions(
|
127
|
+
torch._decomp._decomp_table_to_post_autograd_aten()
|
128
|
+
)
|
129
|
+
return exported_program
|
117
130
|
|
118
131
|
exported_programs: torch.export.ExportedProgram = [
|
119
132
|
export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
|
@@ -19,4 +19,5 @@ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAten
|
|
19
19
|
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
20
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
|
+
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
22
23
|
from ai_edge_torch.fx_pass_base import CanonicalizePass
|
@@ -113,6 +113,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
|
|
113
113
|
]
|
114
114
|
|
115
115
|
def call(self, exported_program: torch.export.ExportedProgram):
|
116
|
+
exported_program = fx_pass_base.run_passes(
|
117
|
+
exported_program, [fx_pass_base.CanonicalizePass()]
|
118
|
+
)
|
116
119
|
exported_program = exported_program.run_decompositions(
|
117
120
|
_INTERPOLATE_DECOMPOSITIONS
|
118
121
|
)
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Pass to remove all non user outputs from exported program."""
|
16
|
+
|
17
|
+
|
18
|
+
from ai_edge_torch import fx_pass_base
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
|
23
|
+
"""This pass removes all non user outputs from the exported program's output.
|
24
|
+
|
25
|
+
The FX graph may output more tensors/data than what user's original model
|
26
|
+
returns. Those additional outputs include user input mutations, gradient to
|
27
|
+
parameter, etc. Those outputs are not supported by our inference only
|
28
|
+
conversion or runtime. This pass remove all those outputs to ensure the
|
29
|
+
converted models' outputs match what returned from user's model in eval mode.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
33
|
+
for node in exported_program.graph.nodes:
|
34
|
+
if node.op != "output":
|
35
|
+
continue
|
36
|
+
|
37
|
+
outputs = node.args[0]
|
38
|
+
output_specs = exported_program.graph_signature.output_specs
|
39
|
+
|
40
|
+
new_outputs = []
|
41
|
+
new_output_specs = []
|
42
|
+
for output, spec in zip(outputs, output_specs):
|
43
|
+
if spec.kind == torch.export.graph_signature.OutputKind.USER_OUTPUT:
|
44
|
+
new_outputs.append(output)
|
45
|
+
new_output_specs.append(spec)
|
46
|
+
|
47
|
+
node.args = (tuple(new_outputs),)
|
48
|
+
exported_program.graph_signature.output_specs = new_output_specs
|
49
|
+
|
50
|
+
exported_program.graph_module.graph.lint()
|
51
|
+
exported_program.graph_module.recompile()
|
52
|
+
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -490,6 +490,22 @@ class TestConvert(googletest.TestCase):
|
|
490
490
|
tflite_output = edge_model(**flat_inputs)
|
491
491
|
np.testing.assert_almost_equal(reference_output, tflite_output)
|
492
492
|
|
493
|
+
def test_convert_model_with_input_mutation(self):
|
494
|
+
class SampleModel(nn.Module):
|
495
|
+
|
496
|
+
def forward(self, x):
|
497
|
+
x /= 1
|
498
|
+
x = x + 10
|
499
|
+
return x
|
500
|
+
|
501
|
+
args = (torch.randn(10, 10),)
|
502
|
+
torch_module = SampleModel().eval()
|
503
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
504
|
+
|
505
|
+
self.assertTrue(
|
506
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
507
|
+
)
|
508
|
+
|
493
509
|
|
494
510
|
if __name__ == "__main__":
|
495
511
|
googletest.main()
|
ai_edge_torch/fx_pass_base.py
CHANGED
@@ -102,7 +102,7 @@ class CanonicalizePass(ExportedProgramPassBase):
|
|
102
102
|
# for retracing. If the input memory format is already contiguous,
|
103
103
|
# retracing in run_decomposition below would decompose torch.reshape
|
104
104
|
# back to one aten.view.
|
105
|
-
node.target = lambda self, size: torch.reshape(self, size)
|
105
|
+
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
106
106
|
|
107
107
|
exported_program = exported_program.run_decompositions(
|
108
108
|
self._DUMMY_DECOMP_TABLE
|
@@ -107,7 +107,7 @@ def build_sliding_window_mask_cache(
|
|
107
107
|
sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
|
108
108
|
all_ones, window_size - 1
|
109
109
|
)
|
110
|
-
return torch.where(sliding_mask == 1, mask, -
|
110
|
+
return torch.where(sliding_mask == 1, mask, float('-inf'))
|
111
111
|
|
112
112
|
|
113
113
|
def relative_position_bucket(
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import dataclasses
|
18
18
|
from typing import Any, Callable, Optional, Union
|
19
19
|
|
20
|
+
from ai_edge_torch import fx_pass_base
|
20
21
|
from ai_edge_torch.hlfb.mark_pattern import passes
|
21
22
|
import torch
|
22
23
|
from torch.export.graph_signature import TensorArgument
|
@@ -116,6 +117,7 @@ def _find_scalar_attr(
|
|
116
117
|
track_args[tracker.pattern_arg_pos] = source
|
117
118
|
ep = torch.export.export(pattern_module, tuple(track_args))
|
118
119
|
if decomp_table is not None:
|
120
|
+
ep = fx_pass_base.run_passes(ep, [fx_pass_base.CanonicalizePass()])
|
119
121
|
ep = ep.run_decompositions(decomp_table)
|
120
122
|
|
121
123
|
scalar_locs = set()
|
@@ -198,6 +200,9 @@ class Pattern:
|
|
198
200
|
|
199
201
|
exported_program = torch.export.export(module, export_args)
|
200
202
|
if decomp_table is not None:
|
203
|
+
exported_program = fx_pass_base.run_passes(
|
204
|
+
exported_program, [fx_pass_base.CanonicalizePass()]
|
205
|
+
)
|
201
206
|
exported_program = exported_program.run_decompositions(decomp_table)
|
202
207
|
|
203
208
|
self.exported_program = exported_program
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241017
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,19 +1,20 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
|
2
2
|
ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
-
ai_edge_torch/fx_pass_base.py,sha256=
|
4
|
+
ai_edge_torch/fx_pass_base.py,sha256=SrYveglaiA_DXPoRBqSXClWM1q7853I5ujRorq_MV0M,4251
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=IOdHlIlwdynvTEykO7NwuEz-QELjyYKSAJzpF2n52U4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
10
|
ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
13
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
15
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=XCVqWg_ask0Kb64PED0ZGAODsUuIgfyO2ZJM6aK-TXI,4283
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
|
17
|
+
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
17
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
18
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
|
19
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
@@ -25,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
25
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
26
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
27
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
28
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
|
29
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
30
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
31
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -100,7 +101,7 @@ ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjA
|
|
100
101
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
101
102
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
102
103
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
103
|
-
ai_edge_torch/generative/layers/attention_utils.py,sha256=
|
104
|
+
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
104
105
|
ai_edge_torch/generative/layers/builder.py,sha256=XyZS1RrnMbvypeLMfwU7h1Y4x5r4WGgOx2YGJF0OUNQ,5064
|
105
106
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
106
107
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
@@ -137,7 +138,7 @@ ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx
|
|
137
138
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
138
139
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
139
140
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
140
|
-
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=
|
141
|
+
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
|
141
142
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
142
143
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
|
143
144
|
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
|
@@ -180,8 +181,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
181
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
182
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
183
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241017.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241017.dist-info/METADATA,sha256=t4Fa0qYkmDLA0UXbEK9sxB4yHSL8cOIOxLs1dmWjAUQ,1897
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241017.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241017.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
188
|
+
ai_edge_torch_nightly-0.3.0.dev20241017.dist-info/RECORD,,
|
File without changes
|
File without changes
|