ai-edge-torch-nightly 0.5.0.dev20250422__py3-none-any.whl → 0.5.0.dev20250423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +21 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/testing/export.py +6 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250422.dist-info → ai_edge_torch_nightly-0.5.0.dev20250423.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250422.dist-info → ai_edge_torch_nightly-0.5.0.dev20250423.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.5.0.dev20250422.dist-info → ai_edge_torch_nightly-0.5.0.dev20250423.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250422.dist-info → ai_edge_torch_nightly-0.5.0.dev20250423.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250422.dist-info → ai_edge_torch_nightly-0.5.0.dev20250423.dist-info}/top_level.txt +0 -0
@@ -37,6 +37,7 @@ def _run_convert_passes(
|
|
37
37
|
passes = [
|
38
38
|
fx_passes.CastInputsBf16ToF32Pass(),
|
39
39
|
fx_passes.BuildInterpolateCompositePass(),
|
40
|
+
fx_passes.CanonicalizePass(),
|
40
41
|
fx_passes.OptimizeLayoutTransposesPass(),
|
41
42
|
fx_passes.CanonicalizePass(),
|
42
43
|
fx_passes.BuildAtenCompositePass(),
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import math
|
16
|
+
import operator
|
16
17
|
from typing import Optional, Union
|
17
18
|
|
18
19
|
from ai_edge_torch.odml_torch import export_utils
|
@@ -24,6 +25,7 @@ from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
24
25
|
import numpy as np
|
25
26
|
import torch
|
26
27
|
|
28
|
+
|
27
29
|
LoweringContext = context.LoweringContext
|
28
30
|
lower = registry.lower
|
29
31
|
|
@@ -320,3 +322,22 @@ def _aten_to_copy(
|
|
320
322
|
),
|
321
323
|
x,
|
322
324
|
)
|
325
|
+
|
326
|
+
|
327
|
+
# Schema:
|
328
|
+
# - aten::sym_size.int(Tensor self, int dim) -> SymInt
|
329
|
+
@lower(torch.ops.aten.sym_size.int)
|
330
|
+
def _aten_sym_size_int(lctx, x: ir.Value, dim: int):
|
331
|
+
return stablehlo.get_dimension_size(x, dim)
|
332
|
+
|
333
|
+
|
334
|
+
# Lowering for the multiplication operator (`*`).
|
335
|
+
# Handles cases where one operand is an integer (scalar) and the other is a
|
336
|
+
# tensor, broadcasting the scalar to the tensor's shape before multiplication.
|
337
|
+
@lower(operator.mul)
|
338
|
+
def _operator_mul(lctx, self: int | ir.Value, other: int | ir.Value):
|
339
|
+
if isinstance(self, int) and isinstance(other, ir.Value):
|
340
|
+
self = utils.splat(self, other.type.element_type, other.type.shape)
|
341
|
+
if isinstance(other, int) and isinstance(self, ir.Value):
|
342
|
+
other = utils.splat(other, self.type.element_type, self.type.shape)
|
343
|
+
return stablehlo.multiply(self, other)
|
@@ -218,7 +218,6 @@ lower_by_torch_xla2(torch.ops.aten.stack)
|
|
218
218
|
lower_by_torch_xla2(torch.ops.aten.sub.Scalar)
|
219
219
|
lower_by_torch_xla2(torch.ops.aten.sub.Tensor)
|
220
220
|
lower_by_torch_xla2(torch.ops.aten.sum)
|
221
|
-
lower_by_torch_xla2(torch.ops.aten.sym_size)
|
222
221
|
lower_by_torch_xla2(torch.ops.aten.t)
|
223
222
|
lower_by_torch_xla2(torch.ops.aten.tan)
|
224
223
|
lower_by_torch_xla2(torch.ops.aten.tanh)
|
ai_edge_torch/testing/export.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Torch export utilities for testing."""
|
16
16
|
|
17
17
|
from collections.abc import Callable
|
18
|
-
from typing import Any
|
18
|
+
from typing import Any, Dict, Sequence
|
19
19
|
|
20
20
|
import torch
|
21
21
|
from torch.utils import _pytree as pytree
|
@@ -25,6 +25,7 @@ def export_with_tensor_inputs_only(
|
|
25
25
|
model: Callable[..., Any],
|
26
26
|
args: tuple[Any, ...],
|
27
27
|
kwargs: dict[str, Any],
|
28
|
+
dynamic_shapes: Dict[str, Any] | Sequence[Any] | None = None,
|
28
29
|
) -> torch.export.ExportedProgram:
|
29
30
|
"""Exports a PyTorch model, treating only tensor inputs as export inputs.
|
30
31
|
|
@@ -76,8 +77,12 @@ def export_with_tensor_inputs_only(
|
|
76
77
|
|
77
78
|
export_args = tuple(export_args)
|
78
79
|
export_kwargs = {}
|
80
|
+
# Need to wrap dynamic_shapes in a tuple to match the inputs structure of
|
81
|
+
# ModuleWrapper.
|
82
|
+
dynamic_shapes = (dynamic_shapes,) if dynamic_shapes else None
|
79
83
|
return torch.export.export(
|
80
84
|
ModuleWrapper(model, args, kwargs).eval(),
|
81
85
|
export_args,
|
82
86
|
export_kwargs,
|
87
|
+
dynamic_shapes=dynamic_shapes,
|
83
88
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250423
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,9 +2,9 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=DjzQwP8czvLmUu-dJhnWVQJHOuaOqJJKuH2_TOViMvg,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=dOr3TUfF0UCvkmlUrMqKvgaN4jh3lJ9XFuO-sHaAmIw,5521
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
@@ -225,11 +225,11 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
|
|
225
225
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
226
226
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
227
227
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
228
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
228
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
|
229
229
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
230
230
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
231
231
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
|
232
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
232
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
|
233
233
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
234
234
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
235
235
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
@@ -242,11 +242,11 @@ ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV
|
|
242
242
|
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=nuO3w9gOj9sKcsTBBexVDw8UZnd06KsjNrFr_gyNaiA,36710
|
243
243
|
ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9PphCRdO8o,3172
|
244
244
|
ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPGhPbo,833
|
245
|
-
ai_edge_torch/testing/export.py,sha256=
|
245
|
+
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
246
246
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
247
247
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
252
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/METADATA,sha256=PGzcX4WVfFW0wE0TSKLAuRB94iemrNff4L8CL_VUMnQ,2051
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/RECORD,,
|
File without changes
|
File without changes
|