ai-edge-torch-nightly 0.7.0.dev20250929__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -27
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +1 -30
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +1 -30
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +5 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
- ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
- ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
- ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
- ai_edge_torch/generative/test/test_kv_cache.py +18 -6
- ai_edge_torch/generative/test/test_quantize.py +17 -26
- ai_edge_torch/generative/utilities/converter.py +183 -28
- ai_edge_torch/generative/utilities/export_config.py +2 -0
- ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/lowertools/translate_recipe.py +8 -3
- ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
- ai_edge_torch/odml_torch/export.py +24 -7
- ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +57 -51
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,728 @@
|
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Torch-TFL op to MLIR lowerings."""
|
|
16
|
+
|
|
17
|
+
from typing import Sequence
|
|
18
|
+
|
|
19
|
+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
|
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
|
21
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
|
22
|
+
from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
|
|
23
|
+
from jax._src.lib.mlir import ir
|
|
24
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
lower = registry.lower
|
|
30
|
+
LoweringContext = context.LoweringContext
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _ir_operation(
|
|
34
|
+
name: str,
|
|
35
|
+
results: Sequence[ir.Type],
|
|
36
|
+
operands: Sequence[ir.Value] | None = None,
|
|
37
|
+
attributes: dict[str, ir.Attribute] | None = None,
|
|
38
|
+
):
|
|
39
|
+
"""Helper function to create an IR operation in StableHLO CustomCall carrier."""
|
|
40
|
+
if not operands:
|
|
41
|
+
operands = []
|
|
42
|
+
attributes = ir.DictAttr.get(attributes if attributes else {})
|
|
43
|
+
return stablehlo.custom_call(
|
|
44
|
+
result=results,
|
|
45
|
+
inputs=operands,
|
|
46
|
+
call_target_name=ir.StringAttr.get(name),
|
|
47
|
+
has_side_effect=ir.BoolAttr.get(False),
|
|
48
|
+
backend_config=ir.StringAttr.get(str(attributes)),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@lower(torch.ops.tfl.batch_matmul.default)
|
|
53
|
+
def _tfl_batch_matmul_lowering(
|
|
54
|
+
lctx: LoweringContext,
|
|
55
|
+
x: ir.Value,
|
|
56
|
+
y: ir.Value,
|
|
57
|
+
adj_x: bool = False,
|
|
58
|
+
adj_y: bool = False,
|
|
59
|
+
) -> ir.Value:
|
|
60
|
+
return _ir_operation(
|
|
61
|
+
"tfl.batch_matmul",
|
|
62
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
63
|
+
operands=[x, y],
|
|
64
|
+
attributes={
|
|
65
|
+
"adj_x": ir.BoolAttr.get(adj_x),
|
|
66
|
+
"adj_y": ir.BoolAttr.get(adj_y),
|
|
67
|
+
"asymmetric_quantize_inputs": ir.BoolAttr.get(False),
|
|
68
|
+
},
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@lower(torch.ops.tfl.add.default)
|
|
73
|
+
def _tfl_add_lowering(
|
|
74
|
+
lctx: LoweringContext,
|
|
75
|
+
lhs: ir.Value | int | float,
|
|
76
|
+
rhs: ir.Value | int | float,
|
|
77
|
+
fused_activation_function: str = "NONE",
|
|
78
|
+
) -> ir.Value:
|
|
79
|
+
lhs = lowering_utils.convert_to_ir_value(lhs)
|
|
80
|
+
rhs = lowering_utils.convert_to_ir_value(rhs)
|
|
81
|
+
return _ir_operation(
|
|
82
|
+
"tfl.add",
|
|
83
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
84
|
+
operands=[lhs, rhs],
|
|
85
|
+
attributes={
|
|
86
|
+
"fused_activation_function": ir.StringAttr.get(
|
|
87
|
+
fused_activation_function
|
|
88
|
+
),
|
|
89
|
+
},
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@lower(torch.ops.tfl.sub.default)
|
|
94
|
+
def _tfl_sub_lowering(
|
|
95
|
+
lctx: LoweringContext,
|
|
96
|
+
lhs: ir.Value,
|
|
97
|
+
rhs: ir.Value | int | float,
|
|
98
|
+
fused_activation_function: str = "NONE",
|
|
99
|
+
) -> ir.Value:
|
|
100
|
+
rhs = lowering_utils.convert_to_ir_value(rhs)
|
|
101
|
+
return _ir_operation(
|
|
102
|
+
"tfl.sub",
|
|
103
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
104
|
+
operands=[lhs, rhs],
|
|
105
|
+
attributes={
|
|
106
|
+
"fused_activation_function": ir.StringAttr.get(
|
|
107
|
+
fused_activation_function
|
|
108
|
+
),
|
|
109
|
+
},
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@lower(torch.ops.tfl.mul.default)
|
|
114
|
+
def _tfl_mul_lowering(
|
|
115
|
+
lctx: LoweringContext,
|
|
116
|
+
lhs: ir.Value,
|
|
117
|
+
rhs: ir.Value | int | float,
|
|
118
|
+
fused_activation_function: str = "NONE",
|
|
119
|
+
) -> ir.Value:
|
|
120
|
+
rhs = lowering_utils.convert_to_ir_value(rhs)
|
|
121
|
+
return _ir_operation(
|
|
122
|
+
"tfl.mul",
|
|
123
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
124
|
+
operands=[lhs, rhs],
|
|
125
|
+
attributes={
|
|
126
|
+
"fused_activation_function": ir.StringAttr.get(
|
|
127
|
+
fused_activation_function
|
|
128
|
+
),
|
|
129
|
+
},
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@lower(torch.ops.tfl.div.default)
|
|
134
|
+
def _tfl_div_lowering(
|
|
135
|
+
lctx: LoweringContext,
|
|
136
|
+
lhs: ir.Value,
|
|
137
|
+
rhs: ir.Value | int | float,
|
|
138
|
+
fused_activation_function: str = "NONE",
|
|
139
|
+
) -> ir.Value:
|
|
140
|
+
rhs = lowering_utils.convert_to_ir_value(rhs)
|
|
141
|
+
return _ir_operation(
|
|
142
|
+
"tfl.div",
|
|
143
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
144
|
+
operands=[lhs, rhs],
|
|
145
|
+
attributes={
|
|
146
|
+
"fused_activation_function": ir.StringAttr.get(
|
|
147
|
+
fused_activation_function
|
|
148
|
+
),
|
|
149
|
+
},
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@lower(torch.ops.tfl.pow.default)
|
|
154
|
+
def _tfl_pow_lowering(
|
|
155
|
+
lctx: LoweringContext,
|
|
156
|
+
lhs: ir.Value,
|
|
157
|
+
rhs: ir.Value | int | float,
|
|
158
|
+
) -> ir.Value:
|
|
159
|
+
lhs = lowering_utils.convert_to_ir_value(lhs)
|
|
160
|
+
rhs = lowering_utils.convert_to_ir_value(rhs)
|
|
161
|
+
return _ir_operation(
|
|
162
|
+
"tfl.pow",
|
|
163
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
164
|
+
operands=[lhs, rhs],
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@lower(torch.ops.tfl.logical_and.default)
|
|
169
|
+
def _tfl_logical_and_lowering(
|
|
170
|
+
lctx: LoweringContext,
|
|
171
|
+
lhs: ir.Value,
|
|
172
|
+
rhs: ir.Value,
|
|
173
|
+
) -> ir.Value:
|
|
174
|
+
return _ir_operation(
|
|
175
|
+
"tfl.logical_and",
|
|
176
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
177
|
+
operands=[lhs, rhs],
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@lower(torch.ops.tfl.mean.default)
|
|
182
|
+
def _tfl_mean_lowering(
|
|
183
|
+
lctx: LoweringContext,
|
|
184
|
+
x: ir.Value,
|
|
185
|
+
dims: int | ir.Value | Sequence[int | ir.Value],
|
|
186
|
+
keepdim: bool = False,
|
|
187
|
+
) -> ir.Value:
|
|
188
|
+
if isinstance(dims, int) or isinstance(dims, ir.Value):
|
|
189
|
+
dims_ir_value = lowering_utils.convert_to_ir_value(dims)
|
|
190
|
+
else:
|
|
191
|
+
dims_ir_value = lowering_utils.convert_shape_to_ir_value(dims)
|
|
192
|
+
return _ir_operation(
|
|
193
|
+
"tfl.mean",
|
|
194
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
195
|
+
operands=[x, dims_ir_value],
|
|
196
|
+
attributes={
|
|
197
|
+
"keep_dims": ir.BoolAttr.get(keepdim),
|
|
198
|
+
},
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@lower(torch.ops.tfl.greater.default)
|
|
203
|
+
def _tfl_greater_lowering(
|
|
204
|
+
lctx: LoweringContext,
|
|
205
|
+
lhs: ir.Value,
|
|
206
|
+
rhs: ir.Value,
|
|
207
|
+
) -> ir.Value:
|
|
208
|
+
return _ir_operation(
|
|
209
|
+
"tfl.greater",
|
|
210
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
211
|
+
operands=[lhs, rhs],
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@lower(torch.ops.tfl.less.default)
|
|
216
|
+
def _tfl_less_lowering(
|
|
217
|
+
lctx: LoweringContext,
|
|
218
|
+
lhs: ir.Value,
|
|
219
|
+
rhs: ir.Value,
|
|
220
|
+
) -> ir.Value:
|
|
221
|
+
return _ir_operation(
|
|
222
|
+
"tfl.less",
|
|
223
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
224
|
+
operands=[lhs, rhs],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@lower(torch.ops.tfl.maximum.default)
|
|
229
|
+
def _tfl_maximum_lowering(
|
|
230
|
+
lctx: LoweringContext,
|
|
231
|
+
lhs: ir.Value,
|
|
232
|
+
rhs: ir.Value,
|
|
233
|
+
) -> ir.Value:
|
|
234
|
+
return _ir_operation(
|
|
235
|
+
"tfl.maximum",
|
|
236
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
237
|
+
operands=[lhs, rhs],
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@lower(torch.ops.tfl.minimum.default)
|
|
242
|
+
def _tfl_minimum_lowering(
|
|
243
|
+
lctx: LoweringContext,
|
|
244
|
+
lhs: ir.Value,
|
|
245
|
+
rhs: ir.Value,
|
|
246
|
+
) -> ir.Value:
|
|
247
|
+
return _ir_operation(
|
|
248
|
+
"tfl.minimum",
|
|
249
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
250
|
+
operands=[lhs, rhs],
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@lower(torch.ops.tfl.sin.default)
|
|
255
|
+
def _tfl_sin_lowering(
|
|
256
|
+
lctx: LoweringContext,
|
|
257
|
+
x: ir.Value,
|
|
258
|
+
) -> ir.Value:
|
|
259
|
+
return _ir_operation(
|
|
260
|
+
"tfl.sin",
|
|
261
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
262
|
+
operands=[x],
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@lower(torch.ops.tfl.cos.default)
|
|
267
|
+
def _tfl_cos_lowering(
|
|
268
|
+
lctx: LoweringContext,
|
|
269
|
+
x: ir.Value,
|
|
270
|
+
) -> ir.Value:
|
|
271
|
+
return _ir_operation(
|
|
272
|
+
"tfl.cos",
|
|
273
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
274
|
+
operands=[x],
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@lower(torch.ops.tfl.rsqrt.default)
|
|
279
|
+
def _tfl_rsqrt_lowering(
|
|
280
|
+
lctx: LoweringContext,
|
|
281
|
+
x: ir.Value,
|
|
282
|
+
) -> ir.Value:
|
|
283
|
+
return _ir_operation(
|
|
284
|
+
"tfl.rsqrt",
|
|
285
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
286
|
+
operands=[x],
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@lower(torch.ops.tfl.neg.default)
|
|
291
|
+
def _tfl_neg_lowering(
|
|
292
|
+
lctx: LoweringContext,
|
|
293
|
+
x: ir.Value,
|
|
294
|
+
) -> ir.Value:
|
|
295
|
+
return _ir_operation(
|
|
296
|
+
"tfl.neg",
|
|
297
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
298
|
+
operands=[x],
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@lower(torch.ops.tfl.gelu.default)
|
|
303
|
+
def _tfl_gelu_lowering(
|
|
304
|
+
lctx: LoweringContext,
|
|
305
|
+
x: ir.Value,
|
|
306
|
+
approximate: bool = False,
|
|
307
|
+
) -> ir.Value:
|
|
308
|
+
return _ir_operation(
|
|
309
|
+
"tfl.gelu",
|
|
310
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
311
|
+
operands=[x],
|
|
312
|
+
attributes={
|
|
313
|
+
"approximate": ir.BoolAttr.get(approximate),
|
|
314
|
+
},
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@lower(torch.ops.tfl.transpose.default)
|
|
319
|
+
def _tfl_transpose_lowering(
|
|
320
|
+
lctx: LoweringContext,
|
|
321
|
+
x: ir.Value,
|
|
322
|
+
perm: Sequence[int],
|
|
323
|
+
) -> ir.Value:
|
|
324
|
+
constant_perm = lowering_utils.numpy_array_constant(
|
|
325
|
+
np.array(perm, dtype=np.int32)
|
|
326
|
+
)
|
|
327
|
+
return _ir_operation(
|
|
328
|
+
"tfl.transpose",
|
|
329
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
330
|
+
operands=[x, constant_perm],
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@lower(torch.ops.tfl.concatenation.default)
|
|
335
|
+
def _tfl_concatenation_lowering(
|
|
336
|
+
lctx: LoweringContext,
|
|
337
|
+
tensors: Sequence[ir.Value],
|
|
338
|
+
axis: int,
|
|
339
|
+
fused_activation_function: str = "NONE",
|
|
340
|
+
) -> ir.Value:
|
|
341
|
+
return _ir_operation(
|
|
342
|
+
"tfl.concatenation",
|
|
343
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
344
|
+
operands=tensors,
|
|
345
|
+
attributes={
|
|
346
|
+
"axis": ir.IntegerAttr.get(ir.IntegerType.get_signless(32), axis),
|
|
347
|
+
"fused_activation_function": ir.StringAttr.get(
|
|
348
|
+
fused_activation_function
|
|
349
|
+
),
|
|
350
|
+
},
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
@lower(torch.ops.tfl.fill.default)
|
|
355
|
+
def _tfl_fill_lowering(
|
|
356
|
+
lctx: LoweringContext,
|
|
357
|
+
dims: Sequence[int | ir.Value],
|
|
358
|
+
fill_value: int | float | ir.Value,
|
|
359
|
+
) -> ir.Value:
|
|
360
|
+
dims_ir_value = lowering_utils.convert_shape_to_ir_value(dims)
|
|
361
|
+
fill_value_ir_value = lowering_utils.convert_to_ir_value(fill_value)
|
|
362
|
+
|
|
363
|
+
# Ensure fill_value_ir_value is a scalar (0-D tensor) for TFLite Fill op.
|
|
364
|
+
# The TFLite Fill kernel expects the value to be a 0-D tensor.
|
|
365
|
+
if isinstance(fill_value_ir_value.type, ir.RankedTensorType):
|
|
366
|
+
tensor_type = fill_value_ir_value.type
|
|
367
|
+
# If it's a 1-D tensor with a single element, reshape to 0-D.
|
|
368
|
+
if list(tensor_type.shape) == [1]:
|
|
369
|
+
scalar_type = ir.RankedTensorType.get([], tensor_type.element_type)
|
|
370
|
+
fill_value_ir_value = stablehlo.reshape(scalar_type, fill_value_ir_value)
|
|
371
|
+
|
|
372
|
+
# Determine the target element type from the node's output definition.
|
|
373
|
+
result_types = lowering_utils.node_meta_to_ir_types(lctx.node)
|
|
374
|
+
if not result_types or not isinstance(result_types[0], ir.RankedTensorType):
|
|
375
|
+
raise ValueError(
|
|
376
|
+
"tfl.fill: Unable to determine result tensor type or result is not a"
|
|
377
|
+
" ranked tensor."
|
|
378
|
+
)
|
|
379
|
+
target_element_type = result_types[0].element_type
|
|
380
|
+
|
|
381
|
+
# Ensure fill_value_ir_value is a RankedTensorType to access its properties.
|
|
382
|
+
if not isinstance(fill_value_ir_value.type, ir.RankedTensorType):
|
|
383
|
+
raise TypeError(
|
|
384
|
+
"tfl.fill: fill_value_ir_value expected to be RankedTensorType, got"
|
|
385
|
+
f" {fill_value_ir_value.type}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
current_fill_tensor_type = fill_value_ir_value.type
|
|
389
|
+
current_element_type = current_fill_tensor_type.element_type
|
|
390
|
+
|
|
391
|
+
# If the element type of the (scalar) fill_value doesn't match the target
|
|
392
|
+
# output element type, cast fill_value_ir_value to the target_element_type
|
|
393
|
+
# while maintaining its current shape (which should be scalar).
|
|
394
|
+
if current_element_type != target_element_type:
|
|
395
|
+
cast_to_type = ir.RankedTensorType.get(
|
|
396
|
+
current_fill_tensor_type.shape, target_element_type
|
|
397
|
+
)
|
|
398
|
+
fill_value_ir_value = stablehlo.convert(cast_to_type, fill_value_ir_value)
|
|
399
|
+
|
|
400
|
+
return _ir_operation(
|
|
401
|
+
"tfl.fill",
|
|
402
|
+
results=result_types,
|
|
403
|
+
operands=[dims_ir_value, fill_value_ir_value],
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@lower(torch.ops.tfl.reshape.default)
|
|
408
|
+
def _tfl_reshape_lowering(
|
|
409
|
+
lctx: LoweringContext,
|
|
410
|
+
x: ir.Value,
|
|
411
|
+
shape: Sequence[int | ir.Value],
|
|
412
|
+
) -> ir.Value:
|
|
413
|
+
return _ir_operation(
|
|
414
|
+
"tfl.reshape",
|
|
415
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
416
|
+
operands=[x, lowering_utils.convert_shape_to_ir_value(shape)],
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@lower(torch.ops.tfl.range.default)
|
|
421
|
+
def _tfl_range_lowering(
|
|
422
|
+
lctx: LoweringContext,
|
|
423
|
+
start: int | float | ir.Value,
|
|
424
|
+
limit: int | float | ir.Value,
|
|
425
|
+
delta: int | float | ir.Value = 1,
|
|
426
|
+
) -> ir.Value:
|
|
427
|
+
tensor_meta = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
|
|
428
|
+
output_torch_dtype = tensor_meta.dtype
|
|
429
|
+
|
|
430
|
+
original_mlir_output_types = lowering_utils.node_meta_to_ir_types(lctx.node)
|
|
431
|
+
if not original_mlir_output_types or not isinstance(
|
|
432
|
+
original_mlir_output_types[0], ir.RankedTensorType
|
|
433
|
+
):
|
|
434
|
+
raise ValueError(
|
|
435
|
+
"tfl.range output type is not a RankedTensorType as expected."
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
original_mlir_output_type = original_mlir_output_types[0]
|
|
439
|
+
original_output_shape = original_mlir_output_type.shape
|
|
440
|
+
original_output_element_type = original_mlir_output_type.element_type
|
|
441
|
+
tflite_op_internal_element_type = (
|
|
442
|
+
lowering_utils.torch_dtype_to_ir_element_type(output_torch_dtype)
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# All operands and the output of tfl.range must have the same element type.
|
|
446
|
+
# We cast all operands to the expected output element type.
|
|
447
|
+
operands = []
|
|
448
|
+
for operand in [start, limit, delta]:
|
|
449
|
+
if not isinstance(operand, ir.Value):
|
|
450
|
+
# Convert python scalars to ir.Value.
|
|
451
|
+
numpy_scalar_0d = (
|
|
452
|
+
torch.tensor(operand, dtype=output_torch_dtype).detach().numpy()
|
|
453
|
+
)
|
|
454
|
+
operand = lowering_utils.numpy_array_constant(numpy_scalar_0d)
|
|
455
|
+
|
|
456
|
+
# `operand` is now an ir.Value.
|
|
457
|
+
# Cast its element type to the output element type if they don't match.
|
|
458
|
+
operand_type = operand.type
|
|
459
|
+
if not isinstance(operand_type, ir.RankedTensorType):
|
|
460
|
+
raise TypeError(
|
|
461
|
+
"tfl.range operand expected to be RankedTensorType, got"
|
|
462
|
+
f" {operand_type}"
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if operand_type.element_type != tflite_op_internal_element_type:
|
|
466
|
+
cast_to_type = ir.RankedTensorType.get(
|
|
467
|
+
operand_type.shape, tflite_op_internal_element_type
|
|
468
|
+
)
|
|
469
|
+
operand = stablehlo.convert(cast_to_type, operand)
|
|
470
|
+
operands.append(operand)
|
|
471
|
+
|
|
472
|
+
# Define the result type that the tfl.range *kernel* (the custom op) will
|
|
473
|
+
# produce.
|
|
474
|
+
tfl_op_kernel_output_type = ir.RankedTensorType.get(
|
|
475
|
+
original_output_shape, tflite_op_internal_element_type
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
tfl_range_op_val = _ir_operation(
|
|
479
|
+
"tfl.range",
|
|
480
|
+
results=[tfl_op_kernel_output_type],
|
|
481
|
+
operands=operands,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# The _tfl_range_lowering function must return a value of the
|
|
485
|
+
# original_mlir_output_type.
|
|
486
|
+
# If the tfl.range op's internal element type is different from the
|
|
487
|
+
# original_output_element_type, we need to convert.
|
|
488
|
+
if tflite_op_internal_element_type != original_output_element_type:
|
|
489
|
+
# Convert the tfl.range output to the original expected type.
|
|
490
|
+
final_output_val = stablehlo.convert(
|
|
491
|
+
original_mlir_output_type, tfl_range_op_val
|
|
492
|
+
)
|
|
493
|
+
else:
|
|
494
|
+
final_output_val = tfl_range_op_val
|
|
495
|
+
|
|
496
|
+
return final_output_val
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@lower(torch.ops.tfl.split_v.default)
|
|
500
|
+
def _tfl_split_v_lowering(
|
|
501
|
+
lctx: LoweringContext,
|
|
502
|
+
x: ir.Value,
|
|
503
|
+
size_splits: Sequence[int | ir.Value],
|
|
504
|
+
dim: int | ir.Value,
|
|
505
|
+
) -> ir.Value:
|
|
506
|
+
size_splits_ir_value = lowering_utils.convert_shape_to_ir_value(size_splits)
|
|
507
|
+
dim_ir_value = lowering_utils.numpy_array_constant(
|
|
508
|
+
np.array(dim, dtype=np.int32)
|
|
509
|
+
)
|
|
510
|
+
return _ir_operation(
|
|
511
|
+
"tfl.split_v",
|
|
512
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
513
|
+
operands=[x, size_splits_ir_value, dim_ir_value],
|
|
514
|
+
attributes={
|
|
515
|
+
"num_splits": ir.IntegerAttr.get(
|
|
516
|
+
ir.IntegerType.get_signless(32), len(size_splits)
|
|
517
|
+
),
|
|
518
|
+
},
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
@lower(torch.ops.tfl.slice.default)
|
|
523
|
+
def _tfl_slice_lowering(
|
|
524
|
+
lctx: LoweringContext,
|
|
525
|
+
x: ir.Value,
|
|
526
|
+
begin: Sequence[int | ir.Value],
|
|
527
|
+
size: Sequence[int | ir.Value],
|
|
528
|
+
) -> ir.Value:
|
|
529
|
+
begin_ir_value = lowering_utils.convert_shape_to_ir_value(begin)
|
|
530
|
+
size_ir_value = lowering_utils.convert_shape_to_ir_value(size)
|
|
531
|
+
return _ir_operation(
|
|
532
|
+
"tfl.slice",
|
|
533
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
534
|
+
operands=[x, begin_ir_value, size_ir_value],
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
@lower(torch.ops.tfl.expand_dims.default)
|
|
539
|
+
def _tfl_expand_dims_lowering(
|
|
540
|
+
lctx: LoweringContext,
|
|
541
|
+
x: ir.Value,
|
|
542
|
+
dim: int | ir.Value,
|
|
543
|
+
) -> ir.Value:
|
|
544
|
+
dim_ir_value = lowering_utils.numpy_array_constant(
|
|
545
|
+
np.array(dim, dtype=np.int32)
|
|
546
|
+
)
|
|
547
|
+
return _ir_operation(
|
|
548
|
+
"tfl.expand_dims",
|
|
549
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
550
|
+
operands=[x, dim_ir_value],
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@lower(torch.ops.tfl.broadcast_to.default)
|
|
555
|
+
def _tfl_broadcast_to_lowering(
|
|
556
|
+
lctx: LoweringContext,
|
|
557
|
+
x: ir.Value,
|
|
558
|
+
shape: Sequence[int | ir.Value],
|
|
559
|
+
) -> ir.Value:
|
|
560
|
+
return _ir_operation(
|
|
561
|
+
"tfl.broadcast_to",
|
|
562
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
563
|
+
operands=[x, lowering_utils.convert_shape_to_ir_value(shape)],
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
@lower(torch.ops.tfl.squeeze.default)
|
|
568
|
+
def _tfl_squeeze_lowering(
|
|
569
|
+
lctx: LoweringContext,
|
|
570
|
+
x: ir.Value,
|
|
571
|
+
squeeze_dims: Sequence[int | ir.Value],
|
|
572
|
+
) -> ir.Value:
|
|
573
|
+
return _ir_operation(
|
|
574
|
+
"tfl.squeeze",
|
|
575
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
576
|
+
operands=[x],
|
|
577
|
+
attributes={
|
|
578
|
+
"squeeze_dims": ir.ArrayAttr.get([
|
|
579
|
+
ir.IntegerAttr.get(ir.IntegerType.get_signless(64), int(d))
|
|
580
|
+
for d in squeeze_dims
|
|
581
|
+
]),
|
|
582
|
+
},
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
@lower(torch.ops.tfl.strided_slice.default)
|
|
587
|
+
def _tfl_strided_slice_lowering(
|
|
588
|
+
lctx: LoweringContext,
|
|
589
|
+
x: ir.Value,
|
|
590
|
+
begin: Sequence[int | ir.Value],
|
|
591
|
+
end: Sequence[int | ir.Value],
|
|
592
|
+
strides: Sequence[int | ir.Value],
|
|
593
|
+
begin_mask: int = 0,
|
|
594
|
+
end_mask: int = 0,
|
|
595
|
+
ellipsis_mask: int = 0,
|
|
596
|
+
new_axis_mask: int = 0,
|
|
597
|
+
shrink_axis_mask: int = 0,
|
|
598
|
+
offset: bool = False,
|
|
599
|
+
) -> ir.Value:
|
|
600
|
+
begin_ir_value = lowering_utils.convert_shape_to_ir_value(begin)
|
|
601
|
+
end_ir_value = lowering_utils.convert_shape_to_ir_value(end)
|
|
602
|
+
strides_ir_value = lowering_utils.convert_shape_to_ir_value(strides)
|
|
603
|
+
return _ir_operation(
|
|
604
|
+
"tfl.strided_slice",
|
|
605
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
606
|
+
operands=[x, begin_ir_value, end_ir_value, strides_ir_value],
|
|
607
|
+
attributes={
|
|
608
|
+
"begin_mask": ir.IntegerAttr.get(
|
|
609
|
+
ir.IntegerType.get_signless(32), begin_mask
|
|
610
|
+
),
|
|
611
|
+
"end_mask": ir.IntegerAttr.get(
|
|
612
|
+
ir.IntegerType.get_signless(32), end_mask
|
|
613
|
+
),
|
|
614
|
+
"ellipsis_mask": ir.IntegerAttr.get(
|
|
615
|
+
ir.IntegerType.get_signless(32), ellipsis_mask
|
|
616
|
+
),
|
|
617
|
+
"new_axis_mask": ir.IntegerAttr.get(
|
|
618
|
+
ir.IntegerType.get_signless(32), new_axis_mask
|
|
619
|
+
),
|
|
620
|
+
"shrink_axis_mask": ir.IntegerAttr.get(
|
|
621
|
+
ir.IntegerType.get_signless(32), shrink_axis_mask
|
|
622
|
+
),
|
|
623
|
+
"offset": ir.BoolAttr.get(offset),
|
|
624
|
+
},
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
@lower(torch.ops.tfl.select_v2.default)
|
|
629
|
+
def _tfl_select_v2_lowering(
|
|
630
|
+
lctx: LoweringContext,
|
|
631
|
+
condition: ir.Value,
|
|
632
|
+
x: ir.Value,
|
|
633
|
+
y: ir.Value,
|
|
634
|
+
) -> ir.Value:
|
|
635
|
+
return _ir_operation(
|
|
636
|
+
"tfl.select_v2",
|
|
637
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
638
|
+
operands=[condition, x, y],
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
@lower(torch.ops.tfl.embedding_lookup.default)
|
|
643
|
+
def _tfl_embedding_lookup_lowering(
|
|
644
|
+
lctx: LoweringContext,
|
|
645
|
+
indices: ir.Value,
|
|
646
|
+
weight: ir.Value,
|
|
647
|
+
) -> ir.Value:
|
|
648
|
+
return _ir_operation(
|
|
649
|
+
"tfl.embedding_lookup",
|
|
650
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
651
|
+
operands=[indices, weight],
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
@lower(torch.ops.tfl.gather.default)
|
|
656
|
+
def _tfl_gather_lowering(
|
|
657
|
+
lctx: LoweringContext,
|
|
658
|
+
x: ir.Value,
|
|
659
|
+
indices: ir.Value,
|
|
660
|
+
axis: int,
|
|
661
|
+
batch_dims: int = 0,
|
|
662
|
+
) -> ir.Value:
|
|
663
|
+
return _ir_operation(
|
|
664
|
+
"tfl.gather",
|
|
665
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
666
|
+
operands=[x, indices],
|
|
667
|
+
attributes={
|
|
668
|
+
"axis": ir.IntegerAttr.get(ir.IntegerType.get_signless(32), axis),
|
|
669
|
+
"batch_dims": ir.IntegerAttr.get(
|
|
670
|
+
ir.IntegerType.get_signless(32), batch_dims
|
|
671
|
+
),
|
|
672
|
+
},
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
@lower(torch.ops.tfl.softmax.default)
|
|
677
|
+
def _tfl_softmax_lowering(
|
|
678
|
+
lctx: LoweringContext,
|
|
679
|
+
x: ir.Value,
|
|
680
|
+
beta: float = 1.0,
|
|
681
|
+
) -> ir.Value:
|
|
682
|
+
return _ir_operation(
|
|
683
|
+
"tfl.softmax",
|
|
684
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
685
|
+
operands=[x],
|
|
686
|
+
attributes={
|
|
687
|
+
"beta": ir.FloatAttr.get(ir.F32Type.get(), beta),
|
|
688
|
+
},
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
@lower(torch.ops.tfl.topk_v2.default)
|
|
693
|
+
def _tfl_topk_v2_lowering(
|
|
694
|
+
lctx: LoweringContext,
|
|
695
|
+
x: ir.Value,
|
|
696
|
+
k: int,
|
|
697
|
+
) -> tuple[ir.Value, ir.Value]:
|
|
698
|
+
return _ir_operation(
|
|
699
|
+
"tfl.topk_v2",
|
|
700
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
701
|
+
operands=[
|
|
702
|
+
x,
|
|
703
|
+
lowering_utils.numpy_array_constant(np.array(k, dtype=np.int32)),
|
|
704
|
+
],
|
|
705
|
+
attributes={},
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
@lower(torch.ops.tfl.multinomial.default)
|
|
710
|
+
def _tfl_multinomial_lowering(
|
|
711
|
+
lctx: LoweringContext,
|
|
712
|
+
logits: ir.Value,
|
|
713
|
+
num_samples: int,
|
|
714
|
+
replacement: bool = False,
|
|
715
|
+
) -> ir.Value:
|
|
716
|
+
if replacement:
|
|
717
|
+
raise ValueError("tfl.multinomial does not support with_replacement=True.")
|
|
718
|
+
return _ir_operation(
|
|
719
|
+
"tfl.multinomial",
|
|
720
|
+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
|
|
721
|
+
operands=[
|
|
722
|
+
logits,
|
|
723
|
+
lowering_utils.numpy_array_constant(
|
|
724
|
+
np.array(num_samples, dtype=np.int32)
|
|
725
|
+
),
|
|
726
|
+
],
|
|
727
|
+
attributes={},
|
|
728
|
+
)
|