ai-edge-torch-nightly 0.7.0.dev20251007__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +5 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
- ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
- ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
- ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
- ai_edge_torch/generative/test/test_kv_cache.py +18 -6
- ai_edge_torch/generative/test/test_quantize.py +17 -26
- ai_edge_torch/generative/utilities/converter.py +97 -22
- ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/lowertools/translate_recipe.py +8 -3
- ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
- ai_edge_torch/odml_torch/export.py +24 -7
- ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +40 -34
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Torch-TFL op definitions and fake implementations."""
|
|
16
|
+
|
|
17
|
+
import re
|
|
18
|
+
from typing import Any, Sequence
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch.odml_torch.experimental.torch_tfl import torch_library_utils
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
custom_op_with_fake = torch_library_utils.custom_op_with_fake
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@custom_op_with_fake("tfl::batch_matmul")
|
|
29
|
+
def tfl_batch_matmul(
|
|
30
|
+
x: torch.Tensor, y: torch.Tensor, adj_x: bool = False, adj_y: bool = False
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
if x.ndim < 2 or y.ndim < 2:
|
|
33
|
+
raise ValueError("Input tensors must have at least 2 dimensions.")
|
|
34
|
+
if adj_x:
|
|
35
|
+
x = torch.transpose(x, -1, -2)
|
|
36
|
+
if adj_y:
|
|
37
|
+
y = torch.transpose(y, -1, -2)
|
|
38
|
+
return torch.matmul(x, y)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@custom_op_with_fake("tfl::add", schema="(Tensor x, Any y) -> Tensor")
|
|
42
|
+
def tfl_add(x: torch.Tensor, y: Any) -> torch.Tensor:
|
|
43
|
+
return torch.add(x, y)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@custom_op_with_fake("tfl::sub", schema="(Tensor x, Any y) -> Tensor")
|
|
47
|
+
def tfl_sub(x: torch.Tensor, y: Any) -> torch.Tensor:
|
|
48
|
+
return torch.sub(x, y)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@custom_op_with_fake("tfl::mul", schema="(Tensor x, Any y) -> Tensor")
|
|
52
|
+
def tfl_mul(x: torch.Tensor, y: Any) -> torch.Tensor:
|
|
53
|
+
return torch.mul(x, y)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@custom_op_with_fake("tfl::div", schema="(Tensor x, Any y) -> Tensor")
|
|
57
|
+
def tfl_div(x: torch.Tensor, y: Any) -> torch.Tensor:
|
|
58
|
+
return torch.div(x, y)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@custom_op_with_fake("tfl::pow", schema="(Any x, Any y) -> Tensor")
|
|
62
|
+
def tfl_pow(x: Any, y: Any) -> torch.Tensor:
|
|
63
|
+
return torch.pow(x, y)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@custom_op_with_fake("tfl::logical_and")
|
|
67
|
+
def tfl_logical_and(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
return torch.logical_and(x, y)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@custom_op_with_fake(
|
|
72
|
+
"tfl::mean", schema="(Tensor x, Any dims, bool keepdim) -> Tensor"
|
|
73
|
+
)
|
|
74
|
+
def tfl_mean(x: torch.Tensor, dims: Any, keepdim: bool = False) -> torch.Tensor:
|
|
75
|
+
return torch.mean(x, dims, keepdim)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@custom_op_with_fake("tfl::greater")
|
|
79
|
+
def tfl_greater(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
return torch.gt(x, y)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@custom_op_with_fake("tfl::less")
|
|
84
|
+
def tfl_less(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
85
|
+
return torch.lt(x, y)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@custom_op_with_fake("tfl::maximum")
|
|
89
|
+
def tfl_maximum(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
90
|
+
return torch.maximum(x, y)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@custom_op_with_fake("tfl::minimum")
|
|
94
|
+
def tfl_minimum(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
95
|
+
return torch.minimum(x, y)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@custom_op_with_fake("tfl::sin")
|
|
99
|
+
def tfl_sin(x: torch.Tensor) -> torch.Tensor:
|
|
100
|
+
return torch.sin(x)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@custom_op_with_fake("tfl::cos")
|
|
104
|
+
def tfl_cos(x: torch.Tensor) -> torch.Tensor:
|
|
105
|
+
return torch.cos(x)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@custom_op_with_fake("tfl::rsqrt")
|
|
109
|
+
def tfl_rsqrt(x: torch.Tensor) -> torch.Tensor:
|
|
110
|
+
return torch.rsqrt(x)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@custom_op_with_fake("tfl::neg")
|
|
114
|
+
def tfl_neg(x: torch.Tensor) -> torch.Tensor:
|
|
115
|
+
return torch.neg(x)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@custom_op_with_fake("tfl::gelu")
|
|
119
|
+
def tfl_gelu(x: torch.Tensor, approximate: bool = False) -> torch.Tensor:
|
|
120
|
+
gelu_approximate = "tanh" if approximate else "none"
|
|
121
|
+
return torch.nn.functional.gelu(x, approximate=gelu_approximate)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@custom_op_with_fake("tfl::transpose")
|
|
125
|
+
def tfl_transpose(input: torch.Tensor, perm: Sequence[int]) -> torch.Tensor:
|
|
126
|
+
assert len(perm) == input.ndim
|
|
127
|
+
|
|
128
|
+
return torch.permute(input, perm).clone()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@custom_op_with_fake("tfl::concatenation")
|
|
132
|
+
def tfl_concatenation(
|
|
133
|
+
tensors: Sequence[torch.Tensor], dim: int
|
|
134
|
+
) -> torch.Tensor:
|
|
135
|
+
return torch.cat(tensors, dim=dim)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@custom_op_with_fake("tfl::fill", schema="(SymInt[] x, Any y) -> Tensor")
|
|
139
|
+
def tfl_fill(dims: Sequence[torch.SymInt], fill_value: Any) -> torch.Tensor:
|
|
140
|
+
return torch.full(dims, fill_value)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _normalize_shape(
|
|
144
|
+
tensor_input: torch.Tensor, shape: Sequence[int]
|
|
145
|
+
) -> Sequence[int]:
|
|
146
|
+
"""Normalize the size for the -1 dimension in the "shape".
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
tensor_input: The input tensor.
|
|
150
|
+
shape: The desired shape, which may contain a -1 to indicate an inferred
|
|
151
|
+
dimension.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
The inferred shape.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
ValueError: If the shape is invalid or cannot be inferred.
|
|
158
|
+
"""
|
|
159
|
+
inferred_shape = list(shape)
|
|
160
|
+
if -1 in inferred_shape:
|
|
161
|
+
numel = tensor_input.numel()
|
|
162
|
+
product = 1
|
|
163
|
+
neg_one_idx = -1
|
|
164
|
+
for i, dim in enumerate(inferred_shape):
|
|
165
|
+
if dim == -1:
|
|
166
|
+
if neg_one_idx != -1:
|
|
167
|
+
raise ValueError("Only one dimension can be inferred (-1)")
|
|
168
|
+
neg_one_idx = i
|
|
169
|
+
elif dim >= 0:
|
|
170
|
+
product *= dim
|
|
171
|
+
else:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"Shape dimensions must be non-negative or -1 for inference"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if neg_one_idx != -1:
|
|
177
|
+
if product == 0:
|
|
178
|
+
if numel != 0:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
"Cannot infer dimension for non-zero input size when other"
|
|
181
|
+
" dimensions multiply to zero"
|
|
182
|
+
)
|
|
183
|
+
inferred_shape[neg_one_idx] = 0
|
|
184
|
+
else:
|
|
185
|
+
if numel % product != 0:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Input size {numel} not divisible by product of known dimensions"
|
|
188
|
+
f" {product}"
|
|
189
|
+
)
|
|
190
|
+
inferred_shape[neg_one_idx] = numel // product
|
|
191
|
+
|
|
192
|
+
# Ensure the inferred shape still matches the total number of elements
|
|
193
|
+
if np.prod(inferred_shape) != tensor_input.numel():
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Calculated shape {inferred_shape} does not match input numel"
|
|
196
|
+
f" {tensor_input.numel()}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return inferred_shape
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@torch.library.custom_op("tfl::reshape", mutates_args=())
|
|
203
|
+
def tfl_reshape(input: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
|
|
204
|
+
inferred_shape = _normalize_shape(input, shape)
|
|
205
|
+
return input.view(inferred_shape).clone()
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# Use explicit fake implementation for tfl.reshape because dynamo cannot
|
|
209
|
+
# derive the output's symbolic shape from the impl above.
|
|
210
|
+
@torch.library.register_fake("tfl::reshape")
|
|
211
|
+
def tfl_reshape_fake(input: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
|
|
212
|
+
inferred_shape = _normalize_shape(input, shape)
|
|
213
|
+
return torch.empty(inferred_shape, dtype=input.dtype)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@custom_op_with_fake(
|
|
217
|
+
"tfl::range", schema="(Scalar start, Scalar limit, Scalar delta) -> Tensor"
|
|
218
|
+
)
|
|
219
|
+
def tfl_range(
|
|
220
|
+
start: int | float, limit: int | float, delta: int | float
|
|
221
|
+
) -> torch.Tensor:
|
|
222
|
+
return torch.arange(start, limit, delta)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@custom_op_with_fake(
|
|
226
|
+
"tfl::split_v", schema="(Tensor x, SymInt[] y, int z) -> Tensor[]"
|
|
227
|
+
)
|
|
228
|
+
def tfl_split_v(
|
|
229
|
+
input: torch.Tensor, size_splits: Sequence[torch.SymInt], split_dim: int
|
|
230
|
+
) -> Sequence[torch.Tensor]:
|
|
231
|
+
# Clone the output tensors to avoid aliasing issues.
|
|
232
|
+
return [t.clone() for t in torch.split(input, size_splits, dim=split_dim)]
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@custom_op_with_fake("tfl::expand_dims")
|
|
236
|
+
def tfl_expand_dims(x: torch.Tensor, dim: int) -> torch.Tensor:
|
|
237
|
+
return torch.unsqueeze(x, dim).clone()
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@custom_op_with_fake("tfl::broadcast_to")
|
|
241
|
+
def tfl_broadcast_to(x: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
|
|
242
|
+
return x.expand(shape).clone()
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@custom_op_with_fake("tfl::squeeze")
|
|
246
|
+
def tfl_squeeze(x: torch.Tensor, squeeze_dims: Sequence[int]) -> torch.Tensor:
|
|
247
|
+
return torch.squeeze(x, squeeze_dims).clone()
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@custom_op_with_fake("tfl::strided_slice")
|
|
251
|
+
def tfl_strided_slice(
|
|
252
|
+
input: torch.Tensor,
|
|
253
|
+
begin: Sequence[int],
|
|
254
|
+
end: Sequence[int],
|
|
255
|
+
strides: Sequence[int],
|
|
256
|
+
) -> torch.Tensor:
|
|
257
|
+
assert (
|
|
258
|
+
len(begin) == len(end) == len(strides) == input.ndim
|
|
259
|
+
), "Dimension mismatch"
|
|
260
|
+
|
|
261
|
+
slices = []
|
|
262
|
+
|
|
263
|
+
for i in range(input.ndim):
|
|
264
|
+
b = begin[i]
|
|
265
|
+
e = end[i]
|
|
266
|
+
s = strides[i]
|
|
267
|
+
slices.append(slice(b, e, s))
|
|
268
|
+
|
|
269
|
+
result = input[tuple(slices)].clone()
|
|
270
|
+
|
|
271
|
+
return result
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@custom_op_with_fake("tfl::select_v2")
|
|
275
|
+
def tfl_select_v2(
|
|
276
|
+
condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor
|
|
277
|
+
) -> torch.Tensor:
|
|
278
|
+
return torch.where(condition, x, y)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@custom_op_with_fake("tfl::embedding_lookup")
|
|
282
|
+
def tfl_embedding_lookup(
|
|
283
|
+
indices: torch.Tensor, weight: torch.Tensor
|
|
284
|
+
) -> torch.Tensor:
|
|
285
|
+
return torch.nn.functional.embedding(indices, weight)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@custom_op_with_fake("tfl::gather")
|
|
289
|
+
def tfl_gather(
|
|
290
|
+
input: torch.Tensor, indices: torch.Tensor, axis: int
|
|
291
|
+
) -> torch.Tensor:
|
|
292
|
+
return torch.index_select(input, axis, indices)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@custom_op_with_fake("tfl::softmax")
|
|
296
|
+
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
|
|
297
|
+
return torch.nn.functional.softmax(x, dim=-1)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@custom_op_with_fake("tfl::topk_v2")
|
|
301
|
+
def tfl_topk_v2(x: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
302
|
+
out, indices = torch.topk(x, k, dim=-1, largest=True, sorted=True)
|
|
303
|
+
indices = indices.to(torch.int32)
|
|
304
|
+
return out, indices
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@custom_op_with_fake("tfl::multinomial")
|
|
308
|
+
def tfl_multinomial(
|
|
309
|
+
logits: torch.Tensor, num_samples: int, replacement: bool = False
|
|
310
|
+
) -> torch.Tensor:
|
|
311
|
+
indices = torch.multinomial(
|
|
312
|
+
torch.nn.functional.softmax(logits, dim=-1),
|
|
313
|
+
num_samples,
|
|
314
|
+
replacement=replacement,
|
|
315
|
+
)
|
|
316
|
+
return indices
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@custom_op_with_fake(
|
|
320
|
+
"tfl::slice", schema="(Tensor x, SymInt[] begin, SymInt[] size) -> Tensor"
|
|
321
|
+
)
|
|
322
|
+
def tfl_slice(
|
|
323
|
+
input: torch.Tensor,
|
|
324
|
+
begin: Sequence[torch.SymInt],
|
|
325
|
+
size: Sequence[torch.SymInt],
|
|
326
|
+
) -> torch.Tensor:
|
|
327
|
+
assert len(begin) == len(size) == input.ndim
|
|
328
|
+
|
|
329
|
+
slices = [slice(i, i + l) for i, l in zip(begin, size)]
|
|
330
|
+
return input[tuple(slices)].clone()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@torch.library.custom_op("tfl::slice.tensor", mutates_args=())
|
|
334
|
+
def tfl_slice_tensor(
|
|
335
|
+
input: torch.Tensor,
|
|
336
|
+
begin: torch.Tensor,
|
|
337
|
+
size: torch.Tensor,
|
|
338
|
+
*,
|
|
339
|
+
shape: str = "",
|
|
340
|
+
) -> torch.Tensor:
|
|
341
|
+
assert begin.ndim == size.ndim == 1
|
|
342
|
+
assert begin.numel() == size.numel() == input.ndim
|
|
343
|
+
assert begin.dtype == torch.int32 and size.dtype == torch.int32
|
|
344
|
+
assert not shape or shape.count(",") == input.ndim - 1
|
|
345
|
+
|
|
346
|
+
slices = [slice(i, i + l) for i, l in zip(begin.tolist(), size.tolist())]
|
|
347
|
+
return input[tuple(slices)].clone()
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@torch.library.register_fake("tfl::slice.tensor")
|
|
351
|
+
def tfl_slice_tensor_fake(
|
|
352
|
+
input: torch.Tensor,
|
|
353
|
+
begin: torch.Tensor,
|
|
354
|
+
size: torch.Tensor,
|
|
355
|
+
*,
|
|
356
|
+
shape: str = "",
|
|
357
|
+
) -> torch.Tensor:
|
|
358
|
+
ctx = torch.library.get_ctx()
|
|
359
|
+
shape_str = shape
|
|
360
|
+
if not shape_str:
|
|
361
|
+
shape_str = ",".join(["?" for _ in range(input.ndim)])
|
|
362
|
+
|
|
363
|
+
shape = []
|
|
364
|
+
shape_symbols = shape_str.split(",")
|
|
365
|
+
for sym in shape_symbols:
|
|
366
|
+
if re.match(r"\d+", sym):
|
|
367
|
+
shape.append(int(sym))
|
|
368
|
+
else:
|
|
369
|
+
nnz = ctx.new_dynamic_size()
|
|
370
|
+
shape.append(nnz)
|
|
371
|
+
return input.new_empty(shape, dtype=input.dtype)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Utility functions for defining custom ops in torch library."""
|
|
16
|
+
from typing import Callable, Iterable
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def custom_op_with_fake(
|
|
21
|
+
name: str,
|
|
22
|
+
*,
|
|
23
|
+
mutates_args: str | Iterable[str] = (),
|
|
24
|
+
schema: str | None = None,
|
|
25
|
+
):
|
|
26
|
+
"""Defines a custom op with a FakeTensor implementation using the same function."""
|
|
27
|
+
|
|
28
|
+
def register(fn: Callable[..., object]):
|
|
29
|
+
op = torch.library.custom_op(
|
|
30
|
+
name,
|
|
31
|
+
mutates_args=mutates_args,
|
|
32
|
+
schema=schema,
|
|
33
|
+
)(fn)
|
|
34
|
+
torch.library.register_fake(name)(fn)
|
|
35
|
+
return op
|
|
36
|
+
|
|
37
|
+
return register
|
|
@@ -21,6 +21,7 @@ import operator
|
|
|
21
21
|
from typing import Any, Callable, Optional
|
|
22
22
|
|
|
23
23
|
from ai_edge_torch import fx_infra
|
|
24
|
+
from ai_edge_torch.odml_torch.experimental import torch_tfl
|
|
24
25
|
from jax._src.lib.mlir import ir
|
|
25
26
|
from jax._src.lib.mlir.dialects import func
|
|
26
27
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
@@ -349,7 +350,8 @@ def exported_program_to_mlir(
|
|
|
349
350
|
exported_program: The exported program to lower.
|
|
350
351
|
ir_context: The MLIR context to use. If not provided, a new context will be
|
|
351
352
|
created.
|
|
352
|
-
_pre_lower_pass: A function to run on exported program before lowering
|
|
353
|
+
_pre_lower_pass: A function to run on exported program before lowering,
|
|
354
|
+
after all run_decompositions calls.
|
|
353
355
|
|
|
354
356
|
Returns:
|
|
355
357
|
The lowered MLIR module, metadata, and weight tensors bundle from exported
|
|
@@ -358,18 +360,33 @@ def exported_program_to_mlir(
|
|
|
358
360
|
exported_program = fx_infra.safe_run_decompositions(
|
|
359
361
|
exported_program,
|
|
360
362
|
fx_infra.decomp.pre_lower_decomp(),
|
|
363
|
+
can_skip=False,
|
|
361
364
|
)
|
|
362
|
-
if _convert_i64_to_i32(exported_program):
|
|
363
|
-
# Run decompositions for retracing and cananicalization, if modified.
|
|
364
|
-
exported_program = fx_infra.safe_run_decompositions(exported_program, {})
|
|
365
365
|
|
|
366
|
-
# Passes
|
|
367
|
-
#
|
|
368
|
-
|
|
366
|
+
# Passes to run after pre_lower_decomp - requires ops to be decomposed into
|
|
367
|
+
# lower level ops.
|
|
368
|
+
_convert_i64_to_i32(exported_program)
|
|
369
369
|
|
|
370
|
+
# Last decomposition and canonicalization before lowering.
|
|
371
|
+
exported_program = fx_infra.safe_run_decompositions(
|
|
372
|
+
exported_program,
|
|
373
|
+
fx_infra.decomp.pre_lower_decomp()
|
|
374
|
+
| {
|
|
375
|
+
op: torch_tfl.decomps[op]
|
|
376
|
+
for op in [
|
|
377
|
+
torch.ops.aten.multinomial.default,
|
|
378
|
+
]
|
|
379
|
+
},
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Passes below modify the exported program to a state not executable by torch.
|
|
383
|
+
# Do not call run_decompositions after applying the passes.
|
|
370
384
|
if _pre_lower_pass:
|
|
371
385
|
_pre_lower_pass(exported_program)
|
|
372
386
|
|
|
387
|
+
_convert_q_dq_per_channel_args_to_list(exported_program)
|
|
388
|
+
|
|
389
|
+
# Begin of lowering.
|
|
373
390
|
if not ir_context:
|
|
374
391
|
ir_context = export_utils.create_ir_context()
|
|
375
392
|
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import logging
|
|
15
16
|
import math
|
|
16
17
|
import operator
|
|
17
18
|
from typing import Optional, Union
|
|
@@ -51,6 +52,81 @@ def _aten_mul_tensor(lctx, self: ir.Value, other: ir.Value):
|
|
|
51
52
|
return stablehlo.multiply(self, other)
|
|
52
53
|
|
|
53
54
|
|
|
55
|
+
def _hann_window_impl(
|
|
56
|
+
lctx: LoweringContext,
|
|
57
|
+
size: int,
|
|
58
|
+
periodic: bool,
|
|
59
|
+
dtype: Optional[torch.dtype],
|
|
60
|
+
) -> ir.Value:
|
|
61
|
+
if dtype is None:
|
|
62
|
+
ir_dtype = ir.F32Type.get()
|
|
63
|
+
else:
|
|
64
|
+
ir_dtype = utils.torch_dtype_to_ir_element_type(dtype)
|
|
65
|
+
|
|
66
|
+
if not isinstance(ir_dtype, ir.FloatType):
|
|
67
|
+
raise ValueError("hann_window only supports float dtypes.")
|
|
68
|
+
|
|
69
|
+
if size == 0:
|
|
70
|
+
return stablehlo.ConstantOp(
|
|
71
|
+
ir.RankedTensorType.get((0,), ir_dtype),
|
|
72
|
+
ir.DenseElementsAttr.get_empty(ir.RankedTensorType.get((0,), ir_dtype)),
|
|
73
|
+
).result
|
|
74
|
+
if size == 1:
|
|
75
|
+
return utils.splat(1.0, ir_dtype, [1])
|
|
76
|
+
|
|
77
|
+
denom = size if periodic else size - 1
|
|
78
|
+
|
|
79
|
+
i64 = ir.IntegerType.get_signless(64)
|
|
80
|
+
iota_type = ir.RankedTensorType.get((size,), i64)
|
|
81
|
+
n_i64 = stablehlo.IotaOp(
|
|
82
|
+
iota_type, iota_dimension=ir.IntegerAttr.get(i64, 0)
|
|
83
|
+
).result
|
|
84
|
+
|
|
85
|
+
n_type = ir.RankedTensorType.get((size,), ir_dtype)
|
|
86
|
+
n = stablehlo.convert(n_type, n_i64)
|
|
87
|
+
|
|
88
|
+
pi_val = math.pi
|
|
89
|
+
scale = 2.0 * pi_val / denom
|
|
90
|
+
|
|
91
|
+
scale_splat = utils.splat(scale, ir_dtype, [size])
|
|
92
|
+
arg_cos = stablehlo.multiply(n, scale_splat)
|
|
93
|
+
cos_val = stablehlo.cosine(arg_cos)
|
|
94
|
+
|
|
95
|
+
half_splat = utils.splat(0.5, ir_dtype, [size])
|
|
96
|
+
scaled_cos = stablehlo.multiply(half_splat, cos_val)
|
|
97
|
+
return stablehlo.subtract(half_splat, scaled_cos)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# hann_window(int size, *, ScalarType? dtype=None) -> Tensor
|
|
101
|
+
@lower(torch.ops.aten.hann_window.default)
|
|
102
|
+
def _aten_hann_window_default(
|
|
103
|
+
lctx: LoweringContext,
|
|
104
|
+
size: int,
|
|
105
|
+
**kwargs,
|
|
106
|
+
) -> ir.Value:
|
|
107
|
+
dtype = kwargs.pop("dtype", None)
|
|
108
|
+
layout = kwargs.pop("layout", torch.strided)
|
|
109
|
+
if layout != torch.strided:
|
|
110
|
+
logging.warning("hann_window only supports torch.strided layout.")
|
|
111
|
+
return _hann_window_impl(lctx, size, True, dtype)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# hann_window.periodic(int size, bool periodic, *, ScalarType? dtype=None) ->
|
|
115
|
+
# Tensor
|
|
116
|
+
@lower(torch.ops.aten.hann_window.periodic)
|
|
117
|
+
def _aten_hann_window_periodic(
|
|
118
|
+
lctx: LoweringContext,
|
|
119
|
+
size: int,
|
|
120
|
+
periodic: bool,
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> ir.Value:
|
|
123
|
+
dtype = kwargs.pop("dtype", None)
|
|
124
|
+
layout = kwargs.pop("layout", torch.strided)
|
|
125
|
+
if layout != torch.strided:
|
|
126
|
+
logging.warning("hann_window only supports torch.strided layout.")
|
|
127
|
+
return _hann_window_impl(lctx, size, periodic, dtype)
|
|
128
|
+
|
|
129
|
+
|
|
54
130
|
# cat(Tensor[] tensors, int dim=0) -> Tensor
|
|
55
131
|
# @lower(torch.ops.aten.cat)
|
|
56
132
|
def _aten_cat(lctx, tensors: list[ir.Value], dim: int = 1):
|
|
@@ -249,6 +325,85 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
|
|
|
249
325
|
return stablehlo.concatenate(non_empty_tensors, dim)
|
|
250
326
|
|
|
251
327
|
|
|
328
|
+
# Schema:
|
|
329
|
+
# - aten::unfold(Tensor self, int dim, int size, int step) -> Tensor
|
|
330
|
+
# Torch Reference:
|
|
331
|
+
# - https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
|
|
332
|
+
@lower(torch.ops.aten.unfold.default)
|
|
333
|
+
def _aten_unfold(lctx, x: ir.Value, dim: int, size: int, step: int):
|
|
334
|
+
x_shape = x.type.shape
|
|
335
|
+
rank = len(x_shape)
|
|
336
|
+
if dim < 0:
|
|
337
|
+
dim += rank
|
|
338
|
+
|
|
339
|
+
num_windows = (x_shape[dim] - size) // step + 1
|
|
340
|
+
batch_shape = list(x_shape[:dim]) + [num_windows] + list(x_shape[dim + 1 :])
|
|
341
|
+
|
|
342
|
+
# Create start_indices for gather.
|
|
343
|
+
# The shape of start_indices will be batch_shape + [rank].
|
|
344
|
+
# start_indices[b_0,...,b_{rank-1}] will be [p_0,...,p_{rank-1}] where
|
|
345
|
+
# p_j = b_j for j != dim and p_dim = b_dim * step.
|
|
346
|
+
indices_parts = []
|
|
347
|
+
i64 = ir.IntegerType.get_signless(64)
|
|
348
|
+
for i in range(rank):
|
|
349
|
+
bshape = [1] * rank
|
|
350
|
+
bshape[i] = batch_shape[i]
|
|
351
|
+
dim_len = batch_shape[i]
|
|
352
|
+
|
|
353
|
+
iota = stablehlo.IotaOp(
|
|
354
|
+
ir.RankedTensorType.get([dim_len], i64),
|
|
355
|
+
iota_dimension=ir.IntegerAttr.get(i64, 0),
|
|
356
|
+
).result
|
|
357
|
+
if i == dim:
|
|
358
|
+
iota = stablehlo.multiply(iota, utils.splat(step, i64, [dim_len]))
|
|
359
|
+
|
|
360
|
+
iota_reshaped = stablehlo.reshape(
|
|
361
|
+
ir.RankedTensorType.get(bshape, i64), iota
|
|
362
|
+
)
|
|
363
|
+
indices_parts.append(
|
|
364
|
+
stablehlo.broadcast_in_dim(
|
|
365
|
+
ir.RankedTensorType.get(batch_shape, i64),
|
|
366
|
+
iota_reshaped,
|
|
367
|
+
ir.DenseI64ArrayAttr.get(list(range(rank))),
|
|
368
|
+
)
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# For each dimension i, indices_parts[i] contains the i-th coordinate
|
|
372
|
+
# of start_indices. We unsqueeze each part to shape batch_shape + [1]
|
|
373
|
+
# and concatenate along the new dimension to produce start_indices of
|
|
374
|
+
# shape batch_shape + [rank].
|
|
375
|
+
unsqueezed_parts = [
|
|
376
|
+
stablehlo.reshape(ir.RankedTensorType.get(batch_shape + [1], i64), part)
|
|
377
|
+
for part in indices_parts
|
|
378
|
+
]
|
|
379
|
+
start_indices = stablehlo.concatenate(
|
|
380
|
+
unsqueezed_parts, ir.IntegerAttr.get(i64, rank)
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
slice_sizes_list = [1] * rank
|
|
384
|
+
slice_sizes_list[dim] = size
|
|
385
|
+
slice_sizes = ir.DenseI64ArrayAttr.get(slice_sizes_list)
|
|
386
|
+
|
|
387
|
+
collapsed_slice_dims_list = [i for i in range(rank) if i != dim]
|
|
388
|
+
|
|
389
|
+
dnums = stablehlo.GatherDimensionNumbers.get(
|
|
390
|
+
offset_dims=[rank],
|
|
391
|
+
collapsed_slice_dims=collapsed_slice_dims_list,
|
|
392
|
+
operand_batching_dims=[],
|
|
393
|
+
start_indices_batching_dims=[],
|
|
394
|
+
start_index_map=list(range(rank)),
|
|
395
|
+
index_vector_dim=rank,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
return stablehlo.gather(
|
|
399
|
+
x,
|
|
400
|
+
start_indices,
|
|
401
|
+
dnums,
|
|
402
|
+
slice_sizes,
|
|
403
|
+
indices_are_sorted=ir.BoolAttr.get(False),
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
|
|
252
407
|
# Schema:
|
|
253
408
|
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
|
|
254
409
|
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
|