ai-edge-torch-nightly 0.7.0.dev20250915__py3-none-any.whl → 0.7.0.dev20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -18,6 +18,7 @@ import logging
18
18
  from ai_edge_torch.odml_torch import jax_bridge
19
19
  from ai_edge_torch.odml_torch.lowerings import context
20
20
  from ai_edge_torch.odml_torch.lowerings import registry
21
+ import jax
21
22
  import jax.numpy as jnp
22
23
  from jax._src.lib.mlir import ir
23
24
  import torch
@@ -218,7 +219,6 @@ lower_by_torch_xla2(torch.ops.aten.tensor_split.sections)
218
219
  lower_by_torch_xla2(torch.ops.aten.to.device)
219
220
  lower_by_torch_xla2(torch.ops.aten.to.device)
220
221
  lower_by_torch_xla2(torch.ops.aten.to.dtype)
221
- lower_by_torch_xla2(torch.ops.aten.topk)
222
222
  lower_by_torch_xla2(torch.ops.aten.transpose)
223
223
  lower_by_torch_xla2(torch.ops.aten.transpose_copy)
224
224
  lower_by_torch_xla2(torch.ops.aten.triu)
@@ -508,3 +508,36 @@ def _aten_einsum_default(
508
508
  return jnp.einsum(equation, *operands, optimize="optimal")
509
509
 
510
510
  return jax_lowering(lctx, tuple(tensors))
511
+
512
+
513
+ @registry.lower(torch.ops.aten.topk)
514
+ def _aten_topk(
515
+ lctx: LoweringContext, self, k, dim=-1, largest=True, sorted=True
516
+ ):
517
+ _log_usage(torch.ops.aten.topk)
518
+
519
+ if not sorted:
520
+ logging.warning(
521
+ "aten.topk lowering ignores `sorted=False` and always returns sorted"
522
+ " results."
523
+ )
524
+
525
+ @jax_bridge.wrap
526
+ def jax_lowering(self, k):
527
+ if not largest:
528
+ self = -self
529
+ # jax.lax.top_k always sorts and operates on the last dimension.
530
+ move_dim_to_last = dim != -1 and dim != self.ndim - 1
531
+ if move_dim_to_last:
532
+ input_tensor = jnp.moveaxis(self, dim, -1)
533
+ else:
534
+ input_tensor = self
535
+ values, indices = jax.lax.top_k(input_tensor, k)
536
+ if move_dim_to_last:
537
+ values = jnp.moveaxis(values, -1, dim)
538
+ indices = jnp.moveaxis(indices, -1, dim)
539
+ if not largest:
540
+ values = -values
541
+ return values, indices
542
+
543
+ return jax_lowering(lctx, self, k)
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.7.0.dev20250915"
18
+ __version__ = "0.7.0.dev20250917"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.7.0.dev20250915
3
+ Version: 0.7.0.dev20250917
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=sj61xhc7PcSll3exrFNmhToMA_lBgYwQ1IMBgahQ9rQ,806
5
+ ai_edge_torch/version.py,sha256=RJAGmMeczkThv7SjCLM7f3fY57qXGMf2nQcJkjQqyB4,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -253,7 +253,7 @@ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=sC4N5-7RS9yKecs97kM9J56enGvs
253
253
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
254
254
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
255
255
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
256
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Hth8F1SwPVEympR1_T6sEsCmmRXtoO7zatPR1uRLFbY,18411
256
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=6KZc0-SM0LupOnDAgnUjj47aHk7YJ0cC1ulAeWvV5_w,19289
257
257
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
258
258
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
259
259
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
@@ -269,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
269
269
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
270
270
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
271
271
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
272
- ai_edge_torch_nightly-0.7.0.dev20250915.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
- ai_edge_torch_nightly-0.7.0.dev20250915.dist-info/METADATA,sha256=jSrydAnNyQUAsVL_K2iue8o-seM1IevQdaK5NwDNlAc,2074
274
- ai_edge_torch_nightly-0.7.0.dev20250915.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
- ai_edge_torch_nightly-0.7.0.dev20250915.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
- ai_edge_torch_nightly-0.7.0.dev20250915.dist-info/RECORD,,
272
+ ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
+ ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/METADATA,sha256=-mF2daWYgvauqMwRg8-wNaVQZ2hGtTRBrjPql34gNIU,2074
274
+ ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
+ ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
+ ai_edge_torch_nightly-0.7.0.dev20250917.dist-info/RECORD,,