ai-edge-torch-nightly 0.8.0.dev20251206__py3-none-any.whl → 0.8.0.dev20251225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/fx_infra/__init__.py +1 -0
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +18 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +94 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.8.0.dev20251206.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.8.0.dev20251206.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.8.0.dev20251206.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.8.0.dev20251206.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/licenses/LICENSE +0 -0
- {ai_edge_torch_nightly-0.8.0.dev20251206.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/top_level.txt +0 -0
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
|
|
16
16
|
import operator
|
|
17
|
+
from typing import Any, Callable
|
|
17
18
|
import torch
|
|
18
19
|
|
|
19
20
|
|
|
@@ -59,6 +60,15 @@ def _require_decomp(
|
|
|
59
60
|
return False
|
|
60
61
|
|
|
61
62
|
|
|
63
|
+
_FORCE_DECOMP_ATTR = "_ai_edge_torch_force_decomp"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def annotate_force_decomp(decomp: Callable[..., Any]):
|
|
67
|
+
"""Annotates a decomp to force it to be run (at least shallowly) in safe_run_decompositions."""
|
|
68
|
+
setattr(decomp, _FORCE_DECOMP_ATTR, _FORCE_DECOMP_ATTR)
|
|
69
|
+
return decomp
|
|
70
|
+
|
|
71
|
+
|
|
62
72
|
def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
|
|
63
73
|
"""Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
|
|
64
74
|
|
|
@@ -79,6 +89,14 @@ def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
|
|
|
79
89
|
# back to one aten.view.
|
|
80
90
|
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
|
81
91
|
|
|
92
|
+
# Torch may skip some decompositions even if target is in decomp_table.
|
|
93
|
+
# The following ensures the target is always run through the decompositions
|
|
94
|
+
# shallowly if it has _FORCE_DECOMP_ATTR.
|
|
95
|
+
if decomp_table and node.target in decomp_table:
|
|
96
|
+
decomp = decomp_table[node.target]
|
|
97
|
+
if hasattr(decomp, _FORCE_DECOMP_ATTR):
|
|
98
|
+
node.target = decomp
|
|
99
|
+
|
|
82
100
|
exported_program = exported_program.run_decompositions(decomp_table)
|
|
83
101
|
|
|
84
102
|
if hasattr(torch.ops.aten, "_assert_tensor_metadata"):
|
|
@@ -14,13 +14,72 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Torch export decompositions to run before lowering."""
|
|
16
16
|
|
|
17
|
+
import functools
|
|
17
18
|
from ai_edge_torch import fx_infra
|
|
18
19
|
import torch
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
# Fork from pytorch/torch/_decomp/decompositions.py
|
|
23
|
+
def upsample_compute_output_size(input_size, output_size, scale_factors):
|
|
24
|
+
spatial_dimensions = len(input_size) - 2
|
|
25
|
+
if output_size is not None:
|
|
26
|
+
torch._check(
|
|
27
|
+
scale_factors is None,
|
|
28
|
+
lambda: "Must specify exactly one of output_size and scale_factors",
|
|
29
|
+
)
|
|
30
|
+
torch._check(len(output_size) == spatial_dimensions, lambda: "")
|
|
31
|
+
return output_size
|
|
32
|
+
if scale_factors is not None:
|
|
33
|
+
# NB: this isn't necessary lol
|
|
34
|
+
torch._check(
|
|
35
|
+
output_size is None,
|
|
36
|
+
lambda: "Must specify exactly one of output_size and scale_factors",
|
|
37
|
+
)
|
|
38
|
+
torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
|
|
39
|
+
output_size = []
|
|
40
|
+
for i, s in enumerate(scale_factors):
|
|
41
|
+
if int(s) == s:
|
|
42
|
+
output_size.append(input_size[i + 2] * int(s))
|
|
43
|
+
else:
|
|
44
|
+
output_size.append(torch.sym_int(input_size[i + 2] * s))
|
|
45
|
+
return output_size
|
|
46
|
+
torch._check(
|
|
47
|
+
False, lambda: "Must specify exactly one of output_size and scale_factors"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# Fork from pytorch/torch/_decomp/decompositions.py
|
|
52
|
+
def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
|
|
53
|
+
indices = []
|
|
54
|
+
num_spatial_dims = len(output_size)
|
|
55
|
+
offset = 0.5 if exact else 0.0
|
|
56
|
+
|
|
57
|
+
for d in range(num_spatial_dims):
|
|
58
|
+
osize = output_size[d]
|
|
59
|
+
isize = input.shape[-num_spatial_dims + d]
|
|
60
|
+
scale = (
|
|
61
|
+
isize / (isize * scales[d]) if scales[d] is not None else isize / osize
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
output_indices = torch.arange(
|
|
65
|
+
osize, dtype=torch.float32, device=input.device
|
|
66
|
+
)
|
|
67
|
+
input_indices = ((output_indices + offset) * scale).to(torch.int64)
|
|
68
|
+
for _ in range(num_spatial_dims - 1 - d):
|
|
69
|
+
input_indices = input_indices.unsqueeze(-1)
|
|
70
|
+
indices.append(input_indices)
|
|
71
|
+
return tuple(indices)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Fork from pytorch/torch/_decomp/decompositions.py
|
|
75
|
+
def _upsample_nearest2d_common(input, h_indices, w_indices):
|
|
76
|
+
result = torch.ops.aten.index(input, (None, None, h_indices, w_indices))
|
|
77
|
+
result = result.contiguous()
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
|
|
21
81
|
fx_infra.decomp.update_pre_lower_decomp(
|
|
22
82
|
torch._decomp.get_decompositions([
|
|
23
|
-
torch.ops.aten.upsample_nearest2d,
|
|
24
83
|
torch.ops.aten._native_batch_norm_legit.no_stats,
|
|
25
84
|
torch.ops.aten._native_batch_norm_legit_functional,
|
|
26
85
|
torch.ops.aten._adaptive_avg_pool2d,
|
|
@@ -35,11 +94,44 @@ fx_infra.decomp.update_pre_lower_decomp(
|
|
|
35
94
|
torch.ops.aten.replication_pad2d,
|
|
36
95
|
torch.ops.aten.replication_pad3d,
|
|
37
96
|
torch.ops.aten.upsample_bilinear2d.vec,
|
|
38
|
-
torch.ops.aten.upsample_nearest2d.vec,
|
|
39
97
|
torch.ops.aten.addmm,
|
|
40
98
|
])
|
|
41
99
|
)
|
|
42
100
|
|
|
101
|
+
|
|
102
|
+
@functools.partial(
|
|
103
|
+
fx_infra.decomp.add_pre_lower_decomp,
|
|
104
|
+
torch.ops.aten.upsample_nearest2d.default,
|
|
105
|
+
)
|
|
106
|
+
@fx_infra.annotate_force_decomp
|
|
107
|
+
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
|
|
108
|
+
h_indices, w_indices = _compute_upsample_nearest_indices(
|
|
109
|
+
input, output_size, (scales_h, scales_w)
|
|
110
|
+
)
|
|
111
|
+
return _upsample_nearest2d_common(input, h_indices, w_indices)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def get_scale_value(scales, idx):
|
|
115
|
+
if scales is None:
|
|
116
|
+
return None
|
|
117
|
+
return scales[idx]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@functools.partial(
|
|
121
|
+
fx_infra.decomp.add_pre_lower_decomp,
|
|
122
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
|
123
|
+
)
|
|
124
|
+
@fx_infra.annotate_force_decomp
|
|
125
|
+
def upsample_nearest2d_vec(input, output_size, scale_factors):
|
|
126
|
+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
|
127
|
+
scale_h = get_scale_value(scale_factors, 0)
|
|
128
|
+
scale_w = get_scale_value(scale_factors, 1)
|
|
129
|
+
|
|
130
|
+
return torch.ops.aten.upsample_nearest2d.default(
|
|
131
|
+
input, osize, scale_h, scale_w
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
43
135
|
fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
|
|
44
136
|
|
|
45
137
|
# Torch's default einsum impl/decompositions is less efficient and
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.8.0.
|
|
3
|
+
Version: 0.8.0.dev20251225
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
|
5
|
-
ai_edge_torch/version.py,sha256=
|
|
5
|
+
ai_edge_torch/version.py,sha256=EqYE0SbfgYjfk194m19-ExhokdnUqLLGwlHCgT7w_rM,806
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=JqGZZGbpTmYiT-ta07IQbJ9-gFm-3Vip2aSzW9ulIng,6117
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
@@ -41,9 +41,9 @@ ai_edge_torch/examples/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzN
|
|
|
41
41
|
ai_edge_torch/examples/selfie_segmentation/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
|
42
42
|
ai_edge_torch/examples/selfie_segmentation/model.py,sha256=5otCH1MzNgSP0fikYq53hgiO1F0ZN1SCVzOIo7cVAcA,17136
|
|
43
43
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
44
|
-
ai_edge_torch/fx_infra/__init__.py,sha256=
|
|
44
|
+
ai_edge_torch/fx_infra/__init__.py,sha256=bseeaX7oDyyvl_oAIT2MfDYZWITmNpn660AlBOUcyUc,1418
|
|
45
45
|
ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
|
|
46
|
-
ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=
|
|
46
|
+
ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=3rGVQj7OgrIBd--olhnGpSyF2kvpjXMfADss6VPsvxQ,4167
|
|
47
47
|
ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
|
|
48
48
|
ai_edge_torch/fx_infra/graph_utils.py,sha256=nqGe-xIJ77RamSUh0UYyI2XHOsZqFDWax-vpRAtVR_E,2796
|
|
49
49
|
ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
|
|
@@ -259,7 +259,7 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62l
|
|
|
259
259
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=HOTYfQWin8tqi1yakIyardxhRViZ6rhLV6ZomMSS7zA,17554
|
|
260
260
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
|
261
261
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
|
262
|
-
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=
|
|
262
|
+
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=LmSj5RsZBi00EE7KfF3dI2U0e60LMHA6mDKc-TC2U0U,5486
|
|
263
263
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=YmyM-5HJeeYaIhmKTOnCjfX3_A1PPh1gPGUi1d8EBs8,26454
|
|
264
264
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
|
265
265
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
|
@@ -276,8 +276,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
|
276
276
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
|
277
277
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
278
278
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
279
|
-
ai_edge_torch_nightly-0.8.0.
|
|
280
|
-
ai_edge_torch_nightly-0.8.0.
|
|
281
|
-
ai_edge_torch_nightly-0.8.0.
|
|
282
|
-
ai_edge_torch_nightly-0.8.0.
|
|
283
|
-
ai_edge_torch_nightly-0.8.0.
|
|
279
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
280
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/METADATA,sha256=aFJe9FPOOmsDU7-axoCav2W0-n4fRxf0lApj08AZb0s,2399
|
|
281
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
|
282
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
283
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|