ai-edge-torch-nightly 0.3.0.dev20241210__py3-none-any.whl → 0.3.0.dev20241211__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,14 +15,14 @@
15
15
 
16
16
 
17
17
  import ast
18
- import io
19
- import sys
20
18
 
21
- from ai_edge_torch.debug import find_culprits
19
+ import ai_edge_torch.debug
22
20
  import torch
23
21
 
24
22
  from absl.testing import absltest as googletest
25
23
 
24
+ find_culprits = ai_edge_torch.debug.find_culprits
25
+
26
26
  _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
27
 
28
28
  _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
@@ -52,6 +52,11 @@ class BadModel(torch.nn.Module):
52
52
 
53
53
  class TestCulprit(googletest.TestCase):
54
54
 
55
+ def setUp(self):
56
+ super().setUp()
57
+ torch.manual_seed(0)
58
+ torch._dynamo.reset()
59
+
55
60
  def test_find_culprits(self):
56
61
  model = BadModel().eval()
57
62
  args = (torch.rand(10),)
@@ -151,7 +151,7 @@ class TestModelConversion(googletest.TestCase):
151
151
  )
152
152
  def test_openelm(self):
153
153
  config = openelm.get_fake_model_config()
154
- pytorch_model = openelm.OpenElm(config).eval()
154
+ pytorch_model = openelm.OpenELM(config).eval()
155
155
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156
156
 
157
157
  @googletest.skipIf(
@@ -21,6 +21,6 @@ from . import _quantized_decomposed
21
21
  from . import context
22
22
  from . import registry
23
23
  from . import utils
24
- from .registry import decompositions
24
+ from .decomp import decompositions
25
25
  from .registry import lookup
26
26
  from .registry import lower
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch export decompositions to run before lowering."""
16
+
17
+ import functools
18
+
19
+ import torch
20
+
21
+
22
+ @functools.cache
23
+ def decompositions():
24
+ # Base: Core ATen decompositions
25
+ decompositions = torch._decomp.core_aten_decompositions()
26
+
27
+ decompositions.update(
28
+ torch._decomp.get_decompositions([
29
+ torch.ops.aten.upsample_nearest2d,
30
+ torch.ops.aten._native_batch_norm_legit.no_stats,
31
+ torch.ops.aten._native_batch_norm_legit_functional,
32
+ torch.ops.aten._adaptive_avg_pool2d,
33
+ torch.ops.aten._adaptive_avg_pool3d,
34
+ torch.ops.aten.grid_sampler_2d,
35
+ torch.ops.aten.native_group_norm,
36
+ torch.ops.aten.native_dropout,
37
+ torch.ops.aten.reflection_pad1d,
38
+ torch.ops.aten.reflection_pad2d,
39
+ torch.ops.aten.reflection_pad3d,
40
+ torch.ops.aten.replication_pad1d,
41
+ torch.ops.aten.replication_pad2d,
42
+ torch.ops.aten.replication_pad3d,
43
+ torch.ops.aten.addmm,
44
+ ])
45
+ )
46
+
47
+ torch._decomp.remove_decompositions(
48
+ decompositions,
49
+ [torch.ops.aten.roll],
50
+ )
51
+
52
+ # Override _safe_softmax decompositions with regular softmax.
53
+ # _safe_softmax introduces additional check-select ops to guard extreme
54
+ # input values to softmax, which could make the converted model inefficient
55
+ # on-device.
56
+ if hasattr(torch.ops.aten, "_safe_softmax"):
57
+ decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
58
+
59
+ return decompositions
@@ -26,7 +26,6 @@ class LoweringRegistry:
26
26
 
27
27
  def __init__(self):
28
28
  self.registered_ops = {}
29
- self.decompositions = {}
30
29
 
31
30
  def lookup(self, op_or_name):
32
31
  candidate = self._get_lowering(op_or_name)
@@ -52,33 +51,6 @@ class LoweringRegistry:
52
51
 
53
52
 
54
53
  global_registry = LoweringRegistry()
55
- global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
56
- global_registry.decompositions.update(
57
- torch._decomp.get_decompositions([
58
- torch.ops.aten.upsample_nearest2d,
59
- torch.ops.aten._native_batch_norm_legit.no_stats,
60
- torch.ops.aten._native_batch_norm_legit_functional,
61
- torch.ops.aten._adaptive_avg_pool2d,
62
- torch.ops.aten._adaptive_avg_pool3d,
63
- torch.ops.aten.grid_sampler_2d,
64
- torch.ops.aten.native_group_norm,
65
- torch.ops.aten.native_dropout,
66
- torch.ops.aten.reflection_pad1d,
67
- torch.ops.aten.reflection_pad2d,
68
- torch.ops.aten.reflection_pad3d,
69
- torch.ops.aten.replication_pad1d,
70
- torch.ops.aten.replication_pad2d,
71
- torch.ops.aten.replication_pad3d,
72
- torch.ops.aten.addmm,
73
- ])
74
- )
75
-
76
- torch._decomp.remove_decompositions(
77
- global_registry.decompositions,
78
- [
79
- torch.ops.aten.roll,
80
- ],
81
- )
82
54
 
83
55
 
84
56
  def lookup(op):
@@ -91,7 +63,3 @@ def lower(op):
91
63
  return lowering
92
64
 
93
65
  return inner
94
-
95
-
96
- def decompositions():
97
- return global_registry.decompositions
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241210"
16
+ __version__ = "0.3.0.dev20241211"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241210
3
+ Version: 0.3.0.dev20241211
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=AYxcupivW-iYIlCjWXl-QtEvpRsQqFcNK9I6uyGDqaU,706
6
+ ai_edge_torch/version.py,sha256=_uS2Df0H-aUbz-7M-gLxfjDVOJxr03EeNDfbVC_cBrE,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -34,7 +34,7 @@ ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyu
34
34
  ai_edge_torch/debug/culprit.py,sha256=7UYVpVWpiCXbMAyThVtHt_kc_poT7sCTh5UUPvcycgk,14832
35
35
  ai_edge_torch/debug/utils.py,sha256=vOAL4t6Lj47uhKapfEsc_WHmvwew3eKO9hSJyzvPXnU,1625
36
36
  ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
37
- ai_edge_torch/debug/test/test_culprit.py,sha256=SLX4rC-5Dlna8MWHhGRNe72K71AHTFufDrWLlFQn50c,3773
37
+ ai_edge_torch/debug/test/test_culprit.py,sha256=fRN-8jJicawJ2mhPRQNAQUZ8AdGg-s0tYMXyhnLAlWw,3875
38
38
  ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOawhbgiSEu96PmioPE,1668
39
39
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -140,7 +140,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1e
140
140
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
141
141
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
142
142
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=i3tQ6mEAo9lCctNoqFAnULk94hgKncC4ywn8IvgbUOo,6341
143
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IBuvXvORtHu3khr3mLJzYXyCd-zQLUdURTfH28Oo9e0,11079
143
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=MNxJHxDOx8SXi5dJ--JH7KnpJUgtUjwfvCm7X_rnNzA,11079
144
144
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
145
145
  ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
146
146
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -182,7 +182,7 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
182
182
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
183
183
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
184
184
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
185
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=1lMKPoStK3SUA8yYTPZBRhESN33BghGXnfqOOg4oeVk,995
185
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
186
186
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
187
187
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
188
188
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
@@ -190,7 +190,8 @@ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5ly
190
190
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
191
191
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
192
192
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
193
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
193
+ ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=aR6JPFP2Iq-aR0qPxJEHehmAVTjiGhgQEoycZV_1vPY,2130
194
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
194
195
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
195
196
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
196
197
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -200,8 +201,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
200
201
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
201
202
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
202
203
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
203
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/METADATA,sha256=SM6aXiKe6YYFKtS0NbSZwwYIdZES74y0X7wautX45S4,1897
205
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
- ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/RECORD,,
204
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/METADATA,sha256=Lyub5vadYf6Yu6mGY7l1PFk8Jg2rB36ojIBHm9CxhBM,1897
206
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
+ ai_edge_torch_nightly-0.3.0.dev20241211.dist-info/RECORD,,