ai-edge-torch-nightly 0.3.0.dev20241216__py3-none-any.whl → 0.3.0.dev20241217__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +5 -1
- ai_edge_torch/_convert/converter.py +8 -0
- ai_edge_torch/lowertools/_shim.py +4 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +4 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +6 -0
- ai_edge_torch/odml_torch/export.py +4 -0
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +5 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241217.dist-info}/top_level.txt +0 -0
@@ -78,7 +78,8 @@ def convert_signatures(
|
|
78
78
|
*,
|
79
79
|
strict_export: Union[Literal["auto"], bool] = True,
|
80
80
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
81
|
-
_tfl_converter_flags: Optional[dict[str, Any]],
|
81
|
+
_tfl_converter_flags: Optional[dict[str, Any]] = None,
|
82
|
+
_saved_model_dir: Optional[str] = None,
|
82
83
|
) -> model.TfLiteModel:
|
83
84
|
"""Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
|
84
85
|
|
@@ -93,6 +94,8 @@ def convert_signatures(
|
|
93
94
|
quant_config: User-defined quantization method and scheme of the model.
|
94
95
|
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
95
96
|
underlying tflite converter.
|
97
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
98
|
+
specified, a random temporary directory would be used.
|
96
99
|
|
97
100
|
Returns:
|
98
101
|
The converted `model.TfLiteModel` object.
|
@@ -140,6 +143,7 @@ def convert_signatures(
|
|
140
143
|
signatures,
|
141
144
|
quant_config=quant_config,
|
142
145
|
_tfl_converter_flags=_tfl_converter_flags,
|
146
|
+
_saved_model_dir=_saved_model_dir,
|
143
147
|
)
|
144
148
|
|
145
149
|
return model.TfLiteModel(tflite_model)
|
@@ -106,6 +106,7 @@ class Converter:
|
|
106
106
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
107
107
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
108
108
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
109
|
+
_saved_model_dir: Optional[str] = None,
|
109
110
|
) -> model.TfLiteModel:
|
110
111
|
"""Finalizes the conversion and produces an edge model.
|
111
112
|
|
@@ -139,6 +140,8 @@ class Converter:
|
|
139
140
|
of this function and so needs to be treated as such. Please do not rely
|
140
141
|
on this parameter except for local debugging as this can be removed in a
|
141
142
|
future release.
|
143
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
144
|
+
specified, a random temporary directory would be used.
|
142
145
|
|
143
146
|
Returns:
|
144
147
|
The converted edge model.
|
@@ -171,6 +174,7 @@ class Converter:
|
|
171
174
|
strict_export=strict_export,
|
172
175
|
quant_config=quant_config,
|
173
176
|
_tfl_converter_flags=_ai_edge_converter_flags,
|
177
|
+
_saved_model_dir=_saved_model_dir,
|
174
178
|
)
|
175
179
|
|
176
180
|
|
@@ -216,6 +220,7 @@ def convert(
|
|
216
220
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
217
221
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
218
222
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
223
|
+
_saved_model_dir: Optional[str] = None,
|
219
224
|
) -> model.TfLiteModel:
|
220
225
|
"""Converts a PyTorch model to an edge model with a default signature.
|
221
226
|
|
@@ -240,6 +245,8 @@ def convert(
|
|
240
245
|
this function and so needs to be treated as such. Please do not rely on
|
241
246
|
this parameter except for local debugging as this can be removed in a
|
242
247
|
future release.
|
248
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
249
|
+
specified, a random temporary directory would be used.
|
243
250
|
|
244
251
|
Returns:
|
245
252
|
The converted edge model.
|
@@ -259,4 +266,5 @@ def convert(
|
|
259
266
|
quant_config=quant_config,
|
260
267
|
dynamic_shapes=dynamic_shapes,
|
261
268
|
_ai_edge_converter_flags=_ai_edge_converter_flags,
|
269
|
+
_saved_model_dir=_saved_model_dir,
|
262
270
|
)
|
@@ -50,6 +50,7 @@ def exported_programs_to_tflite(
|
|
50
50
|
*,
|
51
51
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
52
52
|
_tfl_converter_flags: Optional[dict[str, Any]] = None,
|
53
|
+
_saved_model_dir: Optional[str] = None
|
53
54
|
):
|
54
55
|
"""Converts a list of ExportedProgram to a TFLite model.
|
55
56
|
|
@@ -57,6 +58,8 @@ def exported_programs_to_tflite(
|
|
57
58
|
exported_programs: A list of ExportedProgram.
|
58
59
|
signatures: A list of Signature.
|
59
60
|
quant_config: A QuantConfig.
|
61
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
62
|
+
specified, a random temporary directory would be used.
|
60
63
|
_tfl_converter_flags: A dict of flags for TFLiteConverter.
|
61
64
|
|
62
65
|
Returns:
|
@@ -79,4 +82,5 @@ def exported_programs_to_tflite(
|
|
79
82
|
signatures,
|
80
83
|
quant_config=quant_config,
|
81
84
|
_tfl_converter_flags=_tfl_converter_flags,
|
85
|
+
_saved_model_dir=_saved_model_dir,
|
82
86
|
)
|
@@ -138,6 +138,7 @@ def merged_bundle_to_tfl_model(
|
|
138
138
|
*,
|
139
139
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
140
140
|
_tfl_converter_flags: dict = {},
|
141
|
+
_saved_model_dir: Optional[str] = None,
|
141
142
|
):
|
142
143
|
tf_state_dict = merged_bundle.bundles[0].state_dict
|
143
144
|
|
@@ -173,6 +174,9 @@ def merged_bundle_to_tfl_model(
|
|
173
174
|
# We need to temporarily save since TFLite's from_concrete_functions does not
|
174
175
|
# allow providing names for each of the concrete functions.
|
175
176
|
with tempfile.TemporaryDirectory() as temp_dir_path:
|
177
|
+
if _saved_model_dir is not None:
|
178
|
+
temp_dir_path = _saved_model_dir
|
179
|
+
|
176
180
|
tf.saved_model.save(
|
177
181
|
tf_module,
|
178
182
|
temp_dir_path,
|
@@ -192,6 +192,7 @@ def merged_bundle_to_tfl_model(
|
|
192
192
|
*,
|
193
193
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
194
194
|
_tfl_converter_flags: dict = {},
|
195
|
+
_saved_model_dir: Optional[str] = None,
|
195
196
|
) -> None:
|
196
197
|
"""Converts a StableHLOGraphModule to a tflite model.
|
197
198
|
|
@@ -200,6 +201,8 @@ def merged_bundle_to_tfl_model(
|
|
200
201
|
signatures: List of signatures from which names of the signatures is
|
201
202
|
extracted.
|
202
203
|
quant_config: User-defined quantization method and scheme of the model.
|
204
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
205
|
+
specified, a random temporary directory would be used.
|
203
206
|
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
204
207
|
underlying tflite converter.
|
205
208
|
"""
|
@@ -246,6 +249,9 @@ def merged_bundle_to_tfl_model(
|
|
246
249
|
# We need to temporarily save since TFLite's from_concrete_functions does not
|
247
250
|
# allow providing names for each of the concrete functions.
|
248
251
|
with tempfile.TemporaryDirectory() as temp_dir_path:
|
252
|
+
if _saved_model_dir is not None:
|
253
|
+
temp_dir_path = _saved_model_dir
|
254
|
+
|
249
255
|
tf.saved_model.save(
|
250
256
|
tf_module,
|
251
257
|
temp_dir_path,
|
@@ -304,9 +304,13 @@ def exported_program_to_mlir(
|
|
304
304
|
)
|
305
305
|
|
306
306
|
_convert_i64_to_i32(exported_program)
|
307
|
+
|
307
308
|
exported_program = _torch_future.safe_run_decompositions(
|
308
309
|
exported_program, lowerings.decompositions()
|
309
310
|
)
|
311
|
+
|
312
|
+
# Passes below mutate the exported program to a state not executable by torch.
|
313
|
+
# Do not call run_decompositions after applying the passes.
|
310
314
|
_convert_q_dq_per_channel_args_to_list(exported_program)
|
311
315
|
|
312
316
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
@@ -52,10 +52,13 @@ def _uniform_quantized_type(
|
|
52
52
|
assert isinstance(scale, (list, tuple))
|
53
53
|
assert isinstance(zero_point, (list, tuple))
|
54
54
|
|
55
|
+
scale = list(scale)
|
56
|
+
zero_point = list(zero_point)
|
57
|
+
|
55
58
|
if len(scale) == 1:
|
56
|
-
scale
|
59
|
+
scale = scale * channel_axis_size
|
57
60
|
if len(zero_point) == 1:
|
58
|
-
zero_point
|
61
|
+
zero_point = zero_point * channel_axis_size
|
59
62
|
|
60
63
|
assert len(scale) == len(zero_point) == channel_axis_size
|
61
64
|
scale_zp_strs = []
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241217
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,11 +3,11 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=D9ZffEHjGHmublg0LV01j677usvnse7YN7pG7upZoNw,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
|
-
ai_edge_torch/_convert/converter.py,sha256=
|
10
|
+
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
@@ -160,16 +160,16 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNgh
|
|
160
160
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
161
161
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
|
162
162
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
163
|
-
ai_edge_torch/lowertools/_shim.py,sha256=
|
163
|
+
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
164
164
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
165
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
165
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
|
166
166
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
167
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=
|
167
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
|
168
168
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
169
169
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
170
170
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
171
171
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
172
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
172
|
+
ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
|
173
173
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
174
174
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
175
175
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -187,7 +187,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
|
|
187
187
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
188
188
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
|
189
189
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
190
|
-
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=
|
190
|
+
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
191
191
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
192
192
|
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
|
193
193
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
@@ -200,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
200
200
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
201
201
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
202
202
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
203
|
-
ai_edge_torch_nightly-0.3.0.
|
204
|
-
ai_edge_torch_nightly-0.3.0.
|
205
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/METADATA,sha256=BG9KpWduInLQP8oEHFqxRogKnZecqzm5E536x6RFmcE,1966
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241217.dist-info/RECORD,,
|
File without changes
|
File without changes
|