ai-edge-torch-nightly 0.3.0.dev20241014__py3-none-any.whl → 0.3.0.dev20241018__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,7 +35,7 @@ def _run_convert_passes(
35
35
  exported_program = generative_fx_passes.run_generative_passes(
36
36
  exported_program
37
37
  )
38
- return fx_pass_base.run_passes(
38
+ exported_program = fx_pass_base.run_passes(
39
39
  exported_program,
40
40
  [
41
41
  fx_passes.BuildInterpolateCompositePass(),
@@ -44,10 +44,13 @@ def _run_convert_passes(
44
44
  fx_passes.CanonicalizePass(),
45
45
  fx_passes.BuildAtenCompositePass(),
46
46
  fx_passes.CanonicalizePass(),
47
+ fx_passes.RemoveNonUserOutputsPass(),
48
+ fx_passes.CanonicalizePass(),
47
49
  fx_passes.InjectMlirDebuginfoPass(),
48
50
  fx_passes.CanonicalizePass(),
49
51
  ],
50
52
  )
53
+ return exported_program
51
54
 
52
55
 
53
56
  def _warn_training_modules(signatures: list[signature.Signature]):
@@ -103,17 +106,27 @@ def convert_signatures(
103
106
  nonlocal strict_export
104
107
  if strict_export == "auto":
105
108
  try:
106
- return torch.export.export(*args, **kwargs, strict=True)
109
+ exported_program = torch.export.export(*args, **kwargs, strict=True)
107
110
  except Exception:
108
111
  logging.warning(
109
112
  "torch.export.export(..., strict=True) failed. Retrying with"
110
113
  " strict=False"
111
114
  )
112
- return torch.export.export(*args, **kwargs, strict=False)
115
+ exported_program = torch.export.export(*args, **kwargs, strict=False)
113
116
  elif not strict_export:
114
- return torch.export.export(*args, **kwargs, strict=False)
117
+ exported_program = torch.export.export(*args, **kwargs, strict=False)
115
118
  else:
116
- return torch.export.export(*args, **kwargs, strict=True)
119
+ exported_program = torch.export.export(*args, **kwargs, strict=True)
120
+
121
+ if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
122
+ # Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
123
+ # stop-gap table which replicates the old behaviour of post-dispatch IR.
124
+ # This could help ensure the collection of aten ops remaining still as the
125
+ # implementation of torch.export changes.
126
+ exported_program = exported_program.run_decompositions(
127
+ torch._decomp._decomp_table_to_post_autograd_aten()
128
+ )
129
+ return exported_program
117
130
 
118
131
  exported_programs: torch.export.ExportedProgram = [
119
132
  export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
@@ -19,4 +19,5 @@ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAten
19
19
  from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
20
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
+ from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
22
23
  from ai_edge_torch.fx_pass_base import CanonicalizePass
@@ -113,6 +113,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
113
113
  ]
114
114
 
115
115
  def call(self, exported_program: torch.export.ExportedProgram):
116
+ exported_program = fx_pass_base.run_passes(
117
+ exported_program, [fx_pass_base.CanonicalizePass()]
118
+ )
116
119
  exported_program = exported_program.run_decompositions(
117
120
  _INTERPOLATE_DECOMPOSITIONS
118
121
  )
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Pass to remove all non user outputs from exported program."""
16
+
17
+
18
+ from ai_edge_torch import fx_pass_base
19
+ import torch
20
+
21
+
22
+ class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
23
+ """This pass removes all non user outputs from the exported program's output.
24
+
25
+ The FX graph may output more tensors/data than what user's original model
26
+ returns. Those additional outputs include user input mutations, gradient to
27
+ parameter, etc. Those outputs are not supported by our inference only
28
+ conversion or runtime. This pass remove all those outputs to ensure the
29
+ converted models' outputs match what returned from user's model in eval mode.
30
+ """
31
+
32
+ def call(self, exported_program: torch.export.ExportedProgram):
33
+ for node in exported_program.graph.nodes:
34
+ if node.op != "output":
35
+ continue
36
+
37
+ outputs = node.args[0]
38
+ output_specs = exported_program.graph_signature.output_specs
39
+
40
+ new_outputs = []
41
+ new_output_specs = []
42
+ for output, spec in zip(outputs, output_specs):
43
+ if spec.kind == torch.export.graph_signature.OutputKind.USER_OUTPUT:
44
+ new_outputs.append(output)
45
+ new_output_specs.append(spec)
46
+
47
+ node.args = (tuple(new_outputs),)
48
+ exported_program.graph_signature.output_specs = new_output_specs
49
+
50
+ exported_program.graph_module.graph.lint()
51
+ exported_program.graph_module.recompile()
52
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -490,6 +490,22 @@ class TestConvert(googletest.TestCase):
490
490
  tflite_output = edge_model(**flat_inputs)
491
491
  np.testing.assert_almost_equal(reference_output, tflite_output)
492
492
 
493
+ def test_convert_model_with_input_mutation(self):
494
+ class SampleModel(nn.Module):
495
+
496
+ def forward(self, x):
497
+ x /= 1
498
+ x = x + 10
499
+ return x
500
+
501
+ args = (torch.randn(10, 10),)
502
+ torch_module = SampleModel().eval()
503
+ edge_model = ai_edge_torch.convert(torch_module, args)
504
+
505
+ self.assertTrue(
506
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
507
+ )
508
+
493
509
 
494
510
  if __name__ == "__main__":
495
511
  googletest.main()
@@ -102,7 +102,7 @@ class CanonicalizePass(ExportedProgramPassBase):
102
102
  # for retracing. If the input memory format is already contiguous,
103
103
  # retracing in run_decomposition below would decompose torch.reshape
104
104
  # back to one aten.view.
105
- node.target = lambda self, size: torch.reshape(self, size)
105
+ node.target = lambda self, size: torch.reshape(self.contiguous(), size)
106
106
 
107
107
  exported_program = exported_program.run_decompositions(
108
108
  self._DUMMY_DECOMP_TABLE
@@ -107,7 +107,7 @@ def build_sliding_window_mask_cache(
107
107
  sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
108
108
  all_ones, window_size - 1
109
109
  )
110
- return torch.where(sliding_mask == 1, mask, -2.3819763e38)
110
+ return torch.where(sliding_mask == 1, mask, float('-inf'))
111
111
 
112
112
 
113
113
  def relative_position_bucket(
@@ -17,6 +17,7 @@
17
17
  import dataclasses
18
18
  from typing import Any, Callable, Optional, Union
19
19
 
20
+ from ai_edge_torch import fx_pass_base
20
21
  from ai_edge_torch.hlfb.mark_pattern import passes
21
22
  import torch
22
23
  from torch.export.graph_signature import TensorArgument
@@ -116,6 +117,7 @@ def _find_scalar_attr(
116
117
  track_args[tracker.pattern_arg_pos] = source
117
118
  ep = torch.export.export(pattern_module, tuple(track_args))
118
119
  if decomp_table is not None:
120
+ ep = fx_pass_base.run_passes(ep, [fx_pass_base.CanonicalizePass()])
119
121
  ep = ep.run_decompositions(decomp_table)
120
122
 
121
123
  scalar_locs = set()
@@ -198,6 +200,9 @@ class Pattern:
198
200
 
199
201
  exported_program = torch.export.export(module, export_args)
200
202
  if decomp_table is not None:
203
+ exported_program = fx_pass_base.run_passes(
204
+ exported_program, [fx_pass_base.CanonicalizePass()]
205
+ )
201
206
  exported_program = exported_program.run_decompositions(decomp_table)
202
207
 
203
208
  self.exported_program = exported_program
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241014"
16
+ __version__ = "0.3.0.dev20241018"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241014
3
+ Version: 0.3.0.dev20241018
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,19 +1,20 @@
1
1
  ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
2
2
  ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/fx_pass_base.py,sha256=ToyPo7MBRUNSZivwfLgAS-CNS5SnlhmZ7i_Pdzv9IJw,4238
4
+ ai_edge_torch/fx_pass_base.py,sha256=SrYveglaiA_DXPoRBqSXClWM1q7853I5ujRorq_MV0M,4251
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=lCXRgmm48WXivUhIJHdj4a9DrbLbyywsh6G4yWE1Bro,706
6
+ ai_edge_torch/version.py,sha256=K5OPFteQXEGRnfR-yFeyGvU5qSbFmO67wiOGgKTvmi0,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=p1u1HJYzCxd7UyZ9_mjGOqJOZO4XeLQzmqdgO_6vL0Y,4755
8
+ ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
10
  ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=xuRI0WehbUlxLHvuYjj8MeyIKBtcCp10D3E1uD1MRdw,1168
13
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=vd2phg5j3Exn6BuGpASe5cU_wY4JV_YcNTssM6Q9k2c,4169
15
+ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=XCVqWg_ask0Kb64PED0ZGAODsUuIgfyO2ZJM6aK-TXI,4283
16
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
17
+ ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
17
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
18
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
19
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
@@ -25,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
25
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
26
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
27
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
28
- ai_edge_torch/_convert/test/test_convert.py,sha256=40QRxQFNeSRr4dLXJkzG-wKUlvJtsfv62cdvRrmBv5w,15097
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
29
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
30
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
31
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -100,7 +101,7 @@ ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjA
100
101
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
101
102
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
102
103
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
103
- ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
104
+ ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
104
105
  ai_edge_torch/generative/layers/builder.py,sha256=XyZS1RrnMbvypeLMfwU7h1Y4x5r4WGgOx2YGJF0OUNQ,5064
105
106
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
106
107
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
@@ -137,7 +138,7 @@ ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx
137
138
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
138
139
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
139
140
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
140
- ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=uiYRfzD1T8deCEAGfdAFusRbI41m14zeTt0Lz5lNT3M,9808
141
+ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
141
142
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
142
143
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
143
144
  ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
@@ -180,8 +181,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
180
181
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
181
182
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
182
183
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
183
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/METADATA,sha256=LQb4o0OMij14VvLSvUCYScCnWxxlTAp_MjNtaYI-nd4,1897
185
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/RECORD,,
184
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/METADATA,sha256=9jDElmb5TsG683OMDoH2ywytzqOAGNM5ImTPKEKyeFc,1897
186
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/RECORD,,