ai-edge-torch-nightly 0.3.0.dev20241014__py3-none-any.whl → 0.3.0.dev20241018__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -35,7 +35,7 @@ def _run_convert_passes(
35
35
  exported_program = generative_fx_passes.run_generative_passes(
36
36
  exported_program
37
37
  )
38
- return fx_pass_base.run_passes(
38
+ exported_program = fx_pass_base.run_passes(
39
39
  exported_program,
40
40
  [
41
41
  fx_passes.BuildInterpolateCompositePass(),
@@ -44,10 +44,13 @@ def _run_convert_passes(
44
44
  fx_passes.CanonicalizePass(),
45
45
  fx_passes.BuildAtenCompositePass(),
46
46
  fx_passes.CanonicalizePass(),
47
+ fx_passes.RemoveNonUserOutputsPass(),
48
+ fx_passes.CanonicalizePass(),
47
49
  fx_passes.InjectMlirDebuginfoPass(),
48
50
  fx_passes.CanonicalizePass(),
49
51
  ],
50
52
  )
53
+ return exported_program
51
54
 
52
55
 
53
56
  def _warn_training_modules(signatures: list[signature.Signature]):
@@ -103,17 +106,27 @@ def convert_signatures(
103
106
  nonlocal strict_export
104
107
  if strict_export == "auto":
105
108
  try:
106
- return torch.export.export(*args, **kwargs, strict=True)
109
+ exported_program = torch.export.export(*args, **kwargs, strict=True)
107
110
  except Exception:
108
111
  logging.warning(
109
112
  "torch.export.export(..., strict=True) failed. Retrying with"
110
113
  " strict=False"
111
114
  )
112
- return torch.export.export(*args, **kwargs, strict=False)
115
+ exported_program = torch.export.export(*args, **kwargs, strict=False)
113
116
  elif not strict_export:
114
- return torch.export.export(*args, **kwargs, strict=False)
117
+ exported_program = torch.export.export(*args, **kwargs, strict=False)
115
118
  else:
116
- return torch.export.export(*args, **kwargs, strict=True)
119
+ exported_program = torch.export.export(*args, **kwargs, strict=True)
120
+
121
+ if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
122
+ # Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
123
+ # stop-gap table which replicates the old behaviour of post-dispatch IR.
124
+ # This could help ensure the collection of aten ops remaining still as the
125
+ # implementation of torch.export changes.
126
+ exported_program = exported_program.run_decompositions(
127
+ torch._decomp._decomp_table_to_post_autograd_aten()
128
+ )
129
+ return exported_program
117
130
 
118
131
  exported_programs: torch.export.ExportedProgram = [
119
132
  export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
@@ -19,4 +19,5 @@ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAten
19
19
  from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
20
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
+ from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
22
23
  from ai_edge_torch.fx_pass_base import CanonicalizePass
@@ -113,6 +113,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
113
113
  ]
114
114
 
115
115
  def call(self, exported_program: torch.export.ExportedProgram):
116
+ exported_program = fx_pass_base.run_passes(
117
+ exported_program, [fx_pass_base.CanonicalizePass()]
118
+ )
116
119
  exported_program = exported_program.run_decompositions(
117
120
  _INTERPOLATE_DECOMPOSITIONS
118
121
  )
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Pass to remove all non user outputs from exported program."""
16
+
17
+
18
+ from ai_edge_torch import fx_pass_base
19
+ import torch
20
+
21
+
22
+ class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
23
+ """This pass removes all non user outputs from the exported program's output.
24
+
25
+ The FX graph may output more tensors/data than what user's original model
26
+ returns. Those additional outputs include user input mutations, gradient to
27
+ parameter, etc. Those outputs are not supported by our inference only
28
+ conversion or runtime. This pass remove all those outputs to ensure the
29
+ converted models' outputs match what returned from user's model in eval mode.
30
+ """
31
+
32
+ def call(self, exported_program: torch.export.ExportedProgram):
33
+ for node in exported_program.graph.nodes:
34
+ if node.op != "output":
35
+ continue
36
+
37
+ outputs = node.args[0]
38
+ output_specs = exported_program.graph_signature.output_specs
39
+
40
+ new_outputs = []
41
+ new_output_specs = []
42
+ for output, spec in zip(outputs, output_specs):
43
+ if spec.kind == torch.export.graph_signature.OutputKind.USER_OUTPUT:
44
+ new_outputs.append(output)
45
+ new_output_specs.append(spec)
46
+
47
+ node.args = (tuple(new_outputs),)
48
+ exported_program.graph_signature.output_specs = new_output_specs
49
+
50
+ exported_program.graph_module.graph.lint()
51
+ exported_program.graph_module.recompile()
52
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -490,6 +490,22 @@ class TestConvert(googletest.TestCase):
490
490
  tflite_output = edge_model(**flat_inputs)
491
491
  np.testing.assert_almost_equal(reference_output, tflite_output)
492
492
 
493
+ def test_convert_model_with_input_mutation(self):
494
+ class SampleModel(nn.Module):
495
+
496
+ def forward(self, x):
497
+ x /= 1
498
+ x = x + 10
499
+ return x
500
+
501
+ args = (torch.randn(10, 10),)
502
+ torch_module = SampleModel().eval()
503
+ edge_model = ai_edge_torch.convert(torch_module, args)
504
+
505
+ self.assertTrue(
506
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
507
+ )
508
+
493
509
 
494
510
  if __name__ == "__main__":
495
511
  googletest.main()
@@ -102,7 +102,7 @@ class CanonicalizePass(ExportedProgramPassBase):
102
102
  # for retracing. If the input memory format is already contiguous,
103
103
  # retracing in run_decomposition below would decompose torch.reshape
104
104
  # back to one aten.view.
105
- node.target = lambda self, size: torch.reshape(self, size)
105
+ node.target = lambda self, size: torch.reshape(self.contiguous(), size)
106
106
 
107
107
  exported_program = exported_program.run_decompositions(
108
108
  self._DUMMY_DECOMP_TABLE
@@ -107,7 +107,7 @@ def build_sliding_window_mask_cache(
107
107
  sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
108
108
  all_ones, window_size - 1
109
109
  )
110
- return torch.where(sliding_mask == 1, mask, -2.3819763e38)
110
+ return torch.where(sliding_mask == 1, mask, float('-inf'))
111
111
 
112
112
 
113
113
  def relative_position_bucket(
@@ -17,6 +17,7 @@
17
17
  import dataclasses
18
18
  from typing import Any, Callable, Optional, Union
19
19
 
20
+ from ai_edge_torch import fx_pass_base
20
21
  from ai_edge_torch.hlfb.mark_pattern import passes
21
22
  import torch
22
23
  from torch.export.graph_signature import TensorArgument
@@ -116,6 +117,7 @@ def _find_scalar_attr(
116
117
  track_args[tracker.pattern_arg_pos] = source
117
118
  ep = torch.export.export(pattern_module, tuple(track_args))
118
119
  if decomp_table is not None:
120
+ ep = fx_pass_base.run_passes(ep, [fx_pass_base.CanonicalizePass()])
119
121
  ep = ep.run_decompositions(decomp_table)
120
122
 
121
123
  scalar_locs = set()
@@ -198,6 +200,9 @@ class Pattern:
198
200
 
199
201
  exported_program = torch.export.export(module, export_args)
200
202
  if decomp_table is not None:
203
+ exported_program = fx_pass_base.run_passes(
204
+ exported_program, [fx_pass_base.CanonicalizePass()]
205
+ )
201
206
  exported_program = exported_program.run_decompositions(decomp_table)
202
207
 
203
208
  self.exported_program = exported_program
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241014"
16
+ __version__ = "0.3.0.dev20241018"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241014
3
+ Version: 0.3.0.dev20241018
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,19 +1,20 @@
1
1
  ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
2
2
  ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/fx_pass_base.py,sha256=ToyPo7MBRUNSZivwfLgAS-CNS5SnlhmZ7i_Pdzv9IJw,4238
4
+ ai_edge_torch/fx_pass_base.py,sha256=SrYveglaiA_DXPoRBqSXClWM1q7853I5ujRorq_MV0M,4251
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=lCXRgmm48WXivUhIJHdj4a9DrbLbyywsh6G4yWE1Bro,706
6
+ ai_edge_torch/version.py,sha256=K5OPFteQXEGRnfR-yFeyGvU5qSbFmO67wiOGgKTvmi0,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=p1u1HJYzCxd7UyZ9_mjGOqJOZO4XeLQzmqdgO_6vL0Y,4755
8
+ ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
10
  ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=xuRI0WehbUlxLHvuYjj8MeyIKBtcCp10D3E1uD1MRdw,1168
13
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=vd2phg5j3Exn6BuGpASe5cU_wY4JV_YcNTssM6Q9k2c,4169
15
+ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=XCVqWg_ask0Kb64PED0ZGAODsUuIgfyO2ZJM6aK-TXI,4283
16
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
17
+ ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
17
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
18
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
19
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
@@ -25,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
25
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
26
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
27
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
28
- ai_edge_torch/_convert/test/test_convert.py,sha256=40QRxQFNeSRr4dLXJkzG-wKUlvJtsfv62cdvRrmBv5w,15097
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
29
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
30
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
31
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -100,7 +101,7 @@ ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjA
100
101
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
101
102
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
102
103
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
103
- ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
104
+ ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
104
105
  ai_edge_torch/generative/layers/builder.py,sha256=XyZS1RrnMbvypeLMfwU7h1Y4x5r4WGgOx2YGJF0OUNQ,5064
105
106
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
106
107
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
@@ -137,7 +138,7 @@ ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx
137
138
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
138
139
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
139
140
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
140
- ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=uiYRfzD1T8deCEAGfdAFusRbI41m14zeTt0Lz5lNT3M,9808
141
+ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
141
142
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
142
143
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
143
144
  ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
@@ -180,8 +181,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
180
181
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
181
182
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
182
183
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
183
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/METADATA,sha256=LQb4o0OMij14VvLSvUCYScCnWxxlTAp_MjNtaYI-nd4,1897
185
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
- ai_edge_torch_nightly-0.3.0.dev20241014.dist-info/RECORD,,
184
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/METADATA,sha256=9jDElmb5TsG683OMDoH2ywytzqOAGNM5ImTPKEKyeFc,1897
186
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
+ ai_edge_torch_nightly-0.3.0.dev20241018.dist-info/RECORD,,