ai-edge-torch-nightly 0.3.0.dev20241010__py3-none-any.whl → 0.3.0.dev20241011__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +25 -5
- ai_edge_torch/_convert/converter.py +15 -1
- ai_edge_torch/fx_pass_base.py +9 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import os
|
18
|
-
from typing import Any, Optional
|
18
|
+
from typing import Any, Literal, Optional, Union
|
19
19
|
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
21
|
from ai_edge_torch import lowertools
|
@@ -73,6 +73,7 @@ def _warn_training_modules(signatures: list[signature.Signature]):
|
|
73
73
|
def convert_signatures(
|
74
74
|
signatures: list[signature.Signature],
|
75
75
|
*,
|
76
|
+
strict_export: Union[Literal["auto"], bool] = True,
|
76
77
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
77
78
|
_tfl_converter_flags: Optional[dict[str, Any]],
|
78
79
|
) -> model.TfLiteModel:
|
@@ -81,6 +82,11 @@ def convert_signatures(
|
|
81
82
|
Args:
|
82
83
|
signatures: The list of 'signature.Signature' objects containing PyTorch
|
83
84
|
modules to be converted.
|
85
|
+
strict_export: Experimental `strict` arg for torch.export.export. When
|
86
|
+
enabled, the export function will trace the program through TorchDynamo
|
87
|
+
and ensure the soundness of the exported graph. When
|
88
|
+
strict_export="auto", the function will try to export module in both
|
89
|
+
modes and use the first one succeeds for downstream conversion.
|
84
90
|
quant_config: User-defined quantization method and scheme of the model.
|
85
91
|
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
86
92
|
underlying tflite converter.
|
@@ -93,10 +99,24 @@ def convert_signatures(
|
|
93
99
|
|
94
100
|
_warn_training_modules(signatures)
|
95
101
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
102
|
+
def export(*args, **kwargs):
|
103
|
+
nonlocal strict_export
|
104
|
+
if strict_export == "auto":
|
105
|
+
try:
|
106
|
+
return torch.export.export(*args, **kwargs, strict=True)
|
107
|
+
except Exception:
|
108
|
+
logging.warning(
|
109
|
+
"torch.export.export(..., strict=True) failed. Retrying with"
|
110
|
+
" strict=False"
|
111
|
+
)
|
112
|
+
return torch.export.export(*args, **kwargs, strict=False)
|
113
|
+
elif not strict_export:
|
114
|
+
return torch.export.export(*args, **kwargs, strict=False)
|
115
|
+
else:
|
116
|
+
return torch.export.export(*args, **kwargs, strict=True)
|
117
|
+
|
118
|
+
exported_programs: torch.export.ExportedProgram = [
|
119
|
+
export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
|
100
120
|
for sig in signatures
|
101
121
|
]
|
102
122
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
from typing import Any, Optional, Tuple, Union
|
18
|
+
from typing import Any, Literal, Optional, Tuple, Union
|
19
19
|
|
20
20
|
from ai_edge_torch import model
|
21
21
|
from ai_edge_torch._convert import conversion
|
@@ -102,6 +102,7 @@ class Converter:
|
|
102
102
|
sample_args=None,
|
103
103
|
sample_kwargs=None,
|
104
104
|
*,
|
105
|
+
strict_export: Union[Literal["auto"], bool] = True,
|
105
106
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
106
107
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
107
108
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
@@ -123,6 +124,11 @@ class Converter:
|
|
123
124
|
with prior to conversion.
|
124
125
|
sample_kwargs: Dict of str to tensor by which the torch module will be
|
125
126
|
traced with prior to conversion.
|
127
|
+
strict_export: Experimental `strict` arg for torch.export.export. When
|
128
|
+
enabled, the export function will trace the program through TorchDynamo
|
129
|
+
and ensure the soundness of the exported graph. When
|
130
|
+
strict_export="auto", the function will try to export module in both
|
131
|
+
modes and use the first one succeeds for downstream conversion.
|
126
132
|
quant_config: User-defined quantization method and scheme of the model.
|
127
133
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
128
134
|
specifications for each input in original order. See
|
@@ -162,6 +168,7 @@ class Converter:
|
|
162
168
|
)
|
163
169
|
return conversion.convert_signatures(
|
164
170
|
self._signatures,
|
171
|
+
strict_export=strict_export,
|
165
172
|
quant_config=quant_config,
|
166
173
|
_tfl_converter_flags=_ai_edge_converter_flags,
|
167
174
|
)
|
@@ -205,6 +212,7 @@ def convert(
|
|
205
212
|
sample_args=None,
|
206
213
|
sample_kwargs=None,
|
207
214
|
*,
|
215
|
+
strict_export: Union[Literal["auto"], bool] = True,
|
208
216
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
209
217
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
210
218
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
@@ -217,6 +225,11 @@ def convert(
|
|
217
225
|
prior to conversion.
|
218
226
|
sample_kwargs: Dict of str to tensor by which the torch module will be
|
219
227
|
traced with prior to conversion.
|
228
|
+
strict_export: Experimental `strict` arg for torch.export.export. When
|
229
|
+
enabled, the export function will trace the program through TorchDynamo
|
230
|
+
and ensure the soundness of the exported graph. When strict_export="auto",
|
231
|
+
the function will try to export module in both modes and use the first one
|
232
|
+
succeeds for downstream conversion.
|
220
233
|
quant_config: User-defined quantization method and scheme of the model.
|
221
234
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
222
235
|
specifications for each input in original order. See
|
@@ -242,6 +255,7 @@ def convert(
|
|
242
255
|
module,
|
243
256
|
sample_args,
|
244
257
|
sample_kwargs,
|
258
|
+
strict_export=strict_export,
|
245
259
|
quant_config=quant_config,
|
246
260
|
dynamic_shapes=dynamic_shapes,
|
247
261
|
_ai_edge_converter_flags=_ai_edge_converter_flags,
|
ai_edge_torch/fx_pass_base.py
CHANGED
@@ -95,6 +95,15 @@ class CanonicalizePass(ExportedProgramPassBase):
|
|
95
95
|
}
|
96
96
|
|
97
97
|
def call(self, exported_program: torch.export.ExportedProgram):
|
98
|
+
for node in exported_program.graph.nodes:
|
99
|
+
if node.target == torch.ops.aten.view.default:
|
100
|
+
# Passes or torch.export may generate aten.view nodes not respecting the
|
101
|
+
# tensor memory format. Changes all the aten.view to torch.reshape
|
102
|
+
# for retracing. If the input memory format is already contiguous,
|
103
|
+
# retracing in run_decomposition below would decompose torch.reshape
|
104
|
+
# back to one aten.view.
|
105
|
+
node.target = lambda self, size: torch.reshape(self, size)
|
106
|
+
|
98
107
|
exported_program = exported_program.run_decompositions(
|
99
108
|
self._DUMMY_DECOMP_TABLE
|
100
109
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241011
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,13 +1,13 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
|
2
2
|
ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
-
ai_edge_torch/fx_pass_base.py,sha256=
|
4
|
+
ai_edge_torch/fx_pass_base.py,sha256=ToyPo7MBRUNSZivwfLgAS-CNS5SnlhmZ7i_Pdzv9IJw,4238
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=EdlZ-aVAW07q00kkGVtIhUqLLXc1HnI2lw1V-ymvO_0,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=p1u1HJYzCxd7UyZ9_mjGOqJOZO4XeLQzmqdgO_6vL0Y,4755
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
|
-
ai_edge_torch/_convert/converter.py,sha256=
|
10
|
+
ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=xuRI0WehbUlxLHvuYjj8MeyIKBtcCp10D3E1uD1MRdw,1168
|
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/METADATA,sha256=JlypGVS0IGIdrAWambT-SDYN2jKv63-gR0B9xL9VZLw,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/RECORD,,
|
File without changes
|
File without changes
|