ai-edge-torch-nightly 0.7.0.dev20251007__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (40) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
  3. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
  4. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
  5. ai_edge_torch/generative/layers/attention.py +25 -2
  6. ai_edge_torch/generative/layers/attention_test.py +13 -1
  7. ai_edge_torch/generative/layers/attention_utils.py +62 -1
  8. ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
  9. ai_edge_torch/generative/layers/builder.py +4 -2
  10. ai_edge_torch/generative/layers/model_config.py +5 -0
  11. ai_edge_torch/generative/layers/normalization.py +8 -2
  12. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
  13. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
  14. ai_edge_torch/generative/quantize/example.py +1 -1
  15. ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
  16. ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
  17. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
  18. ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
  19. ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
  20. ai_edge_torch/generative/test/test_kv_cache.py +18 -6
  21. ai_edge_torch/generative/test/test_quantize.py +17 -26
  22. ai_edge_torch/generative/utilities/converter.py +97 -22
  23. ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
  24. ai_edge_torch/generative/utilities/loader.py +2 -1
  25. ai_edge_torch/lowertools/translate_recipe.py +8 -3
  26. ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
  27. ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
  28. ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
  29. ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
  30. ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
  31. ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
  32. ai_edge_torch/odml_torch/export.py +24 -7
  33. ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
  34. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
  35. ai_edge_torch/version.py +1 -1
  36. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
  37. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +40 -34
  38. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
  39. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
  40. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,438 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch ops to Torch-TFL decompositions."""
16
+ from typing import Sequence
17
+ from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
18
+ import torch
19
+
20
+ decomps = {}
21
+
22
+
23
+ def register_decomp(op):
24
+ global decomps
25
+ ops = [op]
26
+ if isinstance(op, torch._ops.OpOverloadPacket):
27
+ ops = [getattr(op, overload) for overload in op.overloads()]
28
+
29
+ def register(decomp_fn):
30
+ for op in ops:
31
+ decomps[op] = decomp_fn
32
+ return decomp_fn
33
+
34
+ return register
35
+
36
+
37
+ @register_decomp(torch.ops.aten.mm.default)
38
+ def _aten_mm_decomp(x, y):
39
+ return torch.ops.tfl.batch_matmul(x, y)
40
+
41
+
42
+ @register_decomp(torch.ops.aten.bmm.default)
43
+ def _aten_bmm_decomp(x, y):
44
+ return torch.ops.tfl.batch_matmul(x, y)
45
+
46
+
47
+ def _promote_types_for_binary_op(x, y):
48
+ """Promotes operand types for a binary op."""
49
+ # TFLite's binary ops require operands to have the same element type.
50
+ # We promote the types before calling the op.
51
+ # Handle scalar operand by converting scalar to a tensor.
52
+ if not isinstance(x, torch.Tensor):
53
+ x = torch.scalar_tensor(x)
54
+ elif not isinstance(y, torch.Tensor):
55
+ y = torch.scalar_tensor(y)
56
+
57
+ target_dtype = torch.promote_types(x.dtype, y.dtype)
58
+ if x.dtype != target_dtype:
59
+ x = x.to(target_dtype)
60
+ if y.dtype != target_dtype:
61
+ y = y.to(target_dtype)
62
+ return x, y
63
+
64
+
65
+ @register_decomp(torch.ops.aten.add.Tensor)
66
+ def _aten_add_tensor_decomp(x, y, alpha=1):
67
+ if alpha == 1:
68
+ x, y = _promote_types_for_binary_op(x, y)
69
+ return torch.ops.tfl.add(x, y)
70
+
71
+ # The op is add(x, mul(y, alpha))
72
+ y, alpha = _promote_types_for_binary_op(y, alpha)
73
+ mul_out = torch.ops.tfl.mul(y, alpha)
74
+ x, mul_out = _promote_types_for_binary_op(x, mul_out)
75
+ return torch.ops.tfl.add(x, mul_out)
76
+
77
+
78
+ @register_decomp(torch.ops.aten.sub.Tensor)
79
+ def _aten_sub_tensor_decomp(x, y, alpha=1):
80
+ if alpha == 1:
81
+ x, y = _promote_types_for_binary_op(x, y)
82
+ return torch.ops.tfl.sub(x, y)
83
+
84
+ # The op is sub(x, mul(y, alpha))
85
+ y, alpha = _promote_types_for_binary_op(y, alpha)
86
+ mul_out = torch.ops.tfl.mul(y, alpha)
87
+ x, mul_out = _promote_types_for_binary_op(x, mul_out)
88
+ return torch.ops.tfl.sub(x, mul_out)
89
+
90
+
91
+ @register_decomp(torch.ops.aten.mul.Tensor)
92
+ def _aten_mul_tensor_decomp(x, y):
93
+ x, y = _promote_types_for_binary_op(x, y)
94
+ return torch.ops.tfl.mul(x, y)
95
+
96
+
97
+ @register_decomp(torch.ops.aten.mul.Scalar)
98
+ def _aten_mul_scalar_decomp(x, y):
99
+ x, y = _promote_types_for_binary_op(x, y)
100
+ return torch.ops.tfl.mul(x, y)
101
+
102
+
103
+ @register_decomp(torch.ops.aten.div.Tensor)
104
+ def _aten_div_tensor_decomp(x, y):
105
+ x, y = _promote_types_for_binary_op(x, y)
106
+ return torch.ops.tfl.div(x, y)
107
+
108
+
109
+ @register_decomp(torch.ops.aten.pow.Scalar)
110
+ def _aten_pow_scalar_decomp(x, y):
111
+ x, y = _promote_types_for_binary_op(x, y)
112
+ return torch.ops.tfl.pow(x, y)
113
+
114
+
115
+ @register_decomp(torch.ops.aten.pow.Tensor_Scalar)
116
+ def _aten_pow_tensor_scalar_decomp(x, y):
117
+ x, y = _promote_types_for_binary_op(x, y)
118
+ return torch.ops.tfl.pow(x, y)
119
+
120
+
121
+ @register_decomp(torch.ops.aten.pow.Tensor_Tensor)
122
+ def _aten_pow_tensor_tensor_decomp(x, y):
123
+ x, y = _promote_types_for_binary_op(x, y)
124
+ return torch.ops.tfl.pow(x, y)
125
+
126
+
127
+ @register_decomp(torch.ops.aten.bitwise_and.Tensor)
128
+ def _aten_bitwise_and_tensor_decomp(x, y):
129
+ if not (
130
+ isinstance(x, torch.Tensor)
131
+ and x.dtype == torch.bool
132
+ and isinstance(y, torch.Tensor)
133
+ and y.dtype == torch.bool
134
+ ):
135
+ raise TypeError(
136
+ "Input tensors for aten.bitwise_and only supports bool for now."
137
+ )
138
+ return torch.ops.tfl.logical_and(x, y)
139
+
140
+
141
+ @register_decomp(torch.ops.aten.mean.dim)
142
+ def _aten_mean_dim_decomp(x, dim, keepdim=False):
143
+ return torch.ops.tfl.mean(x, dim, keepdim)
144
+
145
+
146
+ @register_decomp(torch.ops.aten.gt.Tensor)
147
+ def _aten_gt_tensor_decomp(x, y):
148
+ x, y = _promote_types_for_binary_op(x, y)
149
+ return torch.ops.tfl.greater(x, y)
150
+
151
+
152
+ @register_decomp(torch.ops.aten.lt.Tensor)
153
+ def _aten_lt_tensor_decomp(x, y):
154
+ x, y = _promote_types_for_binary_op(x, y)
155
+ return torch.ops.tfl.less(x, y)
156
+
157
+
158
+ @register_decomp(torch.ops.aten.maximum.default)
159
+ def _aten_maximum_tensor_decomp(x, y):
160
+ return torch.ops.tfl.maximum(x, y)
161
+
162
+
163
+ @register_decomp(torch.ops.aten.minimum.default)
164
+ def _aten_minimum_tensor_decomp(x, y):
165
+ return torch.ops.tfl.minimum(x, y)
166
+
167
+
168
+ @register_decomp(torch.ops.aten.sin.default)
169
+ def _aten_sin_decomp(x):
170
+ return torch.ops.tfl.sin(x)
171
+
172
+
173
+ @register_decomp(torch.ops.aten.cos.default)
174
+ def _aten_cos_decomp(x):
175
+ return torch.ops.tfl.cos(x)
176
+
177
+
178
+ @register_decomp(torch.ops.aten.rsqrt.default)
179
+ def _aten_rsqrt_decomp(x):
180
+ return torch.ops.tfl.rsqrt(x)
181
+
182
+
183
+ @register_decomp(torch.ops.aten.neg.default)
184
+ def _aten_neg_decomp(x):
185
+ return torch.ops.tfl.neg(x)
186
+
187
+
188
+ @register_decomp(torch.ops.aten.gelu.default)
189
+ def _aten_gelu_decomp(x, approximate="none"):
190
+ return torch.ops.tfl.gelu(x, approximate != "none")
191
+
192
+
193
+ @register_decomp(torch.ops.aten.permute.default)
194
+ def _aten_permute_decomp(x, dims: Sequence[int]):
195
+ return torch.ops.tfl.transpose(x, dims)
196
+
197
+
198
+ def _prepare_tensors_for_concatenation(
199
+ tensors: Sequence[torch.Tensor], axis: int
200
+ ) -> Sequence[torch.Tensor]:
201
+ """Prepares PyTorch tensors for concatenation by reshaping 1D (0,) tensors if needed."""
202
+ max_rank = 0
203
+ # First pass: determine max_rank among all input tensors
204
+ for t_val_rank_check in tensors:
205
+ max_rank = max(max_rank, t_val_rank_check.dim())
206
+
207
+ ref_tensor_for_shape_inference = None
208
+ # If max_rank > 1, we might need to reshape. Find a reference tensor.
209
+ if max_rank > 1:
210
+ for t_val_ref_check in tensors:
211
+ if t_val_ref_check.dim() == max_rank:
212
+ ref_tensor_for_shape_inference = t_val_ref_check
213
+ break
214
+
215
+ processed_operands = []
216
+ # Perform reshaping of 1D (0,) tensors only if concatenating multi-dimensional
217
+ # tensors and a valid reference tensor was found.
218
+ perform_reshaping = (
219
+ max_rank > 1 and ref_tensor_for_shape_inference is not None
220
+ )
221
+
222
+ if perform_reshaping:
223
+ ref_shape = list(ref_tensor_for_shape_inference.shape)
224
+ for t_val in tensors:
225
+ current_val = t_val
226
+
227
+ # Check if this tensor is 1D, shape (0,), and we are in a context
228
+ # where reshaping to max_rank is needed.
229
+ if torch.numel(t_val) == 0:
230
+ new_shape = list(ref_shape)
231
+ new_shape[axis] = 0
232
+ current_val = torch.ops.tfl.reshape(t_val, new_shape)
233
+ processed_operands.append(current_val)
234
+ else:
235
+ # No reshaping needed, use tensors as they are.
236
+ processed_operands = list(tensors)
237
+ return processed_operands
238
+
239
+
240
+ @register_decomp(torch.ops.aten.cat.default)
241
+ def _aten_cat_decomp(tensors, dim=0):
242
+ processed_tensors = _prepare_tensors_for_concatenation(tensors, dim)
243
+ return torch.ops.tfl.concatenation(processed_tensors, dim)
244
+
245
+
246
+ @register_decomp(torch.ops.aten.full.default)
247
+ def _aten_full_decomp(
248
+ size,
249
+ fill_value,
250
+ dtype=None,
251
+ layout=None,
252
+ device=None,
253
+ pin_memory=None,
254
+ ):
255
+ return torch.ops.tfl.fill(tuple(size), fill_value)
256
+
257
+
258
+ @register_decomp(torch.ops.aten.full_like.default)
259
+ def _aten_full_like_decomp(
260
+ x,
261
+ fill_value,
262
+ dtype=None,
263
+ layout=None,
264
+ device=None,
265
+ pin_memory=None,
266
+ memory_format=None,
267
+ ):
268
+ return torch.ops.tfl.fill(tuple(x.shape), fill_value)
269
+
270
+
271
+ @register_decomp(torch.ops.aten.view.default)
272
+ def _aten_view_decomp(x, shape: Sequence[int]):
273
+ return torch.ops.tfl.reshape(x, shape)
274
+
275
+
276
+ @register_decomp(torch.ops.aten.arange.start_step)
277
+ def _aten_arange_start_step_decomp(
278
+ start, end, step=1, dtype=None, layout=None, device=None, pin_memory=None
279
+ ):
280
+ return torch.ops.tfl.range(start, end, step)
281
+
282
+
283
+ @register_decomp(torch.ops.aten.split_with_sizes.default)
284
+ def _aten_split_with_sizes_decomp(x, split_sizes, dim=0):
285
+ outputs = []
286
+ offset = 0
287
+ for size in split_sizes:
288
+ begin = [0] * x.dim()
289
+ begin[dim] = offset
290
+ output_size = list(x.shape)
291
+ output_size[dim] = size
292
+ outputs.append(torch.ops.tfl.slice(x, begin, output_size))
293
+ offset += size
294
+ return tuple(outputs)
295
+
296
+
297
+ @register_decomp(torch.ops.aten.unsqueeze.default)
298
+ def _aten_unsqueeze_decomp(x, dim):
299
+ return torch.ops.tfl.expand_dims(x, dim)
300
+
301
+
302
+ @register_decomp(torch.ops.aten.expand.default)
303
+ def _aten_expand_decomp(x, shape: Sequence[int]):
304
+ return torch.ops.tfl.broadcast_to(x, shape)
305
+
306
+
307
+ @register_decomp(torch.ops.aten.squeeze.dims)
308
+ def _aten_squeeze_dims_decomp(x, squeeze_dims: Sequence[int]):
309
+ if len(squeeze_dims) > 8:
310
+ raise ValueError(
311
+ "torch.ops.tfl.squeeze supports squeezing at most 8 dimensions, but got"
312
+ f" {len(squeeze_dims)} dimensions."
313
+ )
314
+ return torch.ops.tfl.squeeze(x, squeeze_dims)
315
+
316
+
317
+ @register_decomp(torch.ops.aten.select.int)
318
+ def _aten_select_int_decomp(x, dim, index):
319
+ rank = len(x.shape)
320
+
321
+ # Initialize begin, end, strides
322
+ begin = [0] * rank
323
+ end = list(x.shape)
324
+ strides = [1] * rank
325
+
326
+ # Select the index on the given dim
327
+ begin[dim] = index
328
+ end[dim] = index + 1
329
+
330
+ # Perform the strided slice
331
+ sliced = torch.ops.tfl.strided_slice(x, begin, end, strides)
332
+
333
+ # Remove the selected dimension
334
+ return torch.ops.tfl.squeeze(sliced, [dim])
335
+
336
+
337
+ @register_decomp(torch.ops.aten.slice.Tensor)
338
+ def _aten_slice_tensor_decomp(x, dim=0, start=None, end=None, step=1):
339
+ rank = x.dim()
340
+ dim_size = x.shape[dim]
341
+
342
+ # Initialize begin, end, strides for tfl.strided_slice
343
+ begin = [0] * rank
344
+ end_vec = list(x.shape)
345
+ strides = [1] * rank
346
+
347
+ # The logic below is to match PyTorch's `slice` behavior.
348
+ # `start` and `end` can be negative, which means they count from the end.
349
+ # `start=None` defaults to 0.
350
+ # `end=None` or a large number defaults to `dim_size` after clamping.
351
+
352
+ start_val = 0 if start is None else start
353
+ if start_val < 0:
354
+ start_val += dim_size
355
+
356
+ end_val = dim_size if end is None else end
357
+ if end_val < 0:
358
+ end_val += dim_size
359
+
360
+ # Clamp start and end to be within the dimension size, following PyTorch's
361
+ # logic.
362
+ start_val = max(0, min(start_val, dim_size))
363
+ end_val = max(start_val, min(end_val, dim_size))
364
+
365
+ begin[dim], end_vec[dim], strides[dim] = start_val, end_val, step
366
+ return torch.ops.tfl.strided_slice(x, begin, end_vec, strides)
367
+
368
+
369
+ @register_decomp(torch.ops.aten.where.self)
370
+ def _aten_where_self_decomp(condition, x, y):
371
+ x, y = _promote_types_for_binary_op(x, y)
372
+ return torch.ops.tfl.select_v2(condition, x, y)
373
+
374
+
375
+ @register_decomp(torch.ops.aten.embedding.default)
376
+ def _aten_embedding_decomp(weight, indices, padding_idx=-1):
377
+ # The `tfl.gather` op only supports 1D indices, so we need to flatten the
378
+ # indices and then reshape the output to the correct shape.
379
+ original_indices_shape = list(indices.shape)
380
+ flat_indices = torch.ops.tfl.reshape(indices, [-1])
381
+ # Need to convert indices to int32 for tfl.embedding_lookup.
382
+ flat_indices = flat_indices.to(torch.int32)
383
+ output = torch.ops.tfl.embedding_lookup(flat_indices, weight)
384
+ output_shape = original_indices_shape + [weight.shape[-1]]
385
+ return torch.ops.tfl.reshape(output, output_shape)
386
+
387
+
388
+ @register_decomp(torch.ops.aten._softmax.default)
389
+ def _aten__softmax_decomp(
390
+ x, dim: int, half_to_float: bool # pylint: disable=unused-argument
391
+ ):
392
+ if dim == -1 or dim == x.dim() - 1:
393
+ return torch.ops.tfl.softmax(x)
394
+ else:
395
+ dims = list(range(x.dim()))
396
+ # Transpose the input by swapping the dim with the last dimension.
397
+ dims[dim], dims[-1] = dims[-1], dims[dim]
398
+ x_permuted = torch.ops.tfl.transpose(x, dims)
399
+ # Compute the softmax on the last dimension.
400
+ softmax_result = torch.ops.tfl.softmax(x_permuted)
401
+ # Transpose the result back to the original dimensions.
402
+ return torch.ops.tfl.transpose(softmax_result, dims)
403
+
404
+
405
+ @register_decomp(torch.ops.aten.multinomial.default)
406
+ def _aten_multinomial_decomp(x, num_samples, replacement=False, generator=None):
407
+ is_1d = x.dim() == 1
408
+ if is_1d:
409
+ x = torch.ops.aten.unsqueeze.default(x, 0)
410
+ logits = torch.log(x)
411
+ indices = torch.ops.tfl.multinomial(logits, num_samples, replacement)
412
+ if is_1d:
413
+ indices = torch.ops.aten.squeeze.dims(indices, [0])
414
+ return indices.to(torch.int64)
415
+
416
+
417
+ @register_decomp(torch.ops.aten.topk.default)
418
+ def _aten_topk_decomp(self, k, dim=-1, largest=True, sorted=True):
419
+ if not largest:
420
+ raise ValueError("Only largest=True is supported for torch.topk.")
421
+
422
+ if dim < 0:
423
+ dim = self.dim() + dim
424
+
425
+ if dim != self.dim() - 1:
426
+ self = torch.transpose(self, dim, -1)
427
+
428
+ # Ignores sorted value: tfl.topk_v2 only supports sorted=True, but it doesn't
429
+ # affect the correctness of the output.
430
+ out, indices = torch.ops.tfl.topk_v2(self, k)
431
+
432
+ if dim != self.dim() - 1:
433
+ out = torch.transpose(out, dim, -1)
434
+ indices = torch.transpose(indices, dim, -1)
435
+
436
+ # torch.topk returns int64 indices, but tfl.topk_v2 returns indices in int32.
437
+ indices = indices.to(torch.int64)
438
+ return out, indices