ai-edge-torch-nightly 0.3.0.dev20241010__py3-none-any.whl → 0.3.0.dev20241011__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,7 @@
15
15
 
16
16
  import logging
17
17
  import os
18
- from typing import Any, Optional
18
+ from typing import Any, Literal, Optional, Union
19
19
 
20
20
  from ai_edge_torch import fx_pass_base
21
21
  from ai_edge_torch import lowertools
@@ -73,6 +73,7 @@ def _warn_training_modules(signatures: list[signature.Signature]):
73
73
  def convert_signatures(
74
74
  signatures: list[signature.Signature],
75
75
  *,
76
+ strict_export: Union[Literal["auto"], bool] = True,
76
77
  quant_config: Optional[qcfg.QuantConfig] = None,
77
78
  _tfl_converter_flags: Optional[dict[str, Any]],
78
79
  ) -> model.TfLiteModel:
@@ -81,6 +82,11 @@ def convert_signatures(
81
82
  Args:
82
83
  signatures: The list of 'signature.Signature' objects containing PyTorch
83
84
  modules to be converted.
85
+ strict_export: Experimental `strict` arg for torch.export.export. When
86
+ enabled, the export function will trace the program through TorchDynamo
87
+ and ensure the soundness of the exported graph. When
88
+ strict_export="auto", the function will try to export module in both
89
+ modes and use the first one succeeds for downstream conversion.
84
90
  quant_config: User-defined quantization method and scheme of the model.
85
91
  _tfl_converter_flags: A nested dictionary allowing setting flags for the
86
92
  underlying tflite converter.
@@ -93,10 +99,24 @@ def convert_signatures(
93
99
 
94
100
  _warn_training_modules(signatures)
95
101
 
96
- exported_programs: torch.export.torch.export.ExportedProgram = [
97
- torch.export.export(
98
- sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
99
- )
102
+ def export(*args, **kwargs):
103
+ nonlocal strict_export
104
+ if strict_export == "auto":
105
+ try:
106
+ return torch.export.export(*args, **kwargs, strict=True)
107
+ except Exception:
108
+ logging.warning(
109
+ "torch.export.export(..., strict=True) failed. Retrying with"
110
+ " strict=False"
111
+ )
112
+ return torch.export.export(*args, **kwargs, strict=False)
113
+ elif not strict_export:
114
+ return torch.export.export(*args, **kwargs, strict=False)
115
+ else:
116
+ return torch.export.export(*args, **kwargs, strict=True)
117
+
118
+ exported_programs: torch.export.ExportedProgram = [
119
+ export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
100
120
  for sig in signatures
101
121
  ]
102
122
 
@@ -15,7 +15,7 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from typing import Any, Optional, Tuple, Union
18
+ from typing import Any, Literal, Optional, Tuple, Union
19
19
 
20
20
  from ai_edge_torch import model
21
21
  from ai_edge_torch._convert import conversion
@@ -102,6 +102,7 @@ class Converter:
102
102
  sample_args=None,
103
103
  sample_kwargs=None,
104
104
  *,
105
+ strict_export: Union[Literal["auto"], bool] = True,
105
106
  quant_config: Optional[qcfg.QuantConfig] = None,
106
107
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
107
108
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
@@ -123,6 +124,11 @@ class Converter:
123
124
  with prior to conversion.
124
125
  sample_kwargs: Dict of str to tensor by which the torch module will be
125
126
  traced with prior to conversion.
127
+ strict_export: Experimental `strict` arg for torch.export.export. When
128
+ enabled, the export function will trace the program through TorchDynamo
129
+ and ensure the soundness of the exported graph. When
130
+ strict_export="auto", the function will try to export module in both
131
+ modes and use the first one succeeds for downstream conversion.
126
132
  quant_config: User-defined quantization method and scheme of the model.
127
133
  dynamic_shapes: Optional dict or tuple that specify dynamic shape
128
134
  specifications for each input in original order. See
@@ -162,6 +168,7 @@ class Converter:
162
168
  )
163
169
  return conversion.convert_signatures(
164
170
  self._signatures,
171
+ strict_export=strict_export,
165
172
  quant_config=quant_config,
166
173
  _tfl_converter_flags=_ai_edge_converter_flags,
167
174
  )
@@ -205,6 +212,7 @@ def convert(
205
212
  sample_args=None,
206
213
  sample_kwargs=None,
207
214
  *,
215
+ strict_export: Union[Literal["auto"], bool] = True,
208
216
  quant_config: Optional[qcfg.QuantConfig] = None,
209
217
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
210
218
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
@@ -217,6 +225,11 @@ def convert(
217
225
  prior to conversion.
218
226
  sample_kwargs: Dict of str to tensor by which the torch module will be
219
227
  traced with prior to conversion.
228
+ strict_export: Experimental `strict` arg for torch.export.export. When
229
+ enabled, the export function will trace the program through TorchDynamo
230
+ and ensure the soundness of the exported graph. When strict_export="auto",
231
+ the function will try to export module in both modes and use the first one
232
+ succeeds for downstream conversion.
220
233
  quant_config: User-defined quantization method and scheme of the model.
221
234
  dynamic_shapes: Optional dict or tuple that specify dynamic shape
222
235
  specifications for each input in original order. See
@@ -242,6 +255,7 @@ def convert(
242
255
  module,
243
256
  sample_args,
244
257
  sample_kwargs,
258
+ strict_export=strict_export,
245
259
  quant_config=quant_config,
246
260
  dynamic_shapes=dynamic_shapes,
247
261
  _ai_edge_converter_flags=_ai_edge_converter_flags,
@@ -95,6 +95,15 @@ class CanonicalizePass(ExportedProgramPassBase):
95
95
  }
96
96
 
97
97
  def call(self, exported_program: torch.export.ExportedProgram):
98
+ for node in exported_program.graph.nodes:
99
+ if node.target == torch.ops.aten.view.default:
100
+ # Passes or torch.export may generate aten.view nodes not respecting the
101
+ # tensor memory format. Changes all the aten.view to torch.reshape
102
+ # for retracing. If the input memory format is already contiguous,
103
+ # retracing in run_decomposition below would decompose torch.reshape
104
+ # back to one aten.view.
105
+ node.target = lambda self, size: torch.reshape(self, size)
106
+
98
107
  exported_program = exported_program.run_decompositions(
99
108
  self._DUMMY_DECOMP_TABLE
100
109
  )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241010"
16
+ __version__ = "0.3.0.dev20241011"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241010
3
+ Version: 0.3.0.dev20241011
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,13 +1,13 @@
1
1
  ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
2
2
  ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
4
+ ai_edge_torch/fx_pass_base.py,sha256=ToyPo7MBRUNSZivwfLgAS-CNS5SnlhmZ7i_Pdzv9IJw,4238
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=bIO_IFfZmFf0EzBSomi1SB_NI_Q9y1vmOEGOd23r1GE,706
6
+ ai_edge_torch/version.py,sha256=EdlZ-aVAW07q00kkGVtIhUqLLXc1HnI2lw1V-ymvO_0,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
8
+ ai_edge_torch/_convert/conversion.py,sha256=p1u1HJYzCxd7UyZ9_mjGOqJOZO4XeLQzmqdgO_6vL0Y,4755
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
- ai_edge_torch/_convert/converter.py,sha256=ezmaATnQi7NWDo37LUb-hEXtZSmT7_AT6vqXC6Fcq1o,8615
10
+ ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=xuRI0WehbUlxLHvuYjj8MeyIKBtcCp10D3E1uD1MRdw,1168
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
180
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
181
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
182
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
183
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/METADATA,sha256=GHeN1JauoPCMl-nkwNdqrSCv_GXFlErnAjrSVhymN9E,1897
185
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
- ai_edge_torch_nightly-0.3.0.dev20241010.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/METADATA,sha256=JlypGVS0IGIdrAWambT-SDYN2jKv63-gR0B9xL9VZLw,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/RECORD,,