ai-edge-torch-nightly 0.7.0.dev20251017__py3-none-any.whl → 0.7.0.dev20251018__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/odml_torch/lowerings/_basic.py +69 -0
 - ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +43 -0
 - ai_edge_torch/version.py +1 -1
 - {ai_edge_torch_nightly-0.7.0.dev20251017.dist-info → ai_edge_torch_nightly-0.7.0.dev20251018.dist-info}/METADATA +13 -3
 - {ai_edge_torch_nightly-0.7.0.dev20251017.dist-info → ai_edge_torch_nightly-0.7.0.dev20251018.dist-info}/RECORD +8 -8
 - {ai_edge_torch_nightly-0.7.0.dev20251017.dist-info → ai_edge_torch_nightly-0.7.0.dev20251018.dist-info}/WHEEL +1 -1
 - {ai_edge_torch_nightly-0.7.0.dev20251017.dist-info → ai_edge_torch_nightly-0.7.0.dev20251018.dist-info/licenses}/LICENSE +0 -0
 - {ai_edge_torch_nightly-0.7.0.dev20251017.dist-info → ai_edge_torch_nightly-0.7.0.dev20251018.dist-info}/top_level.txt +0 -0
 
| 
         @@ -51,6 +51,75 @@ def _aten_mul_tensor(lctx, self: ir.Value, other: ir.Value): 
     | 
|
| 
       51 
51 
     | 
    
         
             
              return stablehlo.multiply(self, other)
         
     | 
| 
       52 
52 
     | 
    
         | 
| 
       53 
53 
     | 
    
         | 
| 
      
 54 
     | 
    
         
            +
            def _hann_window_impl(
         
     | 
| 
      
 55 
     | 
    
         
            +
                lctx: LoweringContext,
         
     | 
| 
      
 56 
     | 
    
         
            +
                size: int,
         
     | 
| 
      
 57 
     | 
    
         
            +
                periodic: bool,
         
     | 
| 
      
 58 
     | 
    
         
            +
                dtype: Optional[torch.dtype],
         
     | 
| 
      
 59 
     | 
    
         
            +
            ) -> ir.Value:
         
     | 
| 
      
 60 
     | 
    
         
            +
              if dtype is None:
         
     | 
| 
      
 61 
     | 
    
         
            +
                ir_dtype = ir.F32Type.get()
         
     | 
| 
      
 62 
     | 
    
         
            +
              else:
         
     | 
| 
      
 63 
     | 
    
         
            +
                ir_dtype = utils.torch_dtype_to_ir_element_type(dtype)
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
              if not isinstance(ir_dtype, ir.FloatType):
         
     | 
| 
      
 66 
     | 
    
         
            +
                raise ValueError("hann_window only supports float dtypes.")
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
              if size == 0:
         
     | 
| 
      
 69 
     | 
    
         
            +
                return stablehlo.ConstantOp(
         
     | 
| 
      
 70 
     | 
    
         
            +
                    ir.RankedTensorType.get((0,), ir_dtype),
         
     | 
| 
      
 71 
     | 
    
         
            +
                    ir.DenseElementsAttr.get_empty(ir.RankedTensorType.get((0,), ir_dtype)),
         
     | 
| 
      
 72 
     | 
    
         
            +
                ).result
         
     | 
| 
      
 73 
     | 
    
         
            +
              if size == 1:
         
     | 
| 
      
 74 
     | 
    
         
            +
                return utils.splat(1.0, ir_dtype, [1])
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
              denom = size if periodic else size - 1
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
              i64 = ir.IntegerType.get_signless(64)
         
     | 
| 
      
 79 
     | 
    
         
            +
              iota_type = ir.RankedTensorType.get((size,), i64)
         
     | 
| 
      
 80 
     | 
    
         
            +
              n_i64 = stablehlo.IotaOp(
         
     | 
| 
      
 81 
     | 
    
         
            +
                  iota_type, iota_dimension=ir.IntegerAttr.get(i64, 0)
         
     | 
| 
      
 82 
     | 
    
         
            +
              ).result
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
              n_type = ir.RankedTensorType.get((size,), ir_dtype)
         
     | 
| 
      
 85 
     | 
    
         
            +
              n = stablehlo.convert(n_type, n_i64)
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
              pi_val = math.pi
         
     | 
| 
      
 88 
     | 
    
         
            +
              scale = 2.0 * pi_val / denom
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
              scale_splat = utils.splat(scale, ir_dtype, [size])
         
     | 
| 
      
 91 
     | 
    
         
            +
              arg_cos = stablehlo.multiply(n, scale_splat)
         
     | 
| 
      
 92 
     | 
    
         
            +
              cos_val = stablehlo.cosine(arg_cos)
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
      
 94 
     | 
    
         
            +
              half_splat = utils.splat(0.5, ir_dtype, [size])
         
     | 
| 
      
 95 
     | 
    
         
            +
              scaled_cos = stablehlo.multiply(half_splat, cos_val)
         
     | 
| 
      
 96 
     | 
    
         
            +
              return stablehlo.subtract(half_splat, scaled_cos)
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
            # hann_window(int size, *, ScalarType? dtype=None) -> Tensor
         
     | 
| 
      
 100 
     | 
    
         
            +
            @lower(torch.ops.aten.hann_window.default)
         
     | 
| 
      
 101 
     | 
    
         
            +
            def _aten_hann_window_default(
         
     | 
| 
      
 102 
     | 
    
         
            +
                lctx: LoweringContext,
         
     | 
| 
      
 103 
     | 
    
         
            +
                size: int,
         
     | 
| 
      
 104 
     | 
    
         
            +
                *,
         
     | 
| 
      
 105 
     | 
    
         
            +
                dtype: Optional[torch.dtype] = None,
         
     | 
| 
      
 106 
     | 
    
         
            +
            ) -> ir.Value:
         
     | 
| 
      
 107 
     | 
    
         
            +
              return _hann_window_impl(lctx, size, True, dtype)
         
     | 
| 
      
 108 
     | 
    
         
            +
             
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
            # hann_window.periodic(int size, bool periodic, *, ScalarType? dtype=None) ->
         
     | 
| 
      
 111 
     | 
    
         
            +
            # Tensor
         
     | 
| 
      
 112 
     | 
    
         
            +
            @lower(torch.ops.aten.hann_window.periodic)
         
     | 
| 
      
 113 
     | 
    
         
            +
            def _aten_hann_window_periodic(
         
     | 
| 
      
 114 
     | 
    
         
            +
                lctx: LoweringContext,
         
     | 
| 
      
 115 
     | 
    
         
            +
                size: int,
         
     | 
| 
      
 116 
     | 
    
         
            +
                periodic: bool,
         
     | 
| 
      
 117 
     | 
    
         
            +
                *,
         
     | 
| 
      
 118 
     | 
    
         
            +
                dtype: Optional[torch.dtype] = None,
         
     | 
| 
      
 119 
     | 
    
         
            +
            ) -> ir.Value:
         
     | 
| 
      
 120 
     | 
    
         
            +
              return _hann_window_impl(lctx, size, periodic, dtype)
         
     | 
| 
      
 121 
     | 
    
         
            +
             
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
       54 
123 
     | 
    
         
             
            # cat(Tensor[] tensors, int dim=0) -> Tensor
         
     | 
| 
       55 
124 
     | 
    
         
             
            # @lower(torch.ops.aten.cat)
         
     | 
| 
       56 
125 
     | 
    
         
             
            def _aten_cat(lctx, tensors: list[ir.Value], dim: int = 1):
         
     | 
| 
         @@ -541,3 +541,46 @@ def _aten_topk( 
     | 
|
| 
       541 
541 
     | 
    
         
             
                return values, indices
         
     | 
| 
       542 
542 
     | 
    
         | 
| 
       543 
543 
     | 
    
         
             
              return jax_lowering(lctx, self, k)
         
     | 
| 
      
 544 
     | 
    
         
            +
             
     | 
| 
      
 545 
     | 
    
         
            +
             
     | 
| 
      
 546 
     | 
    
         
            +
            @registry.lower(torch.ops.aten.multinomial)
         
     | 
| 
      
 547 
     | 
    
         
            +
            def _aten_multinomial(
         
     | 
| 
      
 548 
     | 
    
         
            +
                lctx: LoweringContext,
         
     | 
| 
      
 549 
     | 
    
         
            +
                self,
         
     | 
| 
      
 550 
     | 
    
         
            +
                num_samples,
         
     | 
| 
      
 551 
     | 
    
         
            +
                replacement=False,
         
     | 
| 
      
 552 
     | 
    
         
            +
                generator=None,
         
     | 
| 
      
 553 
     | 
    
         
            +
            ):
         
     | 
| 
      
 554 
     | 
    
         
            +
              _log_usage(torch.ops.aten.multinomial)
         
     | 
| 
      
 555 
     | 
    
         
            +
             
     | 
| 
      
 556 
     | 
    
         
            +
              @jax_bridge.wrap
         
     | 
| 
      
 557 
     | 
    
         
            +
              def jax_lowering(self, num_samples, replacement):
         
     | 
| 
      
 558 
     | 
    
         
            +
                if generator is not None:
         
     | 
| 
      
 559 
     | 
    
         
            +
                  logging.warning("aten.multinomial lowering ignores `generator`.")
         
     | 
| 
      
 560 
     | 
    
         
            +
             
     | 
| 
      
 561 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 562 
     | 
    
         
            +
                    num_samples <= self.shape[-1] or replacement
         
     | 
| 
      
 563 
     | 
    
         
            +
                ), "cannot take a larger sample than population when replacement=False"
         
     | 
| 
      
 564 
     | 
    
         
            +
             
     | 
| 
      
 565 
     | 
    
         
            +
                # TODO: Add proper PRNG key handling.
         
     | 
| 
      
 566 
     | 
    
         
            +
                key = jax.random.PRNGKey(0)
         
     | 
| 
      
 567 
     | 
    
         
            +
                if self.ndim == 1:
         
     | 
| 
      
 568 
     | 
    
         
            +
                  return jax.random.choice(
         
     | 
| 
      
 569 
     | 
    
         
            +
                      key, self.shape[-1], (num_samples,), replace=replacement, p=self
         
     | 
| 
      
 570 
     | 
    
         
            +
                  ).astype(jnp.int64)
         
     | 
| 
      
 571 
     | 
    
         
            +
                else:
         
     | 
| 
      
 572 
     | 
    
         
            +
                  return jnp.array(
         
     | 
| 
      
 573 
     | 
    
         
            +
                      [
         
     | 
| 
      
 574 
     | 
    
         
            +
                          jax.random.choice(
         
     | 
| 
      
 575 
     | 
    
         
            +
                              key,
         
     | 
| 
      
 576 
     | 
    
         
            +
                              self.shape[-1],
         
     | 
| 
      
 577 
     | 
    
         
            +
                              (num_samples,),
         
     | 
| 
      
 578 
     | 
    
         
            +
                              replace=replacement,
         
     | 
| 
      
 579 
     | 
    
         
            +
                              p=self[i, :],
         
     | 
| 
      
 580 
     | 
    
         
            +
                          )
         
     | 
| 
      
 581 
     | 
    
         
            +
                          for i in range(self.shape[0])
         
     | 
| 
      
 582 
     | 
    
         
            +
                      ],
         
     | 
| 
      
 583 
     | 
    
         
            +
                      dtype=jnp.int64,
         
     | 
| 
      
 584 
     | 
    
         
            +
                  )
         
     | 
| 
      
 585 
     | 
    
         
            +
             
     | 
| 
      
 586 
     | 
    
         
            +
              return jax_lowering(lctx, self, num_samples, replacement)
         
     | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            Metadata-Version: 2. 
     | 
| 
      
 1 
     | 
    
         
            +
            Metadata-Version: 2.4
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: ai-edge-torch-nightly
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.7.0. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.7.0.dev20251018
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         
     | 
| 
       5 
5 
     | 
    
         
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         
     | 
| 
       6 
6 
     | 
    
         
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         
     | 
| 
         @@ -37,7 +37,17 @@ Requires-Dist: ai-edge-quantizer-nightly 
     | 
|
| 
       37 
37 
     | 
    
         
             
            Requires-Dist: jax
         
     | 
| 
       38 
38 
     | 
    
         
             
            Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
         
     | 
| 
       39 
39 
     | 
    
         
             
            Provides-Extra: torch-xla
         
     | 
| 
       40 
     | 
    
         
            -
            Requires-Dist:  
     | 
| 
      
 40 
     | 
    
         
            +
            Requires-Dist: torch_xla>=2.4.0; extra == "torch-xla"
         
     | 
| 
      
 41 
     | 
    
         
            +
            Dynamic: classifier
         
     | 
| 
      
 42 
     | 
    
         
            +
            Dynamic: description
         
     | 
| 
      
 43 
     | 
    
         
            +
            Dynamic: description-content-type
         
     | 
| 
      
 44 
     | 
    
         
            +
            Dynamic: home-page
         
     | 
| 
      
 45 
     | 
    
         
            +
            Dynamic: keywords
         
     | 
| 
      
 46 
     | 
    
         
            +
            Dynamic: license-file
         
     | 
| 
      
 47 
     | 
    
         
            +
            Dynamic: provides-extra
         
     | 
| 
      
 48 
     | 
    
         
            +
            Dynamic: requires-dist
         
     | 
| 
      
 49 
     | 
    
         
            +
            Dynamic: requires-python
         
     | 
| 
      
 50 
     | 
    
         
            +
            Dynamic: summary
         
     | 
| 
       41 
51 
     | 
    
         | 
| 
       42 
52 
     | 
    
         
             
            Library that supports converting PyTorch models into a .tflite format, which can
         
     | 
| 
       43 
53 
     | 
    
         
             
            then be run with TensorFlow Lite and MediaPipe.  This enables applications for
         
     | 
| 
         @@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129 
     | 
|
| 
       2 
2 
     | 
    
         
             
            ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
         
     | 
| 
       3 
3 
     | 
    
         
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         
     | 
| 
       4 
4 
     | 
    
         
             
            ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
         
     | 
| 
       5 
     | 
    
         
            -
            ai_edge_torch/version.py,sha256= 
     | 
| 
      
 5 
     | 
    
         
            +
            ai_edge_torch/version.py,sha256=07ZM69GTj0eqLoL4igeb-C1Ezp_B86je2hnvTkDH3Ho,806
         
     | 
| 
       6 
6 
     | 
    
         
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       7 
7 
     | 
    
         
             
            ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
         
     | 
| 
       8 
8 
     | 
    
         
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         
     | 
| 
         @@ -250,11 +250,11 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi 
     | 
|
| 
       250 
250 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
         
     | 
| 
       251 
251 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
         
     | 
| 
       252 
252 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
         
     | 
| 
       253 
     | 
    
         
            -
            ai_edge_torch/odml_torch/lowerings/_basic.py,sha256 
     | 
| 
      
 253 
     | 
    
         
            +
            ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=-zKZtqvOx3COBhjCDtiZWMn5fY-boktkWXjp5Kepiro,14716
         
     | 
| 
       254 
254 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
         
     | 
| 
       255 
255 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
         
     | 
| 
       256 
256 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
         
     | 
| 
       257 
     | 
    
         
            -
            ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256= 
     | 
| 
      
 257 
     | 
    
         
            +
            ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=7XlctokKu8jmXG51ZzdLz5HA7DDeD1bLai7aGUMs008,20457
         
     | 
| 
       258 
258 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
         
     | 
| 
       259 
259 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
         
     | 
| 
       260 
260 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
         
     | 
| 
         @@ -270,8 +270,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG 
     | 
|
| 
       270 
270 
     | 
    
         
             
            ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
         
     | 
| 
       271 
271 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         
     | 
| 
       272 
272 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         
     | 
| 
       273 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
       274 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
       275 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
       276 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
       277 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
      
 273 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251018.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         
     | 
| 
      
 274 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251018.dist-info/METADATA,sha256=upHQ9ZpIm3SXQXAyfLoCWjW6fjiHFxZI5dfaRV_yZ8c,2297
         
     | 
| 
      
 275 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251018.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
         
     | 
| 
      
 276 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251018.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         
     | 
| 
      
 277 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251018.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     |