ai-edge-torch-nightly 0.3.0.dev20241210__py3-none-any.whl → 0.3.0.dev20241211__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,14 +15,14 @@
15
15
 
16
16
 
17
17
  import ast
18
- import io
19
- import sys
20
18
 
21
- from ai_edge_torch.debug import find_culprits
19
+ import ai_edge_torch.debug
22
20
  import torch
23
21
 
24
22
  from absl.testing import absltest as googletest
25
23
 
24
+ find_culprits = ai_edge_torch.debug.find_culprits
25
+
26
26
  _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
27
 
28
28
  _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
@@ -52,6 +52,11 @@ class BadModel(torch.nn.Module):
52
52
 
53
53
  class TestCulprit(googletest.TestCase):
54
54
 
55
+ def setUp(self):
56
+ super().setUp()
57
+ torch.manual_seed(0)
58
+ torch._dynamo.reset()
59
+
55
60
  def test_find_culprits(self):
56
61
  model = BadModel().eval()
57
62
  args = (torch.rand(10),)
@@ -151,7 +151,7 @@ class TestModelConversion(googletest.TestCase):
151
151
  )
152
152
  def test_openelm(self):
153
153
  config = openelm.get_fake_model_config()
154
- pytorch_model = openelm.OpenElm(config).eval()
154
+ pytorch_model = openelm.OpenELM(config).eval()
155
155
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156
156
 
157
157
  @googletest.skipIf(
@@ -21,6 +21,6 @@ from . import _quantized_decomposed
21
21
  from . import context
22
22
  from . import registry
23
23
  from . import utils
24
- from .registry import decompositions
24
+ from .decomp import decompositions
25
25
  from .registry import lookup
26
26
  from .registry import lower
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch export decompositions to run before lowering."""
16
+
17
+ import functools
18
+
19
+ import torch
20
+
21
+
22
+ @functools.cache
23
+ def decompositions():
24
+ # Base: Core ATen decompositions
25
+ decompositions = torch._decomp.core_aten_decompositions()
26
+
27
+ decompositions.update(
28
+ torch._decomp.get_decompositions([
29
+ torch.ops.aten.upsample_nearest2d,
30
+ torch.ops.aten._native_batch_norm_legit.no_stats,
31
+ torch.ops.aten._native_batch_norm_legit_functional,
32
+ torch.ops.aten._adaptive_avg_pool2d,
33
+ torch.ops.aten._adaptive_avg_pool3d,
34
+ torch.ops.aten.grid_sampler_2d,
35
+ torch.ops.aten.native_group_norm,
36
+ torch.ops.aten.native_dropout,
37
+ torch.ops.aten.reflection_pad1d,
38
+ torch.ops.aten.reflection_pad2d,
39
+ torch.ops.aten.reflection_pad3d,
40
+ torch.ops.aten.replication_pad1d,
41
+ torch.ops.aten.replication_pad2d,
42
+ torch.ops.aten.replication_pad3d,
43
+ torch.ops.aten.addmm,
44
+ ])
45
+ )
46
+
47
+ torch._decomp.remove_decompositions(
48
+ decompositions,
49
+ [torch.ops.aten.roll],
50
+ )
51
+
52
+ # Override _safe_softmax decompositions with regular softmax.
53
+ # _safe_softmax introduces additional check-select ops to guard extreme
54
+ # input values to softmax, which could make the converted model inefficient
55
+ # on-device.
56
+ if hasattr(torch.ops.aten, "_safe_softmax"):
57
+ decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
58
+
59
+ return decompositions
@@ -26,7 +26,6 @@ class LoweringRegistry:
26
26
 
27
27
  def __init__(self):
28
28
  self.registered_ops = {}
29
- self.decompositions = {}
30
29
 
31
30
  def lookup(self, op_or_name):
32
31
  candidate = self._get_lowering(op_or_name)
@@ -52,33 +51,6 @@ class LoweringRegistry:
52
51
 
53
52
 
54
53
  global_registry = LoweringRegistry()
55
- global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
56
- global_registry.decompositions.update(
57
- torch._decomp.get_decompositions([
58
- torch.ops.aten.upsample_nearest2d,
59
- torch.ops.aten._native_batch_norm_legit.no_stats,
60
- torch.ops.aten._native_batch_norm_legit_functional,
61
- torch.ops.aten._adaptive_avg_pool2d,
62
- torch.ops.aten._adaptive_avg_pool3d,
63
- torch.ops.aten.grid_sampler_2d,
64
- torch.ops.aten.native_group_norm,
65
- torch.ops.aten.native_dropout,
66
- torch.ops.aten.reflection_pad1d,
67
- torch.ops.aten.reflection_pad2d,
68
- torch.ops.aten.reflection_pad3d,
69
- torch.ops.aten.replication_pad1d,
70
- torch.ops.aten.replication_pad2d,
71
- torch.ops.aten.replication_pad3d,
72
- torch.ops.aten.addmm,
73
- ])
74
- )
75
-
76
- torch._decomp.remove_decompositions(
77
- global_registry.decompositions,
78
- [
79
- torch.ops.aten.roll,
80
- ],
81
- )
82
54
 
83
55
 
84
56
  def lookup(op):
@@ -91,7 +63,3 @@ def lower(op):
91
63
  return lowering
92
64
 
93
65
  return inner
94
-
95
-
96
- def decompositions():
97
- return global_registry.decompositions
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241210"
16
+ __version__ = "0.3.0.dev20241211"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241210
3
+ Version: 0.3.0.dev20241211
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=AYxcupivW-iYIlCjWXl-QtEvpRsQqFcNK9I6uyGDqaU,706
6
+ ai_edge_torch/version.py,sha256=_uS2Df0H-aUbz-7M-gLxfjDVOJxr03EeNDfbVC_cBrE,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -34,7 +34,7 @@ ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyu
34
34
  ai_edge_torch/debug/culprit.py,sha256=7UYVpVWpiCXbMAyThVtHt_kc_poT7sCTh5UUPvcycgk,14832
35
35
  ai_edge_torch/debug/utils.py,sha256=vOAL4t6Lj47uhKapfEsc_WHmvwew3eKO9hSJyzvPXnU,1625
36
36
  ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
37
- ai_edge_torch/debug/test/test_culprit.py,sha256=SLX4rC-5Dlna8MWHhGRNe72K71AHTFufDrWLlFQn50c,3773
37
+ ai_edge_torch/debug/test/test_culprit.py,sha256=fRN-8jJicawJ2mhPRQNAQUZ8AdGg-s0tYMXyhnLAlWw,3875
38
38
  ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOawhbgiSEu96PmioPE,1668
39
39
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -140,7 +140,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1e
140
140
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
141
141
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
142
142
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=i3tQ6mEAo9lCctNoqFAnULk94hgKncC4ywn8IvgbUOo,6341
143
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IBuvXvORtHu3khr3mLJzYXyCd-zQLUdURTfH28Oo9e0,11079
143
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=MNxJHxDOx8SXi5dJ--JH7KnpJUgtUjwfvCm7X_rnNzA,11079
144
144
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
145
145
  ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
146
146
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -182,7 +182,7 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
182
182
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
183
183
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
184
184
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
185
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=1lMKPoStK3SUA8yYTPZBRhESN33BghGXnfqOOg4oeVk,995
185
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
186
186
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
187
187
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
188
188
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
@@ -190,7 +190,8 @@ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5ly
190
190
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
191
191
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
192
192
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
193
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
193
+ ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=aR6JPFP2Iq-aR0qPxJEHehmAVTjiGhgQEoycZV_1vPY,2130
194
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
194
195
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
195
196
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
196
197
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -200,8 +201,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
200
201
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
201
202
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
202
203
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
203
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/METADATA,sha256=SM6aXiKe6YYFKtS0NbSZwwYIdZES74y0X7wautX45S4,1897
205
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/RECORD,,
204
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/METADATA,sha256=Lyub5vadYf6Yu6mGY7l1PFk8Jg2rB36ojIBHm9CxhBM,1897
206
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/RECORD,,