ai-edge-torch-nightly 0.8.0.dev20251206__py3-none-any.whl → 0.8.0.dev20260105__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,3 +30,4 @@ run_passes = pass_base.run_passes
30
30
 
31
31
  CanonicalizePass = _canonicalize_pass.CanonicalizePass
32
32
  safe_run_decompositions = _safe_run_decompositions.safe_run_decompositions
33
+ annotate_force_decomp = _safe_run_decompositions.annotate_force_decomp
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
16
16
  import operator
17
+ from typing import Any, Callable
17
18
  import torch
18
19
 
19
20
 
@@ -59,6 +60,15 @@ def _require_decomp(
59
60
  return False
60
61
 
61
62
 
63
+ _FORCE_DECOMP_ATTR = "_ai_edge_torch_force_decomp"
64
+
65
+
66
+ def annotate_force_decomp(decomp: Callable[..., Any]):
67
+ """Annotates a decomp to force it to be run (at least shallowly) in safe_run_decompositions."""
68
+ setattr(decomp, _FORCE_DECOMP_ATTR, _FORCE_DECOMP_ATTR)
69
+ return decomp
70
+
71
+
62
72
  def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
63
73
  """Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
64
74
 
@@ -79,6 +89,14 @@ def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
79
89
  # back to one aten.view.
80
90
  node.target = lambda self, size: torch.reshape(self.contiguous(), size)
81
91
 
92
+ # Torch may skip some decompositions even if target is in decomp_table.
93
+ # The following ensures the target is always run through the decompositions
94
+ # shallowly if it has _FORCE_DECOMP_ATTR.
95
+ if decomp_table and node.target in decomp_table:
96
+ decomp = decomp_table[node.target]
97
+ if hasattr(decomp, _FORCE_DECOMP_ATTR):
98
+ node.target = decomp
99
+
82
100
  exported_program = exported_program.run_decompositions(decomp_table)
83
101
 
84
102
  if hasattr(torch.ops.aten, "_assert_tensor_metadata"):
@@ -14,13 +14,72 @@
14
14
  # ==============================================================================
15
15
  """Torch export decompositions to run before lowering."""
16
16
 
17
+ import functools
17
18
  from ai_edge_torch import fx_infra
18
19
  import torch
19
20
 
20
21
 
22
+ # Fork from pytorch/torch/_decomp/decompositions.py
23
+ def upsample_compute_output_size(input_size, output_size, scale_factors):
24
+ spatial_dimensions = len(input_size) - 2
25
+ if output_size is not None:
26
+ torch._check(
27
+ scale_factors is None,
28
+ lambda: "Must specify exactly one of output_size and scale_factors",
29
+ )
30
+ torch._check(len(output_size) == spatial_dimensions, lambda: "")
31
+ return output_size
32
+ if scale_factors is not None:
33
+ # NB: this isn't necessary lol
34
+ torch._check(
35
+ output_size is None,
36
+ lambda: "Must specify exactly one of output_size and scale_factors",
37
+ )
38
+ torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
39
+ output_size = []
40
+ for i, s in enumerate(scale_factors):
41
+ if int(s) == s:
42
+ output_size.append(input_size[i + 2] * int(s))
43
+ else:
44
+ output_size.append(torch.sym_int(input_size[i + 2] * s))
45
+ return output_size
46
+ torch._check(
47
+ False, lambda: "Must specify exactly one of output_size and scale_factors"
48
+ )
49
+
50
+
51
+ # Fork from pytorch/torch/_decomp/decompositions.py
52
+ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
53
+ indices = []
54
+ num_spatial_dims = len(output_size)
55
+ offset = 0.5 if exact else 0.0
56
+
57
+ for d in range(num_spatial_dims):
58
+ osize = output_size[d]
59
+ isize = input.shape[-num_spatial_dims + d]
60
+ scale = (
61
+ isize / (isize * scales[d]) if scales[d] is not None else isize / osize
62
+ )
63
+
64
+ output_indices = torch.arange(
65
+ osize, dtype=torch.float32, device=input.device
66
+ )
67
+ input_indices = ((output_indices + offset) * scale).to(torch.int64)
68
+ for _ in range(num_spatial_dims - 1 - d):
69
+ input_indices = input_indices.unsqueeze(-1)
70
+ indices.append(input_indices)
71
+ return tuple(indices)
72
+
73
+
74
+ # Fork from pytorch/torch/_decomp/decompositions.py
75
+ def _upsample_nearest2d_common(input, h_indices, w_indices):
76
+ result = torch.ops.aten.index(input, (None, None, h_indices, w_indices))
77
+ result = result.contiguous()
78
+ return result
79
+
80
+
21
81
  fx_infra.decomp.update_pre_lower_decomp(
22
82
  torch._decomp.get_decompositions([
23
- torch.ops.aten.upsample_nearest2d,
24
83
  torch.ops.aten._native_batch_norm_legit.no_stats,
25
84
  torch.ops.aten._native_batch_norm_legit_functional,
26
85
  torch.ops.aten._adaptive_avg_pool2d,
@@ -35,11 +94,44 @@ fx_infra.decomp.update_pre_lower_decomp(
35
94
  torch.ops.aten.replication_pad2d,
36
95
  torch.ops.aten.replication_pad3d,
37
96
  torch.ops.aten.upsample_bilinear2d.vec,
38
- torch.ops.aten.upsample_nearest2d.vec,
39
97
  torch.ops.aten.addmm,
40
98
  ])
41
99
  )
42
100
 
101
+
102
+ @functools.partial(
103
+ fx_infra.decomp.add_pre_lower_decomp,
104
+ torch.ops.aten.upsample_nearest2d.default,
105
+ )
106
+ @fx_infra.annotate_force_decomp
107
+ def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
108
+ h_indices, w_indices = _compute_upsample_nearest_indices(
109
+ input, output_size, (scales_h, scales_w)
110
+ )
111
+ return _upsample_nearest2d_common(input, h_indices, w_indices)
112
+
113
+
114
+ def get_scale_value(scales, idx):
115
+ if scales is None:
116
+ return None
117
+ return scales[idx]
118
+
119
+
120
+ @functools.partial(
121
+ fx_infra.decomp.add_pre_lower_decomp,
122
+ torch.ops.aten.upsample_nearest2d.vec,
123
+ )
124
+ @fx_infra.annotate_force_decomp
125
+ def upsample_nearest2d_vec(input, output_size, scale_factors):
126
+ osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
127
+ scale_h = get_scale_value(scale_factors, 0)
128
+ scale_w = get_scale_value(scale_factors, 1)
129
+
130
+ return torch.ops.aten.upsample_nearest2d.default(
131
+ input, osize, scale_h, scale_w
132
+ )
133
+
134
+
43
135
  fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
44
136
 
45
137
  # Torch's default einsum impl/decompositions is less efficient and
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.8.0.dev20251206"
18
+ __version__ = "0.8.0.dev20260105"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.8.0.dev20251206
3
+ Version: 0.8.0.dev20260105
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=igYuweKQrisARFhIhK3eg0wnxeYvQtBYWgM2qt2tV7M,806
5
+ ai_edge_torch/version.py,sha256=PMprc1uQTZmTOISS9DIrSA9E1I9zJZ2iLY3JxxfbY0w,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=JqGZZGbpTmYiT-ta07IQbJ9-gFm-3Vip2aSzW9ulIng,6117
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -41,9 +41,9 @@ ai_edge_torch/examples/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzN
41
41
  ai_edge_torch/examples/selfie_segmentation/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
42
42
  ai_edge_torch/examples/selfie_segmentation/model.py,sha256=5otCH1MzNgSP0fikYq53hgiO1F0ZN1SCVzOIo7cVAcA,17136
43
43
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
- ai_edge_torch/fx_infra/__init__.py,sha256=APjkSqEfwDxcnI8k53rGi3Ef-G2L-M8fdaPGpxXtuiI,1347
44
+ ai_edge_torch/fx_infra/__init__.py,sha256=bseeaX7oDyyvl_oAIT2MfDYZWITmNpn660AlBOUcyUc,1418
45
45
  ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
46
- ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=V-vhvScNrE3nTenT1LIULbWIU-2FD_OHthFZCPuxtzk,3480
46
+ ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=3rGVQj7OgrIBd--olhnGpSyF2kvpjXMfADss6VPsvxQ,4167
47
47
  ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
48
48
  ai_edge_torch/fx_infra/graph_utils.py,sha256=nqGe-xIJ77RamSUh0UYyI2XHOsZqFDWax-vpRAtVR_E,2796
49
49
  ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
@@ -259,7 +259,7 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62l
259
259
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=HOTYfQWin8tqi1yakIyardxhRViZ6rhLV6ZomMSS7zA,17554
260
260
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
261
261
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
262
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
262
+ ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=LmSj5RsZBi00EE7KfF3dI2U0e60LMHA6mDKc-TC2U0U,5486
263
263
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=YmyM-5HJeeYaIhmKTOnCjfX3_A1PPh1gPGUi1d8EBs8,26454
264
264
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
265
265
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
@@ -276,8 +276,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
276
276
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
277
277
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
278
278
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
279
- ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
280
- ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/METADATA,sha256=-BM6OyCHwhI74Y8OYtnKNLt6dYiKGkEr1mSZK5hv9Lg,2399
281
- ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
282
- ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
283
- ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/RECORD,,
279
+ ai_edge_torch_nightly-0.8.0.dev20260105.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
280
+ ai_edge_torch_nightly-0.8.0.dev20260105.dist-info/METADATA,sha256=GC2dZfzgEolabjHuKdFelKgOCCyeWKGSOur-JTxPT-g,2399
281
+ ai_edge_torch_nightly-0.8.0.dev20260105.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
282
+ ai_edge_torch_nightly-0.8.0.dev20260105.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
283
+ ai_edge_torch_nightly-0.8.0.dev20260105.dist-info/RECORD,,