ai-edge-torch-nightly 0.7.0.dev20250929__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (57) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
  3. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -20
  4. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +1 -20
  5. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -20
  6. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -20
  7. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -27
  8. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +1 -20
  9. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -20
  10. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -20
  11. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -20
  12. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -20
  13. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -20
  14. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -20
  15. ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +1 -20
  16. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +1 -30
  17. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +1 -30
  18. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
  19. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
  20. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -20
  21. ai_edge_torch/generative/layers/attention.py +25 -2
  22. ai_edge_torch/generative/layers/attention_test.py +13 -1
  23. ai_edge_torch/generative/layers/attention_utils.py +62 -1
  24. ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
  25. ai_edge_torch/generative/layers/builder.py +4 -2
  26. ai_edge_torch/generative/layers/model_config.py +5 -0
  27. ai_edge_torch/generative/layers/normalization.py +8 -2
  28. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
  29. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
  30. ai_edge_torch/generative/quantize/example.py +1 -1
  31. ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
  32. ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
  33. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
  34. ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
  35. ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
  36. ai_edge_torch/generative/test/test_kv_cache.py +18 -6
  37. ai_edge_torch/generative/test/test_quantize.py +17 -26
  38. ai_edge_torch/generative/utilities/converter.py +183 -28
  39. ai_edge_torch/generative/utilities/export_config.py +2 -0
  40. ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
  41. ai_edge_torch/generative/utilities/loader.py +2 -1
  42. ai_edge_torch/lowertools/translate_recipe.py +8 -3
  43. ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
  44. ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
  45. ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
  46. ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
  47. ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
  48. ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
  49. ai_edge_torch/odml_torch/export.py +24 -7
  50. ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
  51. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
  52. ai_edge_torch/version.py +1 -1
  53. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
  54. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +57 -51
  55. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
  56. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
  57. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,371 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch-TFL op definitions and fake implementations."""
16
+
17
+ import re
18
+ from typing import Any, Sequence
19
+
20
+ from ai_edge_torch.odml_torch.experimental.torch_tfl import torch_library_utils
21
+ import numpy as np
22
+ import torch
23
+
24
+
25
+ custom_op_with_fake = torch_library_utils.custom_op_with_fake
26
+
27
+
28
+ @custom_op_with_fake("tfl::batch_matmul")
29
+ def tfl_batch_matmul(
30
+ x: torch.Tensor, y: torch.Tensor, adj_x: bool = False, adj_y: bool = False
31
+ ) -> torch.Tensor:
32
+ if x.ndim < 2 or y.ndim < 2:
33
+ raise ValueError("Input tensors must have at least 2 dimensions.")
34
+ if adj_x:
35
+ x = torch.transpose(x, -1, -2)
36
+ if adj_y:
37
+ y = torch.transpose(y, -1, -2)
38
+ return torch.matmul(x, y)
39
+
40
+
41
+ @custom_op_with_fake("tfl::add", schema="(Tensor x, Any y) -> Tensor")
42
+ def tfl_add(x: torch.Tensor, y: Any) -> torch.Tensor:
43
+ return torch.add(x, y)
44
+
45
+
46
+ @custom_op_with_fake("tfl::sub", schema="(Tensor x, Any y) -> Tensor")
47
+ def tfl_sub(x: torch.Tensor, y: Any) -> torch.Tensor:
48
+ return torch.sub(x, y)
49
+
50
+
51
+ @custom_op_with_fake("tfl::mul", schema="(Tensor x, Any y) -> Tensor")
52
+ def tfl_mul(x: torch.Tensor, y: Any) -> torch.Tensor:
53
+ return torch.mul(x, y)
54
+
55
+
56
+ @custom_op_with_fake("tfl::div", schema="(Tensor x, Any y) -> Tensor")
57
+ def tfl_div(x: torch.Tensor, y: Any) -> torch.Tensor:
58
+ return torch.div(x, y)
59
+
60
+
61
+ @custom_op_with_fake("tfl::pow", schema="(Any x, Any y) -> Tensor")
62
+ def tfl_pow(x: Any, y: Any) -> torch.Tensor:
63
+ return torch.pow(x, y)
64
+
65
+
66
+ @custom_op_with_fake("tfl::logical_and")
67
+ def tfl_logical_and(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
68
+ return torch.logical_and(x, y)
69
+
70
+
71
+ @custom_op_with_fake(
72
+ "tfl::mean", schema="(Tensor x, Any dims, bool keepdim) -> Tensor"
73
+ )
74
+ def tfl_mean(x: torch.Tensor, dims: Any, keepdim: bool = False) -> torch.Tensor:
75
+ return torch.mean(x, dims, keepdim)
76
+
77
+
78
+ @custom_op_with_fake("tfl::greater")
79
+ def tfl_greater(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
80
+ return torch.gt(x, y)
81
+
82
+
83
+ @custom_op_with_fake("tfl::less")
84
+ def tfl_less(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
85
+ return torch.lt(x, y)
86
+
87
+
88
+ @custom_op_with_fake("tfl::maximum")
89
+ def tfl_maximum(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
90
+ return torch.maximum(x, y)
91
+
92
+
93
+ @custom_op_with_fake("tfl::minimum")
94
+ def tfl_minimum(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
95
+ return torch.minimum(x, y)
96
+
97
+
98
+ @custom_op_with_fake("tfl::sin")
99
+ def tfl_sin(x: torch.Tensor) -> torch.Tensor:
100
+ return torch.sin(x)
101
+
102
+
103
+ @custom_op_with_fake("tfl::cos")
104
+ def tfl_cos(x: torch.Tensor) -> torch.Tensor:
105
+ return torch.cos(x)
106
+
107
+
108
+ @custom_op_with_fake("tfl::rsqrt")
109
+ def tfl_rsqrt(x: torch.Tensor) -> torch.Tensor:
110
+ return torch.rsqrt(x)
111
+
112
+
113
+ @custom_op_with_fake("tfl::neg")
114
+ def tfl_neg(x: torch.Tensor) -> torch.Tensor:
115
+ return torch.neg(x)
116
+
117
+
118
+ @custom_op_with_fake("tfl::gelu")
119
+ def tfl_gelu(x: torch.Tensor, approximate: bool = False) -> torch.Tensor:
120
+ gelu_approximate = "tanh" if approximate else "none"
121
+ return torch.nn.functional.gelu(x, approximate=gelu_approximate)
122
+
123
+
124
+ @custom_op_with_fake("tfl::transpose")
125
+ def tfl_transpose(input: torch.Tensor, perm: Sequence[int]) -> torch.Tensor:
126
+ assert len(perm) == input.ndim
127
+
128
+ return torch.permute(input, perm).clone()
129
+
130
+
131
+ @custom_op_with_fake("tfl::concatenation")
132
+ def tfl_concatenation(
133
+ tensors: Sequence[torch.Tensor], dim: int
134
+ ) -> torch.Tensor:
135
+ return torch.cat(tensors, dim=dim)
136
+
137
+
138
+ @custom_op_with_fake("tfl::fill", schema="(SymInt[] x, Any y) -> Tensor")
139
+ def tfl_fill(dims: Sequence[torch.SymInt], fill_value: Any) -> torch.Tensor:
140
+ return torch.full(dims, fill_value)
141
+
142
+
143
+ def _normalize_shape(
144
+ tensor_input: torch.Tensor, shape: Sequence[int]
145
+ ) -> Sequence[int]:
146
+ """Normalize the size for the -1 dimension in the "shape".
147
+
148
+ Args:
149
+ tensor_input: The input tensor.
150
+ shape: The desired shape, which may contain a -1 to indicate an inferred
151
+ dimension.
152
+
153
+ Returns:
154
+ The inferred shape.
155
+
156
+ Raises:
157
+ ValueError: If the shape is invalid or cannot be inferred.
158
+ """
159
+ inferred_shape = list(shape)
160
+ if -1 in inferred_shape:
161
+ numel = tensor_input.numel()
162
+ product = 1
163
+ neg_one_idx = -1
164
+ for i, dim in enumerate(inferred_shape):
165
+ if dim == -1:
166
+ if neg_one_idx != -1:
167
+ raise ValueError("Only one dimension can be inferred (-1)")
168
+ neg_one_idx = i
169
+ elif dim >= 0:
170
+ product *= dim
171
+ else:
172
+ raise ValueError(
173
+ "Shape dimensions must be non-negative or -1 for inference"
174
+ )
175
+
176
+ if neg_one_idx != -1:
177
+ if product == 0:
178
+ if numel != 0:
179
+ raise ValueError(
180
+ "Cannot infer dimension for non-zero input size when other"
181
+ " dimensions multiply to zero"
182
+ )
183
+ inferred_shape[neg_one_idx] = 0
184
+ else:
185
+ if numel % product != 0:
186
+ raise ValueError(
187
+ f"Input size {numel} not divisible by product of known dimensions"
188
+ f" {product}"
189
+ )
190
+ inferred_shape[neg_one_idx] = numel // product
191
+
192
+ # Ensure the inferred shape still matches the total number of elements
193
+ if np.prod(inferred_shape) != tensor_input.numel():
194
+ raise ValueError(
195
+ f"Calculated shape {inferred_shape} does not match input numel"
196
+ f" {tensor_input.numel()}"
197
+ )
198
+
199
+ return inferred_shape
200
+
201
+
202
+ @torch.library.custom_op("tfl::reshape", mutates_args=())
203
+ def tfl_reshape(input: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
204
+ inferred_shape = _normalize_shape(input, shape)
205
+ return input.view(inferred_shape).clone()
206
+
207
+
208
+ # Use explicit fake implementation for tfl.reshape because dynamo cannot
209
+ # derive the output's symbolic shape from the impl above.
210
+ @torch.library.register_fake("tfl::reshape")
211
+ def tfl_reshape_fake(input: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
212
+ inferred_shape = _normalize_shape(input, shape)
213
+ return torch.empty(inferred_shape, dtype=input.dtype)
214
+
215
+
216
+ @custom_op_with_fake(
217
+ "tfl::range", schema="(Scalar start, Scalar limit, Scalar delta) -> Tensor"
218
+ )
219
+ def tfl_range(
220
+ start: int | float, limit: int | float, delta: int | float
221
+ ) -> torch.Tensor:
222
+ return torch.arange(start, limit, delta)
223
+
224
+
225
+ @custom_op_with_fake(
226
+ "tfl::split_v", schema="(Tensor x, SymInt[] y, int z) -> Tensor[]"
227
+ )
228
+ def tfl_split_v(
229
+ input: torch.Tensor, size_splits: Sequence[torch.SymInt], split_dim: int
230
+ ) -> Sequence[torch.Tensor]:
231
+ # Clone the output tensors to avoid aliasing issues.
232
+ return [t.clone() for t in torch.split(input, size_splits, dim=split_dim)]
233
+
234
+
235
+ @custom_op_with_fake("tfl::expand_dims")
236
+ def tfl_expand_dims(x: torch.Tensor, dim: int) -> torch.Tensor:
237
+ return torch.unsqueeze(x, dim).clone()
238
+
239
+
240
+ @custom_op_with_fake("tfl::broadcast_to")
241
+ def tfl_broadcast_to(x: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
242
+ return x.expand(shape).clone()
243
+
244
+
245
+ @custom_op_with_fake("tfl::squeeze")
246
+ def tfl_squeeze(x: torch.Tensor, squeeze_dims: Sequence[int]) -> torch.Tensor:
247
+ return torch.squeeze(x, squeeze_dims).clone()
248
+
249
+
250
+ @custom_op_with_fake("tfl::strided_slice")
251
+ def tfl_strided_slice(
252
+ input: torch.Tensor,
253
+ begin: Sequence[int],
254
+ end: Sequence[int],
255
+ strides: Sequence[int],
256
+ ) -> torch.Tensor:
257
+ assert (
258
+ len(begin) == len(end) == len(strides) == input.ndim
259
+ ), "Dimension mismatch"
260
+
261
+ slices = []
262
+
263
+ for i in range(input.ndim):
264
+ b = begin[i]
265
+ e = end[i]
266
+ s = strides[i]
267
+ slices.append(slice(b, e, s))
268
+
269
+ result = input[tuple(slices)].clone()
270
+
271
+ return result
272
+
273
+
274
+ @custom_op_with_fake("tfl::select_v2")
275
+ def tfl_select_v2(
276
+ condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor
277
+ ) -> torch.Tensor:
278
+ return torch.where(condition, x, y)
279
+
280
+
281
+ @custom_op_with_fake("tfl::embedding_lookup")
282
+ def tfl_embedding_lookup(
283
+ indices: torch.Tensor, weight: torch.Tensor
284
+ ) -> torch.Tensor:
285
+ return torch.nn.functional.embedding(indices, weight)
286
+
287
+
288
+ @custom_op_with_fake("tfl::gather")
289
+ def tfl_gather(
290
+ input: torch.Tensor, indices: torch.Tensor, axis: int
291
+ ) -> torch.Tensor:
292
+ return torch.index_select(input, axis, indices)
293
+
294
+
295
+ @custom_op_with_fake("tfl::softmax")
296
+ def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
297
+ return torch.nn.functional.softmax(x, dim=-1)
298
+
299
+
300
+ @custom_op_with_fake("tfl::topk_v2")
301
+ def tfl_topk_v2(x: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
302
+ out, indices = torch.topk(x, k, dim=-1, largest=True, sorted=True)
303
+ indices = indices.to(torch.int32)
304
+ return out, indices
305
+
306
+
307
+ @custom_op_with_fake("tfl::multinomial")
308
+ def tfl_multinomial(
309
+ logits: torch.Tensor, num_samples: int, replacement: bool = False
310
+ ) -> torch.Tensor:
311
+ indices = torch.multinomial(
312
+ torch.nn.functional.softmax(logits, dim=-1),
313
+ num_samples,
314
+ replacement=replacement,
315
+ )
316
+ return indices
317
+
318
+
319
+ @custom_op_with_fake(
320
+ "tfl::slice", schema="(Tensor x, SymInt[] begin, SymInt[] size) -> Tensor"
321
+ )
322
+ def tfl_slice(
323
+ input: torch.Tensor,
324
+ begin: Sequence[torch.SymInt],
325
+ size: Sequence[torch.SymInt],
326
+ ) -> torch.Tensor:
327
+ assert len(begin) == len(size) == input.ndim
328
+
329
+ slices = [slice(i, i + l) for i, l in zip(begin, size)]
330
+ return input[tuple(slices)].clone()
331
+
332
+
333
+ @torch.library.custom_op("tfl::slice.tensor", mutates_args=())
334
+ def tfl_slice_tensor(
335
+ input: torch.Tensor,
336
+ begin: torch.Tensor,
337
+ size: torch.Tensor,
338
+ *,
339
+ shape: str = "",
340
+ ) -> torch.Tensor:
341
+ assert begin.ndim == size.ndim == 1
342
+ assert begin.numel() == size.numel() == input.ndim
343
+ assert begin.dtype == torch.int32 and size.dtype == torch.int32
344
+ assert not shape or shape.count(",") == input.ndim - 1
345
+
346
+ slices = [slice(i, i + l) for i, l in zip(begin.tolist(), size.tolist())]
347
+ return input[tuple(slices)].clone()
348
+
349
+
350
+ @torch.library.register_fake("tfl::slice.tensor")
351
+ def tfl_slice_tensor_fake(
352
+ input: torch.Tensor,
353
+ begin: torch.Tensor,
354
+ size: torch.Tensor,
355
+ *,
356
+ shape: str = "",
357
+ ) -> torch.Tensor:
358
+ ctx = torch.library.get_ctx()
359
+ shape_str = shape
360
+ if not shape_str:
361
+ shape_str = ",".join(["?" for _ in range(input.ndim)])
362
+
363
+ shape = []
364
+ shape_symbols = shape_str.split(",")
365
+ for sym in shape_symbols:
366
+ if re.match(r"\d+", sym):
367
+ shape.append(int(sym))
368
+ else:
369
+ nnz = ctx.new_dynamic_size()
370
+ shape.append(nnz)
371
+ return input.new_empty(shape, dtype=input.dtype)
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for defining custom ops in torch library."""
16
+ from typing import Callable, Iterable
17
+ import torch
18
+
19
+
20
+ def custom_op_with_fake(
21
+ name: str,
22
+ *,
23
+ mutates_args: str | Iterable[str] = (),
24
+ schema: str | None = None,
25
+ ):
26
+ """Defines a custom op with a FakeTensor implementation using the same function."""
27
+
28
+ def register(fn: Callable[..., object]):
29
+ op = torch.library.custom_op(
30
+ name,
31
+ mutates_args=mutates_args,
32
+ schema=schema,
33
+ )(fn)
34
+ torch.library.register_fake(name)(fn)
35
+ return op
36
+
37
+ return register
@@ -21,6 +21,7 @@ import operator
21
21
  from typing import Any, Callable, Optional
22
22
 
23
23
  from ai_edge_torch import fx_infra
24
+ from ai_edge_torch.odml_torch.experimental import torch_tfl
24
25
  from jax._src.lib.mlir import ir
25
26
  from jax._src.lib.mlir.dialects import func
26
27
  from jax._src.lib.mlir.dialects import hlo as stablehlo
@@ -349,7 +350,8 @@ def exported_program_to_mlir(
349
350
  exported_program: The exported program to lower.
350
351
  ir_context: The MLIR context to use. If not provided, a new context will be
351
352
  created.
352
- _pre_lower_pass: A function to run on exported program before lowering.
353
+ _pre_lower_pass: A function to run on exported program before lowering,
354
+ after all run_decompositions calls.
353
355
 
354
356
  Returns:
355
357
  The lowered MLIR module, metadata, and weight tensors bundle from exported
@@ -358,18 +360,33 @@ def exported_program_to_mlir(
358
360
  exported_program = fx_infra.safe_run_decompositions(
359
361
  exported_program,
360
362
  fx_infra.decomp.pre_lower_decomp(),
363
+ can_skip=False,
361
364
  )
362
- if _convert_i64_to_i32(exported_program):
363
- # Run decompositions for retracing and cananicalization, if modified.
364
- exported_program = fx_infra.safe_run_decompositions(exported_program, {})
365
365
 
366
- # Passes below mutate the exported program to a state not executable by torch.
367
- # Do not call run_decompositions after applying the passes.
368
- _convert_q_dq_per_channel_args_to_list(exported_program)
366
+ # Passes to run after pre_lower_decomp - requires ops to be decomposed into
367
+ # lower level ops.
368
+ _convert_i64_to_i32(exported_program)
369
369
 
370
+ # Last decomposition and canonicalization before lowering.
371
+ exported_program = fx_infra.safe_run_decompositions(
372
+ exported_program,
373
+ fx_infra.decomp.pre_lower_decomp()
374
+ | {
375
+ op: torch_tfl.decomps[op]
376
+ for op in [
377
+ torch.ops.aten.multinomial.default,
378
+ ]
379
+ },
380
+ )
381
+
382
+ # Passes below modify the exported program to a state not executable by torch.
383
+ # Do not call run_decompositions after applying the passes.
370
384
  if _pre_lower_pass:
371
385
  _pre_lower_pass(exported_program)
372
386
 
387
+ _convert_q_dq_per_channel_args_to_list(exported_program)
388
+
389
+ # Begin of lowering.
373
390
  if not ir_context:
374
391
  ir_context = export_utils.create_ir_context()
375
392
 
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import logging
15
16
  import math
16
17
  import operator
17
18
  from typing import Optional, Union
@@ -51,6 +52,81 @@ def _aten_mul_tensor(lctx, self: ir.Value, other: ir.Value):
51
52
  return stablehlo.multiply(self, other)
52
53
 
53
54
 
55
+ def _hann_window_impl(
56
+ lctx: LoweringContext,
57
+ size: int,
58
+ periodic: bool,
59
+ dtype: Optional[torch.dtype],
60
+ ) -> ir.Value:
61
+ if dtype is None:
62
+ ir_dtype = ir.F32Type.get()
63
+ else:
64
+ ir_dtype = utils.torch_dtype_to_ir_element_type(dtype)
65
+
66
+ if not isinstance(ir_dtype, ir.FloatType):
67
+ raise ValueError("hann_window only supports float dtypes.")
68
+
69
+ if size == 0:
70
+ return stablehlo.ConstantOp(
71
+ ir.RankedTensorType.get((0,), ir_dtype),
72
+ ir.DenseElementsAttr.get_empty(ir.RankedTensorType.get((0,), ir_dtype)),
73
+ ).result
74
+ if size == 1:
75
+ return utils.splat(1.0, ir_dtype, [1])
76
+
77
+ denom = size if periodic else size - 1
78
+
79
+ i64 = ir.IntegerType.get_signless(64)
80
+ iota_type = ir.RankedTensorType.get((size,), i64)
81
+ n_i64 = stablehlo.IotaOp(
82
+ iota_type, iota_dimension=ir.IntegerAttr.get(i64, 0)
83
+ ).result
84
+
85
+ n_type = ir.RankedTensorType.get((size,), ir_dtype)
86
+ n = stablehlo.convert(n_type, n_i64)
87
+
88
+ pi_val = math.pi
89
+ scale = 2.0 * pi_val / denom
90
+
91
+ scale_splat = utils.splat(scale, ir_dtype, [size])
92
+ arg_cos = stablehlo.multiply(n, scale_splat)
93
+ cos_val = stablehlo.cosine(arg_cos)
94
+
95
+ half_splat = utils.splat(0.5, ir_dtype, [size])
96
+ scaled_cos = stablehlo.multiply(half_splat, cos_val)
97
+ return stablehlo.subtract(half_splat, scaled_cos)
98
+
99
+
100
+ # hann_window(int size, *, ScalarType? dtype=None) -> Tensor
101
+ @lower(torch.ops.aten.hann_window.default)
102
+ def _aten_hann_window_default(
103
+ lctx: LoweringContext,
104
+ size: int,
105
+ **kwargs,
106
+ ) -> ir.Value:
107
+ dtype = kwargs.pop("dtype", None)
108
+ layout = kwargs.pop("layout", torch.strided)
109
+ if layout != torch.strided:
110
+ logging.warning("hann_window only supports torch.strided layout.")
111
+ return _hann_window_impl(lctx, size, True, dtype)
112
+
113
+
114
+ # hann_window.periodic(int size, bool periodic, *, ScalarType? dtype=None) ->
115
+ # Tensor
116
+ @lower(torch.ops.aten.hann_window.periodic)
117
+ def _aten_hann_window_periodic(
118
+ lctx: LoweringContext,
119
+ size: int,
120
+ periodic: bool,
121
+ **kwargs,
122
+ ) -> ir.Value:
123
+ dtype = kwargs.pop("dtype", None)
124
+ layout = kwargs.pop("layout", torch.strided)
125
+ if layout != torch.strided:
126
+ logging.warning("hann_window only supports torch.strided layout.")
127
+ return _hann_window_impl(lctx, size, periodic, dtype)
128
+
129
+
54
130
  # cat(Tensor[] tensors, int dim=0) -> Tensor
55
131
  # @lower(torch.ops.aten.cat)
56
132
  def _aten_cat(lctx, tensors: list[ir.Value], dim: int = 1):
@@ -249,6 +325,85 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
249
325
  return stablehlo.concatenate(non_empty_tensors, dim)
250
326
 
251
327
 
328
+ # Schema:
329
+ # - aten::unfold(Tensor self, int dim, int size, int step) -> Tensor
330
+ # Torch Reference:
331
+ # - https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
332
+ @lower(torch.ops.aten.unfold.default)
333
+ def _aten_unfold(lctx, x: ir.Value, dim: int, size: int, step: int):
334
+ x_shape = x.type.shape
335
+ rank = len(x_shape)
336
+ if dim < 0:
337
+ dim += rank
338
+
339
+ num_windows = (x_shape[dim] - size) // step + 1
340
+ batch_shape = list(x_shape[:dim]) + [num_windows] + list(x_shape[dim + 1 :])
341
+
342
+ # Create start_indices for gather.
343
+ # The shape of start_indices will be batch_shape + [rank].
344
+ # start_indices[b_0,...,b_{rank-1}] will be [p_0,...,p_{rank-1}] where
345
+ # p_j = b_j for j != dim and p_dim = b_dim * step.
346
+ indices_parts = []
347
+ i64 = ir.IntegerType.get_signless(64)
348
+ for i in range(rank):
349
+ bshape = [1] * rank
350
+ bshape[i] = batch_shape[i]
351
+ dim_len = batch_shape[i]
352
+
353
+ iota = stablehlo.IotaOp(
354
+ ir.RankedTensorType.get([dim_len], i64),
355
+ iota_dimension=ir.IntegerAttr.get(i64, 0),
356
+ ).result
357
+ if i == dim:
358
+ iota = stablehlo.multiply(iota, utils.splat(step, i64, [dim_len]))
359
+
360
+ iota_reshaped = stablehlo.reshape(
361
+ ir.RankedTensorType.get(bshape, i64), iota
362
+ )
363
+ indices_parts.append(
364
+ stablehlo.broadcast_in_dim(
365
+ ir.RankedTensorType.get(batch_shape, i64),
366
+ iota_reshaped,
367
+ ir.DenseI64ArrayAttr.get(list(range(rank))),
368
+ )
369
+ )
370
+
371
+ # For each dimension i, indices_parts[i] contains the i-th coordinate
372
+ # of start_indices. We unsqueeze each part to shape batch_shape + [1]
373
+ # and concatenate along the new dimension to produce start_indices of
374
+ # shape batch_shape + [rank].
375
+ unsqueezed_parts = [
376
+ stablehlo.reshape(ir.RankedTensorType.get(batch_shape + [1], i64), part)
377
+ for part in indices_parts
378
+ ]
379
+ start_indices = stablehlo.concatenate(
380
+ unsqueezed_parts, ir.IntegerAttr.get(i64, rank)
381
+ )
382
+
383
+ slice_sizes_list = [1] * rank
384
+ slice_sizes_list[dim] = size
385
+ slice_sizes = ir.DenseI64ArrayAttr.get(slice_sizes_list)
386
+
387
+ collapsed_slice_dims_list = [i for i in range(rank) if i != dim]
388
+
389
+ dnums = stablehlo.GatherDimensionNumbers.get(
390
+ offset_dims=[rank],
391
+ collapsed_slice_dims=collapsed_slice_dims_list,
392
+ operand_batching_dims=[],
393
+ start_indices_batching_dims=[],
394
+ start_index_map=list(range(rank)),
395
+ index_vector_dim=rank,
396
+ )
397
+
398
+ return stablehlo.gather(
399
+ x,
400
+ start_indices,
401
+ dnums,
402
+ slice_sizes,
403
+ indices_are_sorted=ir.BoolAttr.get(False),
404
+ )
405
+
406
+
252
407
  # Schema:
253
408
  # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
254
409
  # start=None, SymInt? end=None, SymInt step=1) -> Tensor