ai-edge-torch-nightly 0.3.0.dev20241010__py3-none-any.whl → 0.3.0.dev20241011__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,7 +15,7 @@
15
15
 
16
16
  import logging
17
17
  import os
18
- from typing import Any, Optional
18
+ from typing import Any, Literal, Optional, Union
19
19
 
20
20
  from ai_edge_torch import fx_pass_base
21
21
  from ai_edge_torch import lowertools
@@ -73,6 +73,7 @@ def _warn_training_modules(signatures: list[signature.Signature]):
73
73
  def convert_signatures(
74
74
  signatures: list[signature.Signature],
75
75
  *,
76
+ strict_export: Union[Literal["auto"], bool] = True,
76
77
  quant_config: Optional[qcfg.QuantConfig] = None,
77
78
  _tfl_converter_flags: Optional[dict[str, Any]],
78
79
  ) -> model.TfLiteModel:
@@ -81,6 +82,11 @@ def convert_signatures(
81
82
  Args:
82
83
  signatures: The list of 'signature.Signature' objects containing PyTorch
83
84
  modules to be converted.
85
+ strict_export: Experimental `strict` arg for torch.export.export. When
86
+ enabled, the export function will trace the program through TorchDynamo
87
+ and ensure the soundness of the exported graph. When
88
+ strict_export="auto", the function will try to export module in both
89
+ modes and use the first one succeeds for downstream conversion.
84
90
  quant_config: User-defined quantization method and scheme of the model.
85
91
  _tfl_converter_flags: A nested dictionary allowing setting flags for the
86
92
  underlying tflite converter.
@@ -93,10 +99,24 @@ def convert_signatures(
93
99
 
94
100
  _warn_training_modules(signatures)
95
101
 
96
- exported_programs: torch.export.torch.export.ExportedProgram = [
97
- torch.export.export(
98
- sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
99
- )
102
+ def export(*args, **kwargs):
103
+ nonlocal strict_export
104
+ if strict_export == "auto":
105
+ try:
106
+ return torch.export.export(*args, **kwargs, strict=True)
107
+ except Exception:
108
+ logging.warning(
109
+ "torch.export.export(..., strict=True) failed. Retrying with"
110
+ " strict=False"
111
+ )
112
+ return torch.export.export(*args, **kwargs, strict=False)
113
+ elif not strict_export:
114
+ return torch.export.export(*args, **kwargs, strict=False)
115
+ else:
116
+ return torch.export.export(*args, **kwargs, strict=True)
117
+
118
+ exported_programs: torch.export.ExportedProgram = [
119
+ export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
100
120
  for sig in signatures
101
121
  ]
102
122
 
@@ -15,7 +15,7 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from typing import Any, Optional, Tuple, Union
18
+ from typing import Any, Literal, Optional, Tuple, Union
19
19
 
20
20
  from ai_edge_torch import model
21
21
  from ai_edge_torch._convert import conversion
@@ -102,6 +102,7 @@ class Converter:
102
102
  sample_args=None,
103
103
  sample_kwargs=None,
104
104
  *,
105
+ strict_export: Union[Literal["auto"], bool] = True,
105
106
  quant_config: Optional[qcfg.QuantConfig] = None,
106
107
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
107
108
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
@@ -123,6 +124,11 @@ class Converter:
123
124
  with prior to conversion.
124
125
  sample_kwargs: Dict of str to tensor by which the torch module will be
125
126
  traced with prior to conversion.
127
+ strict_export: Experimental `strict` arg for torch.export.export. When
128
+ enabled, the export function will trace the program through TorchDynamo
129
+ and ensure the soundness of the exported graph. When
130
+ strict_export="auto", the function will try to export module in both
131
+ modes and use the first one succeeds for downstream conversion.
126
132
  quant_config: User-defined quantization method and scheme of the model.
127
133
  dynamic_shapes: Optional dict or tuple that specify dynamic shape
128
134
  specifications for each input in original order. See
@@ -162,6 +168,7 @@ class Converter:
162
168
  )
163
169
  return conversion.convert_signatures(
164
170
  self._signatures,
171
+ strict_export=strict_export,
165
172
  quant_config=quant_config,
166
173
  _tfl_converter_flags=_ai_edge_converter_flags,
167
174
  )
@@ -205,6 +212,7 @@ def convert(
205
212
  sample_args=None,
206
213
  sample_kwargs=None,
207
214
  *,
215
+ strict_export: Union[Literal["auto"], bool] = True,
208
216
  quant_config: Optional[qcfg.QuantConfig] = None,
209
217
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
210
218
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
@@ -217,6 +225,11 @@ def convert(
217
225
  prior to conversion.
218
226
  sample_kwargs: Dict of str to tensor by which the torch module will be
219
227
  traced with prior to conversion.
228
+ strict_export: Experimental `strict` arg for torch.export.export. When
229
+ enabled, the export function will trace the program through TorchDynamo
230
+ and ensure the soundness of the exported graph. When strict_export="auto",
231
+ the function will try to export module in both modes and use the first one
232
+ succeeds for downstream conversion.
220
233
  quant_config: User-defined quantization method and scheme of the model.
221
234
  dynamic_shapes: Optional dict or tuple that specify dynamic shape
222
235
  specifications for each input in original order. See
@@ -242,6 +255,7 @@ def convert(
242
255
  module,
243
256
  sample_args,
244
257
  sample_kwargs,
258
+ strict_export=strict_export,
245
259
  quant_config=quant_config,
246
260
  dynamic_shapes=dynamic_shapes,
247
261
  _ai_edge_converter_flags=_ai_edge_converter_flags,
@@ -95,6 +95,15 @@ class CanonicalizePass(ExportedProgramPassBase):
95
95
  }
96
96
 
97
97
  def call(self, exported_program: torch.export.ExportedProgram):
98
+ for node in exported_program.graph.nodes:
99
+ if node.target == torch.ops.aten.view.default:
100
+ # Passes or torch.export may generate aten.view nodes not respecting the
101
+ # tensor memory format. Changes all the aten.view to torch.reshape
102
+ # for retracing. If the input memory format is already contiguous,
103
+ # retracing in run_decomposition below would decompose torch.reshape
104
+ # back to one aten.view.
105
+ node.target = lambda self, size: torch.reshape(self, size)
106
+
98
107
  exported_program = exported_program.run_decompositions(
99
108
  self._DUMMY_DECOMP_TABLE
100
109
  )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241010"
16
+ __version__ = "0.3.0.dev20241011"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241010
3
+ Version: 0.3.0.dev20241011
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,13 +1,13 @@
1
1
  ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
2
2
  ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
4
+ ai_edge_torch/fx_pass_base.py,sha256=ToyPo7MBRUNSZivwfLgAS-CNS5SnlhmZ7i_Pdzv9IJw,4238
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=bIO_IFfZmFf0EzBSomi1SB_NI_Q9y1vmOEGOd23r1GE,706
6
+ ai_edge_torch/version.py,sha256=EdlZ-aVAW07q00kkGVtIhUqLLXc1HnI2lw1V-ymvO_0,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
8
+ ai_edge_torch/_convert/conversion.py,sha256=p1u1HJYzCxd7UyZ9_mjGOqJOZO4XeLQzmqdgO_6vL0Y,4755
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
- ai_edge_torch/_convert/converter.py,sha256=ezmaATnQi7NWDo37LUb-hEXtZSmT7_AT6vqXC6Fcq1o,8615
10
+ ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=xuRI0WehbUlxLHvuYjj8MeyIKBtcCp10D3E1uD1MRdw,1168
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
180
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
181
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
182
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
183
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/METADATA,sha256=GHeN1JauoPCMl-nkwNdqrSCv_GXFlErnAjrSVhymN9E,1897
185
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/METADATA,sha256=JlypGVS0IGIdrAWambT-SDYN2jKv63-gR0B9xL9VZLw,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/RECORD,,