ai-edge-torch-nightly 0.3.0.dev20241215__py3-none-any.whl → 0.3.0.dev20241218__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -78,7 +78,8 @@ def convert_signatures(
78
78
  *,
79
79
  strict_export: Union[Literal["auto"], bool] = True,
80
80
  quant_config: Optional[qcfg.QuantConfig] = None,
81
- _tfl_converter_flags: Optional[dict[str, Any]],
81
+ _tfl_converter_flags: Optional[dict[str, Any]] = None,
82
+ _saved_model_dir: Optional[str] = None,
82
83
  ) -> model.TfLiteModel:
83
84
  """Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
84
85
 
@@ -93,6 +94,8 @@ def convert_signatures(
93
94
  quant_config: User-defined quantization method and scheme of the model.
94
95
  _tfl_converter_flags: A nested dictionary allowing setting flags for the
95
96
  underlying tflite converter.
97
+ _saved_model_dir: Directory for the intermediate saved model. If not
98
+ specified, a random temporary directory would be used.
96
99
 
97
100
  Returns:
98
101
  The converted `model.TfLiteModel` object.
@@ -140,6 +143,7 @@ def convert_signatures(
140
143
  signatures,
141
144
  quant_config=quant_config,
142
145
  _tfl_converter_flags=_tfl_converter_flags,
146
+ _saved_model_dir=_saved_model_dir,
143
147
  )
144
148
 
145
149
  return model.TfLiteModel(tflite_model)
@@ -106,6 +106,7 @@ class Converter:
106
106
  quant_config: Optional[qcfg.QuantConfig] = None,
107
107
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
108
108
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
109
+ _saved_model_dir: Optional[str] = None,
109
110
  ) -> model.TfLiteModel:
110
111
  """Finalizes the conversion and produces an edge model.
111
112
 
@@ -139,6 +140,8 @@ class Converter:
139
140
  of this function and so needs to be treated as such. Please do not rely
140
141
  on this parameter except for local debugging as this can be removed in a
141
142
  future release.
143
+ _saved_model_dir: Directory for the intermediate saved model. If not
144
+ specified, a random temporary directory would be used.
142
145
 
143
146
  Returns:
144
147
  The converted edge model.
@@ -171,6 +174,7 @@ class Converter:
171
174
  strict_export=strict_export,
172
175
  quant_config=quant_config,
173
176
  _tfl_converter_flags=_ai_edge_converter_flags,
177
+ _saved_model_dir=_saved_model_dir,
174
178
  )
175
179
 
176
180
 
@@ -216,6 +220,7 @@ def convert(
216
220
  quant_config: Optional[qcfg.QuantConfig] = None,
217
221
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
218
222
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
223
+ _saved_model_dir: Optional[str] = None,
219
224
  ) -> model.TfLiteModel:
220
225
  """Converts a PyTorch model to an edge model with a default signature.
221
226
 
@@ -240,6 +245,8 @@ def convert(
240
245
  this function and so needs to be treated as such. Please do not rely on
241
246
  this parameter except for local debugging as this can be removed in a
242
247
  future release.
248
+ _saved_model_dir: Directory for the intermediate saved model. If not
249
+ specified, a random temporary directory would be used.
243
250
 
244
251
  Returns:
245
252
  The converted edge model.
@@ -259,4 +266,5 @@ def convert(
259
266
  quant_config=quant_config,
260
267
  dynamic_shapes=dynamic_shapes,
261
268
  _ai_edge_converter_flags=_ai_edge_converter_flags,
269
+ _saved_model_dir=_saved_model_dir,
262
270
  )
@@ -50,6 +50,7 @@ def exported_programs_to_tflite(
50
50
  *,
51
51
  quant_config: Optional[qcfg.QuantConfig] = None,
52
52
  _tfl_converter_flags: Optional[dict[str, Any]] = None,
53
+ _saved_model_dir: Optional[str] = None
53
54
  ):
54
55
  """Converts a list of ExportedProgram to a TFLite model.
55
56
 
@@ -57,6 +58,8 @@ def exported_programs_to_tflite(
57
58
  exported_programs: A list of ExportedProgram.
58
59
  signatures: A list of Signature.
59
60
  quant_config: A QuantConfig.
61
+ _saved_model_dir: Directory for the intermediate saved model. If not
62
+ specified, a random temporary directory would be used.
60
63
  _tfl_converter_flags: A dict of flags for TFLiteConverter.
61
64
 
62
65
  Returns:
@@ -79,4 +82,5 @@ def exported_programs_to_tflite(
79
82
  signatures,
80
83
  quant_config=quant_config,
81
84
  _tfl_converter_flags=_tfl_converter_flags,
85
+ _saved_model_dir=_saved_model_dir,
82
86
  )
@@ -138,6 +138,7 @@ def merged_bundle_to_tfl_model(
138
138
  *,
139
139
  quant_config: Optional[qcfg.QuantConfig] = None,
140
140
  _tfl_converter_flags: dict = {},
141
+ _saved_model_dir: Optional[str] = None,
141
142
  ):
142
143
  tf_state_dict = merged_bundle.bundles[0].state_dict
143
144
 
@@ -173,6 +174,9 @@ def merged_bundle_to_tfl_model(
173
174
  # We need to temporarily save since TFLite's from_concrete_functions does not
174
175
  # allow providing names for each of the concrete functions.
175
176
  with tempfile.TemporaryDirectory() as temp_dir_path:
177
+ if _saved_model_dir is not None:
178
+ temp_dir_path = _saved_model_dir
179
+
176
180
  tf.saved_model.save(
177
181
  tf_module,
178
182
  temp_dir_path,
@@ -192,6 +192,7 @@ def merged_bundle_to_tfl_model(
192
192
  *,
193
193
  quant_config: Optional[qcfg.QuantConfig] = None,
194
194
  _tfl_converter_flags: dict = {},
195
+ _saved_model_dir: Optional[str] = None,
195
196
  ) -> None:
196
197
  """Converts a StableHLOGraphModule to a tflite model.
197
198
 
@@ -200,6 +201,8 @@ def merged_bundle_to_tfl_model(
200
201
  signatures: List of signatures from which names of the signatures is
201
202
  extracted.
202
203
  quant_config: User-defined quantization method and scheme of the model.
204
+ _saved_model_dir: Directory for the intermediate saved model. If not
205
+ specified, a random temporary directory would be used.
203
206
  _tfl_converter_flags: A nested dictionary allowing setting flags for the
204
207
  underlying tflite converter.
205
208
  """
@@ -246,6 +249,9 @@ def merged_bundle_to_tfl_model(
246
249
  # We need to temporarily save since TFLite's from_concrete_functions does not
247
250
  # allow providing names for each of the concrete functions.
248
251
  with tempfile.TemporaryDirectory() as temp_dir_path:
252
+ if _saved_model_dir is not None:
253
+ temp_dir_path = _saved_model_dir
254
+
249
255
  tf.saved_model.save(
250
256
  tf_module,
251
257
  temp_dir_path,
@@ -304,9 +304,13 @@ def exported_program_to_mlir(
304
304
  )
305
305
 
306
306
  _convert_i64_to_i32(exported_program)
307
+
307
308
  exported_program = _torch_future.safe_run_decompositions(
308
309
  exported_program, lowerings.decompositions()
309
310
  )
311
+
312
+ # Passes below mutate the exported program to a state not executable by torch.
313
+ # Do not call run_decompositions after applying the passes.
310
314
  _convert_q_dq_per_channel_args_to_list(exported_program)
311
315
 
312
316
  with export_utils.create_ir_context() as context, ir.Location.unknown():
@@ -52,10 +52,13 @@ def _uniform_quantized_type(
52
52
  assert isinstance(scale, (list, tuple))
53
53
  assert isinstance(zero_point, (list, tuple))
54
54
 
55
+ scale = list(scale)
56
+ zero_point = list(zero_point)
57
+
55
58
  if len(scale) == 1:
56
- scale *= channel_axis_size
59
+ scale = scale * channel_axis_size
57
60
  if len(zero_point) == 1:
58
- zero_point *= channel_axis_size
61
+ zero_point = zero_point * channel_axis_size
59
62
 
60
63
  assert len(scale) == len(zero_point) == channel_axis_size
61
64
  scale_zp_strs = []
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241215"
16
+ __version__ = "0.3.0.dev20241218"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241215
3
+ Version: 0.3.0.dev20241218
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,11 +3,11 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=K6IQHV_-ygm-XHO2-Za1f4YtOckCWkp3RoVrufaooRk,706
6
+ ai_edge_torch/version.py,sha256=rIACXAIBWhOFz_eTZpMZrMcmcJ5OlWzYMVkLShDZcdM,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
8
+ ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
- ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
10
+ ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
@@ -160,16 +160,16 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNgh
160
160
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
161
161
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
162
162
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
163
- ai_edge_torch/lowertools/_shim.py,sha256=xJIHDSWNoF4PkkT0JkjeJxgguQ9JGEwooJf9xZNkVRU,3058
163
+ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
164
164
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
165
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
165
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
166
166
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
167
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
167
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
168
168
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
169
169
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
170
170
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
171
171
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
172
- ai_edge_torch/odml_torch/export.py,sha256=dgnNGBVkHBz0brlWALX2hGXpQ4YzCKdwbkF4oAfEu4I,13062
172
+ ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
173
173
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
174
174
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
175
175
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -187,7 +187,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
187
187
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
188
188
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
189
189
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
190
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=GEs83mtEjh8GOW_OATI_ur11VKujrOL2xdZeZ0l1HtM,6100
190
+ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
191
191
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
192
192
  ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
193
193
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
@@ -200,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
200
200
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
201
201
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
202
202
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
203
- ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
- ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/METADATA,sha256=kkijdPdACWUh6ocM7K99XNhICV9dA4uH3KlQZ-R2NFg,1966
205
- ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
- ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
- ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/RECORD,,
203
+ ai_edge_torch_nightly-0.3.0.dev20241218.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
+ ai_edge_torch_nightly-0.3.0.dev20241218.dist-info/METADATA,sha256=HnuexBdSckj1hhhEw7qlQgcfTdD5GgcCmnjiY-r2RHc,1966
205
+ ai_edge_torch_nightly-0.3.0.dev20241218.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
+ ai_edge_torch_nightly-0.3.0.dev20241218.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
+ ai_edge_torch_nightly-0.3.0.dev20241218.dist-info/RECORD,,