ai-edge-torch-nightly 0.7.0.dev20251020__py3-none-any.whl → 0.7.0.dev20251022__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/odml_torch/lowerings/_basic.py +90 -4
 - ai_edge_torch/version.py +1 -1
 - {ai_edge_torch_nightly-0.7.0.dev20251020.dist-info → ai_edge_torch_nightly-0.7.0.dev20251022.dist-info}/METADATA +3 -13
 - {ai_edge_torch_nightly-0.7.0.dev20251020.dist-info → ai_edge_torch_nightly-0.7.0.dev20251022.dist-info}/RECORD +7 -7
 - {ai_edge_torch_nightly-0.7.0.dev20251020.dist-info → ai_edge_torch_nightly-0.7.0.dev20251022.dist-info}/WHEEL +1 -1
 - {ai_edge_torch_nightly-0.7.0.dev20251020.dist-info/licenses → ai_edge_torch_nightly-0.7.0.dev20251022.dist-info}/LICENSE +0 -0
 - {ai_edge_torch_nightly-0.7.0.dev20251020.dist-info → ai_edge_torch_nightly-0.7.0.dev20251022.dist-info}/top_level.txt +0 -0
 
| 
         @@ -12,6 +12,7 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
            import logging
         
     | 
| 
       15 
16 
     | 
    
         
             
            import math
         
     | 
| 
       16 
17 
     | 
    
         
             
            import operator
         
     | 
| 
       17 
18 
     | 
    
         
             
            from typing import Optional, Union
         
     | 
| 
         @@ -101,9 +102,12 @@ def _hann_window_impl( 
     | 
|
| 
       101 
102 
     | 
    
         
             
            def _aten_hann_window_default(
         
     | 
| 
       102 
103 
     | 
    
         
             
                lctx: LoweringContext,
         
     | 
| 
       103 
104 
     | 
    
         
             
                size: int,
         
     | 
| 
       104 
     | 
    
         
            -
                 
     | 
| 
       105 
     | 
    
         
            -
                dtype: Optional[torch.dtype] = None,
         
     | 
| 
      
 105 
     | 
    
         
            +
                **kwargs,
         
     | 
| 
       106 
106 
     | 
    
         
             
            ) -> ir.Value:
         
     | 
| 
      
 107 
     | 
    
         
            +
              dtype = kwargs.pop("dtype", None)
         
     | 
| 
      
 108 
     | 
    
         
            +
              layout = kwargs.pop("layout", torch.strided)
         
     | 
| 
      
 109 
     | 
    
         
            +
              if layout != torch.strided:
         
     | 
| 
      
 110 
     | 
    
         
            +
                logging.warning("hann_window only supports torch.strided layout.")
         
     | 
| 
       107 
111 
     | 
    
         
             
              return _hann_window_impl(lctx, size, True, dtype)
         
     | 
| 
       108 
112 
     | 
    
         | 
| 
       109 
113 
     | 
    
         | 
| 
         @@ -114,9 +118,12 @@ def _aten_hann_window_periodic( 
     | 
|
| 
       114 
118 
     | 
    
         
             
                lctx: LoweringContext,
         
     | 
| 
       115 
119 
     | 
    
         
             
                size: int,
         
     | 
| 
       116 
120 
     | 
    
         
             
                periodic: bool,
         
     | 
| 
       117 
     | 
    
         
            -
                 
     | 
| 
       118 
     | 
    
         
            -
                dtype: Optional[torch.dtype] = None,
         
     | 
| 
      
 121 
     | 
    
         
            +
                **kwargs,
         
     | 
| 
       119 
122 
     | 
    
         
             
            ) -> ir.Value:
         
     | 
| 
      
 123 
     | 
    
         
            +
              dtype = kwargs.pop("dtype", None)
         
     | 
| 
      
 124 
     | 
    
         
            +
              layout = kwargs.pop("layout", torch.strided)
         
     | 
| 
      
 125 
     | 
    
         
            +
              if layout != torch.strided:
         
     | 
| 
      
 126 
     | 
    
         
            +
                logging.warning("hann_window only supports torch.strided layout.")
         
     | 
| 
       120 
127 
     | 
    
         
             
              return _hann_window_impl(lctx, size, periodic, dtype)
         
     | 
| 
       121 
128 
     | 
    
         | 
| 
       122 
129 
     | 
    
         | 
| 
         @@ -318,6 +325,85 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0): 
     | 
|
| 
       318 
325 
     | 
    
         
             
              return stablehlo.concatenate(non_empty_tensors, dim)
         
     | 
| 
       319 
326 
     | 
    
         | 
| 
       320 
327 
     | 
    
         | 
| 
      
 328 
     | 
    
         
            +
            # Schema:
         
     | 
| 
      
 329 
     | 
    
         
            +
            #   - aten::unfold(Tensor self, int dim, int size, int step) -> Tensor
         
     | 
| 
      
 330 
     | 
    
         
            +
            # Torch Reference:
         
     | 
| 
      
 331 
     | 
    
         
            +
            #   - https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
         
     | 
| 
      
 332 
     | 
    
         
            +
            @lower(torch.ops.aten.unfold.default)
         
     | 
| 
      
 333 
     | 
    
         
            +
            def _aten_unfold(lctx, x: ir.Value, dim: int, size: int, step: int):
         
     | 
| 
      
 334 
     | 
    
         
            +
              x_shape = x.type.shape
         
     | 
| 
      
 335 
     | 
    
         
            +
              rank = len(x_shape)
         
     | 
| 
      
 336 
     | 
    
         
            +
              if dim < 0:
         
     | 
| 
      
 337 
     | 
    
         
            +
                dim += rank
         
     | 
| 
      
 338 
     | 
    
         
            +
             
     | 
| 
      
 339 
     | 
    
         
            +
              num_windows = (x_shape[dim] - size) // step + 1
         
     | 
| 
      
 340 
     | 
    
         
            +
              batch_shape = list(x_shape[:dim]) + [num_windows] + list(x_shape[dim + 1 :])
         
     | 
| 
      
 341 
     | 
    
         
            +
             
     | 
| 
      
 342 
     | 
    
         
            +
              # Create start_indices for gather.
         
     | 
| 
      
 343 
     | 
    
         
            +
              # The shape of start_indices will be batch_shape + [rank].
         
     | 
| 
      
 344 
     | 
    
         
            +
              # start_indices[b_0,...,b_{rank-1}] will be [p_0,...,p_{rank-1}] where
         
     | 
| 
      
 345 
     | 
    
         
            +
              # p_j = b_j for j != dim and p_dim = b_dim * step.
         
     | 
| 
      
 346 
     | 
    
         
            +
              indices_parts = []
         
     | 
| 
      
 347 
     | 
    
         
            +
              i64 = ir.IntegerType.get_signless(64)
         
     | 
| 
      
 348 
     | 
    
         
            +
              for i in range(rank):
         
     | 
| 
      
 349 
     | 
    
         
            +
                bshape = [1] * rank
         
     | 
| 
      
 350 
     | 
    
         
            +
                bshape[i] = batch_shape[i]
         
     | 
| 
      
 351 
     | 
    
         
            +
                dim_len = batch_shape[i]
         
     | 
| 
      
 352 
     | 
    
         
            +
             
     | 
| 
      
 353 
     | 
    
         
            +
                iota = stablehlo.IotaOp(
         
     | 
| 
      
 354 
     | 
    
         
            +
                    ir.RankedTensorType.get([dim_len], i64),
         
     | 
| 
      
 355 
     | 
    
         
            +
                    iota_dimension=ir.IntegerAttr.get(i64, 0),
         
     | 
| 
      
 356 
     | 
    
         
            +
                ).result
         
     | 
| 
      
 357 
     | 
    
         
            +
                if i == dim:
         
     | 
| 
      
 358 
     | 
    
         
            +
                  iota = stablehlo.multiply(iota, utils.splat(step, i64, [dim_len]))
         
     | 
| 
      
 359 
     | 
    
         
            +
             
     | 
| 
      
 360 
     | 
    
         
            +
                iota_reshaped = stablehlo.reshape(
         
     | 
| 
      
 361 
     | 
    
         
            +
                    ir.RankedTensorType.get(bshape, i64), iota
         
     | 
| 
      
 362 
     | 
    
         
            +
                )
         
     | 
| 
      
 363 
     | 
    
         
            +
                indices_parts.append(
         
     | 
| 
      
 364 
     | 
    
         
            +
                    stablehlo.broadcast_in_dim(
         
     | 
| 
      
 365 
     | 
    
         
            +
                        ir.RankedTensorType.get(batch_shape, i64),
         
     | 
| 
      
 366 
     | 
    
         
            +
                        iota_reshaped,
         
     | 
| 
      
 367 
     | 
    
         
            +
                        ir.DenseI64ArrayAttr.get(list(range(rank))),
         
     | 
| 
      
 368 
     | 
    
         
            +
                    )
         
     | 
| 
      
 369 
     | 
    
         
            +
                )
         
     | 
| 
      
 370 
     | 
    
         
            +
             
     | 
| 
      
 371 
     | 
    
         
            +
              # For each dimension i, indices_parts[i] contains the i-th coordinate
         
     | 
| 
      
 372 
     | 
    
         
            +
              # of start_indices. We unsqueeze each part to shape batch_shape + [1]
         
     | 
| 
      
 373 
     | 
    
         
            +
              # and concatenate along the new dimension to produce start_indices of
         
     | 
| 
      
 374 
     | 
    
         
            +
              # shape batch_shape + [rank].
         
     | 
| 
      
 375 
     | 
    
         
            +
              unsqueezed_parts = [
         
     | 
| 
      
 376 
     | 
    
         
            +
                  stablehlo.reshape(ir.RankedTensorType.get(batch_shape + [1], i64), part)
         
     | 
| 
      
 377 
     | 
    
         
            +
                  for part in indices_parts
         
     | 
| 
      
 378 
     | 
    
         
            +
              ]
         
     | 
| 
      
 379 
     | 
    
         
            +
              start_indices = stablehlo.concatenate(
         
     | 
| 
      
 380 
     | 
    
         
            +
                  unsqueezed_parts, ir.IntegerAttr.get(i64, rank)
         
     | 
| 
      
 381 
     | 
    
         
            +
              )
         
     | 
| 
      
 382 
     | 
    
         
            +
             
     | 
| 
      
 383 
     | 
    
         
            +
              slice_sizes_list = [1] * rank
         
     | 
| 
      
 384 
     | 
    
         
            +
              slice_sizes_list[dim] = size
         
     | 
| 
      
 385 
     | 
    
         
            +
              slice_sizes = ir.DenseI64ArrayAttr.get(slice_sizes_list)
         
     | 
| 
      
 386 
     | 
    
         
            +
             
     | 
| 
      
 387 
     | 
    
         
            +
              collapsed_slice_dims_list = [i for i in range(rank) if i != dim]
         
     | 
| 
      
 388 
     | 
    
         
            +
             
     | 
| 
      
 389 
     | 
    
         
            +
              dnums = stablehlo.GatherDimensionNumbers.get(
         
     | 
| 
      
 390 
     | 
    
         
            +
                  offset_dims=[rank],
         
     | 
| 
      
 391 
     | 
    
         
            +
                  collapsed_slice_dims=collapsed_slice_dims_list,
         
     | 
| 
      
 392 
     | 
    
         
            +
                  operand_batching_dims=[],
         
     | 
| 
      
 393 
     | 
    
         
            +
                  start_indices_batching_dims=[],
         
     | 
| 
      
 394 
     | 
    
         
            +
                  start_index_map=list(range(rank)),
         
     | 
| 
      
 395 
     | 
    
         
            +
                  index_vector_dim=rank,
         
     | 
| 
      
 396 
     | 
    
         
            +
              )
         
     | 
| 
      
 397 
     | 
    
         
            +
             
     | 
| 
      
 398 
     | 
    
         
            +
              return stablehlo.gather(
         
     | 
| 
      
 399 
     | 
    
         
            +
                  x,
         
     | 
| 
      
 400 
     | 
    
         
            +
                  start_indices,
         
     | 
| 
      
 401 
     | 
    
         
            +
                  dnums,
         
     | 
| 
      
 402 
     | 
    
         
            +
                  slice_sizes,
         
     | 
| 
      
 403 
     | 
    
         
            +
                  indices_are_sorted=ir.BoolAttr.get(False),
         
     | 
| 
      
 404 
     | 
    
         
            +
              )
         
     | 
| 
      
 405 
     | 
    
         
            +
             
     | 
| 
      
 406 
     | 
    
         
            +
             
     | 
| 
       321 
407 
     | 
    
         
             
            # Schema:
         
     | 
| 
       322 
408 
     | 
    
         
             
            #   - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
         
     | 
| 
       323 
409 
     | 
    
         
             
            #       start=None, SymInt? end=None, SymInt step=1) -> Tensor
         
     | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            Metadata-Version: 2. 
     | 
| 
      
 1 
     | 
    
         
            +
            Metadata-Version: 2.1
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: ai-edge-torch-nightly
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.7.0. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.7.0.dev20251022
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         
     | 
| 
       5 
5 
     | 
    
         
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         
     | 
| 
       6 
6 
     | 
    
         
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         
     | 
| 
         @@ -37,17 +37,7 @@ Requires-Dist: ai-edge-quantizer-nightly 
     | 
|
| 
       37 
37 
     | 
    
         
             
            Requires-Dist: jax
         
     | 
| 
       38 
38 
     | 
    
         
             
            Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
         
     | 
| 
       39 
39 
     | 
    
         
             
            Provides-Extra: torch-xla
         
     | 
| 
       40 
     | 
    
         
            -
            Requires-Dist:  
     | 
| 
       41 
     | 
    
         
            -
            Dynamic: classifier
         
     | 
| 
       42 
     | 
    
         
            -
            Dynamic: description
         
     | 
| 
       43 
     | 
    
         
            -
            Dynamic: description-content-type
         
     | 
| 
       44 
     | 
    
         
            -
            Dynamic: home-page
         
     | 
| 
       45 
     | 
    
         
            -
            Dynamic: keywords
         
     | 
| 
       46 
     | 
    
         
            -
            Dynamic: license-file
         
     | 
| 
       47 
     | 
    
         
            -
            Dynamic: provides-extra
         
     | 
| 
       48 
     | 
    
         
            -
            Dynamic: requires-dist
         
     | 
| 
       49 
     | 
    
         
            -
            Dynamic: requires-python
         
     | 
| 
       50 
     | 
    
         
            -
            Dynamic: summary
         
     | 
| 
      
 40 
     | 
    
         
            +
            Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
         
     | 
| 
       51 
41 
     | 
    
         | 
| 
       52 
42 
     | 
    
         
             
            Library that supports converting PyTorch models into a .tflite format, which can
         
     | 
| 
       53 
43 
     | 
    
         
             
            then be run with TensorFlow Lite and MediaPipe.  This enables applications for
         
     | 
| 
         @@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129 
     | 
|
| 
       2 
2 
     | 
    
         
             
            ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
         
     | 
| 
       3 
3 
     | 
    
         
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         
     | 
| 
       4 
4 
     | 
    
         
             
            ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
         
     | 
| 
       5 
     | 
    
         
            -
            ai_edge_torch/version.py,sha256= 
     | 
| 
      
 5 
     | 
    
         
            +
            ai_edge_torch/version.py,sha256=EHjxIt8UozzU1M_dt9HhtAkPXLZ4-nFSjG3rQegVtSo,806
         
     | 
| 
       6 
6 
     | 
    
         
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       7 
7 
     | 
    
         
             
            ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
         
     | 
| 
       8 
8 
     | 
    
         
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         
     | 
| 
         @@ -250,7 +250,7 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi 
     | 
|
| 
       250 
250 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
         
     | 
| 
       251 
251 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
         
     | 
| 
       252 
252 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
         
     | 
| 
       253 
     | 
    
         
            -
            ai_edge_torch/odml_torch/lowerings/_basic.py,sha256 
     | 
| 
      
 253 
     | 
    
         
            +
            ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=HOTYfQWin8tqi1yakIyardxhRViZ6rhLV6ZomMSS7zA,17554
         
     | 
| 
       254 
254 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
         
     | 
| 
       255 
255 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
         
     | 
| 
       256 
256 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
         
     | 
| 
         @@ -270,8 +270,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG 
     | 
|
| 
       270 
270 
     | 
    
         
             
            ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
         
     | 
| 
       271 
271 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         
     | 
| 
       272 
272 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         
     | 
| 
       273 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
       274 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
       275 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
       276 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
       277 
     | 
    
         
            -
            ai_edge_torch_nightly-0.7.0. 
     | 
| 
      
 273 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         
     | 
| 
      
 274 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/METADATA,sha256=0QfuVTBKI9hx8RFCAn0FCq5HdeZed6x4rI8A3iROHzA,2074
         
     | 
| 
      
 275 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
         
     | 
| 
      
 276 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         
     | 
| 
      
 277 
     | 
    
         
            +
            ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     |