ai-edge-torch-nightly 0.2.0.dev20240802__py3-none-any.whl → 0.2.0.dev20240805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

ai_edge_torch/__init__.py CHANGED
@@ -17,6 +17,7 @@ from .convert.converter import convert
17
17
  from .convert.converter import signature
18
18
  from .convert.to_channel_last_io import to_channel_last_io
19
19
  from .model import Model
20
+ from .version import __version__
20
21
 
21
22
 
22
23
  def load(path: str) -> Model:
@@ -15,6 +15,7 @@
15
15
 
16
16
  import copy
17
17
  import functools
18
+ from functools import reduce
18
19
  from typing import Any, Callable
19
20
 
20
21
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
@@ -228,25 +229,25 @@ def _aten_embedding(gm: GraphModule, node: Node):
228
229
  full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
229
230
  _, embedding_dim = full_kwargs["weight"].size()
230
231
  idx = full_kwargs["indices"]
231
- # TODO(b/356458830): Handle relative positional encoding
232
- if len(idx.size()) == 2:
233
- idx = idx.type(torch.int)
234
- B, T = idx.size()
235
-
236
- idx = torch.reshape(idx, (B * T,))
237
-
238
- builder = StableHLOCompositeBuilder("odml.embedding_lookup")
239
- full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
240
- idx,
241
- full_kwargs["weight"],
242
- )
243
- output = op(**full_kwargs)
244
- output = builder.mark_outputs(output)
245
-
246
- output = torch.reshape(output, (B, T, embedding_dim))
247
- return output
248
- else:
249
- return op(**full_kwargs)
232
+
233
+ # Explicitly cast to INT32. This places the CastOp outside of the HLFB.
234
+ idx = idx.type(torch.int)
235
+ original_idx_shape = idx.size()
236
+
237
+ # Explicitly reshape to 1D. This places the ReshapeOp outside of the HLFB.
238
+ idx = torch.reshape(idx, (idx.numel(),))
239
+
240
+ builder = StableHLOCompositeBuilder("odml.embedding_lookup")
241
+ full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
242
+ idx,
243
+ full_kwargs["weight"],
244
+ )
245
+ output = op(**full_kwargs)
246
+ output = builder.mark_outputs(output)
247
+
248
+ # Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB.
249
+ output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
250
+ return output
250
251
 
251
252
  node.target = embedding
252
253
 
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ __version__ = "0.2.0.dev20240805"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240802
3
+ Version: 0.2.0.dev20240805
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -27,10 +27,10 @@ Requires-Dist: numpy
27
27
  Requires-Dist: scipy
28
28
  Requires-Dist: safetensors
29
29
  Requires-Dist: tabulate
30
- Requires-Dist: torch >=2.4.0
31
- Requires-Dist: torch-xla >=2.4.0
32
- Requires-Dist: tf-nightly >=2.18.0.dev20240722
33
- Requires-Dist: ai-edge-quantizer-nightly ==0.0.1.dev20240718
30
+ Requires-Dist: torch>=2.4.0
31
+ Requires-Dist: torch-xla>=2.4.0
32
+ Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
+ Requires-Dist: ai-edge-quantizer-nightly==0.0.1.dev20240718
34
34
 
35
35
  Library that supports converting PyTorch models into a .tflite format, which can
36
36
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -1,5 +1,6 @@
1
- ai_edge_torch/__init__.py,sha256=CNDboRP4zQBpz2hznNCQWcQCARvNXUm3DMa1Dw_XXFg,1067
1
+ ai_edge_torch/__init__.py,sha256=WTuorXzCALfr89FC4kX_PBtKOQLipN1hcW2tMDSQW9w,1100
2
2
  ai_edge_torch/model.py,sha256=pSyY9O7J1i-SJu7g4mFD853MJBNFE6LSzBgJw7dtWuI,4494
3
+ ai_edge_torch/version.py,sha256=v9FIJo70Ip9rWQjkZBBntgskfWC49tED7nTExP6nEsI,706
3
4
  ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
4
5
  ai_edge_torch/convert/conversion.py,sha256=bkOyaTTZR9lT1VJMxwCSjcplheYv1HNSwt8A9kEo388,4183
5
6
  ai_edge_torch/convert/conversion_utils.py,sha256=GAOFepARe_vxOaetplMBBaexxojSijJzXvkxft88-Lc,13945
@@ -7,7 +8,7 @@ ai_edge_torch/convert/converter.py,sha256=6BoHl_GEIOkTr1oBg-VzZb5tr6Rv9yDwxKczYd
7
8
  ai_edge_torch/convert/to_channel_last_io.py,sha256=b7Q0_6Lam6IV-3TyhabVTMS7j0ppFpKDOIHTNAw2PnI,2814
8
9
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=D4Xe8YmeP2N0yEN_bc7pEJH47KkwGFf4COZOILmDL4w,2809
9
10
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=WVYZuocpygHAzk9u1GNoGowAIOHTlJXyA_NklmYkRms,1672
10
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=DGoKclQuxjZChGJCxKs-07zufcFxrIzKBS7Ymi-lPiQ,8079
11
+ ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=QaZ5JV7RazGbC2Khdai795vlO5jDc3yhgx3HHNmzHDs,8246
11
12
  ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=BWSU9nkD5DzxHI_WGcs9uH6qKWCw0XB2etDEV6PsZkg,4181
12
13
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=eW0Yae2cL2ALYVkhsuk3wX8v41P6bkGaABtRgdPCdxk,1672
13
14
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
@@ -125,8 +126,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=yP93mRbsB03K1_dYCRIKgxRNEP4EJOYF68
125
126
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
126
127
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
127
128
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=049yZFfnlVefQJAXkcn84ETzVneaZIlz8e0X1BW3vvI,4520
128
- ai_edge_torch_nightly-0.2.0.dev20240802.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
- ai_edge_torch_nightly-0.2.0.dev20240802.dist-info/METADATA,sha256=Bng_BviZH6NODVQolxehLCzUIDv6i6cVDB5Ddfj-uhc,1889
130
- ai_edge_torch_nightly-0.2.0.dev20240802.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
- ai_edge_torch_nightly-0.2.0.dev20240802.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
- ai_edge_torch_nightly-0.2.0.dev20240802.dist-info/RECORD,,
129
+ ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
130
+ ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/METADATA,sha256=swe019N7yzZ_OlniDSwL84aOHGBv2YGqBXnRi34JhDg,1885
131
+ ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
132
+ ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
133
+ ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: bdist_wheel (0.44.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5