ai-edge-torch-nightly 0.7.0.dev20250929__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -27
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +1 -30
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +1 -30
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +5 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
- ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
- ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
- ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
- ai_edge_torch/generative/test/test_kv_cache.py +18 -6
- ai_edge_torch/generative/test/test_quantize.py +17 -26
- ai_edge_torch/generative/utilities/converter.py +183 -28
- ai_edge_torch/generative/utilities/export_config.py +2 -0
- ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/lowertools/translate_recipe.py +8 -3
- ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
- ai_edge_torch/odml_torch/export.py +24 -7
- ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +57 -51
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Torch ops to Torch-TFL decompositions."""
|
|
16
|
+
from typing import Sequence
|
|
17
|
+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
decomps = {}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def register_decomp(op):
|
|
24
|
+
global decomps
|
|
25
|
+
ops = [op]
|
|
26
|
+
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
27
|
+
ops = [getattr(op, overload) for overload in op.overloads()]
|
|
28
|
+
|
|
29
|
+
def register(decomp_fn):
|
|
30
|
+
for op in ops:
|
|
31
|
+
decomps[op] = decomp_fn
|
|
32
|
+
return decomp_fn
|
|
33
|
+
|
|
34
|
+
return register
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@register_decomp(torch.ops.aten.mm.default)
|
|
38
|
+
def _aten_mm_decomp(x, y):
|
|
39
|
+
return torch.ops.tfl.batch_matmul(x, y)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_decomp(torch.ops.aten.bmm.default)
|
|
43
|
+
def _aten_bmm_decomp(x, y):
|
|
44
|
+
return torch.ops.tfl.batch_matmul(x, y)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _promote_types_for_binary_op(x, y):
|
|
48
|
+
"""Promotes operand types for a binary op."""
|
|
49
|
+
# TFLite's binary ops require operands to have the same element type.
|
|
50
|
+
# We promote the types before calling the op.
|
|
51
|
+
# Handle scalar operand by converting scalar to a tensor.
|
|
52
|
+
if not isinstance(x, torch.Tensor):
|
|
53
|
+
x = torch.scalar_tensor(x)
|
|
54
|
+
elif not isinstance(y, torch.Tensor):
|
|
55
|
+
y = torch.scalar_tensor(y)
|
|
56
|
+
|
|
57
|
+
target_dtype = torch.promote_types(x.dtype, y.dtype)
|
|
58
|
+
if x.dtype != target_dtype:
|
|
59
|
+
x = x.to(target_dtype)
|
|
60
|
+
if y.dtype != target_dtype:
|
|
61
|
+
y = y.to(target_dtype)
|
|
62
|
+
return x, y
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@register_decomp(torch.ops.aten.add.Tensor)
|
|
66
|
+
def _aten_add_tensor_decomp(x, y, alpha=1):
|
|
67
|
+
if alpha == 1:
|
|
68
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
69
|
+
return torch.ops.tfl.add(x, y)
|
|
70
|
+
|
|
71
|
+
# The op is add(x, mul(y, alpha))
|
|
72
|
+
y, alpha = _promote_types_for_binary_op(y, alpha)
|
|
73
|
+
mul_out = torch.ops.tfl.mul(y, alpha)
|
|
74
|
+
x, mul_out = _promote_types_for_binary_op(x, mul_out)
|
|
75
|
+
return torch.ops.tfl.add(x, mul_out)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@register_decomp(torch.ops.aten.sub.Tensor)
|
|
79
|
+
def _aten_sub_tensor_decomp(x, y, alpha=1):
|
|
80
|
+
if alpha == 1:
|
|
81
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
82
|
+
return torch.ops.tfl.sub(x, y)
|
|
83
|
+
|
|
84
|
+
# The op is sub(x, mul(y, alpha))
|
|
85
|
+
y, alpha = _promote_types_for_binary_op(y, alpha)
|
|
86
|
+
mul_out = torch.ops.tfl.mul(y, alpha)
|
|
87
|
+
x, mul_out = _promote_types_for_binary_op(x, mul_out)
|
|
88
|
+
return torch.ops.tfl.sub(x, mul_out)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@register_decomp(torch.ops.aten.mul.Tensor)
|
|
92
|
+
def _aten_mul_tensor_decomp(x, y):
|
|
93
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
94
|
+
return torch.ops.tfl.mul(x, y)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@register_decomp(torch.ops.aten.mul.Scalar)
|
|
98
|
+
def _aten_mul_scalar_decomp(x, y):
|
|
99
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
100
|
+
return torch.ops.tfl.mul(x, y)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register_decomp(torch.ops.aten.div.Tensor)
|
|
104
|
+
def _aten_div_tensor_decomp(x, y):
|
|
105
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
106
|
+
return torch.ops.tfl.div(x, y)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@register_decomp(torch.ops.aten.pow.Scalar)
|
|
110
|
+
def _aten_pow_scalar_decomp(x, y):
|
|
111
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
112
|
+
return torch.ops.tfl.pow(x, y)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@register_decomp(torch.ops.aten.pow.Tensor_Scalar)
|
|
116
|
+
def _aten_pow_tensor_scalar_decomp(x, y):
|
|
117
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
118
|
+
return torch.ops.tfl.pow(x, y)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@register_decomp(torch.ops.aten.pow.Tensor_Tensor)
|
|
122
|
+
def _aten_pow_tensor_tensor_decomp(x, y):
|
|
123
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
124
|
+
return torch.ops.tfl.pow(x, y)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@register_decomp(torch.ops.aten.bitwise_and.Tensor)
|
|
128
|
+
def _aten_bitwise_and_tensor_decomp(x, y):
|
|
129
|
+
if not (
|
|
130
|
+
isinstance(x, torch.Tensor)
|
|
131
|
+
and x.dtype == torch.bool
|
|
132
|
+
and isinstance(y, torch.Tensor)
|
|
133
|
+
and y.dtype == torch.bool
|
|
134
|
+
):
|
|
135
|
+
raise TypeError(
|
|
136
|
+
"Input tensors for aten.bitwise_and only supports bool for now."
|
|
137
|
+
)
|
|
138
|
+
return torch.ops.tfl.logical_and(x, y)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@register_decomp(torch.ops.aten.mean.dim)
|
|
142
|
+
def _aten_mean_dim_decomp(x, dim, keepdim=False):
|
|
143
|
+
return torch.ops.tfl.mean(x, dim, keepdim)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@register_decomp(torch.ops.aten.gt.Tensor)
|
|
147
|
+
def _aten_gt_tensor_decomp(x, y):
|
|
148
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
149
|
+
return torch.ops.tfl.greater(x, y)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@register_decomp(torch.ops.aten.lt.Tensor)
|
|
153
|
+
def _aten_lt_tensor_decomp(x, y):
|
|
154
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
155
|
+
return torch.ops.tfl.less(x, y)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@register_decomp(torch.ops.aten.maximum.default)
|
|
159
|
+
def _aten_maximum_tensor_decomp(x, y):
|
|
160
|
+
return torch.ops.tfl.maximum(x, y)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@register_decomp(torch.ops.aten.minimum.default)
|
|
164
|
+
def _aten_minimum_tensor_decomp(x, y):
|
|
165
|
+
return torch.ops.tfl.minimum(x, y)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@register_decomp(torch.ops.aten.sin.default)
|
|
169
|
+
def _aten_sin_decomp(x):
|
|
170
|
+
return torch.ops.tfl.sin(x)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@register_decomp(torch.ops.aten.cos.default)
|
|
174
|
+
def _aten_cos_decomp(x):
|
|
175
|
+
return torch.ops.tfl.cos(x)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@register_decomp(torch.ops.aten.rsqrt.default)
|
|
179
|
+
def _aten_rsqrt_decomp(x):
|
|
180
|
+
return torch.ops.tfl.rsqrt(x)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@register_decomp(torch.ops.aten.neg.default)
|
|
184
|
+
def _aten_neg_decomp(x):
|
|
185
|
+
return torch.ops.tfl.neg(x)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@register_decomp(torch.ops.aten.gelu.default)
|
|
189
|
+
def _aten_gelu_decomp(x, approximate="none"):
|
|
190
|
+
return torch.ops.tfl.gelu(x, approximate != "none")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@register_decomp(torch.ops.aten.permute.default)
|
|
194
|
+
def _aten_permute_decomp(x, dims: Sequence[int]):
|
|
195
|
+
return torch.ops.tfl.transpose(x, dims)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _prepare_tensors_for_concatenation(
|
|
199
|
+
tensors: Sequence[torch.Tensor], axis: int
|
|
200
|
+
) -> Sequence[torch.Tensor]:
|
|
201
|
+
"""Prepares PyTorch tensors for concatenation by reshaping 1D (0,) tensors if needed."""
|
|
202
|
+
max_rank = 0
|
|
203
|
+
# First pass: determine max_rank among all input tensors
|
|
204
|
+
for t_val_rank_check in tensors:
|
|
205
|
+
max_rank = max(max_rank, t_val_rank_check.dim())
|
|
206
|
+
|
|
207
|
+
ref_tensor_for_shape_inference = None
|
|
208
|
+
# If max_rank > 1, we might need to reshape. Find a reference tensor.
|
|
209
|
+
if max_rank > 1:
|
|
210
|
+
for t_val_ref_check in tensors:
|
|
211
|
+
if t_val_ref_check.dim() == max_rank:
|
|
212
|
+
ref_tensor_for_shape_inference = t_val_ref_check
|
|
213
|
+
break
|
|
214
|
+
|
|
215
|
+
processed_operands = []
|
|
216
|
+
# Perform reshaping of 1D (0,) tensors only if concatenating multi-dimensional
|
|
217
|
+
# tensors and a valid reference tensor was found.
|
|
218
|
+
perform_reshaping = (
|
|
219
|
+
max_rank > 1 and ref_tensor_for_shape_inference is not None
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if perform_reshaping:
|
|
223
|
+
ref_shape = list(ref_tensor_for_shape_inference.shape)
|
|
224
|
+
for t_val in tensors:
|
|
225
|
+
current_val = t_val
|
|
226
|
+
|
|
227
|
+
# Check if this tensor is 1D, shape (0,), and we are in a context
|
|
228
|
+
# where reshaping to max_rank is needed.
|
|
229
|
+
if torch.numel(t_val) == 0:
|
|
230
|
+
new_shape = list(ref_shape)
|
|
231
|
+
new_shape[axis] = 0
|
|
232
|
+
current_val = torch.ops.tfl.reshape(t_val, new_shape)
|
|
233
|
+
processed_operands.append(current_val)
|
|
234
|
+
else:
|
|
235
|
+
# No reshaping needed, use tensors as they are.
|
|
236
|
+
processed_operands = list(tensors)
|
|
237
|
+
return processed_operands
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@register_decomp(torch.ops.aten.cat.default)
|
|
241
|
+
def _aten_cat_decomp(tensors, dim=0):
|
|
242
|
+
processed_tensors = _prepare_tensors_for_concatenation(tensors, dim)
|
|
243
|
+
return torch.ops.tfl.concatenation(processed_tensors, dim)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@register_decomp(torch.ops.aten.full.default)
|
|
247
|
+
def _aten_full_decomp(
|
|
248
|
+
size,
|
|
249
|
+
fill_value,
|
|
250
|
+
dtype=None,
|
|
251
|
+
layout=None,
|
|
252
|
+
device=None,
|
|
253
|
+
pin_memory=None,
|
|
254
|
+
):
|
|
255
|
+
return torch.ops.tfl.fill(tuple(size), fill_value)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@register_decomp(torch.ops.aten.full_like.default)
|
|
259
|
+
def _aten_full_like_decomp(
|
|
260
|
+
x,
|
|
261
|
+
fill_value,
|
|
262
|
+
dtype=None,
|
|
263
|
+
layout=None,
|
|
264
|
+
device=None,
|
|
265
|
+
pin_memory=None,
|
|
266
|
+
memory_format=None,
|
|
267
|
+
):
|
|
268
|
+
return torch.ops.tfl.fill(tuple(x.shape), fill_value)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@register_decomp(torch.ops.aten.view.default)
|
|
272
|
+
def _aten_view_decomp(x, shape: Sequence[int]):
|
|
273
|
+
return torch.ops.tfl.reshape(x, shape)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@register_decomp(torch.ops.aten.arange.start_step)
|
|
277
|
+
def _aten_arange_start_step_decomp(
|
|
278
|
+
start, end, step=1, dtype=None, layout=None, device=None, pin_memory=None
|
|
279
|
+
):
|
|
280
|
+
return torch.ops.tfl.range(start, end, step)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@register_decomp(torch.ops.aten.split_with_sizes.default)
|
|
284
|
+
def _aten_split_with_sizes_decomp(x, split_sizes, dim=0):
|
|
285
|
+
outputs = []
|
|
286
|
+
offset = 0
|
|
287
|
+
for size in split_sizes:
|
|
288
|
+
begin = [0] * x.dim()
|
|
289
|
+
begin[dim] = offset
|
|
290
|
+
output_size = list(x.shape)
|
|
291
|
+
output_size[dim] = size
|
|
292
|
+
outputs.append(torch.ops.tfl.slice(x, begin, output_size))
|
|
293
|
+
offset += size
|
|
294
|
+
return tuple(outputs)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@register_decomp(torch.ops.aten.unsqueeze.default)
|
|
298
|
+
def _aten_unsqueeze_decomp(x, dim):
|
|
299
|
+
return torch.ops.tfl.expand_dims(x, dim)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@register_decomp(torch.ops.aten.expand.default)
|
|
303
|
+
def _aten_expand_decomp(x, shape: Sequence[int]):
|
|
304
|
+
return torch.ops.tfl.broadcast_to(x, shape)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@register_decomp(torch.ops.aten.squeeze.dims)
|
|
308
|
+
def _aten_squeeze_dims_decomp(x, squeeze_dims: Sequence[int]):
|
|
309
|
+
if len(squeeze_dims) > 8:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"torch.ops.tfl.squeeze supports squeezing at most 8 dimensions, but got"
|
|
312
|
+
f" {len(squeeze_dims)} dimensions."
|
|
313
|
+
)
|
|
314
|
+
return torch.ops.tfl.squeeze(x, squeeze_dims)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@register_decomp(torch.ops.aten.select.int)
|
|
318
|
+
def _aten_select_int_decomp(x, dim, index):
|
|
319
|
+
rank = len(x.shape)
|
|
320
|
+
|
|
321
|
+
# Initialize begin, end, strides
|
|
322
|
+
begin = [0] * rank
|
|
323
|
+
end = list(x.shape)
|
|
324
|
+
strides = [1] * rank
|
|
325
|
+
|
|
326
|
+
# Select the index on the given dim
|
|
327
|
+
begin[dim] = index
|
|
328
|
+
end[dim] = index + 1
|
|
329
|
+
|
|
330
|
+
# Perform the strided slice
|
|
331
|
+
sliced = torch.ops.tfl.strided_slice(x, begin, end, strides)
|
|
332
|
+
|
|
333
|
+
# Remove the selected dimension
|
|
334
|
+
return torch.ops.tfl.squeeze(sliced, [dim])
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
@register_decomp(torch.ops.aten.slice.Tensor)
|
|
338
|
+
def _aten_slice_tensor_decomp(x, dim=0, start=None, end=None, step=1):
|
|
339
|
+
rank = x.dim()
|
|
340
|
+
dim_size = x.shape[dim]
|
|
341
|
+
|
|
342
|
+
# Initialize begin, end, strides for tfl.strided_slice
|
|
343
|
+
begin = [0] * rank
|
|
344
|
+
end_vec = list(x.shape)
|
|
345
|
+
strides = [1] * rank
|
|
346
|
+
|
|
347
|
+
# The logic below is to match PyTorch's `slice` behavior.
|
|
348
|
+
# `start` and `end` can be negative, which means they count from the end.
|
|
349
|
+
# `start=None` defaults to 0.
|
|
350
|
+
# `end=None` or a large number defaults to `dim_size` after clamping.
|
|
351
|
+
|
|
352
|
+
start_val = 0 if start is None else start
|
|
353
|
+
if start_val < 0:
|
|
354
|
+
start_val += dim_size
|
|
355
|
+
|
|
356
|
+
end_val = dim_size if end is None else end
|
|
357
|
+
if end_val < 0:
|
|
358
|
+
end_val += dim_size
|
|
359
|
+
|
|
360
|
+
# Clamp start and end to be within the dimension size, following PyTorch's
|
|
361
|
+
# logic.
|
|
362
|
+
start_val = max(0, min(start_val, dim_size))
|
|
363
|
+
end_val = max(start_val, min(end_val, dim_size))
|
|
364
|
+
|
|
365
|
+
begin[dim], end_vec[dim], strides[dim] = start_val, end_val, step
|
|
366
|
+
return torch.ops.tfl.strided_slice(x, begin, end_vec, strides)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
@register_decomp(torch.ops.aten.where.self)
|
|
370
|
+
def _aten_where_self_decomp(condition, x, y):
|
|
371
|
+
x, y = _promote_types_for_binary_op(x, y)
|
|
372
|
+
return torch.ops.tfl.select_v2(condition, x, y)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@register_decomp(torch.ops.aten.embedding.default)
|
|
376
|
+
def _aten_embedding_decomp(weight, indices, padding_idx=-1):
|
|
377
|
+
# The `tfl.gather` op only supports 1D indices, so we need to flatten the
|
|
378
|
+
# indices and then reshape the output to the correct shape.
|
|
379
|
+
original_indices_shape = list(indices.shape)
|
|
380
|
+
flat_indices = torch.ops.tfl.reshape(indices, [-1])
|
|
381
|
+
# Need to convert indices to int32 for tfl.embedding_lookup.
|
|
382
|
+
flat_indices = flat_indices.to(torch.int32)
|
|
383
|
+
output = torch.ops.tfl.embedding_lookup(flat_indices, weight)
|
|
384
|
+
output_shape = original_indices_shape + [weight.shape[-1]]
|
|
385
|
+
return torch.ops.tfl.reshape(output, output_shape)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
@register_decomp(torch.ops.aten._softmax.default)
|
|
389
|
+
def _aten__softmax_decomp(
|
|
390
|
+
x, dim: int, half_to_float: bool # pylint: disable=unused-argument
|
|
391
|
+
):
|
|
392
|
+
if dim == -1 or dim == x.dim() - 1:
|
|
393
|
+
return torch.ops.tfl.softmax(x)
|
|
394
|
+
else:
|
|
395
|
+
dims = list(range(x.dim()))
|
|
396
|
+
# Transpose the input by swapping the dim with the last dimension.
|
|
397
|
+
dims[dim], dims[-1] = dims[-1], dims[dim]
|
|
398
|
+
x_permuted = torch.ops.tfl.transpose(x, dims)
|
|
399
|
+
# Compute the softmax on the last dimension.
|
|
400
|
+
softmax_result = torch.ops.tfl.softmax(x_permuted)
|
|
401
|
+
# Transpose the result back to the original dimensions.
|
|
402
|
+
return torch.ops.tfl.transpose(softmax_result, dims)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@register_decomp(torch.ops.aten.multinomial.default)
|
|
406
|
+
def _aten_multinomial_decomp(x, num_samples, replacement=False, generator=None):
|
|
407
|
+
is_1d = x.dim() == 1
|
|
408
|
+
if is_1d:
|
|
409
|
+
x = torch.ops.aten.unsqueeze.default(x, 0)
|
|
410
|
+
logits = torch.log(x)
|
|
411
|
+
indices = torch.ops.tfl.multinomial(logits, num_samples, replacement)
|
|
412
|
+
if is_1d:
|
|
413
|
+
indices = torch.ops.aten.squeeze.dims(indices, [0])
|
|
414
|
+
return indices.to(torch.int64)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
@register_decomp(torch.ops.aten.topk.default)
|
|
418
|
+
def _aten_topk_decomp(self, k, dim=-1, largest=True, sorted=True):
|
|
419
|
+
if not largest:
|
|
420
|
+
raise ValueError("Only largest=True is supported for torch.topk.")
|
|
421
|
+
|
|
422
|
+
if dim < 0:
|
|
423
|
+
dim = self.dim() + dim
|
|
424
|
+
|
|
425
|
+
if dim != self.dim() - 1:
|
|
426
|
+
self = torch.transpose(self, dim, -1)
|
|
427
|
+
|
|
428
|
+
# Ignores sorted value: tfl.topk_v2 only supports sorted=True, but it doesn't
|
|
429
|
+
# affect the correctness of the output.
|
|
430
|
+
out, indices = torch.ops.tfl.topk_v2(self, k)
|
|
431
|
+
|
|
432
|
+
if dim != self.dim() - 1:
|
|
433
|
+
out = torch.transpose(out, dim, -1)
|
|
434
|
+
indices = torch.transpose(indices, dim, -1)
|
|
435
|
+
|
|
436
|
+
# torch.topk returns int64 indices, but tfl.topk_v2 returns indices in int32.
|
|
437
|
+
indices = indices.to(torch.int64)
|
|
438
|
+
return out, indices
|