ai-edge-torch-nightly 0.3.0.dev20241120__py3-none-any.whl → 0.3.0.dev20241121__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,13 +15,17 @@
15
15
  import math
16
16
  from typing import Optional, Union
17
17
 
18
+ from ai_edge_torch.odml_torch import export_utils
19
+ from ai_edge_torch.odml_torch.lowerings import context
20
+ from ai_edge_torch.odml_torch.lowerings import registry
18
21
  from ai_edge_torch.odml_torch.lowerings import utils
19
22
  from jax._src.lib.mlir import ir
20
23
  from jax._src.lib.mlir.dialects import hlo as stablehlo
21
24
  import numpy as np
22
25
  import torch
23
26
 
24
- from .registry import lower
27
+ LoweringContext = context.LoweringContext
28
+ lower = registry.lower
25
29
 
26
30
 
27
31
  # add(Tensor self, Tensor other) -> Tensor
@@ -211,6 +215,31 @@ def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
211
215
  return stablehlo.floor(x)
212
216
 
213
217
 
218
+ # Schema:
219
+ # - aten::cat(Tensor[] tensors, int dim=0) -> Tensor
220
+ # Torch Reference:
221
+ # - https://pytorch.org/docs/main/generated/torch.cat.html
222
+ @lower(torch.ops.aten.cat.default)
223
+ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
224
+ assert tensors
225
+ non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0]
226
+ out_meta = lctx.node.meta["tensor_meta"]
227
+ if not non_empty_tensors:
228
+ return utils.splat(
229
+ 0,
230
+ export_utils.torch_dtype_to_ir_element_type(
231
+ lctx.ir_context, out_meta.dtype
232
+ ),
233
+ out_meta.shape,
234
+ )
235
+
236
+ if dim < 0:
237
+ dim = dim + len(out_meta.shape)
238
+ dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim)
239
+
240
+ return stablehlo.concatenate(non_empty_tensors, dim)
241
+
242
+
214
243
  # Schema:
215
244
  # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
216
245
  # start=None, SymInt? end=None, SymInt step=1) -> Tensor
@@ -105,7 +105,6 @@ lower_by_torch_xla2(torch.ops.aten.bitwise_not)
105
105
  lower_by_torch_xla2(torch.ops.aten.bitwise_or)
106
106
  lower_by_torch_xla2(torch.ops.aten.bitwise_xor)
107
107
  lower_by_torch_xla2(torch.ops.aten.bmm)
108
- lower_by_torch_xla2(torch.ops.aten.cat)
109
108
  lower_by_torch_xla2(torch.ops.aten.ceil)
110
109
  lower_by_torch_xla2(torch.ops.aten.clamp.Tensor)
111
110
  lower_by_torch_xla2(torch.ops.aten.clamp.default)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241120"
16
+ __version__ = "0.3.0.dev20241121"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241120
3
+ Version: 0.3.0.dev20241121
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=52sF7t2CBQE8RcB2Hcmo-f6_BLyCW9NzWZ-wTKM9ho4,706
6
+ ai_edge_torch/version.py,sha256=6eLmEn5xqmozokHVWP7j-jjFiQlv2a1aZxDucMzXDh8,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -178,10 +178,10 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
178
178
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
179
179
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
180
180
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
181
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=z_hPJX9n97d6obcsS9OHXpKqbmw6QqACXgnq5ML6Rhs,9014
181
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=mxNh20Z4ZeQMu0AAdXnNMXdm2PdAh3RmQPzq2SBpxQs,9954
182
182
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
183
183
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
184
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=1JeX3j7Rt3KE7Z2eYRrhtcYgO3EKnRyZFKAUWXw-bsU,10812
184
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=MY6FFSJKYtD1M1l2q3hDKf3P4NpODqQ4NyWudYe1tTE,10772
185
185
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
186
186
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
187
187
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
@@ -194,8 +194,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
194
194
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
195
195
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
196
196
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
197
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
198
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/METADATA,sha256=1Nv_QeerPRw888sOTf4jHx5Ihu-PJD9rL8GOpRHSTa4,1897
199
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
200
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
201
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/RECORD,,
197
+ ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
198
+ ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/METADATA,sha256=AJUg6jkWACMXVy7gopyMvlD0aJfw1BVnkZKbGS9cXX0,1897
199
+ ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
200
+ ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
201
+ ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/RECORD,,