ai-edge-torch-nightly 0.7.0.dev20251007__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (40) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
  3. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
  4. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
  5. ai_edge_torch/generative/layers/attention.py +25 -2
  6. ai_edge_torch/generative/layers/attention_test.py +13 -1
  7. ai_edge_torch/generative/layers/attention_utils.py +62 -1
  8. ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
  9. ai_edge_torch/generative/layers/builder.py +4 -2
  10. ai_edge_torch/generative/layers/model_config.py +5 -0
  11. ai_edge_torch/generative/layers/normalization.py +8 -2
  12. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
  13. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
  14. ai_edge_torch/generative/quantize/example.py +1 -1
  15. ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
  16. ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
  17. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
  18. ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
  19. ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
  20. ai_edge_torch/generative/test/test_kv_cache.py +18 -6
  21. ai_edge_torch/generative/test/test_quantize.py +17 -26
  22. ai_edge_torch/generative/utilities/converter.py +97 -22
  23. ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
  24. ai_edge_torch/generative/utilities/loader.py +2 -1
  25. ai_edge_torch/lowertools/translate_recipe.py +8 -3
  26. ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
  27. ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
  28. ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
  29. ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
  30. ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
  31. ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
  32. ai_edge_torch/odml_torch/export.py +24 -7
  33. ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
  34. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
  35. ai_edge_torch/version.py +1 -1
  36. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
  37. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +40 -34
  38. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
  39. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
  40. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,728 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch-TFL op to MLIR lowerings."""
16
+
17
+ from typing import Sequence
18
+
19
+ from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
20
+ from ai_edge_torch.odml_torch.lowerings import context
21
+ from ai_edge_torch.odml_torch.lowerings import registry
22
+ from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
23
+ from jax._src.lib.mlir import ir
24
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
25
+ import numpy as np
26
+ import torch
27
+
28
+
29
+ lower = registry.lower
30
+ LoweringContext = context.LoweringContext
31
+
32
+
33
+ def _ir_operation(
34
+ name: str,
35
+ results: Sequence[ir.Type],
36
+ operands: Sequence[ir.Value] | None = None,
37
+ attributes: dict[str, ir.Attribute] | None = None,
38
+ ):
39
+ """Helper function to create an IR operation in StableHLO CustomCall carrier."""
40
+ if not operands:
41
+ operands = []
42
+ attributes = ir.DictAttr.get(attributes if attributes else {})
43
+ return stablehlo.custom_call(
44
+ result=results,
45
+ inputs=operands,
46
+ call_target_name=ir.StringAttr.get(name),
47
+ has_side_effect=ir.BoolAttr.get(False),
48
+ backend_config=ir.StringAttr.get(str(attributes)),
49
+ )
50
+
51
+
52
+ @lower(torch.ops.tfl.batch_matmul.default)
53
+ def _tfl_batch_matmul_lowering(
54
+ lctx: LoweringContext,
55
+ x: ir.Value,
56
+ y: ir.Value,
57
+ adj_x: bool = False,
58
+ adj_y: bool = False,
59
+ ) -> ir.Value:
60
+ return _ir_operation(
61
+ "tfl.batch_matmul",
62
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
63
+ operands=[x, y],
64
+ attributes={
65
+ "adj_x": ir.BoolAttr.get(adj_x),
66
+ "adj_y": ir.BoolAttr.get(adj_y),
67
+ "asymmetric_quantize_inputs": ir.BoolAttr.get(False),
68
+ },
69
+ )
70
+
71
+
72
+ @lower(torch.ops.tfl.add.default)
73
+ def _tfl_add_lowering(
74
+ lctx: LoweringContext,
75
+ lhs: ir.Value | int | float,
76
+ rhs: ir.Value | int | float,
77
+ fused_activation_function: str = "NONE",
78
+ ) -> ir.Value:
79
+ lhs = lowering_utils.convert_to_ir_value(lhs)
80
+ rhs = lowering_utils.convert_to_ir_value(rhs)
81
+ return _ir_operation(
82
+ "tfl.add",
83
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
84
+ operands=[lhs, rhs],
85
+ attributes={
86
+ "fused_activation_function": ir.StringAttr.get(
87
+ fused_activation_function
88
+ ),
89
+ },
90
+ )
91
+
92
+
93
+ @lower(torch.ops.tfl.sub.default)
94
+ def _tfl_sub_lowering(
95
+ lctx: LoweringContext,
96
+ lhs: ir.Value,
97
+ rhs: ir.Value | int | float,
98
+ fused_activation_function: str = "NONE",
99
+ ) -> ir.Value:
100
+ rhs = lowering_utils.convert_to_ir_value(rhs)
101
+ return _ir_operation(
102
+ "tfl.sub",
103
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
104
+ operands=[lhs, rhs],
105
+ attributes={
106
+ "fused_activation_function": ir.StringAttr.get(
107
+ fused_activation_function
108
+ ),
109
+ },
110
+ )
111
+
112
+
113
+ @lower(torch.ops.tfl.mul.default)
114
+ def _tfl_mul_lowering(
115
+ lctx: LoweringContext,
116
+ lhs: ir.Value,
117
+ rhs: ir.Value | int | float,
118
+ fused_activation_function: str = "NONE",
119
+ ) -> ir.Value:
120
+ rhs = lowering_utils.convert_to_ir_value(rhs)
121
+ return _ir_operation(
122
+ "tfl.mul",
123
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
124
+ operands=[lhs, rhs],
125
+ attributes={
126
+ "fused_activation_function": ir.StringAttr.get(
127
+ fused_activation_function
128
+ ),
129
+ },
130
+ )
131
+
132
+
133
+ @lower(torch.ops.tfl.div.default)
134
+ def _tfl_div_lowering(
135
+ lctx: LoweringContext,
136
+ lhs: ir.Value,
137
+ rhs: ir.Value | int | float,
138
+ fused_activation_function: str = "NONE",
139
+ ) -> ir.Value:
140
+ rhs = lowering_utils.convert_to_ir_value(rhs)
141
+ return _ir_operation(
142
+ "tfl.div",
143
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
144
+ operands=[lhs, rhs],
145
+ attributes={
146
+ "fused_activation_function": ir.StringAttr.get(
147
+ fused_activation_function
148
+ ),
149
+ },
150
+ )
151
+
152
+
153
+ @lower(torch.ops.tfl.pow.default)
154
+ def _tfl_pow_lowering(
155
+ lctx: LoweringContext,
156
+ lhs: ir.Value,
157
+ rhs: ir.Value | int | float,
158
+ ) -> ir.Value:
159
+ lhs = lowering_utils.convert_to_ir_value(lhs)
160
+ rhs = lowering_utils.convert_to_ir_value(rhs)
161
+ return _ir_operation(
162
+ "tfl.pow",
163
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
164
+ operands=[lhs, rhs],
165
+ )
166
+
167
+
168
+ @lower(torch.ops.tfl.logical_and.default)
169
+ def _tfl_logical_and_lowering(
170
+ lctx: LoweringContext,
171
+ lhs: ir.Value,
172
+ rhs: ir.Value,
173
+ ) -> ir.Value:
174
+ return _ir_operation(
175
+ "tfl.logical_and",
176
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
177
+ operands=[lhs, rhs],
178
+ )
179
+
180
+
181
+ @lower(torch.ops.tfl.mean.default)
182
+ def _tfl_mean_lowering(
183
+ lctx: LoweringContext,
184
+ x: ir.Value,
185
+ dims: int | ir.Value | Sequence[int | ir.Value],
186
+ keepdim: bool = False,
187
+ ) -> ir.Value:
188
+ if isinstance(dims, int) or isinstance(dims, ir.Value):
189
+ dims_ir_value = lowering_utils.convert_to_ir_value(dims)
190
+ else:
191
+ dims_ir_value = lowering_utils.convert_shape_to_ir_value(dims)
192
+ return _ir_operation(
193
+ "tfl.mean",
194
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
195
+ operands=[x, dims_ir_value],
196
+ attributes={
197
+ "keep_dims": ir.BoolAttr.get(keepdim),
198
+ },
199
+ )
200
+
201
+
202
+ @lower(torch.ops.tfl.greater.default)
203
+ def _tfl_greater_lowering(
204
+ lctx: LoweringContext,
205
+ lhs: ir.Value,
206
+ rhs: ir.Value,
207
+ ) -> ir.Value:
208
+ return _ir_operation(
209
+ "tfl.greater",
210
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
211
+ operands=[lhs, rhs],
212
+ )
213
+
214
+
215
+ @lower(torch.ops.tfl.less.default)
216
+ def _tfl_less_lowering(
217
+ lctx: LoweringContext,
218
+ lhs: ir.Value,
219
+ rhs: ir.Value,
220
+ ) -> ir.Value:
221
+ return _ir_operation(
222
+ "tfl.less",
223
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
224
+ operands=[lhs, rhs],
225
+ )
226
+
227
+
228
+ @lower(torch.ops.tfl.maximum.default)
229
+ def _tfl_maximum_lowering(
230
+ lctx: LoweringContext,
231
+ lhs: ir.Value,
232
+ rhs: ir.Value,
233
+ ) -> ir.Value:
234
+ return _ir_operation(
235
+ "tfl.maximum",
236
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
237
+ operands=[lhs, rhs],
238
+ )
239
+
240
+
241
+ @lower(torch.ops.tfl.minimum.default)
242
+ def _tfl_minimum_lowering(
243
+ lctx: LoweringContext,
244
+ lhs: ir.Value,
245
+ rhs: ir.Value,
246
+ ) -> ir.Value:
247
+ return _ir_operation(
248
+ "tfl.minimum",
249
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
250
+ operands=[lhs, rhs],
251
+ )
252
+
253
+
254
+ @lower(torch.ops.tfl.sin.default)
255
+ def _tfl_sin_lowering(
256
+ lctx: LoweringContext,
257
+ x: ir.Value,
258
+ ) -> ir.Value:
259
+ return _ir_operation(
260
+ "tfl.sin",
261
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
262
+ operands=[x],
263
+ )
264
+
265
+
266
+ @lower(torch.ops.tfl.cos.default)
267
+ def _tfl_cos_lowering(
268
+ lctx: LoweringContext,
269
+ x: ir.Value,
270
+ ) -> ir.Value:
271
+ return _ir_operation(
272
+ "tfl.cos",
273
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
274
+ operands=[x],
275
+ )
276
+
277
+
278
+ @lower(torch.ops.tfl.rsqrt.default)
279
+ def _tfl_rsqrt_lowering(
280
+ lctx: LoweringContext,
281
+ x: ir.Value,
282
+ ) -> ir.Value:
283
+ return _ir_operation(
284
+ "tfl.rsqrt",
285
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
286
+ operands=[x],
287
+ )
288
+
289
+
290
+ @lower(torch.ops.tfl.neg.default)
291
+ def _tfl_neg_lowering(
292
+ lctx: LoweringContext,
293
+ x: ir.Value,
294
+ ) -> ir.Value:
295
+ return _ir_operation(
296
+ "tfl.neg",
297
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
298
+ operands=[x],
299
+ )
300
+
301
+
302
+ @lower(torch.ops.tfl.gelu.default)
303
+ def _tfl_gelu_lowering(
304
+ lctx: LoweringContext,
305
+ x: ir.Value,
306
+ approximate: bool = False,
307
+ ) -> ir.Value:
308
+ return _ir_operation(
309
+ "tfl.gelu",
310
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
311
+ operands=[x],
312
+ attributes={
313
+ "approximate": ir.BoolAttr.get(approximate),
314
+ },
315
+ )
316
+
317
+
318
+ @lower(torch.ops.tfl.transpose.default)
319
+ def _tfl_transpose_lowering(
320
+ lctx: LoweringContext,
321
+ x: ir.Value,
322
+ perm: Sequence[int],
323
+ ) -> ir.Value:
324
+ constant_perm = lowering_utils.numpy_array_constant(
325
+ np.array(perm, dtype=np.int32)
326
+ )
327
+ return _ir_operation(
328
+ "tfl.transpose",
329
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
330
+ operands=[x, constant_perm],
331
+ )
332
+
333
+
334
+ @lower(torch.ops.tfl.concatenation.default)
335
+ def _tfl_concatenation_lowering(
336
+ lctx: LoweringContext,
337
+ tensors: Sequence[ir.Value],
338
+ axis: int,
339
+ fused_activation_function: str = "NONE",
340
+ ) -> ir.Value:
341
+ return _ir_operation(
342
+ "tfl.concatenation",
343
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
344
+ operands=tensors,
345
+ attributes={
346
+ "axis": ir.IntegerAttr.get(ir.IntegerType.get_signless(32), axis),
347
+ "fused_activation_function": ir.StringAttr.get(
348
+ fused_activation_function
349
+ ),
350
+ },
351
+ )
352
+
353
+
354
+ @lower(torch.ops.tfl.fill.default)
355
+ def _tfl_fill_lowering(
356
+ lctx: LoweringContext,
357
+ dims: Sequence[int | ir.Value],
358
+ fill_value: int | float | ir.Value,
359
+ ) -> ir.Value:
360
+ dims_ir_value = lowering_utils.convert_shape_to_ir_value(dims)
361
+ fill_value_ir_value = lowering_utils.convert_to_ir_value(fill_value)
362
+
363
+ # Ensure fill_value_ir_value is a scalar (0-D tensor) for TFLite Fill op.
364
+ # The TFLite Fill kernel expects the value to be a 0-D tensor.
365
+ if isinstance(fill_value_ir_value.type, ir.RankedTensorType):
366
+ tensor_type = fill_value_ir_value.type
367
+ # If it's a 1-D tensor with a single element, reshape to 0-D.
368
+ if list(tensor_type.shape) == [1]:
369
+ scalar_type = ir.RankedTensorType.get([], tensor_type.element_type)
370
+ fill_value_ir_value = stablehlo.reshape(scalar_type, fill_value_ir_value)
371
+
372
+ # Determine the target element type from the node's output definition.
373
+ result_types = lowering_utils.node_meta_to_ir_types(lctx.node)
374
+ if not result_types or not isinstance(result_types[0], ir.RankedTensorType):
375
+ raise ValueError(
376
+ "tfl.fill: Unable to determine result tensor type or result is not a"
377
+ " ranked tensor."
378
+ )
379
+ target_element_type = result_types[0].element_type
380
+
381
+ # Ensure fill_value_ir_value is a RankedTensorType to access its properties.
382
+ if not isinstance(fill_value_ir_value.type, ir.RankedTensorType):
383
+ raise TypeError(
384
+ "tfl.fill: fill_value_ir_value expected to be RankedTensorType, got"
385
+ f" {fill_value_ir_value.type}"
386
+ )
387
+
388
+ current_fill_tensor_type = fill_value_ir_value.type
389
+ current_element_type = current_fill_tensor_type.element_type
390
+
391
+ # If the element type of the (scalar) fill_value doesn't match the target
392
+ # output element type, cast fill_value_ir_value to the target_element_type
393
+ # while maintaining its current shape (which should be scalar).
394
+ if current_element_type != target_element_type:
395
+ cast_to_type = ir.RankedTensorType.get(
396
+ current_fill_tensor_type.shape, target_element_type
397
+ )
398
+ fill_value_ir_value = stablehlo.convert(cast_to_type, fill_value_ir_value)
399
+
400
+ return _ir_operation(
401
+ "tfl.fill",
402
+ results=result_types,
403
+ operands=[dims_ir_value, fill_value_ir_value],
404
+ )
405
+
406
+
407
+ @lower(torch.ops.tfl.reshape.default)
408
+ def _tfl_reshape_lowering(
409
+ lctx: LoweringContext,
410
+ x: ir.Value,
411
+ shape: Sequence[int | ir.Value],
412
+ ) -> ir.Value:
413
+ return _ir_operation(
414
+ "tfl.reshape",
415
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
416
+ operands=[x, lowering_utils.convert_shape_to_ir_value(shape)],
417
+ )
418
+
419
+
420
+ @lower(torch.ops.tfl.range.default)
421
+ def _tfl_range_lowering(
422
+ lctx: LoweringContext,
423
+ start: int | float | ir.Value,
424
+ limit: int | float | ir.Value,
425
+ delta: int | float | ir.Value = 1,
426
+ ) -> ir.Value:
427
+ tensor_meta = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
428
+ output_torch_dtype = tensor_meta.dtype
429
+
430
+ original_mlir_output_types = lowering_utils.node_meta_to_ir_types(lctx.node)
431
+ if not original_mlir_output_types or not isinstance(
432
+ original_mlir_output_types[0], ir.RankedTensorType
433
+ ):
434
+ raise ValueError(
435
+ "tfl.range output type is not a RankedTensorType as expected."
436
+ )
437
+
438
+ original_mlir_output_type = original_mlir_output_types[0]
439
+ original_output_shape = original_mlir_output_type.shape
440
+ original_output_element_type = original_mlir_output_type.element_type
441
+ tflite_op_internal_element_type = (
442
+ lowering_utils.torch_dtype_to_ir_element_type(output_torch_dtype)
443
+ )
444
+
445
+ # All operands and the output of tfl.range must have the same element type.
446
+ # We cast all operands to the expected output element type.
447
+ operands = []
448
+ for operand in [start, limit, delta]:
449
+ if not isinstance(operand, ir.Value):
450
+ # Convert python scalars to ir.Value.
451
+ numpy_scalar_0d = (
452
+ torch.tensor(operand, dtype=output_torch_dtype).detach().numpy()
453
+ )
454
+ operand = lowering_utils.numpy_array_constant(numpy_scalar_0d)
455
+
456
+ # `operand` is now an ir.Value.
457
+ # Cast its element type to the output element type if they don't match.
458
+ operand_type = operand.type
459
+ if not isinstance(operand_type, ir.RankedTensorType):
460
+ raise TypeError(
461
+ "tfl.range operand expected to be RankedTensorType, got"
462
+ f" {operand_type}"
463
+ )
464
+
465
+ if operand_type.element_type != tflite_op_internal_element_type:
466
+ cast_to_type = ir.RankedTensorType.get(
467
+ operand_type.shape, tflite_op_internal_element_type
468
+ )
469
+ operand = stablehlo.convert(cast_to_type, operand)
470
+ operands.append(operand)
471
+
472
+ # Define the result type that the tfl.range *kernel* (the custom op) will
473
+ # produce.
474
+ tfl_op_kernel_output_type = ir.RankedTensorType.get(
475
+ original_output_shape, tflite_op_internal_element_type
476
+ )
477
+
478
+ tfl_range_op_val = _ir_operation(
479
+ "tfl.range",
480
+ results=[tfl_op_kernel_output_type],
481
+ operands=operands,
482
+ )
483
+
484
+ # The _tfl_range_lowering function must return a value of the
485
+ # original_mlir_output_type.
486
+ # If the tfl.range op's internal element type is different from the
487
+ # original_output_element_type, we need to convert.
488
+ if tflite_op_internal_element_type != original_output_element_type:
489
+ # Convert the tfl.range output to the original expected type.
490
+ final_output_val = stablehlo.convert(
491
+ original_mlir_output_type, tfl_range_op_val
492
+ )
493
+ else:
494
+ final_output_val = tfl_range_op_val
495
+
496
+ return final_output_val
497
+
498
+
499
+ @lower(torch.ops.tfl.split_v.default)
500
+ def _tfl_split_v_lowering(
501
+ lctx: LoweringContext,
502
+ x: ir.Value,
503
+ size_splits: Sequence[int | ir.Value],
504
+ dim: int | ir.Value,
505
+ ) -> ir.Value:
506
+ size_splits_ir_value = lowering_utils.convert_shape_to_ir_value(size_splits)
507
+ dim_ir_value = lowering_utils.numpy_array_constant(
508
+ np.array(dim, dtype=np.int32)
509
+ )
510
+ return _ir_operation(
511
+ "tfl.split_v",
512
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
513
+ operands=[x, size_splits_ir_value, dim_ir_value],
514
+ attributes={
515
+ "num_splits": ir.IntegerAttr.get(
516
+ ir.IntegerType.get_signless(32), len(size_splits)
517
+ ),
518
+ },
519
+ )
520
+
521
+
522
+ @lower(torch.ops.tfl.slice.default)
523
+ def _tfl_slice_lowering(
524
+ lctx: LoweringContext,
525
+ x: ir.Value,
526
+ begin: Sequence[int | ir.Value],
527
+ size: Sequence[int | ir.Value],
528
+ ) -> ir.Value:
529
+ begin_ir_value = lowering_utils.convert_shape_to_ir_value(begin)
530
+ size_ir_value = lowering_utils.convert_shape_to_ir_value(size)
531
+ return _ir_operation(
532
+ "tfl.slice",
533
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
534
+ operands=[x, begin_ir_value, size_ir_value],
535
+ )
536
+
537
+
538
+ @lower(torch.ops.tfl.expand_dims.default)
539
+ def _tfl_expand_dims_lowering(
540
+ lctx: LoweringContext,
541
+ x: ir.Value,
542
+ dim: int | ir.Value,
543
+ ) -> ir.Value:
544
+ dim_ir_value = lowering_utils.numpy_array_constant(
545
+ np.array(dim, dtype=np.int32)
546
+ )
547
+ return _ir_operation(
548
+ "tfl.expand_dims",
549
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
550
+ operands=[x, dim_ir_value],
551
+ )
552
+
553
+
554
+ @lower(torch.ops.tfl.broadcast_to.default)
555
+ def _tfl_broadcast_to_lowering(
556
+ lctx: LoweringContext,
557
+ x: ir.Value,
558
+ shape: Sequence[int | ir.Value],
559
+ ) -> ir.Value:
560
+ return _ir_operation(
561
+ "tfl.broadcast_to",
562
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
563
+ operands=[x, lowering_utils.convert_shape_to_ir_value(shape)],
564
+ )
565
+
566
+
567
+ @lower(torch.ops.tfl.squeeze.default)
568
+ def _tfl_squeeze_lowering(
569
+ lctx: LoweringContext,
570
+ x: ir.Value,
571
+ squeeze_dims: Sequence[int | ir.Value],
572
+ ) -> ir.Value:
573
+ return _ir_operation(
574
+ "tfl.squeeze",
575
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
576
+ operands=[x],
577
+ attributes={
578
+ "squeeze_dims": ir.ArrayAttr.get([
579
+ ir.IntegerAttr.get(ir.IntegerType.get_signless(64), int(d))
580
+ for d in squeeze_dims
581
+ ]),
582
+ },
583
+ )
584
+
585
+
586
+ @lower(torch.ops.tfl.strided_slice.default)
587
+ def _tfl_strided_slice_lowering(
588
+ lctx: LoweringContext,
589
+ x: ir.Value,
590
+ begin: Sequence[int | ir.Value],
591
+ end: Sequence[int | ir.Value],
592
+ strides: Sequence[int | ir.Value],
593
+ begin_mask: int = 0,
594
+ end_mask: int = 0,
595
+ ellipsis_mask: int = 0,
596
+ new_axis_mask: int = 0,
597
+ shrink_axis_mask: int = 0,
598
+ offset: bool = False,
599
+ ) -> ir.Value:
600
+ begin_ir_value = lowering_utils.convert_shape_to_ir_value(begin)
601
+ end_ir_value = lowering_utils.convert_shape_to_ir_value(end)
602
+ strides_ir_value = lowering_utils.convert_shape_to_ir_value(strides)
603
+ return _ir_operation(
604
+ "tfl.strided_slice",
605
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
606
+ operands=[x, begin_ir_value, end_ir_value, strides_ir_value],
607
+ attributes={
608
+ "begin_mask": ir.IntegerAttr.get(
609
+ ir.IntegerType.get_signless(32), begin_mask
610
+ ),
611
+ "end_mask": ir.IntegerAttr.get(
612
+ ir.IntegerType.get_signless(32), end_mask
613
+ ),
614
+ "ellipsis_mask": ir.IntegerAttr.get(
615
+ ir.IntegerType.get_signless(32), ellipsis_mask
616
+ ),
617
+ "new_axis_mask": ir.IntegerAttr.get(
618
+ ir.IntegerType.get_signless(32), new_axis_mask
619
+ ),
620
+ "shrink_axis_mask": ir.IntegerAttr.get(
621
+ ir.IntegerType.get_signless(32), shrink_axis_mask
622
+ ),
623
+ "offset": ir.BoolAttr.get(offset),
624
+ },
625
+ )
626
+
627
+
628
+ @lower(torch.ops.tfl.select_v2.default)
629
+ def _tfl_select_v2_lowering(
630
+ lctx: LoweringContext,
631
+ condition: ir.Value,
632
+ x: ir.Value,
633
+ y: ir.Value,
634
+ ) -> ir.Value:
635
+ return _ir_operation(
636
+ "tfl.select_v2",
637
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
638
+ operands=[condition, x, y],
639
+ )
640
+
641
+
642
+ @lower(torch.ops.tfl.embedding_lookup.default)
643
+ def _tfl_embedding_lookup_lowering(
644
+ lctx: LoweringContext,
645
+ indices: ir.Value,
646
+ weight: ir.Value,
647
+ ) -> ir.Value:
648
+ return _ir_operation(
649
+ "tfl.embedding_lookup",
650
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
651
+ operands=[indices, weight],
652
+ )
653
+
654
+
655
+ @lower(torch.ops.tfl.gather.default)
656
+ def _tfl_gather_lowering(
657
+ lctx: LoweringContext,
658
+ x: ir.Value,
659
+ indices: ir.Value,
660
+ axis: int,
661
+ batch_dims: int = 0,
662
+ ) -> ir.Value:
663
+ return _ir_operation(
664
+ "tfl.gather",
665
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
666
+ operands=[x, indices],
667
+ attributes={
668
+ "axis": ir.IntegerAttr.get(ir.IntegerType.get_signless(32), axis),
669
+ "batch_dims": ir.IntegerAttr.get(
670
+ ir.IntegerType.get_signless(32), batch_dims
671
+ ),
672
+ },
673
+ )
674
+
675
+
676
+ @lower(torch.ops.tfl.softmax.default)
677
+ def _tfl_softmax_lowering(
678
+ lctx: LoweringContext,
679
+ x: ir.Value,
680
+ beta: float = 1.0,
681
+ ) -> ir.Value:
682
+ return _ir_operation(
683
+ "tfl.softmax",
684
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
685
+ operands=[x],
686
+ attributes={
687
+ "beta": ir.FloatAttr.get(ir.F32Type.get(), beta),
688
+ },
689
+ )
690
+
691
+
692
+ @lower(torch.ops.tfl.topk_v2.default)
693
+ def _tfl_topk_v2_lowering(
694
+ lctx: LoweringContext,
695
+ x: ir.Value,
696
+ k: int,
697
+ ) -> tuple[ir.Value, ir.Value]:
698
+ return _ir_operation(
699
+ "tfl.topk_v2",
700
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
701
+ operands=[
702
+ x,
703
+ lowering_utils.numpy_array_constant(np.array(k, dtype=np.int32)),
704
+ ],
705
+ attributes={},
706
+ )
707
+
708
+
709
+ @lower(torch.ops.tfl.multinomial.default)
710
+ def _tfl_multinomial_lowering(
711
+ lctx: LoweringContext,
712
+ logits: ir.Value,
713
+ num_samples: int,
714
+ replacement: bool = False,
715
+ ) -> ir.Value:
716
+ if replacement:
717
+ raise ValueError("tfl.multinomial does not support with_replacement=True.")
718
+ return _ir_operation(
719
+ "tfl.multinomial",
720
+ results=lowering_utils.node_meta_to_ir_types(lctx.node),
721
+ operands=[
722
+ logits,
723
+ lowering_utils.numpy_array_constant(
724
+ np.array(num_samples, dtype=np.int32)
725
+ ),
726
+ ],
727
+ attributes={},
728
+ )