ai-edge-torch-nightly 0.2.0.dev20240730__py3-none-any.whl → 0.2.0.dev20240731__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -223,21 +223,25 @@ def _aten_embedding(gm: GraphModule, node: Node):
223
223
  full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
224
224
  _, embedding_dim = full_kwargs["weight"].size()
225
225
  idx = full_kwargs["indices"]
226
- idx = idx.type(torch.int)
227
- B, T = idx.size()
228
-
229
- idx = torch.reshape(idx, (B * T,))
230
-
231
- builder = StableHLOCompositeBuilder("odml.embedding_lookup")
232
- full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
233
- idx,
234
- full_kwargs["weight"],
235
- )
236
- output = op(**full_kwargs)
237
- output = builder.mark_outputs(output)
238
-
239
- output = torch.reshape(output, (B, T, embedding_dim))
240
- return output
226
+ # TODO(b/356458830): Handle relative positional encoding
227
+ if len(idx.size()) == 2:
228
+ idx = idx.type(torch.int)
229
+ B, T = idx.size()
230
+
231
+ idx = torch.reshape(idx, (B * T,))
232
+
233
+ builder = StableHLOCompositeBuilder("odml.embedding_lookup")
234
+ full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
235
+ idx,
236
+ full_kwargs["weight"],
237
+ )
238
+ output = op(**full_kwargs)
239
+ output = builder.mark_outputs(output)
240
+
241
+ output = torch.reshape(output, (B, T, embedding_dim))
242
+ return output
243
+ else:
244
+ return op(**full_kwargs)
241
245
 
242
246
  node.target = embedding
243
247
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240730
3
+ Version: 0.2.0.dev20240731
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -7,7 +7,7 @@ ai_edge_torch/convert/converter.py,sha256=hSrW6A-kix9cjdD6CuLL7rseWrLKoV6GRy-iUS
7
7
  ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
8
8
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
9
9
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
10
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=ouV1CD_t5-MpDgr-7_zUG6vPrRYDT3-YWq81oZqCi9M,7924
10
+ ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=NwJp8GbZWRtQML7hG2q5sS1W92RKiRsgDDdMLy4uIBc,8079
11
11
  ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vcycd9f3OQgQLx2hhQjsKfOqdxE5EkjzqrxqyAQM,4168
12
12
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
13
13
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
@@ -125,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
125
125
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
126
126
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
127
127
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
128
- ai_edge_torch_nightly-0.2.0.dev20240730.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
- ai_edge_torch_nightly-0.2.0.dev20240730.dist-info/METADATA,sha256=f9fGXKPiQpY75w77AU-rJxkSdq6WJPJ9Jvu71s7IIhk,1889
130
- ai_edge_torch_nightly-0.2.0.dev20240730.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
- ai_edge_torch_nightly-0.2.0.dev20240730.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
- ai_edge_torch_nightly-0.2.0.dev20240730.dist-info/RECORD,,
128
+ ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
+ ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/METADATA,sha256=B2Nf7g2PWOU-bYTAByfDNV_FAKy3ah88O-Plsk-uW_M,1889
130
+ ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
+ ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
+ ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/RECORD,,