ai-edge-torch-nightly 0.7.0.dev20250916__py3-none-any.whl → 0.7.0.dev20250917__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +34 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250916.dist-info → ai_edge_torch_nightly-0.7.0.dev20250917.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250916.dist-info → ai_edge_torch_nightly-0.7.0.dev20250917.dist-info}/RECORD +7 -7
- {ai_edge_torch_nightly-0.7.0.dev20250916.dist-info → ai_edge_torch_nightly-0.7.0.dev20250917.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20250916.dist-info → ai_edge_torch_nightly-0.7.0.dev20250917.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20250916.dist-info → ai_edge_torch_nightly-0.7.0.dev20250917.dist-info}/top_level.txt +0 -0
|
@@ -18,6 +18,7 @@ import logging
|
|
|
18
18
|
from ai_edge_torch.odml_torch import jax_bridge
|
|
19
19
|
from ai_edge_torch.odml_torch.lowerings import context
|
|
20
20
|
from ai_edge_torch.odml_torch.lowerings import registry
|
|
21
|
+
import jax
|
|
21
22
|
import jax.numpy as jnp
|
|
22
23
|
from jax._src.lib.mlir import ir
|
|
23
24
|
import torch
|
|
@@ -218,7 +219,6 @@ lower_by_torch_xla2(torch.ops.aten.tensor_split.sections)
|
|
|
218
219
|
lower_by_torch_xla2(torch.ops.aten.to.device)
|
|
219
220
|
lower_by_torch_xla2(torch.ops.aten.to.device)
|
|
220
221
|
lower_by_torch_xla2(torch.ops.aten.to.dtype)
|
|
221
|
-
lower_by_torch_xla2(torch.ops.aten.topk)
|
|
222
222
|
lower_by_torch_xla2(torch.ops.aten.transpose)
|
|
223
223
|
lower_by_torch_xla2(torch.ops.aten.transpose_copy)
|
|
224
224
|
lower_by_torch_xla2(torch.ops.aten.triu)
|
|
@@ -508,3 +508,36 @@ def _aten_einsum_default(
|
|
|
508
508
|
return jnp.einsum(equation, *operands, optimize="optimal")
|
|
509
509
|
|
|
510
510
|
return jax_lowering(lctx, tuple(tensors))
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
@registry.lower(torch.ops.aten.topk)
|
|
514
|
+
def _aten_topk(
|
|
515
|
+
lctx: LoweringContext, self, k, dim=-1, largest=True, sorted=True
|
|
516
|
+
):
|
|
517
|
+
_log_usage(torch.ops.aten.topk)
|
|
518
|
+
|
|
519
|
+
if not sorted:
|
|
520
|
+
logging.warning(
|
|
521
|
+
"aten.topk lowering ignores `sorted=False` and always returns sorted"
|
|
522
|
+
" results."
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
@jax_bridge.wrap
|
|
526
|
+
def jax_lowering(self, k):
|
|
527
|
+
if not largest:
|
|
528
|
+
self = -self
|
|
529
|
+
# jax.lax.top_k always sorts and operates on the last dimension.
|
|
530
|
+
move_dim_to_last = dim != -1 and dim != self.ndim - 1
|
|
531
|
+
if move_dim_to_last:
|
|
532
|
+
input_tensor = jnp.moveaxis(self, dim, -1)
|
|
533
|
+
else:
|
|
534
|
+
input_tensor = self
|
|
535
|
+
values, indices = jax.lax.top_k(input_tensor, k)
|
|
536
|
+
if move_dim_to_last:
|
|
537
|
+
values = jnp.moveaxis(values, -1, dim)
|
|
538
|
+
indices = jnp.moveaxis(indices, -1, dim)
|
|
539
|
+
if not largest:
|
|
540
|
+
values = -values
|
|
541
|
+
return values, indices
|
|
542
|
+
|
|
543
|
+
return jax_lowering(lctx, self, k)
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.7.0.
|
|
3
|
+
Version: 0.7.0.dev20250917
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
|
5
|
-
ai_edge_torch/version.py,sha256=
|
|
5
|
+
ai_edge_torch/version.py,sha256=RJAGmMeczkThv7SjCLM7f3fY57qXGMf2nQcJkjQqyB4,806
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
@@ -253,7 +253,7 @@ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=sC4N5-7RS9yKecs97kM9J56enGvs
|
|
|
253
253
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
|
254
254
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
|
255
255
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
|
|
256
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
|
256
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=6KZc0-SM0LupOnDAgnUjj47aHk7YJ0cC1ulAeWvV5_w,19289
|
|
257
257
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
|
258
258
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
|
259
259
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
|
@@ -269,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
|
269
269
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
|
270
270
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
271
271
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
272
|
-
ai_edge_torch_nightly-0.7.0.
|
|
273
|
-
ai_edge_torch_nightly-0.7.0.
|
|
274
|
-
ai_edge_torch_nightly-0.7.0.
|
|
275
|
-
ai_edge_torch_nightly-0.7.0.
|
|
276
|
-
ai_edge_torch_nightly-0.7.0.
|
|
272
|
+
ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
273
|
+
ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/METADATA,sha256=-mF2daWYgvauqMwRg8-wNaVQZ2hGtTRBrjPql34gNIU,2074
|
|
274
|
+
ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
|
275
|
+
ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
276
|
+
ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|