ai-edge-torch-nightly 0.7.0.dev20251021__py3-none-any.whl → 0.7.0.dev20251022__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -325,6 +325,85 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
325
325
  return stablehlo.concatenate(non_empty_tensors, dim)
326
326
 
327
327
 
328
+ # Schema:
329
+ # - aten::unfold(Tensor self, int dim, int size, int step) -> Tensor
330
+ # Torch Reference:
331
+ # - https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
332
+ @lower(torch.ops.aten.unfold.default)
333
+ def _aten_unfold(lctx, x: ir.Value, dim: int, size: int, step: int):
334
+ x_shape = x.type.shape
335
+ rank = len(x_shape)
336
+ if dim < 0:
337
+ dim += rank
338
+
339
+ num_windows = (x_shape[dim] - size) // step + 1
340
+ batch_shape = list(x_shape[:dim]) + [num_windows] + list(x_shape[dim + 1 :])
341
+
342
+ # Create start_indices for gather.
343
+ # The shape of start_indices will be batch_shape + [rank].
344
+ # start_indices[b_0,...,b_{rank-1}] will be [p_0,...,p_{rank-1}] where
345
+ # p_j = b_j for j != dim and p_dim = b_dim * step.
346
+ indices_parts = []
347
+ i64 = ir.IntegerType.get_signless(64)
348
+ for i in range(rank):
349
+ bshape = [1] * rank
350
+ bshape[i] = batch_shape[i]
351
+ dim_len = batch_shape[i]
352
+
353
+ iota = stablehlo.IotaOp(
354
+ ir.RankedTensorType.get([dim_len], i64),
355
+ iota_dimension=ir.IntegerAttr.get(i64, 0),
356
+ ).result
357
+ if i == dim:
358
+ iota = stablehlo.multiply(iota, utils.splat(step, i64, [dim_len]))
359
+
360
+ iota_reshaped = stablehlo.reshape(
361
+ ir.RankedTensorType.get(bshape, i64), iota
362
+ )
363
+ indices_parts.append(
364
+ stablehlo.broadcast_in_dim(
365
+ ir.RankedTensorType.get(batch_shape, i64),
366
+ iota_reshaped,
367
+ ir.DenseI64ArrayAttr.get(list(range(rank))),
368
+ )
369
+ )
370
+
371
+ # For each dimension i, indices_parts[i] contains the i-th coordinate
372
+ # of start_indices. We unsqueeze each part to shape batch_shape + [1]
373
+ # and concatenate along the new dimension to produce start_indices of
374
+ # shape batch_shape + [rank].
375
+ unsqueezed_parts = [
376
+ stablehlo.reshape(ir.RankedTensorType.get(batch_shape + [1], i64), part)
377
+ for part in indices_parts
378
+ ]
379
+ start_indices = stablehlo.concatenate(
380
+ unsqueezed_parts, ir.IntegerAttr.get(i64, rank)
381
+ )
382
+
383
+ slice_sizes_list = [1] * rank
384
+ slice_sizes_list[dim] = size
385
+ slice_sizes = ir.DenseI64ArrayAttr.get(slice_sizes_list)
386
+
387
+ collapsed_slice_dims_list = [i for i in range(rank) if i != dim]
388
+
389
+ dnums = stablehlo.GatherDimensionNumbers.get(
390
+ offset_dims=[rank],
391
+ collapsed_slice_dims=collapsed_slice_dims_list,
392
+ operand_batching_dims=[],
393
+ start_indices_batching_dims=[],
394
+ start_index_map=list(range(rank)),
395
+ index_vector_dim=rank,
396
+ )
397
+
398
+ return stablehlo.gather(
399
+ x,
400
+ start_indices,
401
+ dnums,
402
+ slice_sizes,
403
+ indices_are_sorted=ir.BoolAttr.get(False),
404
+ )
405
+
406
+
328
407
  # Schema:
329
408
  # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
330
409
  # start=None, SymInt? end=None, SymInt step=1) -> Tensor
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.7.0.dev20251021"
18
+ __version__ = "0.7.0.dev20251022"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.7.0.dev20251021
3
+ Version: 0.7.0.dev20251022
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=OqJtg7_I5JrjrCcK_c026Q_w0JSHTPq1q3hHLlxAzt8,806
5
+ ai_edge_torch/version.py,sha256=EHjxIt8UozzU1M_dt9HhtAkPXLZ4-nFSjG3rQegVtSo,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -250,7 +250,7 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
250
250
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
251
251
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
252
252
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
253
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=kUbicoJcYZOgYRhydXYIsLyB2lW_Y-39skvjvrqhevo,15031
253
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=HOTYfQWin8tqi1yakIyardxhRViZ6rhLV6ZomMSS7zA,17554
254
254
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
255
255
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
256
256
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
@@ -270,8 +270,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
270
270
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
271
271
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
272
272
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
273
- ai_edge_torch_nightly-0.7.0.dev20251021.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
274
- ai_edge_torch_nightly-0.7.0.dev20251021.dist-info/METADATA,sha256=wB295znY-1RX_QXKMB8e-ujDBHQg_u39kqx8Pk8qleg,2074
275
- ai_edge_torch_nightly-0.7.0.dev20251021.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
276
- ai_edge_torch_nightly-0.7.0.dev20251021.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
277
- ai_edge_torch_nightly-0.7.0.dev20251021.dist-info/RECORD,,
273
+ ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
274
+ ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/METADATA,sha256=0QfuVTBKI9hx8RFCAn0FCq5HdeZed6x4rI8A3iROHzA,2074
275
+ ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
276
+ ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
277
+ ai_edge_torch_nightly-0.7.0.dev20251022.dist-info/RECORD,,