ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.2.0.dev20240717__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -213,6 +213,35 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
213
213
  node.target = avg_pool2d
214
214
 
215
215
 
216
+ @_register_composite_builder(torch.ops.aten.embedding.default)
217
+ def _aten_embedding(gm: GraphModule, node: Node):
218
+ op = node.target
219
+ args_mapper = TorchOpArgumentsMapper(op)
220
+
221
+ def embedding(*args, **kwargs):
222
+ nonlocal op, args_mapper
223
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
224
+ _, embedding_dim = full_kwargs["weight"].size()
225
+ idx = full_kwargs["indices"]
226
+ idx = idx.type(torch.int)
227
+ B, T = idx.size()
228
+
229
+ idx = torch.reshape(idx, (B * T,))
230
+
231
+ builder = StableHLOCompositeBuilder("odml.embedding_lookup")
232
+ full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
233
+ idx,
234
+ full_kwargs["weight"],
235
+ )
236
+ output = op(**full_kwargs)
237
+ output = builder.mark_outputs(output)
238
+
239
+ output = torch.reshape(output, (B, T, embedding_dim))
240
+ return output
241
+
242
+ node.target = embedding
243
+
244
+
216
245
  class BuildAtenCompositePass(PassBase):
217
246
 
218
247
  def call(self, graph_module: GraphModule):
@@ -187,6 +187,15 @@ class TestConvertComposites(unittest.TestCase):
187
187
 
188
188
  self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
189
189
 
190
+ def test_convert_embedding_lookup(self):
191
+ """Tests conversion of an Embedding module."""
192
+
193
+ args = (torch.full((1, 10), 0, dtype=torch.long),)
194
+ torch_module = torch.nn.Embedding(10, 10)
195
+ edge_model = ai_edge_torch.convert(torch_module, args)
196
+
197
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
198
+
190
199
 
191
200
  if __name__ == '__main__':
192
201
  unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240714
3
+ Version: 0.2.0.dev20240717
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -7,7 +7,7 @@ ai_edge_torch/convert/converter.py,sha256=hSrW6A-kix9cjdD6CuLL7rseWrLKoV6GRy-iUS
7
7
  ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
8
8
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
9
9
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
10
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=2yqUwJJ2R233_X9FNMOP9oYRTTzH34TR_BIUj-wfnKw,7080
10
+ ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=ouV1CD_t5-MpDgr-7_zUG6vPrRYDT3-YWq81oZqCi9M,7924
11
11
  ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=6m_vcycd9f3OQgQLx2hhQjsKfOqdxE5EkjzqrxqyAQM,4168
12
12
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
13
13
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
@@ -23,7 +23,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
23
23
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
24
24
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
25
25
  ai_edge_torch/convert/test/test_convert.py,sha256=h0vOffr8saDQRkiXljNWDZ17EBjnS4xAtxd8DxETleY,9081
26
- ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
26
+ ai_edge_torch/convert/test/test_convert_composites.py,sha256=8UkdPtGkjgSVLCzB_rpM2FmwYuMyt6WE48umX_kr_Sg,7601
27
27
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
28
28
  ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
29
29
  ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
@@ -114,8 +114,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
114
114
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
115
115
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
116
116
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
117
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA,sha256=dqsy1sknQkrtEaIcULIkYjDs64qnYuYZA8e5smN_3JU,1745
119
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD,,
117
+ ai_edge_torch_nightly-0.2.0.dev20240717.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
+ ai_edge_torch_nightly-0.2.0.dev20240717.dist-info/METADATA,sha256=K97HftSRap5QsdWC9otDhPNk0Ueo7HhjebEr0UMkrPM,1745
119
+ ai_edge_torch_nightly-0.2.0.dev20240717.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
+ ai_edge_torch_nightly-0.2.0.dev20240717.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
+ ai_edge_torch_nightly-0.2.0.dev20240717.dist-info/RECORD,,