ai-edge-torch-nightly 0.7.0.dev20251017__py3-none-any.whl → 0.7.0.dev20251019__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,6 +51,75 @@ def _aten_mul_tensor(lctx, self: ir.Value, other: ir.Value):
51
51
  return stablehlo.multiply(self, other)
52
52
 
53
53
 
54
+ def _hann_window_impl(
55
+ lctx: LoweringContext,
56
+ size: int,
57
+ periodic: bool,
58
+ dtype: Optional[torch.dtype],
59
+ ) -> ir.Value:
60
+ if dtype is None:
61
+ ir_dtype = ir.F32Type.get()
62
+ else:
63
+ ir_dtype = utils.torch_dtype_to_ir_element_type(dtype)
64
+
65
+ if not isinstance(ir_dtype, ir.FloatType):
66
+ raise ValueError("hann_window only supports float dtypes.")
67
+
68
+ if size == 0:
69
+ return stablehlo.ConstantOp(
70
+ ir.RankedTensorType.get((0,), ir_dtype),
71
+ ir.DenseElementsAttr.get_empty(ir.RankedTensorType.get((0,), ir_dtype)),
72
+ ).result
73
+ if size == 1:
74
+ return utils.splat(1.0, ir_dtype, [1])
75
+
76
+ denom = size if periodic else size - 1
77
+
78
+ i64 = ir.IntegerType.get_signless(64)
79
+ iota_type = ir.RankedTensorType.get((size,), i64)
80
+ n_i64 = stablehlo.IotaOp(
81
+ iota_type, iota_dimension=ir.IntegerAttr.get(i64, 0)
82
+ ).result
83
+
84
+ n_type = ir.RankedTensorType.get((size,), ir_dtype)
85
+ n = stablehlo.convert(n_type, n_i64)
86
+
87
+ pi_val = math.pi
88
+ scale = 2.0 * pi_val / denom
89
+
90
+ scale_splat = utils.splat(scale, ir_dtype, [size])
91
+ arg_cos = stablehlo.multiply(n, scale_splat)
92
+ cos_val = stablehlo.cosine(arg_cos)
93
+
94
+ half_splat = utils.splat(0.5, ir_dtype, [size])
95
+ scaled_cos = stablehlo.multiply(half_splat, cos_val)
96
+ return stablehlo.subtract(half_splat, scaled_cos)
97
+
98
+
99
+ # hann_window(int size, *, ScalarType? dtype=None) -> Tensor
100
+ @lower(torch.ops.aten.hann_window.default)
101
+ def _aten_hann_window_default(
102
+ lctx: LoweringContext,
103
+ size: int,
104
+ *,
105
+ dtype: Optional[torch.dtype] = None,
106
+ ) -> ir.Value:
107
+ return _hann_window_impl(lctx, size, True, dtype)
108
+
109
+
110
+ # hann_window.periodic(int size, bool periodic, *, ScalarType? dtype=None) ->
111
+ # Tensor
112
+ @lower(torch.ops.aten.hann_window.periodic)
113
+ def _aten_hann_window_periodic(
114
+ lctx: LoweringContext,
115
+ size: int,
116
+ periodic: bool,
117
+ *,
118
+ dtype: Optional[torch.dtype] = None,
119
+ ) -> ir.Value:
120
+ return _hann_window_impl(lctx, size, periodic, dtype)
121
+
122
+
54
123
  # cat(Tensor[] tensors, int dim=0) -> Tensor
55
124
  # @lower(torch.ops.aten.cat)
56
125
  def _aten_cat(lctx, tensors: list[ir.Value], dim: int = 1):
@@ -541,3 +541,46 @@ def _aten_topk(
541
541
  return values, indices
542
542
 
543
543
  return jax_lowering(lctx, self, k)
544
+
545
+
546
+ @registry.lower(torch.ops.aten.multinomial)
547
+ def _aten_multinomial(
548
+ lctx: LoweringContext,
549
+ self,
550
+ num_samples,
551
+ replacement=False,
552
+ generator=None,
553
+ ):
554
+ _log_usage(torch.ops.aten.multinomial)
555
+
556
+ @jax_bridge.wrap
557
+ def jax_lowering(self, num_samples, replacement):
558
+ if generator is not None:
559
+ logging.warning("aten.multinomial lowering ignores `generator`.")
560
+
561
+ assert (
562
+ num_samples <= self.shape[-1] or replacement
563
+ ), "cannot take a larger sample than population when replacement=False"
564
+
565
+ # TODO: Add proper PRNG key handling.
566
+ key = jax.random.PRNGKey(0)
567
+ if self.ndim == 1:
568
+ return jax.random.choice(
569
+ key, self.shape[-1], (num_samples,), replace=replacement, p=self
570
+ ).astype(jnp.int64)
571
+ else:
572
+ return jnp.array(
573
+ [
574
+ jax.random.choice(
575
+ key,
576
+ self.shape[-1],
577
+ (num_samples,),
578
+ replace=replacement,
579
+ p=self[i, :],
580
+ )
581
+ for i in range(self.shape[0])
582
+ ],
583
+ dtype=jnp.int64,
584
+ )
585
+
586
+ return jax_lowering(lctx, self, num_samples, replacement)
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.7.0.dev20251017"
18
+ __version__ = "0.7.0.dev20251019"
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.7.0.dev20251017
3
+ Version: 0.7.0.dev20251019
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -37,7 +37,17 @@ Requires-Dist: ai-edge-quantizer-nightly
37
37
  Requires-Dist: jax
38
38
  Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
39
39
  Provides-Extra: torch-xla
40
- Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
40
+ Requires-Dist: torch_xla>=2.4.0; extra == "torch-xla"
41
+ Dynamic: classifier
42
+ Dynamic: description
43
+ Dynamic: description-content-type
44
+ Dynamic: home-page
45
+ Dynamic: keywords
46
+ Dynamic: license-file
47
+ Dynamic: provides-extra
48
+ Dynamic: requires-dist
49
+ Dynamic: requires-python
50
+ Dynamic: summary
41
51
 
42
52
  Library that supports converting PyTorch models into a .tflite format, which can
43
53
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=s37U1R9i7bjEl8Td8XZ5nmBvKsNMQV5vCFmKLpWXQ-k,806
5
+ ai_edge_torch/version.py,sha256=Bv5NIcJlH0i9JjeUkIhOpW12ztp9GxcN4AO-Avm4Pg8,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -250,11 +250,11 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
250
250
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
251
251
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
252
252
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
253
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=sC4N5-7RS9yKecs97kM9J56enGvsZj1CJo7y79cuzRg,12784
253
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=-zKZtqvOx3COBhjCDtiZWMn5fY-boktkWXjp5Kepiro,14716
254
254
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
255
255
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
256
256
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
257
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=6KZc0-SM0LupOnDAgnUjj47aHk7YJ0cC1ulAeWvV5_w,19289
257
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=7XlctokKu8jmXG51ZzdLz5HA7DDeD1bLai7aGUMs008,20457
258
258
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
259
259
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
260
260
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
@@ -270,8 +270,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
270
270
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
271
271
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
272
272
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
273
- ai_edge_torch_nightly-0.7.0.dev20251017.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
274
- ai_edge_torch_nightly-0.7.0.dev20251017.dist-info/METADATA,sha256=6izafx1yDIwodhe9vqegygTBcgLDDnXQuqtX1xFzwnY,2074
275
- ai_edge_torch_nightly-0.7.0.dev20251017.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
276
- ai_edge_torch_nightly-0.7.0.dev20251017.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
277
- ai_edge_torch_nightly-0.7.0.dev20251017.dist-info/RECORD,,
273
+ ai_edge_torch_nightly-0.7.0.dev20251019.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
274
+ ai_edge_torch_nightly-0.7.0.dev20251019.dist-info/METADATA,sha256=GWPODeZLuihfxxi2sYjWGFfUB6c8ZQk9F5qjRltamTs,2297
275
+ ai_edge_torch_nightly-0.7.0.dev20251019.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
276
+ ai_edge_torch_nightly-0.7.0.dev20251019.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
277
+ ai_edge_torch_nightly-0.7.0.dev20251019.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.45.1)
2
+ Generator: setuptools (79.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5