ai-edge-torch-nightly 0.5.0.dev20250422__py3-none-any.whl → 0.5.0.dev20250423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,6 +37,7 @@ def _run_convert_passes(
37
37
  passes = [
38
38
  fx_passes.CastInputsBf16ToF32Pass(),
39
39
  fx_passes.BuildInterpolateCompositePass(),
40
+ fx_passes.CanonicalizePass(),
40
41
  fx_passes.OptimizeLayoutTransposesPass(),
41
42
  fx_passes.CanonicalizePass(),
42
43
  fx_passes.BuildAtenCompositePass(),
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import math
16
+ import operator
16
17
  from typing import Optional, Union
17
18
 
18
19
  from ai_edge_torch.odml_torch import export_utils
@@ -24,6 +25,7 @@ from jax._src.lib.mlir.dialects import hlo as stablehlo
24
25
  import numpy as np
25
26
  import torch
26
27
 
28
+
27
29
  LoweringContext = context.LoweringContext
28
30
  lower = registry.lower
29
31
 
@@ -320,3 +322,22 @@ def _aten_to_copy(
320
322
  ),
321
323
  x,
322
324
  )
325
+
326
+
327
+ # Schema:
328
+ # - aten::sym_size.int(Tensor self, int dim) -> SymInt
329
+ @lower(torch.ops.aten.sym_size.int)
330
+ def _aten_sym_size_int(lctx, x: ir.Value, dim: int):
331
+ return stablehlo.get_dimension_size(x, dim)
332
+
333
+
334
+ # Lowering for the multiplication operator (`*`).
335
+ # Handles cases where one operand is an integer (scalar) and the other is a
336
+ # tensor, broadcasting the scalar to the tensor's shape before multiplication.
337
+ @lower(operator.mul)
338
+ def _operator_mul(lctx, self: int | ir.Value, other: int | ir.Value):
339
+ if isinstance(self, int) and isinstance(other, ir.Value):
340
+ self = utils.splat(self, other.type.element_type, other.type.shape)
341
+ if isinstance(other, int) and isinstance(self, ir.Value):
342
+ other = utils.splat(other, self.type.element_type, self.type.shape)
343
+ return stablehlo.multiply(self, other)
@@ -218,7 +218,6 @@ lower_by_torch_xla2(torch.ops.aten.stack)
218
218
  lower_by_torch_xla2(torch.ops.aten.sub.Scalar)
219
219
  lower_by_torch_xla2(torch.ops.aten.sub.Tensor)
220
220
  lower_by_torch_xla2(torch.ops.aten.sum)
221
- lower_by_torch_xla2(torch.ops.aten.sym_size)
222
221
  lower_by_torch_xla2(torch.ops.aten.t)
223
222
  lower_by_torch_xla2(torch.ops.aten.tan)
224
223
  lower_by_torch_xla2(torch.ops.aten.tanh)
@@ -15,7 +15,7 @@
15
15
  """Torch export utilities for testing."""
16
16
 
17
17
  from collections.abc import Callable
18
- from typing import Any
18
+ from typing import Any, Dict, Sequence
19
19
 
20
20
  import torch
21
21
  from torch.utils import _pytree as pytree
@@ -25,6 +25,7 @@ def export_with_tensor_inputs_only(
25
25
  model: Callable[..., Any],
26
26
  args: tuple[Any, ...],
27
27
  kwargs: dict[str, Any],
28
+ dynamic_shapes: Dict[str, Any] | Sequence[Any] | None = None,
28
29
  ) -> torch.export.ExportedProgram:
29
30
  """Exports a PyTorch model, treating only tensor inputs as export inputs.
30
31
 
@@ -76,8 +77,12 @@ def export_with_tensor_inputs_only(
76
77
 
77
78
  export_args = tuple(export_args)
78
79
  export_kwargs = {}
80
+ # Need to wrap dynamic_shapes in a tuple to match the inputs structure of
81
+ # ModuleWrapper.
82
+ dynamic_shapes = (dynamic_shapes,) if dynamic_shapes else None
79
83
  return torch.export.export(
80
84
  ModuleWrapper(model, args, kwargs).eval(),
81
85
  export_args,
82
86
  export_kwargs,
87
+ dynamic_shapes=dynamic_shapes,
83
88
  )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250422"
16
+ __version__ = "0.5.0.dev20250423"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250422
3
+ Version: 0.5.0.dev20250423
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,9 +2,9 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=8A53NhGxkOePrHhrJE-p0LgdIpk8vT_45m6C94N8uEw,706
5
+ ai_edge_torch/version.py,sha256=DjzQwP8czvLmUu-dJhnWVQJHOuaOqJJKuH2_TOViMvg,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
- ai_edge_torch/_convert/conversion.py,sha256=0gpwEjlTue5RttDerzM5SVOUnY8g16444yL2YIFBx-E,5485
7
+ ai_edge_torch/_convert/conversion.py,sha256=dOr3TUfF0UCvkmlUrMqKvgaN4jh3lJ9XFuO-sHaAmIw,5521
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
9
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
@@ -225,11 +225,11 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
225
225
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
226
226
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
227
227
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
228
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=4syWstepGiw3IKa8O7lciXywY7RFJ7OCWFMU1Lg3h-s,10777
228
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
229
229
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
230
230
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
231
231
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
232
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=JRGLXW8EQ1L-vdiVTkD1kb4AnTU05eRwZ7Ke010hZmg,11473
232
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
233
233
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
234
234
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
235
235
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
@@ -242,11 +242,11 @@ ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV
242
242
  ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=nuO3w9gOj9sKcsTBBexVDw8UZnd06KsjNrFr_gyNaiA,36710
243
243
  ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9PphCRdO8o,3172
244
244
  ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPGhPbo,833
245
- ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
245
+ ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
246
246
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
247
247
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
248
- ai_edge_torch_nightly-0.5.0.dev20250422.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
- ai_edge_torch_nightly-0.5.0.dev20250422.dist-info/METADATA,sha256=9iAyhGknDER60qxgpUn6ZSR6Nr7LALvvHjgOf0WnYtg,2051
250
- ai_edge_torch_nightly-0.5.0.dev20250422.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
- ai_edge_torch_nightly-0.5.0.dev20250422.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
- ai_edge_torch_nightly-0.5.0.dev20250422.dist-info/RECORD,,
248
+ ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
+ ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/METADATA,sha256=PGzcX4WVfFW0wE0TSKLAuRB94iemrNff4L8CL_VUMnQ,2051
250
+ ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
+ ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
+ ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/RECORD,,