ai-edge-torch-nightly 0.3.0.dev20241216__py3-none-any.whl → 0.3.0.dev20241217__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/conversion.py +5 -1
- ai_edge_torch/_convert/converter.py +8 -0
- ai_edge_torch/lowertools/_shim.py +4 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +4 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +6 -0
- ai_edge_torch/odml_torch/export.py +4 -0
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +5 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/top_level.txt +0 -0
@@ -78,7 +78,8 @@ def convert_signatures(
|
|
78
78
|
*,
|
79
79
|
strict_export: Union[Literal["auto"], bool] = True,
|
80
80
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
81
|
-
_tfl_converter_flags: Optional[dict[str, Any]],
|
81
|
+
_tfl_converter_flags: Optional[dict[str, Any]] = None,
|
82
|
+
_saved_model_dir: Optional[str] = None,
|
82
83
|
) -> model.TfLiteModel:
|
83
84
|
"""Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
|
84
85
|
|
@@ -93,6 +94,8 @@ def convert_signatures(
|
|
93
94
|
quant_config: User-defined quantization method and scheme of the model.
|
94
95
|
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
95
96
|
underlying tflite converter.
|
97
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
98
|
+
specified, a random temporary directory would be used.
|
96
99
|
|
97
100
|
Returns:
|
98
101
|
The converted `model.TfLiteModel` object.
|
@@ -140,6 +143,7 @@ def convert_signatures(
|
|
140
143
|
signatures,
|
141
144
|
quant_config=quant_config,
|
142
145
|
_tfl_converter_flags=_tfl_converter_flags,
|
146
|
+
_saved_model_dir=_saved_model_dir,
|
143
147
|
)
|
144
148
|
|
145
149
|
return model.TfLiteModel(tflite_model)
|
@@ -106,6 +106,7 @@ class Converter:
|
|
106
106
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
107
107
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
108
108
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
109
|
+
_saved_model_dir: Optional[str] = None,
|
109
110
|
) -> model.TfLiteModel:
|
110
111
|
"""Finalizes the conversion and produces an edge model.
|
111
112
|
|
@@ -139,6 +140,8 @@ class Converter:
|
|
139
140
|
of this function and so needs to be treated as such. Please do not rely
|
140
141
|
on this parameter except for local debugging as this can be removed in a
|
141
142
|
future release.
|
143
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
144
|
+
specified, a random temporary directory would be used.
|
142
145
|
|
143
146
|
Returns:
|
144
147
|
The converted edge model.
|
@@ -171,6 +174,7 @@ class Converter:
|
|
171
174
|
strict_export=strict_export,
|
172
175
|
quant_config=quant_config,
|
173
176
|
_tfl_converter_flags=_ai_edge_converter_flags,
|
177
|
+
_saved_model_dir=_saved_model_dir,
|
174
178
|
)
|
175
179
|
|
176
180
|
|
@@ -216,6 +220,7 @@ def convert(
|
|
216
220
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
217
221
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
218
222
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
223
|
+
_saved_model_dir: Optional[str] = None,
|
219
224
|
) -> model.TfLiteModel:
|
220
225
|
"""Converts a PyTorch model to an edge model with a default signature.
|
221
226
|
|
@@ -240,6 +245,8 @@ def convert(
|
|
240
245
|
this function and so needs to be treated as such. Please do not rely on
|
241
246
|
this parameter except for local debugging as this can be removed in a
|
242
247
|
future release.
|
248
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
249
|
+
specified, a random temporary directory would be used.
|
243
250
|
|
244
251
|
Returns:
|
245
252
|
The converted edge model.
|
@@ -259,4 +266,5 @@ def convert(
|
|
259
266
|
quant_config=quant_config,
|
260
267
|
dynamic_shapes=dynamic_shapes,
|
261
268
|
_ai_edge_converter_flags=_ai_edge_converter_flags,
|
269
|
+
_saved_model_dir=_saved_model_dir,
|
262
270
|
)
|
@@ -50,6 +50,7 @@ def exported_programs_to_tflite(
|
|
50
50
|
*,
|
51
51
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
52
52
|
_tfl_converter_flags: Optional[dict[str, Any]] = None,
|
53
|
+
_saved_model_dir: Optional[str] = None
|
53
54
|
):
|
54
55
|
"""Converts a list of ExportedProgram to a TFLite model.
|
55
56
|
|
@@ -57,6 +58,8 @@ def exported_programs_to_tflite(
|
|
57
58
|
exported_programs: A list of ExportedProgram.
|
58
59
|
signatures: A list of Signature.
|
59
60
|
quant_config: A QuantConfig.
|
61
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
62
|
+
specified, a random temporary directory would be used.
|
60
63
|
_tfl_converter_flags: A dict of flags for TFLiteConverter.
|
61
64
|
|
62
65
|
Returns:
|
@@ -79,4 +82,5 @@ def exported_programs_to_tflite(
|
|
79
82
|
signatures,
|
80
83
|
quant_config=quant_config,
|
81
84
|
_tfl_converter_flags=_tfl_converter_flags,
|
85
|
+
_saved_model_dir=_saved_model_dir,
|
82
86
|
)
|
@@ -138,6 +138,7 @@ def merged_bundle_to_tfl_model(
|
|
138
138
|
*,
|
139
139
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
140
140
|
_tfl_converter_flags: dict = {},
|
141
|
+
_saved_model_dir: Optional[str] = None,
|
141
142
|
):
|
142
143
|
tf_state_dict = merged_bundle.bundles[0].state_dict
|
143
144
|
|
@@ -173,6 +174,9 @@ def merged_bundle_to_tfl_model(
|
|
173
174
|
# We need to temporarily save since TFLite's from_concrete_functions does not
|
174
175
|
# allow providing names for each of the concrete functions.
|
175
176
|
with tempfile.TemporaryDirectory() as temp_dir_path:
|
177
|
+
if _saved_model_dir is not None:
|
178
|
+
temp_dir_path = _saved_model_dir
|
179
|
+
|
176
180
|
tf.saved_model.save(
|
177
181
|
tf_module,
|
178
182
|
temp_dir_path,
|
@@ -192,6 +192,7 @@ def merged_bundle_to_tfl_model(
|
|
192
192
|
*,
|
193
193
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
194
194
|
_tfl_converter_flags: dict = {},
|
195
|
+
_saved_model_dir: Optional[str] = None,
|
195
196
|
) -> None:
|
196
197
|
"""Converts a StableHLOGraphModule to a tflite model.
|
197
198
|
|
@@ -200,6 +201,8 @@ def merged_bundle_to_tfl_model(
|
|
200
201
|
signatures: List of signatures from which names of the signatures is
|
201
202
|
extracted.
|
202
203
|
quant_config: User-defined quantization method and scheme of the model.
|
204
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
205
|
+
specified, a random temporary directory would be used.
|
203
206
|
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
204
207
|
underlying tflite converter.
|
205
208
|
"""
|
@@ -246,6 +249,9 @@ def merged_bundle_to_tfl_model(
|
|
246
249
|
# We need to temporarily save since TFLite's from_concrete_functions does not
|
247
250
|
# allow providing names for each of the concrete functions.
|
248
251
|
with tempfile.TemporaryDirectory() as temp_dir_path:
|
252
|
+
if _saved_model_dir is not None:
|
253
|
+
temp_dir_path = _saved_model_dir
|
254
|
+
|
249
255
|
tf.saved_model.save(
|
250
256
|
tf_module,
|
251
257
|
temp_dir_path,
|
@@ -304,9 +304,13 @@ def exported_program_to_mlir(
|
|
304
304
|
)
|
305
305
|
|
306
306
|
_convert_i64_to_i32(exported_program)
|
307
|
+
|
307
308
|
exported_program = _torch_future.safe_run_decompositions(
|
308
309
|
exported_program, lowerings.decompositions()
|
309
310
|
)
|
311
|
+
|
312
|
+
# Passes below mutate the exported program to a state not executable by torch.
|
313
|
+
# Do not call run_decompositions after applying the passes.
|
310
314
|
_convert_q_dq_per_channel_args_to_list(exported_program)
|
311
315
|
|
312
316
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
@@ -52,10 +52,13 @@ def _uniform_quantized_type(
|
|
52
52
|
assert isinstance(scale, (list, tuple))
|
53
53
|
assert isinstance(zero_point, (list, tuple))
|
54
54
|
|
55
|
+
scale = list(scale)
|
56
|
+
zero_point = list(zero_point)
|
57
|
+
|
55
58
|
if len(scale) == 1:
|
56
|
-
scale
|
59
|
+
scale = scale * channel_axis_size
|
57
60
|
if len(zero_point) == 1:
|
58
|
-
zero_point
|
61
|
+
zero_point = zero_point * channel_axis_size
|
59
62
|
|
60
63
|
assert len(scale) == len(zero_point) == channel_axis_size
|
61
64
|
scale_zp_strs = []
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241217
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,11 +3,11 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=D9ZffEHjGHmublg0LV01j677usvnse7YN7pG7upZoNw,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
|
-
ai_edge_torch/_convert/converter.py,sha256=
|
10
|
+
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
@@ -160,16 +160,16 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNgh
|
|
160
160
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
161
161
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
|
162
162
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
163
|
-
ai_edge_torch/lowertools/_shim.py,sha256=
|
163
|
+
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
164
164
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
165
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
165
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
|
166
166
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
167
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=
|
167
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
|
168
168
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
169
169
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
170
170
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
171
171
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
172
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
172
|
+
ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
|
173
173
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
174
174
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
175
175
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -187,7 +187,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
|
|
187
187
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
188
188
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
|
189
189
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
190
|
-
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=
|
190
|
+
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
191
191
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
192
192
|
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
|
193
193
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
@@ -200,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
200
200
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
201
201
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
202
202
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
203
|
-
ai_edge_torch_nightly-0.3.0.
|
204
|
-
ai_edge_torch_nightly-0.3.0.
|
205
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/METADATA,sha256=BG9KpWduInLQP8oEHFqxRogKnZecqzm5E536x6RFmcE,1966
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/RECORD,,
|
File without changes
|
File without changes
|