ai-edge-torch-nightly 0.3.0.dev20241014__py3-none-any.whl → 0.3.0.dev20241018__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +18 -5
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -0
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +16 -0
- ai_edge_torch/fx_pass_base.py +1 -1
- ai_edge_torch/generative/layers/attention_utils.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/pattern.py +5 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241018.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241018.dist-info}/RECORD +14 -13
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241018.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241018.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241014.dist-info → ai_edge_torch_nightly-0.3.0.dev20241018.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,7 @@ def _run_convert_passes(
|
|
35
35
|
exported_program = generative_fx_passes.run_generative_passes(
|
36
36
|
exported_program
|
37
37
|
)
|
38
|
-
|
38
|
+
exported_program = fx_pass_base.run_passes(
|
39
39
|
exported_program,
|
40
40
|
[
|
41
41
|
fx_passes.BuildInterpolateCompositePass(),
|
@@ -44,10 +44,13 @@ def _run_convert_passes(
|
|
44
44
|
fx_passes.CanonicalizePass(),
|
45
45
|
fx_passes.BuildAtenCompositePass(),
|
46
46
|
fx_passes.CanonicalizePass(),
|
47
|
+
fx_passes.RemoveNonUserOutputsPass(),
|
48
|
+
fx_passes.CanonicalizePass(),
|
47
49
|
fx_passes.InjectMlirDebuginfoPass(),
|
48
50
|
fx_passes.CanonicalizePass(),
|
49
51
|
],
|
50
52
|
)
|
53
|
+
return exported_program
|
51
54
|
|
52
55
|
|
53
56
|
def _warn_training_modules(signatures: list[signature.Signature]):
|
@@ -103,17 +106,27 @@ def convert_signatures(
|
|
103
106
|
nonlocal strict_export
|
104
107
|
if strict_export == "auto":
|
105
108
|
try:
|
106
|
-
|
109
|
+
exported_program = torch.export.export(*args, **kwargs, strict=True)
|
107
110
|
except Exception:
|
108
111
|
logging.warning(
|
109
112
|
"torch.export.export(..., strict=True) failed. Retrying with"
|
110
113
|
" strict=False"
|
111
114
|
)
|
112
|
-
|
115
|
+
exported_program = torch.export.export(*args, **kwargs, strict=False)
|
113
116
|
elif not strict_export:
|
114
|
-
|
117
|
+
exported_program = torch.export.export(*args, **kwargs, strict=False)
|
115
118
|
else:
|
116
|
-
|
119
|
+
exported_program = torch.export.export(*args, **kwargs, strict=True)
|
120
|
+
|
121
|
+
if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
|
122
|
+
# Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
|
123
|
+
# stop-gap table which replicates the old behaviour of post-dispatch IR.
|
124
|
+
# This could help ensure the collection of aten ops remaining still as the
|
125
|
+
# implementation of torch.export changes.
|
126
|
+
exported_program = exported_program.run_decompositions(
|
127
|
+
torch._decomp._decomp_table_to_post_autograd_aten()
|
128
|
+
)
|
129
|
+
return exported_program
|
117
130
|
|
118
131
|
exported_programs: torch.export.ExportedProgram = [
|
119
132
|
export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
|
@@ -19,4 +19,5 @@ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAten
|
|
19
19
|
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
20
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
|
+
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
22
23
|
from ai_edge_torch.fx_pass_base import CanonicalizePass
|
@@ -113,6 +113,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
|
|
113
113
|
]
|
114
114
|
|
115
115
|
def call(self, exported_program: torch.export.ExportedProgram):
|
116
|
+
exported_program = fx_pass_base.run_passes(
|
117
|
+
exported_program, [fx_pass_base.CanonicalizePass()]
|
118
|
+
)
|
116
119
|
exported_program = exported_program.run_decompositions(
|
117
120
|
_INTERPOLATE_DECOMPOSITIONS
|
118
121
|
)
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Pass to remove all non user outputs from exported program."""
|
16
|
+
|
17
|
+
|
18
|
+
from ai_edge_torch import fx_pass_base
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
|
23
|
+
"""This pass removes all non user outputs from the exported program's output.
|
24
|
+
|
25
|
+
The FX graph may output more tensors/data than what user's original model
|
26
|
+
returns. Those additional outputs include user input mutations, gradient to
|
27
|
+
parameter, etc. Those outputs are not supported by our inference only
|
28
|
+
conversion or runtime. This pass remove all those outputs to ensure the
|
29
|
+
converted models' outputs match what returned from user's model in eval mode.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
33
|
+
for node in exported_program.graph.nodes:
|
34
|
+
if node.op != "output":
|
35
|
+
continue
|
36
|
+
|
37
|
+
outputs = node.args[0]
|
38
|
+
output_specs = exported_program.graph_signature.output_specs
|
39
|
+
|
40
|
+
new_outputs = []
|
41
|
+
new_output_specs = []
|
42
|
+
for output, spec in zip(outputs, output_specs):
|
43
|
+
if spec.kind == torch.export.graph_signature.OutputKind.USER_OUTPUT:
|
44
|
+
new_outputs.append(output)
|
45
|
+
new_output_specs.append(spec)
|
46
|
+
|
47
|
+
node.args = (tuple(new_outputs),)
|
48
|
+
exported_program.graph_signature.output_specs = new_output_specs
|
49
|
+
|
50
|
+
exported_program.graph_module.graph.lint()
|
51
|
+
exported_program.graph_module.recompile()
|
52
|
+
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -490,6 +490,22 @@ class TestConvert(googletest.TestCase):
|
|
490
490
|
tflite_output = edge_model(**flat_inputs)
|
491
491
|
np.testing.assert_almost_equal(reference_output, tflite_output)
|
492
492
|
|
493
|
+
def test_convert_model_with_input_mutation(self):
|
494
|
+
class SampleModel(nn.Module):
|
495
|
+
|
496
|
+
def forward(self, x):
|
497
|
+
x /= 1
|
498
|
+
x = x + 10
|
499
|
+
return x
|
500
|
+
|
501
|
+
args = (torch.randn(10, 10),)
|
502
|
+
torch_module = SampleModel().eval()
|
503
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
504
|
+
|
505
|
+
self.assertTrue(
|
506
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
507
|
+
)
|
508
|
+
|
493
509
|
|
494
510
|
if __name__ == "__main__":
|
495
511
|
googletest.main()
|
ai_edge_torch/fx_pass_base.py
CHANGED
@@ -102,7 +102,7 @@ class CanonicalizePass(ExportedProgramPassBase):
|
|
102
102
|
# for retracing. If the input memory format is already contiguous,
|
103
103
|
# retracing in run_decomposition below would decompose torch.reshape
|
104
104
|
# back to one aten.view.
|
105
|
-
node.target = lambda self, size: torch.reshape(self, size)
|
105
|
+
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
106
106
|
|
107
107
|
exported_program = exported_program.run_decompositions(
|
108
108
|
self._DUMMY_DECOMP_TABLE
|
@@ -107,7 +107,7 @@ def build_sliding_window_mask_cache(
|
|
107
107
|
sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
|
108
108
|
all_ones, window_size - 1
|
109
109
|
)
|
110
|
-
return torch.where(sliding_mask == 1, mask, -
|
110
|
+
return torch.where(sliding_mask == 1, mask, float('-inf'))
|
111
111
|
|
112
112
|
|
113
113
|
def relative_position_bucket(
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import dataclasses
|
18
18
|
from typing import Any, Callable, Optional, Union
|
19
19
|
|
20
|
+
from ai_edge_torch import fx_pass_base
|
20
21
|
from ai_edge_torch.hlfb.mark_pattern import passes
|
21
22
|
import torch
|
22
23
|
from torch.export.graph_signature import TensorArgument
|
@@ -116,6 +117,7 @@ def _find_scalar_attr(
|
|
116
117
|
track_args[tracker.pattern_arg_pos] = source
|
117
118
|
ep = torch.export.export(pattern_module, tuple(track_args))
|
118
119
|
if decomp_table is not None:
|
120
|
+
ep = fx_pass_base.run_passes(ep, [fx_pass_base.CanonicalizePass()])
|
119
121
|
ep = ep.run_decompositions(decomp_table)
|
120
122
|
|
121
123
|
scalar_locs = set()
|
@@ -198,6 +200,9 @@ class Pattern:
|
|
198
200
|
|
199
201
|
exported_program = torch.export.export(module, export_args)
|
200
202
|
if decomp_table is not None:
|
203
|
+
exported_program = fx_pass_base.run_passes(
|
204
|
+
exported_program, [fx_pass_base.CanonicalizePass()]
|
205
|
+
)
|
201
206
|
exported_program = exported_program.run_decompositions(decomp_table)
|
202
207
|
|
203
208
|
self.exported_program = exported_program
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241018
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,19 +1,20 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
|
2
2
|
ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
-
ai_edge_torch/fx_pass_base.py,sha256=
|
4
|
+
ai_edge_torch/fx_pass_base.py,sha256=SrYveglaiA_DXPoRBqSXClWM1q7853I5ujRorq_MV0M,4251
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=K5OPFteQXEGRnfR-yFeyGvU5qSbFmO67wiOGgKTvmi0,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
10
|
ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
13
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
15
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=XCVqWg_ask0Kb64PED0ZGAODsUuIgfyO2ZJM6aK-TXI,4283
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
|
17
|
+
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
17
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
18
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
|
19
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
@@ -25,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
25
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
26
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
27
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
28
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
|
29
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
30
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
31
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -100,7 +101,7 @@ ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjA
|
|
100
101
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
101
102
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
102
103
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
103
|
-
ai_edge_torch/generative/layers/attention_utils.py,sha256=
|
104
|
+
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
104
105
|
ai_edge_torch/generative/layers/builder.py,sha256=XyZS1RrnMbvypeLMfwU7h1Y4x5r4WGgOx2YGJF0OUNQ,5064
|
105
106
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
106
107
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
@@ -137,7 +138,7 @@ ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx
|
|
137
138
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
138
139
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
139
140
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
140
|
-
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=
|
141
|
+
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
|
141
142
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
142
143
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
|
143
144
|
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
|
@@ -180,8 +181,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
181
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
182
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
183
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/METADATA,sha256=9jDElmb5TsG683OMDoH2ywytzqOAGNM5ImTPKEKyeFc,1897
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
188
|
+
ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/RECORD,,
|
File without changes
|
File without changes
|