ai-edge-torch-nightly 0.3.0.dev20241010__py3-none-any.whl → 0.3.0.dev20241011__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/conversion.py +25 -5
- ai_edge_torch/_convert/converter.py +15 -1
- ai_edge_torch/fx_pass_base.py +9 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241010.dist-info → ai_edge_torch_nightly-0.3.0.dev20241011.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import os
|
18
|
-
from typing import Any, Optional
|
18
|
+
from typing import Any, Literal, Optional, Union
|
19
19
|
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
21
|
from ai_edge_torch import lowertools
|
@@ -73,6 +73,7 @@ def _warn_training_modules(signatures: list[signature.Signature]):
|
|
73
73
|
def convert_signatures(
|
74
74
|
signatures: list[signature.Signature],
|
75
75
|
*,
|
76
|
+
strict_export: Union[Literal["auto"], bool] = True,
|
76
77
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
77
78
|
_tfl_converter_flags: Optional[dict[str, Any]],
|
78
79
|
) -> model.TfLiteModel:
|
@@ -81,6 +82,11 @@ def convert_signatures(
|
|
81
82
|
Args:
|
82
83
|
signatures: The list of 'signature.Signature' objects containing PyTorch
|
83
84
|
modules to be converted.
|
85
|
+
strict_export: Experimental `strict` arg for torch.export.export. When
|
86
|
+
enabled, the export function will trace the program through TorchDynamo
|
87
|
+
and ensure the soundness of the exported graph. When
|
88
|
+
strict_export="auto", the function will try to export module in both
|
89
|
+
modes and use the first one succeeds for downstream conversion.
|
84
90
|
quant_config: User-defined quantization method and scheme of the model.
|
85
91
|
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
86
92
|
underlying tflite converter.
|
@@ -93,10 +99,24 @@ def convert_signatures(
|
|
93
99
|
|
94
100
|
_warn_training_modules(signatures)
|
95
101
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
102
|
+
def export(*args, **kwargs):
|
103
|
+
nonlocal strict_export
|
104
|
+
if strict_export == "auto":
|
105
|
+
try:
|
106
|
+
return torch.export.export(*args, **kwargs, strict=True)
|
107
|
+
except Exception:
|
108
|
+
logging.warning(
|
109
|
+
"torch.export.export(..., strict=True) failed. Retrying with"
|
110
|
+
" strict=False"
|
111
|
+
)
|
112
|
+
return torch.export.export(*args, **kwargs, strict=False)
|
113
|
+
elif not strict_export:
|
114
|
+
return torch.export.export(*args, **kwargs, strict=False)
|
115
|
+
else:
|
116
|
+
return torch.export.export(*args, **kwargs, strict=True)
|
117
|
+
|
118
|
+
exported_programs: torch.export.ExportedProgram = [
|
119
|
+
export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
|
100
120
|
for sig in signatures
|
101
121
|
]
|
102
122
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
from typing import Any, Optional, Tuple, Union
|
18
|
+
from typing import Any, Literal, Optional, Tuple, Union
|
19
19
|
|
20
20
|
from ai_edge_torch import model
|
21
21
|
from ai_edge_torch._convert import conversion
|
@@ -102,6 +102,7 @@ class Converter:
|
|
102
102
|
sample_args=None,
|
103
103
|
sample_kwargs=None,
|
104
104
|
*,
|
105
|
+
strict_export: Union[Literal["auto"], bool] = True,
|
105
106
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
106
107
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
107
108
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
@@ -123,6 +124,11 @@ class Converter:
|
|
123
124
|
with prior to conversion.
|
124
125
|
sample_kwargs: Dict of str to tensor by which the torch module will be
|
125
126
|
traced with prior to conversion.
|
127
|
+
strict_export: Experimental `strict` arg for torch.export.export. When
|
128
|
+
enabled, the export function will trace the program through TorchDynamo
|
129
|
+
and ensure the soundness of the exported graph. When
|
130
|
+
strict_export="auto", the function will try to export module in both
|
131
|
+
modes and use the first one succeeds for downstream conversion.
|
126
132
|
quant_config: User-defined quantization method and scheme of the model.
|
127
133
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
128
134
|
specifications for each input in original order. See
|
@@ -162,6 +168,7 @@ class Converter:
|
|
162
168
|
)
|
163
169
|
return conversion.convert_signatures(
|
164
170
|
self._signatures,
|
171
|
+
strict_export=strict_export,
|
165
172
|
quant_config=quant_config,
|
166
173
|
_tfl_converter_flags=_ai_edge_converter_flags,
|
167
174
|
)
|
@@ -205,6 +212,7 @@ def convert(
|
|
205
212
|
sample_args=None,
|
206
213
|
sample_kwargs=None,
|
207
214
|
*,
|
215
|
+
strict_export: Union[Literal["auto"], bool] = True,
|
208
216
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
209
217
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
210
218
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
@@ -217,6 +225,11 @@ def convert(
|
|
217
225
|
prior to conversion.
|
218
226
|
sample_kwargs: Dict of str to tensor by which the torch module will be
|
219
227
|
traced with prior to conversion.
|
228
|
+
strict_export: Experimental `strict` arg for torch.export.export. When
|
229
|
+
enabled, the export function will trace the program through TorchDynamo
|
230
|
+
and ensure the soundness of the exported graph. When strict_export="auto",
|
231
|
+
the function will try to export module in both modes and use the first one
|
232
|
+
succeeds for downstream conversion.
|
220
233
|
quant_config: User-defined quantization method and scheme of the model.
|
221
234
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
222
235
|
specifications for each input in original order. See
|
@@ -242,6 +255,7 @@ def convert(
|
|
242
255
|
module,
|
243
256
|
sample_args,
|
244
257
|
sample_kwargs,
|
258
|
+
strict_export=strict_export,
|
245
259
|
quant_config=quant_config,
|
246
260
|
dynamic_shapes=dynamic_shapes,
|
247
261
|
_ai_edge_converter_flags=_ai_edge_converter_flags,
|
ai_edge_torch/fx_pass_base.py
CHANGED
@@ -95,6 +95,15 @@ class CanonicalizePass(ExportedProgramPassBase):
|
|
95
95
|
}
|
96
96
|
|
97
97
|
def call(self, exported_program: torch.export.ExportedProgram):
|
98
|
+
for node in exported_program.graph.nodes:
|
99
|
+
if node.target == torch.ops.aten.view.default:
|
100
|
+
# Passes or torch.export may generate aten.view nodes not respecting the
|
101
|
+
# tensor memory format. Changes all the aten.view to torch.reshape
|
102
|
+
# for retracing. If the input memory format is already contiguous,
|
103
|
+
# retracing in run_decomposition below would decompose torch.reshape
|
104
|
+
# back to one aten.view.
|
105
|
+
node.target = lambda self, size: torch.reshape(self, size)
|
106
|
+
|
98
107
|
exported_program = exported_program.run_decompositions(
|
99
108
|
self._DUMMY_DECOMP_TABLE
|
100
109
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241011
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,13 +1,13 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
|
2
2
|
ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
-
ai_edge_torch/fx_pass_base.py,sha256=
|
4
|
+
ai_edge_torch/fx_pass_base.py,sha256=ToyPo7MBRUNSZivwfLgAS-CNS5SnlhmZ7i_Pdzv9IJw,4238
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=EdlZ-aVAW07q00kkGVtIhUqLLXc1HnI2lw1V-ymvO_0,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=p1u1HJYzCxd7UyZ9_mjGOqJOZO4XeLQzmqdgO_6vL0Y,4755
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
|
-
ai_edge_torch/_convert/converter.py,sha256=
|
10
|
+
ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=xuRI0WehbUlxLHvuYjj8MeyIKBtcCp10D3E1uD1MRdw,1168
|
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/METADATA,sha256=JlypGVS0IGIdrAWambT-SDYN2jKv63-gR0B9xL9VZLw,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241011.dist-info/RECORD,,
|
File without changes
|
File without changes
|