emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +50 -23
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +30 -387
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +36 -18
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +1 -1
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +1 -1
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +1 -1
- emx_onnx_cgen/lowering/gather_nd.py +1 -1
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +1 -1
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +1 -1
- emx_onnx_cgen/lowering/identity.py +1 -1
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +1 -1
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +1 -1
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +1 -1
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +1 -1
- emx_onnx_cgen/lowering/one_hot.py +1 -1
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +1 -1
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +1 -1
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +1 -1
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +1 -1
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +1 -1
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +1 -1
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +25 -7
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +1 -1
- emx_onnx_cgen/lowering/unsqueeze.py +1 -1
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/runtime/evaluator.py +325 -1
- emx_onnx_cgen/verification.py +9 -39
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
- emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +11 -0
- shared/ulp.py +17 -0
- emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from enum import Enum
|
|
5
4
|
import itertools
|
|
6
5
|
import math
|
|
7
6
|
from math import prod
|
|
@@ -20,6 +19,85 @@ from ..ops import (
|
|
|
20
19
|
binary_op_symbol,
|
|
21
20
|
unary_op_symbol,
|
|
22
21
|
)
|
|
22
|
+
from ..ir.op_base import (
|
|
23
|
+
BroadcastingOpBase,
|
|
24
|
+
ConvLikeOpBase,
|
|
25
|
+
ElementwiseOpBase,
|
|
26
|
+
GemmLikeOpBase,
|
|
27
|
+
MatMulLikeOpBase,
|
|
28
|
+
ReduceOpBase,
|
|
29
|
+
RenderableOpBase,
|
|
30
|
+
OpBase,
|
|
31
|
+
EmitContext,
|
|
32
|
+
)
|
|
33
|
+
from ..ir.op_context import OpContext
|
|
34
|
+
from ..ir.ops import (
|
|
35
|
+
AdagradOp,
|
|
36
|
+
ArgReduceOp,
|
|
37
|
+
AttentionOp,
|
|
38
|
+
AveragePoolOp,
|
|
39
|
+
BatchNormOp,
|
|
40
|
+
BinaryOp,
|
|
41
|
+
CastOp,
|
|
42
|
+
ClipOp,
|
|
43
|
+
ConcatOp,
|
|
44
|
+
ConstantOfShapeOp,
|
|
45
|
+
ConvOp,
|
|
46
|
+
ConvTransposeOp,
|
|
47
|
+
CumSumOp,
|
|
48
|
+
DepthToSpaceOp,
|
|
49
|
+
EinsumKind,
|
|
50
|
+
EinsumOp,
|
|
51
|
+
ExpandOp,
|
|
52
|
+
EyeLikeOp,
|
|
53
|
+
GatherElementsOp,
|
|
54
|
+
GatherNDOp,
|
|
55
|
+
GatherOp,
|
|
56
|
+
GemmOp,
|
|
57
|
+
GridSampleOp,
|
|
58
|
+
GroupNormalizationOp,
|
|
59
|
+
HardmaxOp,
|
|
60
|
+
IdentityOp,
|
|
61
|
+
InstanceNormalizationOp,
|
|
62
|
+
LayerNormalizationOp,
|
|
63
|
+
LogSoftmaxOp,
|
|
64
|
+
LpNormalizationOp,
|
|
65
|
+
LpPoolOp,
|
|
66
|
+
LrnOp,
|
|
67
|
+
LstmOp,
|
|
68
|
+
MatMulOp,
|
|
69
|
+
MaxPoolOp,
|
|
70
|
+
MeanVarianceNormalizationOp,
|
|
71
|
+
MultiInputBinaryOp,
|
|
72
|
+
NegativeLogLikelihoodLossOp,
|
|
73
|
+
NonMaxSuppressionOp,
|
|
74
|
+
NonZeroOp,
|
|
75
|
+
OneHotOp,
|
|
76
|
+
PadOp,
|
|
77
|
+
QuantizeLinearOp,
|
|
78
|
+
QLinearMatMulOp,
|
|
79
|
+
RangeOp,
|
|
80
|
+
ReduceOp,
|
|
81
|
+
ReshapeOp,
|
|
82
|
+
ResizeOp,
|
|
83
|
+
RMSNormalizationOp,
|
|
84
|
+
RotaryEmbeddingOp,
|
|
85
|
+
ScatterNDOp,
|
|
86
|
+
ShapeOp,
|
|
87
|
+
SizeOp,
|
|
88
|
+
SliceOp,
|
|
89
|
+
SoftmaxCrossEntropyLossOp,
|
|
90
|
+
SoftmaxOp,
|
|
91
|
+
SpaceToDepthOp,
|
|
92
|
+
SplitOp,
|
|
93
|
+
TensorScatterOp,
|
|
94
|
+
TileOp,
|
|
95
|
+
TopKOp,
|
|
96
|
+
TransposeOp,
|
|
97
|
+
TriluOp,
|
|
98
|
+
UnaryOp,
|
|
99
|
+
WhereOp,
|
|
100
|
+
)
|
|
23
101
|
from shared.scalar_functions import (
|
|
24
102
|
ScalarFunction,
|
|
25
103
|
ScalarFunctionKey,
|
|
@@ -150,44 +228,6 @@ _C_KEYWORDS = {
|
|
|
150
228
|
"while",
|
|
151
229
|
}
|
|
152
230
|
|
|
153
|
-
@dataclass(frozen=True)
|
|
154
|
-
class BinaryOp:
|
|
155
|
-
input0: str
|
|
156
|
-
input1: str
|
|
157
|
-
output: str
|
|
158
|
-
function: ScalarFunction
|
|
159
|
-
operator_kind: OperatorKind
|
|
160
|
-
input0_shape: tuple[int, ...]
|
|
161
|
-
input1_shape: tuple[int, ...]
|
|
162
|
-
shape: tuple[int, ...]
|
|
163
|
-
dtype: ScalarType
|
|
164
|
-
input_dtype: ScalarType
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
@dataclass(frozen=True)
|
|
168
|
-
class MultiInputBinaryOp:
|
|
169
|
-
inputs: tuple[str, ...]
|
|
170
|
-
output: str
|
|
171
|
-
function: ScalarFunction
|
|
172
|
-
operator_kind: OperatorKind
|
|
173
|
-
shape: tuple[int, ...]
|
|
174
|
-
dtype: ScalarType
|
|
175
|
-
input_dtype: ScalarType
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
@dataclass(frozen=True)
|
|
179
|
-
class WhereOp:
|
|
180
|
-
condition: str
|
|
181
|
-
input_x: str
|
|
182
|
-
input_y: str
|
|
183
|
-
output: str
|
|
184
|
-
condition_shape: tuple[int, ...]
|
|
185
|
-
x_shape: tuple[int, ...]
|
|
186
|
-
y_shape: tuple[int, ...]
|
|
187
|
-
output_shape: tuple[int, ...]
|
|
188
|
-
dtype: ScalarType
|
|
189
|
-
|
|
190
|
-
|
|
191
231
|
@dataclass(frozen=True)
|
|
192
232
|
class NodeInfo:
|
|
193
233
|
op_type: str
|
|
@@ -197,905 +237,6 @@ class NodeInfo:
|
|
|
197
237
|
attrs: dict[str, object]
|
|
198
238
|
|
|
199
239
|
|
|
200
|
-
@dataclass(frozen=True)
|
|
201
|
-
class UnaryOp:
|
|
202
|
-
input0: str
|
|
203
|
-
output: str
|
|
204
|
-
function: ScalarFunction
|
|
205
|
-
shape: tuple[int, ...]
|
|
206
|
-
dtype: ScalarType
|
|
207
|
-
input_dtype: ScalarType
|
|
208
|
-
params: tuple[float, ...] = ()
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
@dataclass(frozen=True)
|
|
212
|
-
class ClipOp:
|
|
213
|
-
input0: str
|
|
214
|
-
input_min: str | None
|
|
215
|
-
input_max: str | None
|
|
216
|
-
output: str
|
|
217
|
-
input_shape: tuple[int, ...]
|
|
218
|
-
min_shape: tuple[int, ...] | None
|
|
219
|
-
max_shape: tuple[int, ...] | None
|
|
220
|
-
output_shape: tuple[int, ...]
|
|
221
|
-
dtype: ScalarType
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
@dataclass(frozen=True)
|
|
227
|
-
class CastOp:
|
|
228
|
-
input0: str
|
|
229
|
-
output: str
|
|
230
|
-
shape: tuple[int, ...]
|
|
231
|
-
input_dtype: ScalarType
|
|
232
|
-
dtype: ScalarType
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
@dataclass(frozen=True)
|
|
236
|
-
class MatMulOp:
|
|
237
|
-
input0: str
|
|
238
|
-
input1: str
|
|
239
|
-
output: str
|
|
240
|
-
input0_shape: tuple[int, ...]
|
|
241
|
-
input1_shape: tuple[int, ...]
|
|
242
|
-
output_shape: tuple[int, ...]
|
|
243
|
-
batch_shape: tuple[int, ...]
|
|
244
|
-
input0_batch_shape: tuple[int, ...]
|
|
245
|
-
input1_batch_shape: tuple[int, ...]
|
|
246
|
-
m: int
|
|
247
|
-
n: int
|
|
248
|
-
k: int
|
|
249
|
-
left_vector: bool
|
|
250
|
-
right_vector: bool
|
|
251
|
-
dtype: ScalarType
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
class EinsumKind(str, Enum):
|
|
255
|
-
REDUCE_ALL = "reduce_all"
|
|
256
|
-
SUM_J = "sum_j"
|
|
257
|
-
TRANSPOSE = "transpose"
|
|
258
|
-
DOT = "dot"
|
|
259
|
-
BATCH_MATMUL = "batch_matmul"
|
|
260
|
-
BATCH_DIAGONAL = "batch_diagonal"
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
@dataclass(frozen=True)
|
|
264
|
-
class EinsumOp:
|
|
265
|
-
inputs: tuple[str, ...]
|
|
266
|
-
output: str
|
|
267
|
-
kind: EinsumKind
|
|
268
|
-
input_shapes: tuple[tuple[int, ...], ...]
|
|
269
|
-
output_shape: tuple[int, ...]
|
|
270
|
-
dtype: ScalarType
|
|
271
|
-
input_dtype: ScalarType
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
@dataclass(frozen=True)
|
|
275
|
-
class GemmOp:
|
|
276
|
-
input_a: str
|
|
277
|
-
input_b: str
|
|
278
|
-
input_c: str | None
|
|
279
|
-
output: str
|
|
280
|
-
m: int
|
|
281
|
-
n: int
|
|
282
|
-
k: int
|
|
283
|
-
trans_a: bool
|
|
284
|
-
trans_b: bool
|
|
285
|
-
alpha: float | int
|
|
286
|
-
beta: float | int
|
|
287
|
-
c_shape: tuple[int, ...] | None
|
|
288
|
-
dtype: ScalarType
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
@dataclass(frozen=True)
|
|
292
|
-
class AttentionOp:
|
|
293
|
-
input_q: str
|
|
294
|
-
input_k: str
|
|
295
|
-
input_v: str
|
|
296
|
-
input_attn_mask: str | None
|
|
297
|
-
input_past_key: str | None
|
|
298
|
-
input_past_value: str | None
|
|
299
|
-
input_nonpad_kv_seqlen: str | None
|
|
300
|
-
output: str
|
|
301
|
-
output_present_key: str | None
|
|
302
|
-
output_present_value: str | None
|
|
303
|
-
output_qk_matmul: str | None
|
|
304
|
-
batch: int
|
|
305
|
-
q_heads: int
|
|
306
|
-
kv_heads: int
|
|
307
|
-
q_seq: int
|
|
308
|
-
kv_seq: int
|
|
309
|
-
total_seq: int
|
|
310
|
-
past_seq: int
|
|
311
|
-
qk_head_size: int
|
|
312
|
-
v_head_size: int
|
|
313
|
-
q_hidden_size: int | None
|
|
314
|
-
k_hidden_size: int | None
|
|
315
|
-
v_hidden_size: int | None
|
|
316
|
-
scale: float
|
|
317
|
-
is_causal: bool
|
|
318
|
-
softcap: float
|
|
319
|
-
qk_matmul_output_mode: int
|
|
320
|
-
q_rank: int
|
|
321
|
-
k_rank: int
|
|
322
|
-
v_rank: int
|
|
323
|
-
output_rank: int
|
|
324
|
-
mask_shape: tuple[int, ...] | None
|
|
325
|
-
mask_is_bool: bool
|
|
326
|
-
mask_rank: int | None
|
|
327
|
-
mask_broadcast_batch: bool
|
|
328
|
-
mask_broadcast_heads: bool
|
|
329
|
-
mask_broadcast_q_seq: bool
|
|
330
|
-
mask_q_seq: int | None
|
|
331
|
-
mask_kv_seq: int | None
|
|
332
|
-
head_group_size: int
|
|
333
|
-
dtype: ScalarType
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
@dataclass(frozen=True)
|
|
337
|
-
class ConvOp:
|
|
338
|
-
input0: str
|
|
339
|
-
weights: str
|
|
340
|
-
bias: str | None
|
|
341
|
-
output: str
|
|
342
|
-
batch: int
|
|
343
|
-
in_channels: int
|
|
344
|
-
out_channels: int
|
|
345
|
-
spatial_rank: int
|
|
346
|
-
in_spatial: tuple[int, ...]
|
|
347
|
-
out_spatial: tuple[int, ...]
|
|
348
|
-
kernel_shape: tuple[int, ...]
|
|
349
|
-
strides: tuple[int, ...]
|
|
350
|
-
pads: tuple[int, ...]
|
|
351
|
-
dilations: tuple[int, ...]
|
|
352
|
-
group: int
|
|
353
|
-
dtype: ScalarType
|
|
354
|
-
|
|
355
|
-
@property
|
|
356
|
-
def out_h(self) -> int:
|
|
357
|
-
if self.spatial_rank < 1:
|
|
358
|
-
raise ValueError("Conv output height is undefined for spatial_rank < 1")
|
|
359
|
-
return self.out_spatial[0]
|
|
360
|
-
|
|
361
|
-
@property
|
|
362
|
-
def out_w(self) -> int:
|
|
363
|
-
if self.spatial_rank < 2:
|
|
364
|
-
raise ValueError("Conv output width is undefined for spatial_rank < 2")
|
|
365
|
-
return self.out_spatial[1]
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
@dataclass(frozen=True)
|
|
369
|
-
class ConvTransposeOp:
|
|
370
|
-
input0: str
|
|
371
|
-
weights: str
|
|
372
|
-
bias: str | None
|
|
373
|
-
output: str
|
|
374
|
-
batch: int
|
|
375
|
-
in_channels: int
|
|
376
|
-
out_channels: int
|
|
377
|
-
spatial_rank: int
|
|
378
|
-
in_spatial: tuple[int, ...]
|
|
379
|
-
out_spatial: tuple[int, ...]
|
|
380
|
-
kernel_shape: tuple[int, ...]
|
|
381
|
-
strides: tuple[int, ...]
|
|
382
|
-
pads: tuple[int, ...]
|
|
383
|
-
dilations: tuple[int, ...]
|
|
384
|
-
output_padding: tuple[int, ...]
|
|
385
|
-
group: int
|
|
386
|
-
dtype: ScalarType
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
@dataclass(frozen=True)
|
|
390
|
-
class AveragePoolOp:
|
|
391
|
-
input0: str
|
|
392
|
-
output: str
|
|
393
|
-
batch: int
|
|
394
|
-
channels: int
|
|
395
|
-
in_h: int
|
|
396
|
-
in_w: int
|
|
397
|
-
out_h: int
|
|
398
|
-
out_w: int
|
|
399
|
-
kernel_h: int
|
|
400
|
-
kernel_w: int
|
|
401
|
-
stride_h: int
|
|
402
|
-
stride_w: int
|
|
403
|
-
pad_top: int
|
|
404
|
-
pad_left: int
|
|
405
|
-
pad_bottom: int
|
|
406
|
-
pad_right: int
|
|
407
|
-
count_include_pad: bool
|
|
408
|
-
dtype: ScalarType
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
@dataclass(frozen=True)
|
|
412
|
-
class LpPoolOp:
|
|
413
|
-
input0: str
|
|
414
|
-
output: str
|
|
415
|
-
batch: int
|
|
416
|
-
channels: int
|
|
417
|
-
in_h: int
|
|
418
|
-
in_w: int
|
|
419
|
-
out_h: int
|
|
420
|
-
out_w: int
|
|
421
|
-
kernel_h: int
|
|
422
|
-
kernel_w: int
|
|
423
|
-
stride_h: int
|
|
424
|
-
stride_w: int
|
|
425
|
-
pad_top: int
|
|
426
|
-
pad_left: int
|
|
427
|
-
pad_bottom: int
|
|
428
|
-
pad_right: int
|
|
429
|
-
p: int
|
|
430
|
-
dtype: ScalarType
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
@dataclass(frozen=True)
|
|
434
|
-
class QuantizeLinearOp:
|
|
435
|
-
input0: str
|
|
436
|
-
scale: str
|
|
437
|
-
zero_point: str | None
|
|
438
|
-
output: str
|
|
439
|
-
input_shape: tuple[int, ...]
|
|
440
|
-
axis: int | None
|
|
441
|
-
dtype: ScalarType
|
|
442
|
-
input_dtype: ScalarType
|
|
443
|
-
scale_dtype: ScalarType
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
@dataclass(frozen=True)
|
|
447
|
-
class SoftmaxOp:
|
|
448
|
-
input0: str
|
|
449
|
-
output: str
|
|
450
|
-
outer: int
|
|
451
|
-
axis_size: int
|
|
452
|
-
inner: int
|
|
453
|
-
axis: int
|
|
454
|
-
shape: tuple[int, ...]
|
|
455
|
-
dtype: ScalarType
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
@dataclass(frozen=True)
|
|
459
|
-
class LogSoftmaxOp:
|
|
460
|
-
input0: str
|
|
461
|
-
output: str
|
|
462
|
-
outer: int
|
|
463
|
-
axis_size: int
|
|
464
|
-
inner: int
|
|
465
|
-
axis: int
|
|
466
|
-
shape: tuple[int, ...]
|
|
467
|
-
dtype: ScalarType
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
@dataclass(frozen=True)
|
|
471
|
-
class HardmaxOp:
|
|
472
|
-
input0: str
|
|
473
|
-
output: str
|
|
474
|
-
outer: int
|
|
475
|
-
axis_size: int
|
|
476
|
-
inner: int
|
|
477
|
-
axis: int
|
|
478
|
-
shape: tuple[int, ...]
|
|
479
|
-
dtype: ScalarType
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
@dataclass(frozen=True)
|
|
483
|
-
class NegativeLogLikelihoodLossOp:
|
|
484
|
-
input0: str
|
|
485
|
-
target: str
|
|
486
|
-
weight: str | None
|
|
487
|
-
output: str
|
|
488
|
-
input_shape: tuple[int, ...]
|
|
489
|
-
target_shape: tuple[int, ...]
|
|
490
|
-
output_shape: tuple[int, ...]
|
|
491
|
-
n: int
|
|
492
|
-
c: int
|
|
493
|
-
d: int
|
|
494
|
-
reduction: str
|
|
495
|
-
ignore_index: int
|
|
496
|
-
input_dtype: ScalarType
|
|
497
|
-
weight_dtype: ScalarType | None
|
|
498
|
-
weight_shape: tuple[int, ...] | None
|
|
499
|
-
dtype: ScalarType
|
|
500
|
-
target_dtype: ScalarType
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
@dataclass(frozen=True)
|
|
504
|
-
class SoftmaxCrossEntropyLossOp:
|
|
505
|
-
input0: str
|
|
506
|
-
target: str
|
|
507
|
-
weight: str | None
|
|
508
|
-
output: str
|
|
509
|
-
log_prob: str | None
|
|
510
|
-
input_shape: tuple[int, ...]
|
|
511
|
-
target_shape: tuple[int, ...]
|
|
512
|
-
output_shape: tuple[int, ...]
|
|
513
|
-
log_prob_shape: tuple[int, ...] | None
|
|
514
|
-
n: int
|
|
515
|
-
c: int
|
|
516
|
-
d: int
|
|
517
|
-
reduction: str
|
|
518
|
-
ignore_index: int | None
|
|
519
|
-
input_dtype: ScalarType
|
|
520
|
-
weight_dtype: ScalarType | None
|
|
521
|
-
weight_shape: tuple[int, ...] | None
|
|
522
|
-
dtype: ScalarType
|
|
523
|
-
target_dtype: ScalarType
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
@dataclass(frozen=True)
|
|
527
|
-
class BatchNormOp:
|
|
528
|
-
input0: str
|
|
529
|
-
scale: str
|
|
530
|
-
bias: str
|
|
531
|
-
mean: str
|
|
532
|
-
variance: str
|
|
533
|
-
output: str
|
|
534
|
-
shape: tuple[int, ...]
|
|
535
|
-
channels: int
|
|
536
|
-
epsilon: float
|
|
537
|
-
dtype: ScalarType
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
@dataclass(frozen=True)
|
|
541
|
-
class LpNormalizationOp:
|
|
542
|
-
input0: str
|
|
543
|
-
output: str
|
|
544
|
-
shape: tuple[int, ...]
|
|
545
|
-
axis: int
|
|
546
|
-
p: int
|
|
547
|
-
outer: int
|
|
548
|
-
axis_size: int
|
|
549
|
-
inner: int
|
|
550
|
-
dtype: ScalarType
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
@dataclass(frozen=True)
|
|
554
|
-
class InstanceNormalizationOp:
|
|
555
|
-
input0: str
|
|
556
|
-
scale: str
|
|
557
|
-
bias: str
|
|
558
|
-
output: str
|
|
559
|
-
shape: tuple[int, ...]
|
|
560
|
-
channels: int
|
|
561
|
-
spatial_size: int
|
|
562
|
-
epsilon: float
|
|
563
|
-
dtype: ScalarType
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
@dataclass(frozen=True)
|
|
567
|
-
class GroupNormalizationOp:
|
|
568
|
-
input0: str
|
|
569
|
-
scale: str
|
|
570
|
-
bias: str
|
|
571
|
-
output: str
|
|
572
|
-
shape: tuple[int, ...]
|
|
573
|
-
channels: int
|
|
574
|
-
num_groups: int
|
|
575
|
-
group_size: int
|
|
576
|
-
spatial_size: int
|
|
577
|
-
epsilon: float
|
|
578
|
-
dtype: ScalarType
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
@dataclass(frozen=True)
|
|
582
|
-
class LayerNormalizationOp:
|
|
583
|
-
input0: str
|
|
584
|
-
scale: str
|
|
585
|
-
bias: str | None
|
|
586
|
-
output: str
|
|
587
|
-
mean_output: str | None
|
|
588
|
-
invstd_output: str | None
|
|
589
|
-
shape: tuple[int, ...]
|
|
590
|
-
normalized_shape: tuple[int, ...]
|
|
591
|
-
scale_shape: tuple[int, ...]
|
|
592
|
-
bias_shape: tuple[int, ...] | None
|
|
593
|
-
outer: int
|
|
594
|
-
inner: int
|
|
595
|
-
axis: int
|
|
596
|
-
epsilon: float
|
|
597
|
-
dtype: ScalarType
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
@dataclass(frozen=True)
|
|
601
|
-
class MeanVarianceNormalizationOp:
|
|
602
|
-
input0: str
|
|
603
|
-
output: str
|
|
604
|
-
shape: tuple[int, ...]
|
|
605
|
-
axes: tuple[int, ...]
|
|
606
|
-
non_axes: tuple[int, ...]
|
|
607
|
-
reduce_count: int
|
|
608
|
-
epsilon: float
|
|
609
|
-
dtype: ScalarType
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
@dataclass(frozen=True)
|
|
613
|
-
class RMSNormalizationOp:
|
|
614
|
-
input0: str
|
|
615
|
-
scale: str
|
|
616
|
-
output: str
|
|
617
|
-
shape: tuple[int, ...]
|
|
618
|
-
normalized_shape: tuple[int, ...]
|
|
619
|
-
scale_shape: tuple[int, ...]
|
|
620
|
-
outer: int
|
|
621
|
-
inner: int
|
|
622
|
-
axis: int
|
|
623
|
-
epsilon: float
|
|
624
|
-
dtype: ScalarType
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
@dataclass(frozen=True)
|
|
628
|
-
class LrnOp:
|
|
629
|
-
input0: str
|
|
630
|
-
output: str
|
|
631
|
-
shape: tuple[int, ...]
|
|
632
|
-
channels: int
|
|
633
|
-
size: int
|
|
634
|
-
half: int
|
|
635
|
-
alpha: float
|
|
636
|
-
beta: float
|
|
637
|
-
bias: float
|
|
638
|
-
dtype: ScalarType
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
@dataclass(frozen=True)
|
|
642
|
-
class LstmOp:
|
|
643
|
-
input_x: str
|
|
644
|
-
input_w: str
|
|
645
|
-
input_r: str
|
|
646
|
-
input_b: str | None
|
|
647
|
-
input_sequence_lens: str | None
|
|
648
|
-
input_initial_h: str | None
|
|
649
|
-
input_initial_c: str | None
|
|
650
|
-
input_p: str | None
|
|
651
|
-
output_y: str | None
|
|
652
|
-
output_y_h: str | None
|
|
653
|
-
output_y_c: str | None
|
|
654
|
-
seq_length: int
|
|
655
|
-
batch_size: int
|
|
656
|
-
input_size: int
|
|
657
|
-
hidden_size: int
|
|
658
|
-
num_directions: int
|
|
659
|
-
direction: str
|
|
660
|
-
layout: int
|
|
661
|
-
input_forget: int
|
|
662
|
-
clip: float | None
|
|
663
|
-
activation_kinds: tuple[int, ...]
|
|
664
|
-
activation_alphas: tuple[float, ...]
|
|
665
|
-
activation_betas: tuple[float, ...]
|
|
666
|
-
dtype: ScalarType
|
|
667
|
-
sequence_lens_dtype: ScalarType | None
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
@dataclass(frozen=True)
|
|
671
|
-
class MaxPoolOp:
|
|
672
|
-
input0: str
|
|
673
|
-
output: str
|
|
674
|
-
indices: str | None
|
|
675
|
-
batch: int
|
|
676
|
-
channels: int
|
|
677
|
-
spatial_rank: int
|
|
678
|
-
in_spatial: tuple[int, ...]
|
|
679
|
-
out_spatial: tuple[int, ...]
|
|
680
|
-
kernel_shape: tuple[int, ...]
|
|
681
|
-
strides: tuple[int, ...]
|
|
682
|
-
pads: tuple[int, ...]
|
|
683
|
-
dilations: tuple[int, ...]
|
|
684
|
-
ceil_mode: bool
|
|
685
|
-
storage_order: int
|
|
686
|
-
dtype: ScalarType
|
|
687
|
-
indices_dtype: ScalarType | None
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
@dataclass(frozen=True)
|
|
691
|
-
class ConcatOp:
|
|
692
|
-
inputs: tuple[str, ...]
|
|
693
|
-
output: str
|
|
694
|
-
axis: int
|
|
695
|
-
input_shapes: tuple[tuple[int, ...], ...]
|
|
696
|
-
output_shape: tuple[int, ...]
|
|
697
|
-
dtype: ScalarType
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
@dataclass(frozen=True)
|
|
701
|
-
class GatherElementsOp:
|
|
702
|
-
data: str
|
|
703
|
-
indices: str
|
|
704
|
-
output: str
|
|
705
|
-
axis: int
|
|
706
|
-
data_shape: tuple[int, ...]
|
|
707
|
-
indices_shape: tuple[int, ...]
|
|
708
|
-
output_shape: tuple[int, ...]
|
|
709
|
-
dtype: ScalarType
|
|
710
|
-
indices_dtype: ScalarType
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
@dataclass(frozen=True)
|
|
714
|
-
class GatherOp:
|
|
715
|
-
data: str
|
|
716
|
-
indices: str
|
|
717
|
-
output: str
|
|
718
|
-
axis: int
|
|
719
|
-
data_shape: tuple[int, ...]
|
|
720
|
-
indices_shape: tuple[int, ...]
|
|
721
|
-
output_shape: tuple[int, ...]
|
|
722
|
-
dtype: ScalarType
|
|
723
|
-
indices_dtype: ScalarType
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
@dataclass(frozen=True)
|
|
727
|
-
class GatherNDOp:
|
|
728
|
-
data: str
|
|
729
|
-
indices: str
|
|
730
|
-
output: str
|
|
731
|
-
batch_dims: int
|
|
732
|
-
data_shape: tuple[int, ...]
|
|
733
|
-
indices_shape: tuple[int, ...]
|
|
734
|
-
output_shape: tuple[int, ...]
|
|
735
|
-
dtype: ScalarType
|
|
736
|
-
indices_dtype: ScalarType
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
@dataclass(frozen=True)
|
|
740
|
-
class ScatterNDOp:
|
|
741
|
-
data: str
|
|
742
|
-
indices: str
|
|
743
|
-
updates: str
|
|
744
|
-
output: str
|
|
745
|
-
data_shape: tuple[int, ...]
|
|
746
|
-
indices_shape: tuple[int, ...]
|
|
747
|
-
updates_shape: tuple[int, ...]
|
|
748
|
-
output_shape: tuple[int, ...]
|
|
749
|
-
reduction: str
|
|
750
|
-
dtype: ScalarType
|
|
751
|
-
indices_dtype: ScalarType
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
@dataclass(frozen=True)
|
|
755
|
-
class TransposeOp:
|
|
756
|
-
input0: str
|
|
757
|
-
output: str
|
|
758
|
-
perm: tuple[int, ...]
|
|
759
|
-
input_shape: tuple[int, ...]
|
|
760
|
-
output_shape: tuple[int, ...]
|
|
761
|
-
dtype: ScalarType
|
|
762
|
-
input_dtype: ScalarType
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
@dataclass(frozen=True)
|
|
766
|
-
class ReshapeOp:
|
|
767
|
-
input0: str
|
|
768
|
-
output: str
|
|
769
|
-
input_shape: tuple[int, ...]
|
|
770
|
-
output_shape: tuple[int, ...]
|
|
771
|
-
dtype: ScalarType
|
|
772
|
-
input_dtype: ScalarType
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
@dataclass(frozen=True)
|
|
776
|
-
class IdentityOp:
|
|
777
|
-
input0: str
|
|
778
|
-
output: str
|
|
779
|
-
shape: tuple[int, ...]
|
|
780
|
-
dtype: ScalarType
|
|
781
|
-
input_dtype: ScalarType
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
@dataclass(frozen=True)
|
|
785
|
-
class EyeLikeOp:
|
|
786
|
-
input0: str
|
|
787
|
-
output: str
|
|
788
|
-
output_shape: tuple[int, ...]
|
|
789
|
-
k: int
|
|
790
|
-
dtype: ScalarType
|
|
791
|
-
input_dtype: ScalarType
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
@dataclass(frozen=True)
|
|
795
|
-
class TriluOp:
|
|
796
|
-
input0: str
|
|
797
|
-
output: str
|
|
798
|
-
input_shape: tuple[int, ...]
|
|
799
|
-
output_shape: tuple[int, ...]
|
|
800
|
-
upper: bool
|
|
801
|
-
k_value: int
|
|
802
|
-
k_input: str | None
|
|
803
|
-
k_input_shape: tuple[int, ...] | None
|
|
804
|
-
k_input_dtype: ScalarType | None
|
|
805
|
-
dtype: ScalarType
|
|
806
|
-
input_dtype: ScalarType
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
@dataclass(frozen=True)
|
|
810
|
-
class TileOp:
|
|
811
|
-
input0: str
|
|
812
|
-
output: str
|
|
813
|
-
input_shape: tuple[int, ...]
|
|
814
|
-
output_shape: tuple[int, ...]
|
|
815
|
-
repeats: tuple[int, ...]
|
|
816
|
-
input_strides: tuple[int, ...]
|
|
817
|
-
dtype: ScalarType
|
|
818
|
-
input_dtype: ScalarType
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
@dataclass(frozen=True)
|
|
822
|
-
class PadOp:
|
|
823
|
-
input0: str
|
|
824
|
-
output: str
|
|
825
|
-
input_shape: tuple[int, ...]
|
|
826
|
-
output_shape: tuple[int, ...]
|
|
827
|
-
pads_begin: tuple[int, ...] | None
|
|
828
|
-
pads_end: tuple[int, ...] | None
|
|
829
|
-
pads_input: str | None
|
|
830
|
-
pads_shape: tuple[int, ...] | None
|
|
831
|
-
pads_dtype: ScalarType | None
|
|
832
|
-
pads_axis_map: tuple[int | None, ...] | None
|
|
833
|
-
pads_values: tuple[int, ...] | None
|
|
834
|
-
axes_input: str | None
|
|
835
|
-
axes_shape: tuple[int, ...] | None
|
|
836
|
-
axes_dtype: ScalarType | None
|
|
837
|
-
mode: str
|
|
838
|
-
value: float | int | bool
|
|
839
|
-
value_input: str | None
|
|
840
|
-
value_shape: tuple[int, ...] | None
|
|
841
|
-
dtype: ScalarType
|
|
842
|
-
input_dtype: ScalarType
|
|
843
|
-
input_strides: tuple[int, ...]
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
@dataclass(frozen=True)
|
|
847
|
-
class DepthToSpaceOp:
|
|
848
|
-
input0: str
|
|
849
|
-
output: str
|
|
850
|
-
input_shape: tuple[int, ...]
|
|
851
|
-
output_shape: tuple[int, ...]
|
|
852
|
-
blocksize: int
|
|
853
|
-
mode: str
|
|
854
|
-
dtype: ScalarType
|
|
855
|
-
input_dtype: ScalarType
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
@dataclass(frozen=True)
|
|
859
|
-
class SpaceToDepthOp:
|
|
860
|
-
input0: str
|
|
861
|
-
output: str
|
|
862
|
-
input_shape: tuple[int, ...]
|
|
863
|
-
output_shape: tuple[int, ...]
|
|
864
|
-
blocksize: int
|
|
865
|
-
dtype: ScalarType
|
|
866
|
-
input_dtype: ScalarType
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
@dataclass(frozen=True)
|
|
870
|
-
class SliceOp:
|
|
871
|
-
input0: str
|
|
872
|
-
output: str
|
|
873
|
-
input_shape: tuple[int, ...]
|
|
874
|
-
output_shape: tuple[int, ...]
|
|
875
|
-
starts: tuple[int, ...] | None
|
|
876
|
-
steps: tuple[int, ...] | None
|
|
877
|
-
axes: tuple[int, ...] | None
|
|
878
|
-
starts_input: str | None
|
|
879
|
-
ends_input: str | None
|
|
880
|
-
axes_input: str | None
|
|
881
|
-
steps_input: str | None
|
|
882
|
-
starts_shape: tuple[int, ...] | None
|
|
883
|
-
ends_shape: tuple[int, ...] | None
|
|
884
|
-
axes_shape: tuple[int, ...] | None
|
|
885
|
-
steps_shape: tuple[int, ...] | None
|
|
886
|
-
starts_dtype: ScalarType | None
|
|
887
|
-
ends_dtype: ScalarType | None
|
|
888
|
-
axes_dtype: ScalarType | None
|
|
889
|
-
steps_dtype: ScalarType | None
|
|
890
|
-
dtype: ScalarType
|
|
891
|
-
input_dtype: ScalarType
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
@dataclass(frozen=True)
|
|
895
|
-
class ResizeOp:
|
|
896
|
-
input0: str
|
|
897
|
-
output: str
|
|
898
|
-
input_shape: tuple[int, ...]
|
|
899
|
-
output_shape: tuple[int, ...]
|
|
900
|
-
scales: tuple[float, ...]
|
|
901
|
-
scales_input: str | None
|
|
902
|
-
sizes_input: str | None
|
|
903
|
-
roi_input: str | None
|
|
904
|
-
axes: tuple[int, ...]
|
|
905
|
-
scales_shape: tuple[int, ...] | None
|
|
906
|
-
sizes_shape: tuple[int, ...] | None
|
|
907
|
-
roi_shape: tuple[int, ...] | None
|
|
908
|
-
scales_dtype: ScalarType | None
|
|
909
|
-
sizes_dtype: ScalarType | None
|
|
910
|
-
roi_dtype: ScalarType | None
|
|
911
|
-
scales_axes: tuple[int, ...] | None
|
|
912
|
-
sizes_axes: tuple[int, ...] | None
|
|
913
|
-
roi_axes: tuple[int, ...] | None
|
|
914
|
-
mode: str
|
|
915
|
-
coordinate_transformation_mode: str
|
|
916
|
-
nearest_mode: str
|
|
917
|
-
cubic_coeff_a: float
|
|
918
|
-
exclude_outside: bool
|
|
919
|
-
extrapolation_value: float
|
|
920
|
-
antialias: bool
|
|
921
|
-
keep_aspect_ratio_policy: str
|
|
922
|
-
dtype: ScalarType
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
@dataclass(frozen=True)
|
|
926
|
-
class GridSampleOp:
|
|
927
|
-
input0: str
|
|
928
|
-
grid: str
|
|
929
|
-
output: str
|
|
930
|
-
input_shape: tuple[int, ...]
|
|
931
|
-
grid_shape: tuple[int, ...]
|
|
932
|
-
output_shape: tuple[int, ...]
|
|
933
|
-
spatial_rank: int
|
|
934
|
-
input_spatial: tuple[int, ...]
|
|
935
|
-
output_spatial: tuple[int, ...]
|
|
936
|
-
mode: str
|
|
937
|
-
padding_mode: str
|
|
938
|
-
align_corners: bool
|
|
939
|
-
dtype: ScalarType
|
|
940
|
-
grid_dtype: ScalarType
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
@dataclass(frozen=True)
|
|
944
|
-
class ReduceOp:
|
|
945
|
-
input0: str
|
|
946
|
-
output: str
|
|
947
|
-
input_shape: tuple[int, ...]
|
|
948
|
-
output_shape: tuple[int, ...]
|
|
949
|
-
axes: tuple[int, ...]
|
|
950
|
-
axes_input: str | None
|
|
951
|
-
axes_input_shape: tuple[int, ...] | None
|
|
952
|
-
axes_input_dtype: ScalarType | None
|
|
953
|
-
keepdims: bool
|
|
954
|
-
noop_with_empty_axes: bool
|
|
955
|
-
reduce_kind: str
|
|
956
|
-
reduce_count: int | None
|
|
957
|
-
dtype: ScalarType
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
@dataclass(frozen=True)
|
|
961
|
-
class ArgReduceOp:
|
|
962
|
-
input0: str
|
|
963
|
-
output: str
|
|
964
|
-
input_shape: tuple[int, ...]
|
|
965
|
-
output_shape: tuple[int, ...]
|
|
966
|
-
axis: int
|
|
967
|
-
keepdims: bool
|
|
968
|
-
select_last_index: bool
|
|
969
|
-
reduce_kind: str
|
|
970
|
-
input_dtype: ScalarType
|
|
971
|
-
output_dtype: ScalarType
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
@dataclass(frozen=True)
|
|
975
|
-
class TopKOp:
|
|
976
|
-
input0: str
|
|
977
|
-
output_values: str
|
|
978
|
-
output_indices: str
|
|
979
|
-
input_shape: tuple[int, ...]
|
|
980
|
-
output_shape: tuple[int, ...]
|
|
981
|
-
axis: int
|
|
982
|
-
k: int
|
|
983
|
-
largest: bool
|
|
984
|
-
sorted: bool
|
|
985
|
-
input_dtype: ScalarType
|
|
986
|
-
output_values_dtype: ScalarType
|
|
987
|
-
output_indices_dtype: ScalarType
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
@dataclass(frozen=True)
|
|
991
|
-
class ConstantOfShapeOp:
|
|
992
|
-
input0: str
|
|
993
|
-
output: str
|
|
994
|
-
input_shape: tuple[int, ...]
|
|
995
|
-
shape: tuple[int, ...]
|
|
996
|
-
value: float | int | bool
|
|
997
|
-
dtype: ScalarType
|
|
998
|
-
input_dtype: ScalarType
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
@dataclass(frozen=True)
|
|
1002
|
-
class ShapeOp:
|
|
1003
|
-
input0: str
|
|
1004
|
-
output: str
|
|
1005
|
-
input_shape: tuple[int, ...]
|
|
1006
|
-
output_shape: tuple[int, ...]
|
|
1007
|
-
values: tuple[int, ...]
|
|
1008
|
-
dtype: ScalarType
|
|
1009
|
-
input_dtype: ScalarType
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
@dataclass(frozen=True)
|
|
1013
|
-
class SizeOp:
|
|
1014
|
-
input0: str
|
|
1015
|
-
output: str
|
|
1016
|
-
input_shape: tuple[int, ...]
|
|
1017
|
-
output_shape: tuple[int, ...]
|
|
1018
|
-
value: int
|
|
1019
|
-
dtype: ScalarType
|
|
1020
|
-
input_dtype: ScalarType
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
@dataclass(frozen=True)
|
|
1024
|
-
class NonZeroOp:
|
|
1025
|
-
input0: str
|
|
1026
|
-
output: str
|
|
1027
|
-
input_shape: tuple[int, ...]
|
|
1028
|
-
output_shape: tuple[int, ...]
|
|
1029
|
-
dtype: ScalarType
|
|
1030
|
-
input_dtype: ScalarType
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
@dataclass(frozen=True)
|
|
1034
|
-
class ExpandOp:
|
|
1035
|
-
input0: str
|
|
1036
|
-
output: str
|
|
1037
|
-
input_shape: tuple[int, ...]
|
|
1038
|
-
output_shape: tuple[int, ...]
|
|
1039
|
-
input_shape_padded: tuple[int, ...]
|
|
1040
|
-
input_strides: tuple[int, ...]
|
|
1041
|
-
dtype: ScalarType
|
|
1042
|
-
input_dtype: ScalarType
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
@dataclass(frozen=True)
|
|
1046
|
-
class CumSumOp:
|
|
1047
|
-
input0: str
|
|
1048
|
-
axis_input: str | None
|
|
1049
|
-
axis_input_dtype: ScalarType | None
|
|
1050
|
-
axis: int | None
|
|
1051
|
-
output: str
|
|
1052
|
-
input_shape: tuple[int, ...]
|
|
1053
|
-
dtype: ScalarType
|
|
1054
|
-
input_dtype: ScalarType
|
|
1055
|
-
exclusive: bool
|
|
1056
|
-
reverse: bool
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
@dataclass(frozen=True)
|
|
1060
|
-
class RangeOp:
|
|
1061
|
-
start: str
|
|
1062
|
-
limit: str
|
|
1063
|
-
delta: str
|
|
1064
|
-
output: str
|
|
1065
|
-
output_shape: tuple[int, ...]
|
|
1066
|
-
length: int
|
|
1067
|
-
dtype: ScalarType
|
|
1068
|
-
input_dtype: ScalarType
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
@dataclass(frozen=True)
|
|
1072
|
-
class OneHotOp:
|
|
1073
|
-
indices: str
|
|
1074
|
-
depth: str
|
|
1075
|
-
values: str
|
|
1076
|
-
output: str
|
|
1077
|
-
axis: int
|
|
1078
|
-
indices_shape: tuple[int, ...]
|
|
1079
|
-
values_shape: tuple[int, ...]
|
|
1080
|
-
output_shape: tuple[int, ...]
|
|
1081
|
-
depth_dim: int
|
|
1082
|
-
dtype: ScalarType
|
|
1083
|
-
indices_dtype: ScalarType
|
|
1084
|
-
depth_dtype: ScalarType
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
@dataclass(frozen=True)
|
|
1088
|
-
class SplitOp:
|
|
1089
|
-
input0: str
|
|
1090
|
-
outputs: tuple[str, ...]
|
|
1091
|
-
input_shape: tuple[int, ...]
|
|
1092
|
-
output_shapes: tuple[tuple[int, ...], ...]
|
|
1093
|
-
axis: int
|
|
1094
|
-
split_sizes: tuple[int, ...]
|
|
1095
|
-
dtype: ScalarType
|
|
1096
|
-
input_dtype: ScalarType
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
240
|
@dataclass(frozen=True)
|
|
1100
241
|
class ConstTensor:
|
|
1101
242
|
name: str
|
|
@@ -1135,78 +276,29 @@ class ModelHeader:
|
|
|
1135
276
|
|
|
1136
277
|
@dataclass(frozen=True)
|
|
1137
278
|
class LoweredModel:
|
|
1138
|
-
name: str
|
|
1139
|
-
input_names: tuple[str, ...]
|
|
1140
|
-
input_shapes: tuple[tuple[int, ...], ...]
|
|
1141
|
-
input_dtypes: tuple[ScalarType, ...]
|
|
1142
|
-
output_names: tuple[str, ...]
|
|
1143
|
-
output_shapes: tuple[tuple[int, ...], ...]
|
|
1144
|
-
output_dtypes: tuple[ScalarType, ...]
|
|
1145
|
-
constants: tuple[ConstTensor, ...]
|
|
1146
|
-
ops: tuple[
|
|
1147
|
-
BinaryOp
|
|
1148
|
-
| MultiInputBinaryOp
|
|
1149
|
-
| WhereOp
|
|
1150
|
-
| UnaryOp
|
|
1151
|
-
| ClipOp
|
|
1152
|
-
| CastOp
|
|
1153
|
-
| QuantizeLinearOp
|
|
1154
|
-
| MatMulOp
|
|
1155
|
-
| EinsumOp
|
|
1156
|
-
| GemmOp
|
|
1157
|
-
| AttentionOp
|
|
1158
|
-
| ConvOp
|
|
1159
|
-
| ConvTransposeOp
|
|
1160
|
-
| AveragePoolOp
|
|
1161
|
-
| LpPoolOp
|
|
1162
|
-
| BatchNormOp
|
|
1163
|
-
| LpNormalizationOp
|
|
1164
|
-
| InstanceNormalizationOp
|
|
1165
|
-
| GroupNormalizationOp
|
|
1166
|
-
| LayerNormalizationOp
|
|
1167
|
-
| MeanVarianceNormalizationOp
|
|
1168
|
-
| RMSNormalizationOp
|
|
1169
|
-
| LrnOp
|
|
1170
|
-
| LstmOp
|
|
1171
|
-
| SoftmaxOp
|
|
1172
|
-
| LogSoftmaxOp
|
|
1173
|
-
| HardmaxOp
|
|
1174
|
-
| NegativeLogLikelihoodLossOp
|
|
1175
|
-
| SoftmaxCrossEntropyLossOp
|
|
1176
|
-
| MaxPoolOp
|
|
1177
|
-
| ConcatOp
|
|
1178
|
-
| GatherElementsOp
|
|
1179
|
-
| GatherOp
|
|
1180
|
-
| GatherNDOp
|
|
1181
|
-
| ScatterNDOp
|
|
1182
|
-
| TransposeOp
|
|
1183
|
-
| ReshapeOp
|
|
1184
|
-
| IdentityOp
|
|
1185
|
-
| EyeLikeOp
|
|
1186
|
-
| TriluOp
|
|
1187
|
-
| TileOp
|
|
1188
|
-
| PadOp
|
|
1189
|
-
| DepthToSpaceOp
|
|
1190
|
-
| SpaceToDepthOp
|
|
1191
|
-
| SliceOp
|
|
1192
|
-
| ResizeOp
|
|
1193
|
-
| GridSampleOp
|
|
1194
|
-
| ReduceOp
|
|
1195
|
-
| ArgReduceOp
|
|
1196
|
-
| TopKOp
|
|
1197
|
-
| ConstantOfShapeOp
|
|
1198
|
-
| ShapeOp
|
|
1199
|
-
| SizeOp
|
|
1200
|
-
| NonZeroOp
|
|
1201
|
-
| ExpandOp
|
|
1202
|
-
| CumSumOp
|
|
1203
|
-
| RangeOp
|
|
1204
|
-
| OneHotOp
|
|
1205
|
-
| SplitOp,
|
|
1206
|
-
...,
|
|
1207
|
-
]
|
|
279
|
+
name: str
|
|
280
|
+
input_names: tuple[str, ...]
|
|
281
|
+
input_shapes: tuple[tuple[int, ...], ...]
|
|
282
|
+
input_dtypes: tuple[ScalarType, ...]
|
|
283
|
+
output_names: tuple[str, ...]
|
|
284
|
+
output_shapes: tuple[tuple[int, ...], ...]
|
|
285
|
+
output_dtypes: tuple[ScalarType, ...]
|
|
286
|
+
constants: tuple[ConstTensor, ...]
|
|
287
|
+
ops: tuple[OpBase, ...]
|
|
1208
288
|
node_infos: tuple[NodeInfo, ...]
|
|
1209
289
|
header: ModelHeader
|
|
290
|
+
op_context: OpContext
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@dataclass
|
|
294
|
+
class _EmitState:
|
|
295
|
+
model: LoweredModel
|
|
296
|
+
templates: dict[str, Template]
|
|
297
|
+
scalar_registry: ScalarFunctionRegistry
|
|
298
|
+
dim_args: str
|
|
299
|
+
tensor_dim_names: Mapping[str, Mapping[int, str]]
|
|
300
|
+
op_context: OpContext
|
|
301
|
+
value_name_map: Mapping[str, str]
|
|
1210
302
|
|
|
1211
303
|
|
|
1212
304
|
class CEmitter:
|
|
@@ -1235,6 +327,7 @@ class CEmitter:
|
|
|
1235
327
|
if large_weight_threshold < 0:
|
|
1236
328
|
raise CodegenError("large_weight_threshold must be >= 0")
|
|
1237
329
|
self._large_weight_threshold = large_weight_threshold
|
|
330
|
+
self._emit_state: _EmitState | None = None
|
|
1238
331
|
|
|
1239
332
|
@staticmethod
|
|
1240
333
|
def _sanitize_identifier(name: str) -> str:
|
|
@@ -1297,6 +390,26 @@ class CEmitter:
|
|
|
1297
390
|
mapped[key] = unique
|
|
1298
391
|
return mapped
|
|
1299
392
|
|
|
393
|
+
def _ctx_name(self, name: str) -> str:
|
|
394
|
+
if self._emit_state is None:
|
|
395
|
+
raise CodegenError("Emitter state not initialized")
|
|
396
|
+
return self._emit_state.value_name_map.get(name, name)
|
|
397
|
+
|
|
398
|
+
def _ctx_shape(self, name: str) -> tuple[int, ...]:
|
|
399
|
+
if self._emit_state is None:
|
|
400
|
+
raise CodegenError("Emitter state not initialized")
|
|
401
|
+
return self._emit_state.op_context.shape(self._ctx_name(name))
|
|
402
|
+
|
|
403
|
+
def _ctx_dtype(self, name: str) -> ScalarType:
|
|
404
|
+
if self._emit_state is None:
|
|
405
|
+
raise CodegenError("Emitter state not initialized")
|
|
406
|
+
return self._emit_state.op_context.dtype(self._ctx_name(name))
|
|
407
|
+
|
|
408
|
+
def _derived(self, op: OpBase, key: str) -> object:
|
|
409
|
+
if self._emit_state is None:
|
|
410
|
+
raise CodegenError("Emitter state not initialized")
|
|
411
|
+
return self._emit_state.op_context.require_derived(op, key)
|
|
412
|
+
|
|
1300
413
|
@staticmethod
|
|
1301
414
|
def _build_param_decls(
|
|
1302
415
|
specs: Sequence[tuple[str | None, str, str, bool]]
|
|
@@ -1334,10 +447,12 @@ class CEmitter:
|
|
|
1334
447
|
| ClipOp
|
|
1335
448
|
| CastOp
|
|
1336
449
|
| QuantizeLinearOp
|
|
450
|
+
| QLinearMatMulOp
|
|
1337
451
|
| MatMulOp
|
|
1338
452
|
| EinsumOp
|
|
1339
453
|
| GemmOp
|
|
1340
454
|
| AttentionOp
|
|
455
|
+
| RotaryEmbeddingOp
|
|
1341
456
|
| ConvOp
|
|
1342
457
|
| AveragePoolOp
|
|
1343
458
|
| BatchNormOp
|
|
@@ -1349,6 +464,7 @@ class CEmitter:
|
|
|
1349
464
|
| RMSNormalizationOp
|
|
1350
465
|
| LrnOp
|
|
1351
466
|
| LstmOp
|
|
467
|
+
| AdagradOp
|
|
1352
468
|
| SoftmaxOp
|
|
1353
469
|
| LogSoftmaxOp
|
|
1354
470
|
| HardmaxOp
|
|
@@ -1360,6 +476,7 @@ class CEmitter:
|
|
|
1360
476
|
| GatherOp
|
|
1361
477
|
| GatherNDOp
|
|
1362
478
|
| ScatterNDOp
|
|
479
|
+
| TensorScatterOp
|
|
1363
480
|
| TransposeOp
|
|
1364
481
|
| ReshapeOp
|
|
1365
482
|
| IdentityOp
|
|
@@ -1379,6 +496,7 @@ class CEmitter:
|
|
|
1379
496
|
| ShapeOp
|
|
1380
497
|
| SizeOp
|
|
1381
498
|
| NonZeroOp
|
|
499
|
+
| NonMaxSuppressionOp
|
|
1382
500
|
| ExpandOp
|
|
1383
501
|
| CumSumOp
|
|
1384
502
|
| RangeOp
|
|
@@ -1409,6 +527,18 @@ class CEmitter:
|
|
|
1409
527
|
names.append(op.zero_point)
|
|
1410
528
|
names.append(op.output)
|
|
1411
529
|
return tuple(names)
|
|
530
|
+
if isinstance(op, QLinearMatMulOp):
|
|
531
|
+
return (
|
|
532
|
+
op.input0,
|
|
533
|
+
op.input0_scale,
|
|
534
|
+
op.input0_zero_point,
|
|
535
|
+
op.input1,
|
|
536
|
+
op.input1_scale,
|
|
537
|
+
op.input1_zero_point,
|
|
538
|
+
op.output_scale,
|
|
539
|
+
op.output_zero_point,
|
|
540
|
+
op.output,
|
|
541
|
+
)
|
|
1412
542
|
if isinstance(op, MatMulOp):
|
|
1413
543
|
return (op.input0, op.input1, op.output)
|
|
1414
544
|
if isinstance(op, EinsumOp):
|
|
@@ -1437,6 +567,12 @@ class CEmitter:
|
|
|
1437
567
|
if op.output_qk_matmul is not None:
|
|
1438
568
|
names.append(op.output_qk_matmul)
|
|
1439
569
|
return tuple(names)
|
|
570
|
+
if isinstance(op, RotaryEmbeddingOp):
|
|
571
|
+
names = [op.input0, op.cos_cache, op.sin_cache]
|
|
572
|
+
if op.position_ids is not None:
|
|
573
|
+
names.append(op.position_ids)
|
|
574
|
+
names.append(op.output)
|
|
575
|
+
return tuple(names)
|
|
1440
576
|
if isinstance(op, ConvOp):
|
|
1441
577
|
names = [op.input0, op.weights]
|
|
1442
578
|
if op.bias is not None:
|
|
@@ -1494,6 +630,16 @@ class CEmitter:
|
|
|
1494
630
|
if op.output_y_c is not None:
|
|
1495
631
|
names.append(op.output_y_c)
|
|
1496
632
|
return tuple(names)
|
|
633
|
+
if isinstance(op, AdagradOp):
|
|
634
|
+
return (
|
|
635
|
+
op.rate,
|
|
636
|
+
op.timestep,
|
|
637
|
+
*op.inputs,
|
|
638
|
+
*op.gradients,
|
|
639
|
+
*op.accumulators,
|
|
640
|
+
*op.outputs,
|
|
641
|
+
*op.accumulator_outputs,
|
|
642
|
+
)
|
|
1497
643
|
if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
|
|
1498
644
|
return (op.input0, op.output)
|
|
1499
645
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
@@ -1523,6 +669,12 @@ class CEmitter:
|
|
|
1523
669
|
return (op.data, op.indices, op.output)
|
|
1524
670
|
if isinstance(op, ScatterNDOp):
|
|
1525
671
|
return (op.data, op.indices, op.updates, op.output)
|
|
672
|
+
if isinstance(op, TensorScatterOp):
|
|
673
|
+
names = [op.past_cache, op.update]
|
|
674
|
+
if op.write_indices is not None:
|
|
675
|
+
names.append(op.write_indices)
|
|
676
|
+
names.append(op.output)
|
|
677
|
+
return tuple(names)
|
|
1526
678
|
if isinstance(op, ConcatOp):
|
|
1527
679
|
return (*op.inputs, op.output)
|
|
1528
680
|
if isinstance(op, ConstantOfShapeOp):
|
|
@@ -1533,6 +685,16 @@ class CEmitter:
|
|
|
1533
685
|
return (op.input0, op.output)
|
|
1534
686
|
if isinstance(op, NonZeroOp):
|
|
1535
687
|
return (op.input0, op.output)
|
|
688
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
689
|
+
names = [op.boxes, op.scores]
|
|
690
|
+
if op.max_output_boxes_per_class is not None:
|
|
691
|
+
names.append(op.max_output_boxes_per_class)
|
|
692
|
+
if op.iou_threshold is not None:
|
|
693
|
+
names.append(op.iou_threshold)
|
|
694
|
+
if op.score_threshold is not None:
|
|
695
|
+
names.append(op.score_threshold)
|
|
696
|
+
names.append(op.output)
|
|
697
|
+
return tuple(names)
|
|
1536
698
|
if isinstance(op, ExpandOp):
|
|
1537
699
|
return (op.input0, op.output)
|
|
1538
700
|
if isinstance(op, CumSumOp):
|
|
@@ -1653,10 +815,12 @@ class CEmitter:
|
|
|
1653
815
|
| ClipOp
|
|
1654
816
|
| CastOp
|
|
1655
817
|
| QuantizeLinearOp
|
|
818
|
+
| QLinearMatMulOp
|
|
1656
819
|
| MatMulOp
|
|
1657
820
|
| EinsumOp
|
|
1658
821
|
| GemmOp
|
|
1659
822
|
| AttentionOp
|
|
823
|
+
| RotaryEmbeddingOp
|
|
1660
824
|
| ConvOp
|
|
1661
825
|
| ConvTransposeOp
|
|
1662
826
|
| AveragePoolOp
|
|
@@ -1670,6 +834,7 @@ class CEmitter:
|
|
|
1670
834
|
| RMSNormalizationOp
|
|
1671
835
|
| LrnOp
|
|
1672
836
|
| LstmOp
|
|
837
|
+
| AdagradOp
|
|
1673
838
|
| SoftmaxOp
|
|
1674
839
|
| LogSoftmaxOp
|
|
1675
840
|
| HardmaxOp
|
|
@@ -1681,6 +846,7 @@ class CEmitter:
|
|
|
1681
846
|
| GatherOp
|
|
1682
847
|
| GatherNDOp
|
|
1683
848
|
| ScatterNDOp
|
|
849
|
+
| TensorScatterOp
|
|
1684
850
|
| TransposeOp
|
|
1685
851
|
| ReshapeOp
|
|
1686
852
|
| IdentityOp
|
|
@@ -1700,6 +866,7 @@ class CEmitter:
|
|
|
1700
866
|
| ShapeOp
|
|
1701
867
|
| SizeOp
|
|
1702
868
|
| NonZeroOp
|
|
869
|
+
| NonMaxSuppressionOp
|
|
1703
870
|
| ExpandOp
|
|
1704
871
|
| CumSumOp
|
|
1705
872
|
| RangeOp
|
|
@@ -1714,10 +881,12 @@ class CEmitter:
|
|
|
1714
881
|
| ClipOp
|
|
1715
882
|
| CastOp
|
|
1716
883
|
| QuantizeLinearOp
|
|
884
|
+
| QLinearMatMulOp
|
|
1717
885
|
| MatMulOp
|
|
1718
886
|
| EinsumOp
|
|
1719
887
|
| GemmOp
|
|
1720
888
|
| AttentionOp
|
|
889
|
+
| RotaryEmbeddingOp
|
|
1721
890
|
| ConvOp
|
|
1722
891
|
| ConvTransposeOp
|
|
1723
892
|
| AveragePoolOp
|
|
@@ -1731,6 +900,7 @@ class CEmitter:
|
|
|
1731
900
|
| RMSNormalizationOp
|
|
1732
901
|
| LrnOp
|
|
1733
902
|
| LstmOp
|
|
903
|
+
| AdagradOp
|
|
1734
904
|
| SoftmaxOp
|
|
1735
905
|
| LogSoftmaxOp
|
|
1736
906
|
| HardmaxOp
|
|
@@ -1742,6 +912,7 @@ class CEmitter:
|
|
|
1742
912
|
| GatherOp
|
|
1743
913
|
| GatherNDOp
|
|
1744
914
|
| ScatterNDOp
|
|
915
|
+
| TensorScatterOp
|
|
1745
916
|
| TransposeOp
|
|
1746
917
|
| ReshapeOp
|
|
1747
918
|
| IdentityOp
|
|
@@ -1761,6 +932,7 @@ class CEmitter:
|
|
|
1761
932
|
| ShapeOp
|
|
1762
933
|
| SizeOp
|
|
1763
934
|
| NonZeroOp
|
|
935
|
+
| NonMaxSuppressionOp
|
|
1764
936
|
| ExpandOp
|
|
1765
937
|
| CumSumOp
|
|
1766
938
|
| RangeOp
|
|
@@ -1844,6 +1016,47 @@ class CEmitter:
|
|
|
1844
1016
|
input_dtype=op.input_dtype,
|
|
1845
1017
|
scale_dtype=op.scale_dtype,
|
|
1846
1018
|
)
|
|
1019
|
+
if isinstance(op, QLinearMatMulOp):
|
|
1020
|
+
return QLinearMatMulOp(
|
|
1021
|
+
input0=name_map.get(op.input0, op.input0),
|
|
1022
|
+
input0_scale=name_map.get(op.input0_scale, op.input0_scale),
|
|
1023
|
+
input0_zero_point=name_map.get(
|
|
1024
|
+
op.input0_zero_point, op.input0_zero_point
|
|
1025
|
+
),
|
|
1026
|
+
input1=name_map.get(op.input1, op.input1),
|
|
1027
|
+
input1_scale=name_map.get(op.input1_scale, op.input1_scale),
|
|
1028
|
+
input1_zero_point=name_map.get(
|
|
1029
|
+
op.input1_zero_point, op.input1_zero_point
|
|
1030
|
+
),
|
|
1031
|
+
output_scale=name_map.get(op.output_scale, op.output_scale),
|
|
1032
|
+
output_zero_point=name_map.get(
|
|
1033
|
+
op.output_zero_point, op.output_zero_point
|
|
1034
|
+
),
|
|
1035
|
+
output=name_map.get(op.output, op.output),
|
|
1036
|
+
input0_shape=op.input0_shape,
|
|
1037
|
+
input1_shape=op.input1_shape,
|
|
1038
|
+
output_shape=op.output_shape,
|
|
1039
|
+
batch_shape=op.batch_shape,
|
|
1040
|
+
input0_batch_shape=op.input0_batch_shape,
|
|
1041
|
+
input1_batch_shape=op.input1_batch_shape,
|
|
1042
|
+
m=op.m,
|
|
1043
|
+
n=op.n,
|
|
1044
|
+
k=op.k,
|
|
1045
|
+
left_vector=op.left_vector,
|
|
1046
|
+
right_vector=op.right_vector,
|
|
1047
|
+
input0_dtype=op.input0_dtype,
|
|
1048
|
+
input1_dtype=op.input1_dtype,
|
|
1049
|
+
dtype=op.dtype,
|
|
1050
|
+
input0_scale_dtype=op.input0_scale_dtype,
|
|
1051
|
+
input1_scale_dtype=op.input1_scale_dtype,
|
|
1052
|
+
output_scale_dtype=op.output_scale_dtype,
|
|
1053
|
+
input0_scale_shape=op.input0_scale_shape,
|
|
1054
|
+
input1_scale_shape=op.input1_scale_shape,
|
|
1055
|
+
output_scale_shape=op.output_scale_shape,
|
|
1056
|
+
input0_zero_shape=op.input0_zero_shape,
|
|
1057
|
+
input1_zero_shape=op.input1_zero_shape,
|
|
1058
|
+
output_zero_shape=op.output_zero_shape,
|
|
1059
|
+
)
|
|
1847
1060
|
if isinstance(op, MatMulOp):
|
|
1848
1061
|
return MatMulOp(
|
|
1849
1062
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1946,6 +1159,30 @@ class CEmitter:
|
|
|
1946
1159
|
head_group_size=op.head_group_size,
|
|
1947
1160
|
dtype=op.dtype,
|
|
1948
1161
|
)
|
|
1162
|
+
if isinstance(op, RotaryEmbeddingOp):
|
|
1163
|
+
return RotaryEmbeddingOp(
|
|
1164
|
+
input0=name_map.get(op.input0, op.input0),
|
|
1165
|
+
cos_cache=name_map.get(op.cos_cache, op.cos_cache),
|
|
1166
|
+
sin_cache=name_map.get(op.sin_cache, op.sin_cache),
|
|
1167
|
+
position_ids=self._map_optional_name(
|
|
1168
|
+
name_map, op.position_ids
|
|
1169
|
+
),
|
|
1170
|
+
output=name_map.get(op.output, op.output),
|
|
1171
|
+
input_shape=op.input_shape,
|
|
1172
|
+
cos_shape=op.cos_shape,
|
|
1173
|
+
sin_shape=op.sin_shape,
|
|
1174
|
+
position_ids_shape=op.position_ids_shape,
|
|
1175
|
+
dtype=op.dtype,
|
|
1176
|
+
position_ids_dtype=op.position_ids_dtype,
|
|
1177
|
+
rotary_dim=op.rotary_dim,
|
|
1178
|
+
rotary_dim_half=op.rotary_dim_half,
|
|
1179
|
+
head_size=op.head_size,
|
|
1180
|
+
num_heads=op.num_heads,
|
|
1181
|
+
seq_len=op.seq_len,
|
|
1182
|
+
batch=op.batch,
|
|
1183
|
+
input_rank=op.input_rank,
|
|
1184
|
+
interleaved=op.interleaved,
|
|
1185
|
+
)
|
|
1949
1186
|
if isinstance(op, ConvOp):
|
|
1950
1187
|
return ConvOp(
|
|
1951
1188
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -2168,6 +1405,33 @@ class CEmitter:
|
|
|
2168
1405
|
dtype=op.dtype,
|
|
2169
1406
|
sequence_lens_dtype=op.sequence_lens_dtype,
|
|
2170
1407
|
)
|
|
1408
|
+
if isinstance(op, AdagradOp):
|
|
1409
|
+
return AdagradOp(
|
|
1410
|
+
rate=name_map.get(op.rate, op.rate),
|
|
1411
|
+
timestep=name_map.get(op.timestep, op.timestep),
|
|
1412
|
+
inputs=tuple(name_map.get(name, name) for name in op.inputs),
|
|
1413
|
+
gradients=tuple(
|
|
1414
|
+
name_map.get(name, name) for name in op.gradients
|
|
1415
|
+
),
|
|
1416
|
+
accumulators=tuple(
|
|
1417
|
+
name_map.get(name, name) for name in op.accumulators
|
|
1418
|
+
),
|
|
1419
|
+
outputs=tuple(name_map.get(name, name) for name in op.outputs),
|
|
1420
|
+
accumulator_outputs=tuple(
|
|
1421
|
+
name_map.get(name, name)
|
|
1422
|
+
for name in op.accumulator_outputs
|
|
1423
|
+
),
|
|
1424
|
+
rate_shape=op.rate_shape,
|
|
1425
|
+
timestep_shape=op.timestep_shape,
|
|
1426
|
+
tensor_shapes=op.tensor_shapes,
|
|
1427
|
+
output_shapes=op.output_shapes,
|
|
1428
|
+
dtype=op.dtype,
|
|
1429
|
+
rate_dtype=op.rate_dtype,
|
|
1430
|
+
timestep_dtype=op.timestep_dtype,
|
|
1431
|
+
norm_coefficient=op.norm_coefficient,
|
|
1432
|
+
epsilon=op.epsilon,
|
|
1433
|
+
decay_factor=op.decay_factor,
|
|
1434
|
+
)
|
|
2171
1435
|
if isinstance(op, SoftmaxOp):
|
|
2172
1436
|
return SoftmaxOp(
|
|
2173
1437
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -2323,6 +1587,25 @@ class CEmitter:
|
|
|
2323
1587
|
dtype=op.dtype,
|
|
2324
1588
|
indices_dtype=op.indices_dtype,
|
|
2325
1589
|
)
|
|
1590
|
+
if isinstance(op, TensorScatterOp):
|
|
1591
|
+
return TensorScatterOp(
|
|
1592
|
+
past_cache=name_map.get(op.past_cache, op.past_cache),
|
|
1593
|
+
update=name_map.get(op.update, op.update),
|
|
1594
|
+
write_indices=(
|
|
1595
|
+
name_map.get(op.write_indices, op.write_indices)
|
|
1596
|
+
if op.write_indices is not None
|
|
1597
|
+
else None
|
|
1598
|
+
),
|
|
1599
|
+
output=name_map.get(op.output, op.output),
|
|
1600
|
+
past_cache_shape=op.past_cache_shape,
|
|
1601
|
+
update_shape=op.update_shape,
|
|
1602
|
+
output_shape=op.output_shape,
|
|
1603
|
+
write_indices_shape=op.write_indices_shape,
|
|
1604
|
+
axis=op.axis,
|
|
1605
|
+
mode=op.mode,
|
|
1606
|
+
dtype=op.dtype,
|
|
1607
|
+
write_indices_dtype=op.write_indices_dtype,
|
|
1608
|
+
)
|
|
2326
1609
|
if isinstance(op, TransposeOp):
|
|
2327
1610
|
return TransposeOp(
|
|
2328
1611
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -2583,6 +1866,33 @@ class CEmitter:
|
|
|
2583
1866
|
dtype=op.dtype,
|
|
2584
1867
|
input_dtype=op.input_dtype,
|
|
2585
1868
|
)
|
|
1869
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
1870
|
+
return NonMaxSuppressionOp(
|
|
1871
|
+
boxes=name_map.get(op.boxes, op.boxes),
|
|
1872
|
+
scores=name_map.get(op.scores, op.scores),
|
|
1873
|
+
max_output_boxes_per_class=self._map_optional_name(
|
|
1874
|
+
name_map, op.max_output_boxes_per_class
|
|
1875
|
+
),
|
|
1876
|
+
iou_threshold=self._map_optional_name(
|
|
1877
|
+
name_map, op.iou_threshold
|
|
1878
|
+
),
|
|
1879
|
+
score_threshold=self._map_optional_name(
|
|
1880
|
+
name_map, op.score_threshold
|
|
1881
|
+
),
|
|
1882
|
+
output=name_map.get(op.output, op.output),
|
|
1883
|
+
boxes_shape=op.boxes_shape,
|
|
1884
|
+
scores_shape=op.scores_shape,
|
|
1885
|
+
output_shape=op.output_shape,
|
|
1886
|
+
center_point_box=op.center_point_box,
|
|
1887
|
+
boxes_dtype=op.boxes_dtype,
|
|
1888
|
+
output_dtype=op.output_dtype,
|
|
1889
|
+
max_output_dtype=op.max_output_dtype,
|
|
1890
|
+
max_output_shape=op.max_output_shape,
|
|
1891
|
+
iou_threshold_dtype=op.iou_threshold_dtype,
|
|
1892
|
+
iou_threshold_shape=op.iou_threshold_shape,
|
|
1893
|
+
score_threshold_dtype=op.score_threshold_dtype,
|
|
1894
|
+
score_threshold_shape=op.score_threshold_shape,
|
|
1895
|
+
)
|
|
2586
1896
|
if isinstance(op, ExpandOp):
|
|
2587
1897
|
return ExpandOp(
|
|
2588
1898
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -2684,12 +1994,34 @@ class CEmitter:
|
|
|
2684
1994
|
ops=ops,
|
|
2685
1995
|
node_infos=model.node_infos,
|
|
2686
1996
|
header=model.header,
|
|
1997
|
+
op_context=model.op_context,
|
|
2687
1998
|
)
|
|
2688
1999
|
return sanitized, name_map
|
|
2689
2000
|
|
|
2690
2001
|
def _sanitize_model_names(self, model: LoweredModel) -> LoweredModel:
|
|
2691
2002
|
return self._sanitize_model_names_with_map(model)[0]
|
|
2692
2003
|
|
|
2004
|
+
@staticmethod
|
|
2005
|
+
def _copy_derived(
|
|
2006
|
+
op_context: OpContext,
|
|
2007
|
+
source_ops: Sequence[OpBase],
|
|
2008
|
+
target_ops: Sequence[OpBase],
|
|
2009
|
+
) -> None:
|
|
2010
|
+
for source_op, target_op in zip(source_ops, target_ops):
|
|
2011
|
+
op_context.copy_derived(source_op, target_op)
|
|
2012
|
+
|
|
2013
|
+
@staticmethod
|
|
2014
|
+
def _build_value_name_map(
|
|
2015
|
+
name_map: Mapping[str, str],
|
|
2016
|
+
temp_name_map: Mapping[str, str],
|
|
2017
|
+
) -> dict[str, str]:
|
|
2018
|
+
reverse_name_map = {sanitized: original for original, sanitized in name_map.items()}
|
|
2019
|
+
value_name_map = dict(reverse_name_map)
|
|
2020
|
+
for sanitized_name, temp_name in temp_name_map.items():
|
|
2021
|
+
original_name = reverse_name_map.get(sanitized_name, sanitized_name)
|
|
2022
|
+
value_name_map[temp_name] = original_name
|
|
2023
|
+
return value_name_map
|
|
2024
|
+
|
|
2693
2025
|
@staticmethod
|
|
2694
2026
|
def _sanitize_testbench_inputs(
|
|
2695
2027
|
testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None,
|
|
@@ -2716,10 +2048,16 @@ class CEmitter:
|
|
|
2716
2048
|
"quantize_linear": self._env.get_template(
|
|
2717
2049
|
"quantize_linear_op.c.j2"
|
|
2718
2050
|
),
|
|
2051
|
+
"qlinear_matmul": self._env.get_template(
|
|
2052
|
+
"qlinear_matmul_op.c.j2"
|
|
2053
|
+
),
|
|
2719
2054
|
"matmul": self._env.get_template("matmul_op.c.j2"),
|
|
2720
2055
|
"einsum": self._env.get_template("einsum_op.c.j2"),
|
|
2721
2056
|
"gemm": self._env.get_template("gemm_op.c.j2"),
|
|
2722
2057
|
"attention": self._env.get_template("attention_op.c.j2"),
|
|
2058
|
+
"rotary_embedding": self._env.get_template(
|
|
2059
|
+
"rotary_embedding_op.c.j2"
|
|
2060
|
+
),
|
|
2723
2061
|
"conv": self._env.get_template("conv_op.c.j2"),
|
|
2724
2062
|
"conv_transpose": self._env.get_template(
|
|
2725
2063
|
"conv_transpose_op.c.j2"
|
|
@@ -2743,6 +2081,7 @@ class CEmitter:
|
|
|
2743
2081
|
"rms_norm": self._env.get_template("rms_normalization_op.c.j2"),
|
|
2744
2082
|
"lrn": self._env.get_template("lrn_op.c.j2"),
|
|
2745
2083
|
"lstm": self._env.get_template("lstm_op.c.j2"),
|
|
2084
|
+
"adagrad": self._env.get_template("adagrad_op.c.j2"),
|
|
2746
2085
|
"softmax": self._env.get_template("softmax_op.c.j2"),
|
|
2747
2086
|
"logsoftmax": self._env.get_template("logsoftmax_op.c.j2"),
|
|
2748
2087
|
"hardmax": self._env.get_template("hardmax_op.c.j2"),
|
|
@@ -2758,6 +2097,9 @@ class CEmitter:
|
|
|
2758
2097
|
"gather": self._env.get_template("gather_op.c.j2"),
|
|
2759
2098
|
"gather_nd": self._env.get_template("gather_nd_op.c.j2"),
|
|
2760
2099
|
"scatter_nd": self._env.get_template("scatter_nd_op.c.j2"),
|
|
2100
|
+
"tensor_scatter": self._env.get_template(
|
|
2101
|
+
"tensor_scatter_op.c.j2"
|
|
2102
|
+
),
|
|
2761
2103
|
"transpose": self._env.get_template("transpose_op.c.j2"),
|
|
2762
2104
|
"reshape": self._env.get_template("reshape_op.c.j2"),
|
|
2763
2105
|
"identity": self._env.get_template("identity_op.c.j2"),
|
|
@@ -2785,6 +2127,9 @@ class CEmitter:
|
|
|
2785
2127
|
"shape": self._env.get_template("shape_op.c.j2"),
|
|
2786
2128
|
"size": self._env.get_template("size_op.c.j2"),
|
|
2787
2129
|
"nonzero": self._env.get_template("nonzero_op.c.j2"),
|
|
2130
|
+
"nonmax_suppression": self._env.get_template(
|
|
2131
|
+
"nonmax_suppression_op.c.j2"
|
|
2132
|
+
),
|
|
2788
2133
|
"expand": self._env.get_template("expand_op.c.j2"),
|
|
2789
2134
|
"cumsum": self._env.get_template("cumsum_op.c.j2"),
|
|
2790
2135
|
"range": self._env.get_template("range_op.c.j2"),
|
|
@@ -2806,7 +2151,9 @@ class CEmitter:
|
|
|
2806
2151
|
variable_dim_inputs: Mapping[int, Mapping[int, str]] | None = None,
|
|
2807
2152
|
variable_dim_outputs: Mapping[int, Mapping[int, str]] | None = None,
|
|
2808
2153
|
) -> str:
|
|
2154
|
+
original_model = model
|
|
2809
2155
|
model, name_map = self._sanitize_model_names_with_map(model)
|
|
2156
|
+
self._copy_derived(model.op_context, original_model.ops, model.ops)
|
|
2810
2157
|
testbench_inputs = self._sanitize_testbench_inputs(
|
|
2811
2158
|
testbench_inputs, name_map
|
|
2812
2159
|
)
|
|
@@ -2832,68 +2179,17 @@ class CEmitter:
|
|
|
2832
2179
|
self._env.globals["dim_args"] = dim_args
|
|
2833
2180
|
templates = self._load_templates(emit_testbench)
|
|
2834
2181
|
scalar_registry = ScalarFunctionRegistry()
|
|
2835
|
-
binary_template = templates["binary"]
|
|
2836
|
-
multi_input_template = templates["multi_input"]
|
|
2837
|
-
where_template = templates["where"]
|
|
2838
|
-
unary_template = templates["unary"]
|
|
2839
|
-
clip_template = templates["clip"]
|
|
2840
|
-
cast_template = templates["cast"]
|
|
2841
|
-
quantize_linear_template = templates["quantize_linear"]
|
|
2842
|
-
matmul_template = templates["matmul"]
|
|
2843
|
-
einsum_template = templates["einsum"]
|
|
2844
|
-
gemm_template = templates["gemm"]
|
|
2845
|
-
attention_template = templates["attention"]
|
|
2846
|
-
conv_template = templates["conv"]
|
|
2847
|
-
conv_transpose_template = templates["conv_transpose"]
|
|
2848
|
-
avg_pool_template = templates["avg_pool"]
|
|
2849
|
-
lp_pool_template = templates["lp_pool"]
|
|
2850
|
-
batch_norm_template = templates["batch_norm"]
|
|
2851
|
-
lp_norm_template = templates["lp_norm"]
|
|
2852
|
-
instance_norm_template = templates["instance_norm"]
|
|
2853
|
-
group_norm_template = templates["group_norm"]
|
|
2854
|
-
layer_norm_template = templates["layer_norm"]
|
|
2855
|
-
mean_variance_norm_template = templates["mean_variance_norm"]
|
|
2856
|
-
rms_norm_template = templates["rms_norm"]
|
|
2857
|
-
lrn_template = templates["lrn"]
|
|
2858
|
-
lstm_template = templates["lstm"]
|
|
2859
|
-
softmax_template = templates["softmax"]
|
|
2860
|
-
logsoftmax_template = templates["logsoftmax"]
|
|
2861
|
-
hardmax_template = templates["hardmax"]
|
|
2862
|
-
nllloss_template = templates["nllloss"]
|
|
2863
|
-
softmax_cross_entropy_loss_template = templates["softmax_cross_entropy_loss"]
|
|
2864
|
-
maxpool_template = templates["maxpool"]
|
|
2865
|
-
concat_template = templates["concat"]
|
|
2866
|
-
gather_elements_template = templates["gather_elements"]
|
|
2867
|
-
gather_template = templates["gather"]
|
|
2868
|
-
gather_nd_template = templates["gather_nd"]
|
|
2869
|
-
scatter_nd_template = templates["scatter_nd"]
|
|
2870
|
-
transpose_template = templates["transpose"]
|
|
2871
|
-
reshape_template = templates["reshape"]
|
|
2872
|
-
identity_template = templates["identity"]
|
|
2873
|
-
eye_like_template = templates["eye_like"]
|
|
2874
|
-
trilu_template = templates["trilu"]
|
|
2875
|
-
tile_template = templates["tile"]
|
|
2876
|
-
pad_template = templates["pad"]
|
|
2877
|
-
depth_to_space_template = templates["depth_to_space"]
|
|
2878
|
-
space_to_depth_template = templates["space_to_depth"]
|
|
2879
|
-
slice_template = templates["slice"]
|
|
2880
|
-
slice_dynamic_template = templates["slice_dynamic"]
|
|
2881
|
-
resize_template = templates["resize"]
|
|
2882
|
-
grid_sample_template = templates["grid_sample"]
|
|
2883
|
-
reduce_template = templates["reduce"]
|
|
2884
|
-
reduce_dynamic_template = templates["reduce_dynamic"]
|
|
2885
|
-
arg_reduce_template = templates["arg_reduce"]
|
|
2886
|
-
topk_template = templates["topk"]
|
|
2887
|
-
constant_of_shape_template = templates["constant_of_shape"]
|
|
2888
|
-
shape_template = templates["shape"]
|
|
2889
|
-
size_template = templates["size"]
|
|
2890
|
-
nonzero_template = templates["nonzero"]
|
|
2891
|
-
expand_template = templates["expand"]
|
|
2892
|
-
cumsum_template = templates["cumsum"]
|
|
2893
|
-
range_template = templates["range"]
|
|
2894
|
-
one_hot_template = templates["one_hot"]
|
|
2895
|
-
split_template = templates["split"]
|
|
2896
2182
|
testbench_template = templates.get("testbench")
|
|
2183
|
+
initial_name_map = self._build_value_name_map(name_map, {})
|
|
2184
|
+
self._emit_state = _EmitState(
|
|
2185
|
+
model=model,
|
|
2186
|
+
templates=templates,
|
|
2187
|
+
scalar_registry=scalar_registry,
|
|
2188
|
+
dim_args=dim_args,
|
|
2189
|
+
tensor_dim_names=tensor_dim_names,
|
|
2190
|
+
op_context=model.op_context,
|
|
2191
|
+
value_name_map=initial_name_map,
|
|
2192
|
+
)
|
|
2897
2193
|
reserved_names = {
|
|
2898
2194
|
model.name,
|
|
2899
2195
|
*model.input_names,
|
|
@@ -2905,83 +2201,12 @@ class CEmitter:
|
|
|
2905
2201
|
original: buffer.name for original, buffer in temp_buffers.items()
|
|
2906
2202
|
}
|
|
2907
2203
|
resolved_ops = [self._resolve_op(op, temp_name_map) for op in model.ops]
|
|
2204
|
+
self._copy_derived(model.op_context, model.ops, resolved_ops)
|
|
2205
|
+
value_name_map = self._build_value_name_map(name_map, temp_name_map)
|
|
2206
|
+
self._emit_state.value_name_map = value_name_map
|
|
2908
2207
|
self._propagate_tensor_dim_names(resolved_ops, tensor_dim_names)
|
|
2909
2208
|
operator_fns = "\n\n".join(
|
|
2910
|
-
|
|
2911
|
-
model,
|
|
2912
|
-
op,
|
|
2913
|
-
index,
|
|
2914
|
-
array_suffix="",
|
|
2915
|
-
loop_vars=(),
|
|
2916
|
-
c_type=self._op_output_dtype(op).c_type,
|
|
2917
|
-
zero_literal=self._op_output_dtype(op).zero_literal,
|
|
2918
|
-
min_literal=self._op_output_dtype(op).min_literal,
|
|
2919
|
-
max_literal=self._op_output_dtype(op).max_literal,
|
|
2920
|
-
binary_template=binary_template,
|
|
2921
|
-
multi_input_template=multi_input_template,
|
|
2922
|
-
where_template=where_template,
|
|
2923
|
-
unary_template=unary_template,
|
|
2924
|
-
clip_template=clip_template,
|
|
2925
|
-
cast_template=cast_template,
|
|
2926
|
-
quantize_linear_template=quantize_linear_template,
|
|
2927
|
-
matmul_template=matmul_template,
|
|
2928
|
-
einsum_template=einsum_template,
|
|
2929
|
-
gemm_template=gemm_template,
|
|
2930
|
-
attention_template=attention_template,
|
|
2931
|
-
conv_template=conv_template,
|
|
2932
|
-
conv_transpose_template=conv_transpose_template,
|
|
2933
|
-
avg_pool_template=avg_pool_template,
|
|
2934
|
-
lp_pool_template=lp_pool_template,
|
|
2935
|
-
batch_norm_template=batch_norm_template,
|
|
2936
|
-
lp_norm_template=lp_norm_template,
|
|
2937
|
-
instance_norm_template=instance_norm_template,
|
|
2938
|
-
group_norm_template=group_norm_template,
|
|
2939
|
-
layer_norm_template=layer_norm_template,
|
|
2940
|
-
mean_variance_norm_template=mean_variance_norm_template,
|
|
2941
|
-
rms_norm_template=rms_norm_template,
|
|
2942
|
-
lrn_template=lrn_template,
|
|
2943
|
-
lstm_template=lstm_template,
|
|
2944
|
-
softmax_template=softmax_template,
|
|
2945
|
-
logsoftmax_template=logsoftmax_template,
|
|
2946
|
-
hardmax_template=hardmax_template,
|
|
2947
|
-
nllloss_template=nllloss_template,
|
|
2948
|
-
softmax_cross_entropy_loss_template=softmax_cross_entropy_loss_template,
|
|
2949
|
-
maxpool_template=maxpool_template,
|
|
2950
|
-
concat_template=concat_template,
|
|
2951
|
-
gather_elements_template=gather_elements_template,
|
|
2952
|
-
gather_template=gather_template,
|
|
2953
|
-
gather_nd_template=gather_nd_template,
|
|
2954
|
-
scatter_nd_template=scatter_nd_template,
|
|
2955
|
-
transpose_template=transpose_template,
|
|
2956
|
-
reshape_template=reshape_template,
|
|
2957
|
-
identity_template=identity_template,
|
|
2958
|
-
eye_like_template=eye_like_template,
|
|
2959
|
-
trilu_template=trilu_template,
|
|
2960
|
-
tile_template=tile_template,
|
|
2961
|
-
pad_template=pad_template,
|
|
2962
|
-
depth_to_space_template=depth_to_space_template,
|
|
2963
|
-
space_to_depth_template=space_to_depth_template,
|
|
2964
|
-
slice_template=slice_template,
|
|
2965
|
-
slice_dynamic_template=slice_dynamic_template,
|
|
2966
|
-
resize_template=resize_template,
|
|
2967
|
-
grid_sample_template=grid_sample_template,
|
|
2968
|
-
reduce_template=reduce_template,
|
|
2969
|
-
reduce_dynamic_template=reduce_dynamic_template,
|
|
2970
|
-
arg_reduce_template=arg_reduce_template,
|
|
2971
|
-
topk_template=topk_template,
|
|
2972
|
-
constant_of_shape_template=constant_of_shape_template,
|
|
2973
|
-
shape_template=shape_template,
|
|
2974
|
-
size_template=size_template,
|
|
2975
|
-
nonzero_template=nonzero_template,
|
|
2976
|
-
expand_template=expand_template,
|
|
2977
|
-
cumsum_template=cumsum_template,
|
|
2978
|
-
range_template=range_template,
|
|
2979
|
-
one_hot_template=one_hot_template,
|
|
2980
|
-
split_template=split_template,
|
|
2981
|
-
scalar_registry=scalar_registry,
|
|
2982
|
-
dim_args=dim_args,
|
|
2983
|
-
tensor_dim_names=tensor_dim_names,
|
|
2984
|
-
)
|
|
2209
|
+
op.emit(self, EmitContext(op_index=index))
|
|
2985
2210
|
for index, op in enumerate(resolved_ops)
|
|
2986
2211
|
)
|
|
2987
2212
|
wrapper_fn = self._emit_model_wrapper(
|
|
@@ -3073,7 +2298,9 @@ class CEmitter:
|
|
|
3073
2298
|
variable_dim_inputs: Mapping[int, Mapping[int, str]] | None = None,
|
|
3074
2299
|
variable_dim_outputs: Mapping[int, Mapping[int, str]] | None = None,
|
|
3075
2300
|
) -> tuple[str, str]:
|
|
2301
|
+
original_model = model
|
|
3076
2302
|
model, name_map = self._sanitize_model_names_with_map(model)
|
|
2303
|
+
self._copy_derived(model.op_context, original_model.ops, model.ops)
|
|
3077
2304
|
testbench_inputs = self._sanitize_testbench_inputs(
|
|
3078
2305
|
testbench_inputs, name_map
|
|
3079
2306
|
)
|
|
@@ -3099,68 +2326,17 @@ class CEmitter:
|
|
|
3099
2326
|
self._env.globals["dim_args"] = dim_args
|
|
3100
2327
|
templates = self._load_templates(emit_testbench)
|
|
3101
2328
|
scalar_registry = ScalarFunctionRegistry()
|
|
3102
|
-
binary_template = templates["binary"]
|
|
3103
|
-
multi_input_template = templates["multi_input"]
|
|
3104
|
-
where_template = templates["where"]
|
|
3105
|
-
unary_template = templates["unary"]
|
|
3106
|
-
clip_template = templates["clip"]
|
|
3107
|
-
cast_template = templates["cast"]
|
|
3108
|
-
quantize_linear_template = templates["quantize_linear"]
|
|
3109
|
-
matmul_template = templates["matmul"]
|
|
3110
|
-
einsum_template = templates["einsum"]
|
|
3111
|
-
gemm_template = templates["gemm"]
|
|
3112
|
-
attention_template = templates["attention"]
|
|
3113
|
-
conv_template = templates["conv"]
|
|
3114
|
-
conv_transpose_template = templates["conv_transpose"]
|
|
3115
|
-
avg_pool_template = templates["avg_pool"]
|
|
3116
|
-
lp_pool_template = templates["lp_pool"]
|
|
3117
|
-
batch_norm_template = templates["batch_norm"]
|
|
3118
|
-
lp_norm_template = templates["lp_norm"]
|
|
3119
|
-
instance_norm_template = templates["instance_norm"]
|
|
3120
|
-
group_norm_template = templates["group_norm"]
|
|
3121
|
-
layer_norm_template = templates["layer_norm"]
|
|
3122
|
-
mean_variance_norm_template = templates["mean_variance_norm"]
|
|
3123
|
-
rms_norm_template = templates["rms_norm"]
|
|
3124
|
-
lrn_template = templates["lrn"]
|
|
3125
|
-
lstm_template = templates["lstm"]
|
|
3126
|
-
softmax_template = templates["softmax"]
|
|
3127
|
-
logsoftmax_template = templates["logsoftmax"]
|
|
3128
|
-
hardmax_template = templates["hardmax"]
|
|
3129
|
-
nllloss_template = templates["nllloss"]
|
|
3130
|
-
softmax_cross_entropy_loss_template = templates["softmax_cross_entropy_loss"]
|
|
3131
|
-
maxpool_template = templates["maxpool"]
|
|
3132
|
-
concat_template = templates["concat"]
|
|
3133
|
-
gather_elements_template = templates["gather_elements"]
|
|
3134
|
-
gather_template = templates["gather"]
|
|
3135
|
-
gather_nd_template = templates["gather_nd"]
|
|
3136
|
-
scatter_nd_template = templates["scatter_nd"]
|
|
3137
|
-
transpose_template = templates["transpose"]
|
|
3138
|
-
reshape_template = templates["reshape"]
|
|
3139
|
-
identity_template = templates["identity"]
|
|
3140
|
-
eye_like_template = templates["eye_like"]
|
|
3141
|
-
trilu_template = templates["trilu"]
|
|
3142
|
-
tile_template = templates["tile"]
|
|
3143
|
-
pad_template = templates["pad"]
|
|
3144
|
-
depth_to_space_template = templates["depth_to_space"]
|
|
3145
|
-
space_to_depth_template = templates["space_to_depth"]
|
|
3146
|
-
slice_template = templates["slice"]
|
|
3147
|
-
slice_dynamic_template = templates["slice_dynamic"]
|
|
3148
|
-
resize_template = templates["resize"]
|
|
3149
|
-
grid_sample_template = templates["grid_sample"]
|
|
3150
|
-
reduce_template = templates["reduce"]
|
|
3151
|
-
reduce_dynamic_template = templates["reduce_dynamic"]
|
|
3152
|
-
arg_reduce_template = templates["arg_reduce"]
|
|
3153
|
-
topk_template = templates["topk"]
|
|
3154
|
-
constant_of_shape_template = templates["constant_of_shape"]
|
|
3155
|
-
shape_template = templates["shape"]
|
|
3156
|
-
size_template = templates["size"]
|
|
3157
|
-
nonzero_template = templates["nonzero"]
|
|
3158
|
-
expand_template = templates["expand"]
|
|
3159
|
-
cumsum_template = templates["cumsum"]
|
|
3160
|
-
range_template = templates["range"]
|
|
3161
|
-
one_hot_template = templates["one_hot"]
|
|
3162
|
-
split_template = templates["split"]
|
|
3163
2329
|
testbench_template = templates.get("testbench")
|
|
2330
|
+
initial_name_map = self._build_value_name_map(name_map, {})
|
|
2331
|
+
self._emit_state = _EmitState(
|
|
2332
|
+
model=model,
|
|
2333
|
+
templates=templates,
|
|
2334
|
+
scalar_registry=scalar_registry,
|
|
2335
|
+
dim_args=dim_args,
|
|
2336
|
+
tensor_dim_names=tensor_dim_names,
|
|
2337
|
+
op_context=model.op_context,
|
|
2338
|
+
value_name_map=initial_name_map,
|
|
2339
|
+
)
|
|
3164
2340
|
reserved_names = {
|
|
3165
2341
|
model.name,
|
|
3166
2342
|
*model.input_names,
|
|
@@ -3172,83 +2348,12 @@ class CEmitter:
|
|
|
3172
2348
|
original: buffer.name for original, buffer in temp_buffers.items()
|
|
3173
2349
|
}
|
|
3174
2350
|
resolved_ops = [self._resolve_op(op, temp_name_map) for op in model.ops]
|
|
2351
|
+
self._copy_derived(model.op_context, model.ops, resolved_ops)
|
|
2352
|
+
value_name_map = self._build_value_name_map(name_map, temp_name_map)
|
|
2353
|
+
self._emit_state.value_name_map = value_name_map
|
|
3175
2354
|
self._propagate_tensor_dim_names(resolved_ops, tensor_dim_names)
|
|
3176
2355
|
operator_fns = "\n\n".join(
|
|
3177
|
-
|
|
3178
|
-
model,
|
|
3179
|
-
op,
|
|
3180
|
-
index,
|
|
3181
|
-
array_suffix="",
|
|
3182
|
-
loop_vars=(),
|
|
3183
|
-
c_type=self._op_output_dtype(op).c_type,
|
|
3184
|
-
zero_literal=self._op_output_dtype(op).zero_literal,
|
|
3185
|
-
min_literal=self._op_output_dtype(op).min_literal,
|
|
3186
|
-
max_literal=self._op_output_dtype(op).max_literal,
|
|
3187
|
-
binary_template=binary_template,
|
|
3188
|
-
multi_input_template=multi_input_template,
|
|
3189
|
-
where_template=where_template,
|
|
3190
|
-
unary_template=unary_template,
|
|
3191
|
-
clip_template=clip_template,
|
|
3192
|
-
cast_template=cast_template,
|
|
3193
|
-
quantize_linear_template=quantize_linear_template,
|
|
3194
|
-
matmul_template=matmul_template,
|
|
3195
|
-
einsum_template=einsum_template,
|
|
3196
|
-
gemm_template=gemm_template,
|
|
3197
|
-
attention_template=attention_template,
|
|
3198
|
-
conv_template=conv_template,
|
|
3199
|
-
conv_transpose_template=conv_transpose_template,
|
|
3200
|
-
avg_pool_template=avg_pool_template,
|
|
3201
|
-
lp_pool_template=lp_pool_template,
|
|
3202
|
-
batch_norm_template=batch_norm_template,
|
|
3203
|
-
lp_norm_template=lp_norm_template,
|
|
3204
|
-
instance_norm_template=instance_norm_template,
|
|
3205
|
-
group_norm_template=group_norm_template,
|
|
3206
|
-
layer_norm_template=layer_norm_template,
|
|
3207
|
-
mean_variance_norm_template=mean_variance_norm_template,
|
|
3208
|
-
rms_norm_template=rms_norm_template,
|
|
3209
|
-
lrn_template=lrn_template,
|
|
3210
|
-
lstm_template=lstm_template,
|
|
3211
|
-
softmax_template=softmax_template,
|
|
3212
|
-
logsoftmax_template=logsoftmax_template,
|
|
3213
|
-
hardmax_template=hardmax_template,
|
|
3214
|
-
nllloss_template=nllloss_template,
|
|
3215
|
-
softmax_cross_entropy_loss_template=softmax_cross_entropy_loss_template,
|
|
3216
|
-
maxpool_template=maxpool_template,
|
|
3217
|
-
concat_template=concat_template,
|
|
3218
|
-
gather_elements_template=gather_elements_template,
|
|
3219
|
-
gather_template=gather_template,
|
|
3220
|
-
gather_nd_template=gather_nd_template,
|
|
3221
|
-
scatter_nd_template=scatter_nd_template,
|
|
3222
|
-
transpose_template=transpose_template,
|
|
3223
|
-
reshape_template=reshape_template,
|
|
3224
|
-
identity_template=identity_template,
|
|
3225
|
-
eye_like_template=eye_like_template,
|
|
3226
|
-
trilu_template=trilu_template,
|
|
3227
|
-
tile_template=tile_template,
|
|
3228
|
-
pad_template=pad_template,
|
|
3229
|
-
depth_to_space_template=depth_to_space_template,
|
|
3230
|
-
space_to_depth_template=space_to_depth_template,
|
|
3231
|
-
slice_template=slice_template,
|
|
3232
|
-
slice_dynamic_template=slice_dynamic_template,
|
|
3233
|
-
resize_template=resize_template,
|
|
3234
|
-
grid_sample_template=grid_sample_template,
|
|
3235
|
-
reduce_template=reduce_template,
|
|
3236
|
-
reduce_dynamic_template=reduce_dynamic_template,
|
|
3237
|
-
arg_reduce_template=arg_reduce_template,
|
|
3238
|
-
topk_template=topk_template,
|
|
3239
|
-
constant_of_shape_template=constant_of_shape_template,
|
|
3240
|
-
shape_template=shape_template,
|
|
3241
|
-
size_template=size_template,
|
|
3242
|
-
nonzero_template=nonzero_template,
|
|
3243
|
-
expand_template=expand_template,
|
|
3244
|
-
cumsum_template=cumsum_template,
|
|
3245
|
-
range_template=range_template,
|
|
3246
|
-
one_hot_template=one_hot_template,
|
|
3247
|
-
split_template=split_template,
|
|
3248
|
-
scalar_registry=scalar_registry,
|
|
3249
|
-
dim_args=dim_args,
|
|
3250
|
-
tensor_dim_names=tensor_dim_names,
|
|
3251
|
-
)
|
|
2356
|
+
op.emit(self, EmitContext(op_index=index))
|
|
3252
2357
|
for index, op in enumerate(resolved_ops)
|
|
3253
2358
|
)
|
|
3254
2359
|
wrapper_fn = self._emit_model_wrapper(
|
|
@@ -3536,6 +2641,8 @@ class CEmitter:
|
|
|
3536
2641
|
ScalarFunction.SCALED_TANH,
|
|
3537
2642
|
ScalarFunction.THRESHOLDED_RELU,
|
|
3538
2643
|
ScalarFunction.LOGICAL_XOR,
|
|
2644
|
+
ScalarFunction.ISNEGINF,
|
|
2645
|
+
ScalarFunction.ISPOSINF,
|
|
3539
2646
|
}
|
|
3540
2647
|
if function in {ScalarFunction.MAXIMUM, ScalarFunction.MINIMUM}:
|
|
3541
2648
|
if dtype in {ScalarType.F32, ScalarType.F64}:
|
|
@@ -3598,6 +2705,7 @@ class CEmitter:
|
|
|
3598
2705
|
| ClipOp
|
|
3599
2706
|
| CastOp
|
|
3600
2707
|
| QuantizeLinearOp
|
|
2708
|
+
| QLinearMatMulOp
|
|
3601
2709
|
| MatMulOp
|
|
3602
2710
|
| EinsumOp
|
|
3603
2711
|
| GemmOp
|
|
@@ -3615,6 +2723,7 @@ class CEmitter:
|
|
|
3615
2723
|
| RMSNormalizationOp
|
|
3616
2724
|
| LrnOp
|
|
3617
2725
|
| LstmOp
|
|
2726
|
+
| AdagradOp
|
|
3618
2727
|
| SoftmaxOp
|
|
3619
2728
|
| LogSoftmaxOp
|
|
3620
2729
|
| HardmaxOp
|
|
@@ -3626,6 +2735,7 @@ class CEmitter:
|
|
|
3626
2735
|
| GatherOp
|
|
3627
2736
|
| GatherNDOp
|
|
3628
2737
|
| ScatterNDOp
|
|
2738
|
+
| TensorScatterOp
|
|
3629
2739
|
| TransposeOp
|
|
3630
2740
|
| ReshapeOp
|
|
3631
2741
|
| IdentityOp
|
|
@@ -3644,6 +2754,7 @@ class CEmitter:
|
|
|
3644
2754
|
| ShapeOp
|
|
3645
2755
|
| SizeOp
|
|
3646
2756
|
| NonZeroOp
|
|
2757
|
+
| NonMaxSuppressionOp
|
|
3647
2758
|
| ExpandOp
|
|
3648
2759
|
| CumSumOp
|
|
3649
2760
|
| RangeOp
|
|
@@ -3830,6 +2941,7 @@ class CEmitter:
|
|
|
3830
2941
|
| ClipOp
|
|
3831
2942
|
| CastOp
|
|
3832
2943
|
| QuantizeLinearOp
|
|
2944
|
+
| QLinearMatMulOp
|
|
3833
2945
|
| MatMulOp
|
|
3834
2946
|
| EinsumOp
|
|
3835
2947
|
| GemmOp
|
|
@@ -3847,6 +2959,7 @@ class CEmitter:
|
|
|
3847
2959
|
| RMSNormalizationOp
|
|
3848
2960
|
| LrnOp
|
|
3849
2961
|
| LstmOp
|
|
2962
|
+
| AdagradOp
|
|
3850
2963
|
| SoftmaxOp
|
|
3851
2964
|
| LogSoftmaxOp
|
|
3852
2965
|
| HardmaxOp
|
|
@@ -3858,6 +2971,7 @@ class CEmitter:
|
|
|
3858
2971
|
| GatherOp
|
|
3859
2972
|
| GatherNDOp
|
|
3860
2973
|
| ScatterNDOp
|
|
2974
|
+
| TensorScatterOp
|
|
3861
2975
|
| TransposeOp
|
|
3862
2976
|
| ReshapeOp
|
|
3863
2977
|
| IdentityOp
|
|
@@ -3876,6 +2990,7 @@ class CEmitter:
|
|
|
3876
2990
|
| ShapeOp
|
|
3877
2991
|
| SizeOp
|
|
3878
2992
|
| NonZeroOp
|
|
2993
|
+
| NonMaxSuppressionOp
|
|
3879
2994
|
| ExpandOp
|
|
3880
2995
|
| CumSumOp
|
|
3881
2996
|
| RangeOp
|
|
@@ -3948,6 +3063,7 @@ class CEmitter:
|
|
|
3948
3063
|
RMSNormalizationOp,
|
|
3949
3064
|
LrnOp,
|
|
3950
3065
|
LstmOp,
|
|
3066
|
+
AdagradOp,
|
|
3951
3067
|
SoftmaxOp,
|
|
3952
3068
|
LogSoftmaxOp,
|
|
3953
3069
|
SoftmaxCrossEntropyLossOp,
|
|
@@ -3977,7 +3093,7 @@ class CEmitter:
|
|
|
3977
3093
|
):
|
|
3978
3094
|
return True
|
|
3979
3095
|
if any(
|
|
3980
|
-
isinstance(op, (LpPoolOp, QuantizeLinearOp))
|
|
3096
|
+
isinstance(op, (LpPoolOp, QuantizeLinearOp, QLinearMatMulOp))
|
|
3981
3097
|
for op in resolved_ops
|
|
3982
3098
|
):
|
|
3983
3099
|
return True
|
|
@@ -3991,6 +3107,7 @@ class CEmitter:
|
|
|
3991
3107
|
| ClipOp
|
|
3992
3108
|
| CastOp
|
|
3993
3109
|
| QuantizeLinearOp
|
|
3110
|
+
| QLinearMatMulOp
|
|
3994
3111
|
| MatMulOp
|
|
3995
3112
|
| EinsumOp
|
|
3996
3113
|
| GemmOp
|
|
@@ -4036,6 +3153,7 @@ class CEmitter:
|
|
|
4036
3153
|
| ShapeOp
|
|
4037
3154
|
| SizeOp
|
|
4038
3155
|
| NonZeroOp
|
|
3156
|
+
| NonMaxSuppressionOp
|
|
4039
3157
|
| ExpandOp
|
|
4040
3158
|
| CumSumOp
|
|
4041
3159
|
| RangeOp
|
|
@@ -4070,10 +3188,13 @@ class CEmitter:
|
|
|
4070
3188
|
):
|
|
4071
3189
|
return True
|
|
4072
3190
|
if any(
|
|
4073
|
-
isinstance(op, QuantizeLinearOp)
|
|
3191
|
+
isinstance(op, (QuantizeLinearOp, QLinearMatMulOp))
|
|
3192
|
+
and op.dtype.is_integer
|
|
4074
3193
|
for op in resolved_ops
|
|
4075
3194
|
):
|
|
4076
3195
|
return True
|
|
3196
|
+
if any(isinstance(op, NonMaxSuppressionOp) for op in resolved_ops):
|
|
3197
|
+
return True
|
|
4077
3198
|
return False
|
|
4078
3199
|
|
|
4079
3200
|
def _emit_model_wrapper(
|
|
@@ -4086,6 +3207,7 @@ class CEmitter:
|
|
|
4086
3207
|
| ClipOp
|
|
4087
3208
|
| CastOp
|
|
4088
3209
|
| QuantizeLinearOp
|
|
3210
|
+
| QLinearMatMulOp
|
|
4089
3211
|
| MatMulOp
|
|
4090
3212
|
| EinsumOp
|
|
4091
3213
|
| GemmOp
|
|
@@ -4131,6 +3253,7 @@ class CEmitter:
|
|
|
4131
3253
|
| ShapeOp
|
|
4132
3254
|
| SizeOp
|
|
4133
3255
|
| NonZeroOp
|
|
3256
|
+
| NonMaxSuppressionOp
|
|
4134
3257
|
| ExpandOp
|
|
4135
3258
|
| CumSumOp
|
|
4136
3259
|
| RangeOp
|
|
@@ -4195,10 +3318,12 @@ class CEmitter:
|
|
|
4195
3318
|
| ClipOp
|
|
4196
3319
|
| CastOp
|
|
4197
3320
|
| QuantizeLinearOp
|
|
3321
|
+
| QLinearMatMulOp
|
|
4198
3322
|
| MatMulOp
|
|
4199
3323
|
| EinsumOp
|
|
4200
3324
|
| GemmOp
|
|
4201
3325
|
| AttentionOp
|
|
3326
|
+
| RotaryEmbeddingOp
|
|
4202
3327
|
| ConvOp
|
|
4203
3328
|
| ConvTransposeOp
|
|
4204
3329
|
| AveragePoolOp
|
|
@@ -4212,6 +3337,7 @@ class CEmitter:
|
|
|
4212
3337
|
| RMSNormalizationOp
|
|
4213
3338
|
| LrnOp
|
|
4214
3339
|
| LstmOp
|
|
3340
|
+
| AdagradOp
|
|
4215
3341
|
| SoftmaxOp
|
|
4216
3342
|
| LogSoftmaxOp
|
|
4217
3343
|
| HardmaxOp
|
|
@@ -4223,6 +3349,7 @@ class CEmitter:
|
|
|
4223
3349
|
| GatherOp
|
|
4224
3350
|
| GatherNDOp
|
|
4225
3351
|
| ScatterNDOp
|
|
3352
|
+
| TensorScatterOp
|
|
4226
3353
|
| TransposeOp
|
|
4227
3354
|
| ReshapeOp
|
|
4228
3355
|
| IdentityOp
|
|
@@ -4242,6 +3369,7 @@ class CEmitter:
|
|
|
4242
3369
|
| ShapeOp
|
|
4243
3370
|
| SizeOp
|
|
4244
3371
|
| NonZeroOp
|
|
3372
|
+
| NonMaxSuppressionOp
|
|
4245
3373
|
| ExpandOp
|
|
4246
3374
|
| CumSumOp
|
|
4247
3375
|
| RangeOp
|
|
@@ -4261,6 +3389,21 @@ class CEmitter:
|
|
|
4261
3389
|
if isinstance(op, WhereOp):
|
|
4262
3390
|
args.extend([op.condition, op.input_x, op.input_y, op.output])
|
|
4263
3391
|
return ", ".join(args)
|
|
3392
|
+
if isinstance(op, QLinearMatMulOp):
|
|
3393
|
+
args.extend(
|
|
3394
|
+
[
|
|
3395
|
+
op.input0,
|
|
3396
|
+
op.input0_scale,
|
|
3397
|
+
op.input0_zero_point,
|
|
3398
|
+
op.input1,
|
|
3399
|
+
op.input1_scale,
|
|
3400
|
+
op.input1_zero_point,
|
|
3401
|
+
op.output_scale,
|
|
3402
|
+
op.output_zero_point,
|
|
3403
|
+
op.output,
|
|
3404
|
+
]
|
|
3405
|
+
)
|
|
3406
|
+
return ", ".join(args)
|
|
4264
3407
|
if isinstance(op, MatMulOp):
|
|
4265
3408
|
args.extend([op.input0, op.input1, op.output])
|
|
4266
3409
|
return ", ".join(args)
|
|
@@ -4380,6 +3523,19 @@ class CEmitter:
|
|
|
4380
3523
|
call_parts.append(op.output_y_c)
|
|
4381
3524
|
args.extend(call_parts)
|
|
4382
3525
|
return ", ".join(args)
|
|
3526
|
+
if isinstance(op, AdagradOp):
|
|
3527
|
+
args.extend(
|
|
3528
|
+
[
|
|
3529
|
+
op.rate,
|
|
3530
|
+
op.timestep,
|
|
3531
|
+
*op.inputs,
|
|
3532
|
+
*op.gradients,
|
|
3533
|
+
*op.accumulators,
|
|
3534
|
+
*op.outputs,
|
|
3535
|
+
*op.accumulator_outputs,
|
|
3536
|
+
]
|
|
3537
|
+
)
|
|
3538
|
+
return ", ".join(args)
|
|
4383
3539
|
if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
|
|
4384
3540
|
args.extend([op.input0, op.output])
|
|
4385
3541
|
return ", ".join(args)
|
|
@@ -4417,6 +3573,12 @@ class CEmitter:
|
|
|
4417
3573
|
if isinstance(op, ScatterNDOp):
|
|
4418
3574
|
args.extend([op.data, op.indices, op.updates, op.output])
|
|
4419
3575
|
return ", ".join(args)
|
|
3576
|
+
if isinstance(op, TensorScatterOp):
|
|
3577
|
+
args.extend([op.past_cache, op.update])
|
|
3578
|
+
if op.write_indices is not None:
|
|
3579
|
+
args.append(op.write_indices)
|
|
3580
|
+
args.append(op.output)
|
|
3581
|
+
return ", ".join(args)
|
|
4420
3582
|
if isinstance(op, ConcatOp):
|
|
4421
3583
|
args.extend([*op.inputs, op.output])
|
|
4422
3584
|
return ", ".join(args)
|
|
@@ -4432,6 +3594,17 @@ class CEmitter:
|
|
|
4432
3594
|
if isinstance(op, NonZeroOp):
|
|
4433
3595
|
args.extend([op.input0, op.output])
|
|
4434
3596
|
return ", ".join(args)
|
|
3597
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
3598
|
+
call_parts = [op.boxes, op.scores]
|
|
3599
|
+
if op.max_output_boxes_per_class is not None:
|
|
3600
|
+
call_parts.append(op.max_output_boxes_per_class)
|
|
3601
|
+
if op.iou_threshold is not None:
|
|
3602
|
+
call_parts.append(op.iou_threshold)
|
|
3603
|
+
if op.score_threshold is not None:
|
|
3604
|
+
call_parts.append(op.score_threshold)
|
|
3605
|
+
call_parts.append(op.output)
|
|
3606
|
+
args.extend(call_parts)
|
|
3607
|
+
return ", ".join(args)
|
|
4435
3608
|
if isinstance(op, ExpandOp):
|
|
4436
3609
|
args.extend([op.input0, op.output])
|
|
4437
3610
|
return ", ".join(args)
|
|
@@ -4566,10 +3739,12 @@ class CEmitter:
|
|
|
4566
3739
|
| ClipOp
|
|
4567
3740
|
| CastOp
|
|
4568
3741
|
| QuantizeLinearOp
|
|
3742
|
+
| QLinearMatMulOp
|
|
4569
3743
|
| MatMulOp
|
|
4570
3744
|
| EinsumOp
|
|
4571
3745
|
| GemmOp
|
|
4572
3746
|
| AttentionOp
|
|
3747
|
+
| RotaryEmbeddingOp
|
|
4573
3748
|
| ConvOp
|
|
4574
3749
|
| ConvTransposeOp
|
|
4575
3750
|
| AveragePoolOp
|
|
@@ -4583,6 +3758,7 @@ class CEmitter:
|
|
|
4583
3758
|
| RMSNormalizationOp
|
|
4584
3759
|
| LrnOp
|
|
4585
3760
|
| LstmOp
|
|
3761
|
+
| AdagradOp
|
|
4586
3762
|
| SoftmaxOp
|
|
4587
3763
|
| LogSoftmaxOp
|
|
4588
3764
|
| HardmaxOp
|
|
@@ -4594,6 +3770,7 @@ class CEmitter:
|
|
|
4594
3770
|
| GatherOp
|
|
4595
3771
|
| GatherNDOp
|
|
4596
3772
|
| ScatterNDOp
|
|
3773
|
+
| TensorScatterOp
|
|
4597
3774
|
| TransposeOp
|
|
4598
3775
|
| ReshapeOp
|
|
4599
3776
|
| IdentityOp
|
|
@@ -4612,6 +3789,7 @@ class CEmitter:
|
|
|
4612
3789
|
| ShapeOp
|
|
4613
3790
|
| SizeOp
|
|
4614
3791
|
| NonZeroOp
|
|
3792
|
+
| NonMaxSuppressionOp
|
|
4615
3793
|
| ExpandOp
|
|
4616
3794
|
| CumSumOp
|
|
4617
3795
|
| RangeOp
|
|
@@ -4626,6 +3804,7 @@ class CEmitter:
|
|
|
4626
3804
|
| ClipOp
|
|
4627
3805
|
| CastOp
|
|
4628
3806
|
| QuantizeLinearOp
|
|
3807
|
+
| QLinearMatMulOp
|
|
4629
3808
|
| MatMulOp
|
|
4630
3809
|
| EinsumOp
|
|
4631
3810
|
| GemmOp
|
|
@@ -4643,6 +3822,7 @@ class CEmitter:
|
|
|
4643
3822
|
| RMSNormalizationOp
|
|
4644
3823
|
| LrnOp
|
|
4645
3824
|
| LstmOp
|
|
3825
|
+
| AdagradOp
|
|
4646
3826
|
| SoftmaxOp
|
|
4647
3827
|
| LogSoftmaxOp
|
|
4648
3828
|
| HardmaxOp
|
|
@@ -4654,6 +3834,7 @@ class CEmitter:
|
|
|
4654
3834
|
| GatherOp
|
|
4655
3835
|
| GatherNDOp
|
|
4656
3836
|
| ScatterNDOp
|
|
3837
|
+
| TensorScatterOp
|
|
4657
3838
|
| TransposeOp
|
|
4658
3839
|
| ReshapeOp
|
|
4659
3840
|
| IdentityOp
|
|
@@ -4672,6 +3853,7 @@ class CEmitter:
|
|
|
4672
3853
|
| ShapeOp
|
|
4673
3854
|
| SizeOp
|
|
4674
3855
|
| NonZeroOp
|
|
3856
|
+
| NonMaxSuppressionOp
|
|
4675
3857
|
| ExpandOp
|
|
4676
3858
|
| CumSumOp
|
|
4677
3859
|
| RangeOp
|
|
@@ -4791,6 +3973,47 @@ class CEmitter:
|
|
|
4791
3973
|
input_dtype=op.input_dtype,
|
|
4792
3974
|
scale_dtype=op.scale_dtype,
|
|
4793
3975
|
)
|
|
3976
|
+
if isinstance(op, QLinearMatMulOp):
|
|
3977
|
+
return QLinearMatMulOp(
|
|
3978
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
3979
|
+
input0_scale=temp_map.get(op.input0_scale, op.input0_scale),
|
|
3980
|
+
input0_zero_point=temp_map.get(
|
|
3981
|
+
op.input0_zero_point, op.input0_zero_point
|
|
3982
|
+
),
|
|
3983
|
+
input1=temp_map.get(op.input1, op.input1),
|
|
3984
|
+
input1_scale=temp_map.get(op.input1_scale, op.input1_scale),
|
|
3985
|
+
input1_zero_point=temp_map.get(
|
|
3986
|
+
op.input1_zero_point, op.input1_zero_point
|
|
3987
|
+
),
|
|
3988
|
+
output_scale=temp_map.get(op.output_scale, op.output_scale),
|
|
3989
|
+
output_zero_point=temp_map.get(
|
|
3990
|
+
op.output_zero_point, op.output_zero_point
|
|
3991
|
+
),
|
|
3992
|
+
output=temp_map.get(op.output, op.output),
|
|
3993
|
+
input0_shape=op.input0_shape,
|
|
3994
|
+
input1_shape=op.input1_shape,
|
|
3995
|
+
output_shape=op.output_shape,
|
|
3996
|
+
batch_shape=op.batch_shape,
|
|
3997
|
+
input0_batch_shape=op.input0_batch_shape,
|
|
3998
|
+
input1_batch_shape=op.input1_batch_shape,
|
|
3999
|
+
m=op.m,
|
|
4000
|
+
n=op.n,
|
|
4001
|
+
k=op.k,
|
|
4002
|
+
left_vector=op.left_vector,
|
|
4003
|
+
right_vector=op.right_vector,
|
|
4004
|
+
input0_dtype=op.input0_dtype,
|
|
4005
|
+
input1_dtype=op.input1_dtype,
|
|
4006
|
+
dtype=op.dtype,
|
|
4007
|
+
input0_scale_dtype=op.input0_scale_dtype,
|
|
4008
|
+
input1_scale_dtype=op.input1_scale_dtype,
|
|
4009
|
+
output_scale_dtype=op.output_scale_dtype,
|
|
4010
|
+
input0_scale_shape=op.input0_scale_shape,
|
|
4011
|
+
input1_scale_shape=op.input1_scale_shape,
|
|
4012
|
+
output_scale_shape=op.output_scale_shape,
|
|
4013
|
+
input0_zero_shape=op.input0_zero_shape,
|
|
4014
|
+
input1_zero_shape=op.input1_zero_shape,
|
|
4015
|
+
output_zero_shape=op.output_zero_shape,
|
|
4016
|
+
)
|
|
4794
4017
|
if isinstance(op, GemmOp):
|
|
4795
4018
|
return GemmOp(
|
|
4796
4019
|
input_a=temp_map.get(op.input_a, op.input_a),
|
|
@@ -4885,6 +4108,32 @@ class CEmitter:
|
|
|
4885
4108
|
head_group_size=op.head_group_size,
|
|
4886
4109
|
dtype=op.dtype,
|
|
4887
4110
|
)
|
|
4111
|
+
if isinstance(op, RotaryEmbeddingOp):
|
|
4112
|
+
return RotaryEmbeddingOp(
|
|
4113
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
4114
|
+
cos_cache=temp_map.get(op.cos_cache, op.cos_cache),
|
|
4115
|
+
sin_cache=temp_map.get(op.sin_cache, op.sin_cache),
|
|
4116
|
+
position_ids=(
|
|
4117
|
+
temp_map.get(op.position_ids, op.position_ids)
|
|
4118
|
+
if op.position_ids is not None
|
|
4119
|
+
else None
|
|
4120
|
+
),
|
|
4121
|
+
output=temp_map.get(op.output, op.output),
|
|
4122
|
+
input_shape=op.input_shape,
|
|
4123
|
+
cos_shape=op.cos_shape,
|
|
4124
|
+
sin_shape=op.sin_shape,
|
|
4125
|
+
position_ids_shape=op.position_ids_shape,
|
|
4126
|
+
dtype=op.dtype,
|
|
4127
|
+
position_ids_dtype=op.position_ids_dtype,
|
|
4128
|
+
rotary_dim=op.rotary_dim,
|
|
4129
|
+
rotary_dim_half=op.rotary_dim_half,
|
|
4130
|
+
head_size=op.head_size,
|
|
4131
|
+
num_heads=op.num_heads,
|
|
4132
|
+
seq_len=op.seq_len,
|
|
4133
|
+
batch=op.batch,
|
|
4134
|
+
input_rank=op.input_rank,
|
|
4135
|
+
interleaved=op.interleaved,
|
|
4136
|
+
)
|
|
4888
4137
|
if isinstance(op, LstmOp):
|
|
4889
4138
|
return LstmOp(
|
|
4890
4139
|
input_x=temp_map.get(op.input_x, op.input_x),
|
|
@@ -4945,6 +4194,33 @@ class CEmitter:
|
|
|
4945
4194
|
dtype=op.dtype,
|
|
4946
4195
|
sequence_lens_dtype=op.sequence_lens_dtype,
|
|
4947
4196
|
)
|
|
4197
|
+
if isinstance(op, AdagradOp):
|
|
4198
|
+
return AdagradOp(
|
|
4199
|
+
rate=temp_map.get(op.rate, op.rate),
|
|
4200
|
+
timestep=temp_map.get(op.timestep, op.timestep),
|
|
4201
|
+
inputs=tuple(temp_map.get(name, name) for name in op.inputs),
|
|
4202
|
+
gradients=tuple(
|
|
4203
|
+
temp_map.get(name, name) for name in op.gradients
|
|
4204
|
+
),
|
|
4205
|
+
accumulators=tuple(
|
|
4206
|
+
temp_map.get(name, name) for name in op.accumulators
|
|
4207
|
+
),
|
|
4208
|
+
outputs=tuple(temp_map.get(name, name) for name in op.outputs),
|
|
4209
|
+
accumulator_outputs=tuple(
|
|
4210
|
+
temp_map.get(name, name)
|
|
4211
|
+
for name in op.accumulator_outputs
|
|
4212
|
+
),
|
|
4213
|
+
rate_shape=op.rate_shape,
|
|
4214
|
+
timestep_shape=op.timestep_shape,
|
|
4215
|
+
tensor_shapes=op.tensor_shapes,
|
|
4216
|
+
output_shapes=op.output_shapes,
|
|
4217
|
+
dtype=op.dtype,
|
|
4218
|
+
rate_dtype=op.rate_dtype,
|
|
4219
|
+
timestep_dtype=op.timestep_dtype,
|
|
4220
|
+
norm_coefficient=op.norm_coefficient,
|
|
4221
|
+
epsilon=op.epsilon,
|
|
4222
|
+
decay_factor=op.decay_factor,
|
|
4223
|
+
)
|
|
4948
4224
|
if isinstance(op, ConvOp):
|
|
4949
4225
|
return ConvOp(
|
|
4950
4226
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -5301,6 +4577,25 @@ class CEmitter:
|
|
|
5301
4577
|
dtype=op.dtype,
|
|
5302
4578
|
indices_dtype=op.indices_dtype,
|
|
5303
4579
|
)
|
|
4580
|
+
if isinstance(op, TensorScatterOp):
|
|
4581
|
+
return TensorScatterOp(
|
|
4582
|
+
past_cache=temp_map.get(op.past_cache, op.past_cache),
|
|
4583
|
+
update=temp_map.get(op.update, op.update),
|
|
4584
|
+
write_indices=(
|
|
4585
|
+
temp_map.get(op.write_indices, op.write_indices)
|
|
4586
|
+
if op.write_indices is not None
|
|
4587
|
+
else None
|
|
4588
|
+
),
|
|
4589
|
+
output=temp_map.get(op.output, op.output),
|
|
4590
|
+
past_cache_shape=op.past_cache_shape,
|
|
4591
|
+
update_shape=op.update_shape,
|
|
4592
|
+
output_shape=op.output_shape,
|
|
4593
|
+
write_indices_shape=op.write_indices_shape,
|
|
4594
|
+
axis=op.axis,
|
|
4595
|
+
mode=op.mode,
|
|
4596
|
+
dtype=op.dtype,
|
|
4597
|
+
write_indices_dtype=op.write_indices_dtype,
|
|
4598
|
+
)
|
|
5304
4599
|
if isinstance(op, ConcatOp):
|
|
5305
4600
|
return ConcatOp(
|
|
5306
4601
|
inputs=tuple(temp_map.get(name, name) for name in op.inputs),
|
|
@@ -5349,6 +4644,33 @@ class CEmitter:
|
|
|
5349
4644
|
dtype=op.dtype,
|
|
5350
4645
|
input_dtype=op.input_dtype,
|
|
5351
4646
|
)
|
|
4647
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
4648
|
+
return NonMaxSuppressionOp(
|
|
4649
|
+
boxes=temp_map.get(op.boxes, op.boxes),
|
|
4650
|
+
scores=temp_map.get(op.scores, op.scores),
|
|
4651
|
+
max_output_boxes_per_class=CEmitter._map_optional_name(
|
|
4652
|
+
temp_map, op.max_output_boxes_per_class
|
|
4653
|
+
),
|
|
4654
|
+
iou_threshold=CEmitter._map_optional_name(
|
|
4655
|
+
temp_map, op.iou_threshold
|
|
4656
|
+
),
|
|
4657
|
+
score_threshold=CEmitter._map_optional_name(
|
|
4658
|
+
temp_map, op.score_threshold
|
|
4659
|
+
),
|
|
4660
|
+
output=temp_map.get(op.output, op.output),
|
|
4661
|
+
boxes_shape=op.boxes_shape,
|
|
4662
|
+
scores_shape=op.scores_shape,
|
|
4663
|
+
output_shape=op.output_shape,
|
|
4664
|
+
center_point_box=op.center_point_box,
|
|
4665
|
+
boxes_dtype=op.boxes_dtype,
|
|
4666
|
+
output_dtype=op.output_dtype,
|
|
4667
|
+
max_output_dtype=op.max_output_dtype,
|
|
4668
|
+
max_output_shape=op.max_output_shape,
|
|
4669
|
+
iou_threshold_dtype=op.iou_threshold_dtype,
|
|
4670
|
+
iou_threshold_shape=op.iou_threshold_shape,
|
|
4671
|
+
score_threshold_dtype=op.score_threshold_dtype,
|
|
4672
|
+
score_threshold_shape=op.score_threshold_shape,
|
|
4673
|
+
)
|
|
5352
4674
|
if isinstance(op, ExpandOp):
|
|
5353
4675
|
return ExpandOp(
|
|
5354
4676
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -5673,67 +4995,98 @@ class CEmitter:
|
|
|
5673
4995
|
dtype=op.dtype,
|
|
5674
4996
|
)
|
|
5675
4997
|
|
|
4998
|
+
def render_op(self, op: OpBase, ctx: EmitContext) -> str:
|
|
4999
|
+
if self._emit_state is None:
|
|
5000
|
+
raise CodegenError("Emitter state not initialized")
|
|
5001
|
+
state = self._emit_state
|
|
5002
|
+
dtype = self._op_output_dtype(op)
|
|
5003
|
+
templates = state.templates
|
|
5004
|
+
return self._render_op(
|
|
5005
|
+
state.model,
|
|
5006
|
+
op,
|
|
5007
|
+
ctx.op_index,
|
|
5008
|
+
array_suffix="",
|
|
5009
|
+
loop_vars=(),
|
|
5010
|
+
c_type=dtype.c_type,
|
|
5011
|
+
zero_literal=dtype.zero_literal,
|
|
5012
|
+
min_literal=dtype.min_literal,
|
|
5013
|
+
max_literal=dtype.max_literal,
|
|
5014
|
+
binary_template=templates["binary"],
|
|
5015
|
+
multi_input_template=templates["multi_input"],
|
|
5016
|
+
where_template=templates["where"],
|
|
5017
|
+
unary_template=templates["unary"],
|
|
5018
|
+
clip_template=templates["clip"],
|
|
5019
|
+
cast_template=templates["cast"],
|
|
5020
|
+
quantize_linear_template=templates["quantize_linear"],
|
|
5021
|
+
qlinear_matmul_template=templates["qlinear_matmul"],
|
|
5022
|
+
matmul_template=templates["matmul"],
|
|
5023
|
+
einsum_template=templates["einsum"],
|
|
5024
|
+
gemm_template=templates["gemm"],
|
|
5025
|
+
attention_template=templates["attention"],
|
|
5026
|
+
rotary_embedding_template=templates["rotary_embedding"],
|
|
5027
|
+
conv_template=templates["conv"],
|
|
5028
|
+
conv_transpose_template=templates["conv_transpose"],
|
|
5029
|
+
avg_pool_template=templates["avg_pool"],
|
|
5030
|
+
lp_pool_template=templates["lp_pool"],
|
|
5031
|
+
batch_norm_template=templates["batch_norm"],
|
|
5032
|
+
lp_norm_template=templates["lp_norm"],
|
|
5033
|
+
instance_norm_template=templates["instance_norm"],
|
|
5034
|
+
group_norm_template=templates["group_norm"],
|
|
5035
|
+
layer_norm_template=templates["layer_norm"],
|
|
5036
|
+
mean_variance_norm_template=templates["mean_variance_norm"],
|
|
5037
|
+
rms_norm_template=templates["rms_norm"],
|
|
5038
|
+
lrn_template=templates["lrn"],
|
|
5039
|
+
lstm_template=templates["lstm"],
|
|
5040
|
+
adagrad_template=templates["adagrad"],
|
|
5041
|
+
softmax_template=templates["softmax"],
|
|
5042
|
+
logsoftmax_template=templates["logsoftmax"],
|
|
5043
|
+
hardmax_template=templates["hardmax"],
|
|
5044
|
+
nllloss_template=templates["nllloss"],
|
|
5045
|
+
softmax_cross_entropy_loss_template=templates[
|
|
5046
|
+
"softmax_cross_entropy_loss"
|
|
5047
|
+
],
|
|
5048
|
+
maxpool_template=templates["maxpool"],
|
|
5049
|
+
concat_template=templates["concat"],
|
|
5050
|
+
gather_elements_template=templates["gather_elements"],
|
|
5051
|
+
gather_template=templates["gather"],
|
|
5052
|
+
gather_nd_template=templates["gather_nd"],
|
|
5053
|
+
scatter_nd_template=templates["scatter_nd"],
|
|
5054
|
+
transpose_template=templates["transpose"],
|
|
5055
|
+
reshape_template=templates["reshape"],
|
|
5056
|
+
identity_template=templates["identity"],
|
|
5057
|
+
eye_like_template=templates["eye_like"],
|
|
5058
|
+
trilu_template=templates["trilu"],
|
|
5059
|
+
tile_template=templates["tile"],
|
|
5060
|
+
pad_template=templates["pad"],
|
|
5061
|
+
depth_to_space_template=templates["depth_to_space"],
|
|
5062
|
+
space_to_depth_template=templates["space_to_depth"],
|
|
5063
|
+
slice_template=templates["slice"],
|
|
5064
|
+
slice_dynamic_template=templates["slice_dynamic"],
|
|
5065
|
+
resize_template=templates["resize"],
|
|
5066
|
+
grid_sample_template=templates["grid_sample"],
|
|
5067
|
+
reduce_template=templates["reduce"],
|
|
5068
|
+
reduce_dynamic_template=templates["reduce_dynamic"],
|
|
5069
|
+
arg_reduce_template=templates["arg_reduce"],
|
|
5070
|
+
topk_template=templates["topk"],
|
|
5071
|
+
constant_of_shape_template=templates["constant_of_shape"],
|
|
5072
|
+
shape_template=templates["shape"],
|
|
5073
|
+
size_template=templates["size"],
|
|
5074
|
+
nonzero_template=templates["nonzero"],
|
|
5075
|
+
nonmax_suppression_template=templates["nonmax_suppression"],
|
|
5076
|
+
expand_template=templates["expand"],
|
|
5077
|
+
cumsum_template=templates["cumsum"],
|
|
5078
|
+
range_template=templates["range"],
|
|
5079
|
+
one_hot_template=templates["one_hot"],
|
|
5080
|
+
split_template=templates["split"],
|
|
5081
|
+
scalar_registry=state.scalar_registry,
|
|
5082
|
+
dim_args=state.dim_args,
|
|
5083
|
+
tensor_dim_names=state.tensor_dim_names,
|
|
5084
|
+
)
|
|
5085
|
+
|
|
5676
5086
|
def _render_op(
|
|
5677
5087
|
self,
|
|
5678
5088
|
model: LoweredModel,
|
|
5679
|
-
op:
|
|
5680
|
-
| MultiInputBinaryOp
|
|
5681
|
-
| WhereOp
|
|
5682
|
-
| UnaryOp
|
|
5683
|
-
| ClipOp
|
|
5684
|
-
| CastOp
|
|
5685
|
-
| QuantizeLinearOp
|
|
5686
|
-
| MatMulOp
|
|
5687
|
-
| EinsumOp
|
|
5688
|
-
| GemmOp
|
|
5689
|
-
| AttentionOp
|
|
5690
|
-
| ConvOp
|
|
5691
|
-
| ConvTransposeOp
|
|
5692
|
-
| AveragePoolOp
|
|
5693
|
-
| LpPoolOp
|
|
5694
|
-
| BatchNormOp
|
|
5695
|
-
| LpNormalizationOp
|
|
5696
|
-
| InstanceNormalizationOp
|
|
5697
|
-
| GroupNormalizationOp
|
|
5698
|
-
| LayerNormalizationOp
|
|
5699
|
-
| MeanVarianceNormalizationOp
|
|
5700
|
-
| RMSNormalizationOp
|
|
5701
|
-
| LrnOp
|
|
5702
|
-
| LstmOp
|
|
5703
|
-
| SoftmaxOp
|
|
5704
|
-
| LogSoftmaxOp
|
|
5705
|
-
| HardmaxOp
|
|
5706
|
-
| NegativeLogLikelihoodLossOp
|
|
5707
|
-
| SoftmaxCrossEntropyLossOp
|
|
5708
|
-
| MaxPoolOp
|
|
5709
|
-
| ConcatOp
|
|
5710
|
-
| GatherElementsOp
|
|
5711
|
-
| GatherOp
|
|
5712
|
-
| GatherNDOp
|
|
5713
|
-
| ScatterNDOp
|
|
5714
|
-
| TransposeOp
|
|
5715
|
-
| ReshapeOp
|
|
5716
|
-
| IdentityOp
|
|
5717
|
-
| EyeLikeOp
|
|
5718
|
-
| TriluOp
|
|
5719
|
-
| TileOp
|
|
5720
|
-
| DepthToSpaceOp
|
|
5721
|
-
| SpaceToDepthOp
|
|
5722
|
-
| SliceOp
|
|
5723
|
-
| ResizeOp
|
|
5724
|
-
| GridSampleOp
|
|
5725
|
-
| ReduceOp
|
|
5726
|
-
| ArgReduceOp
|
|
5727
|
-
| TopKOp
|
|
5728
|
-
| ConstantOfShapeOp
|
|
5729
|
-
| ShapeOp
|
|
5730
|
-
| SizeOp
|
|
5731
|
-
| NonZeroOp
|
|
5732
|
-
| ExpandOp
|
|
5733
|
-
| CumSumOp
|
|
5734
|
-
| RangeOp
|
|
5735
|
-
| OneHotOp
|
|
5736
|
-
| SplitOp,
|
|
5089
|
+
op: OpBase,
|
|
5737
5090
|
index: int,
|
|
5738
5091
|
*,
|
|
5739
5092
|
array_suffix: str,
|
|
@@ -5749,10 +5102,12 @@ class CEmitter:
|
|
|
5749
5102
|
clip_template,
|
|
5750
5103
|
cast_template,
|
|
5751
5104
|
quantize_linear_template,
|
|
5105
|
+
qlinear_matmul_template,
|
|
5752
5106
|
matmul_template,
|
|
5753
5107
|
einsum_template,
|
|
5754
5108
|
gemm_template,
|
|
5755
5109
|
attention_template,
|
|
5110
|
+
rotary_embedding_template,
|
|
5756
5111
|
conv_template,
|
|
5757
5112
|
conv_transpose_template,
|
|
5758
5113
|
avg_pool_template,
|
|
@@ -5766,6 +5121,7 @@ class CEmitter:
|
|
|
5766
5121
|
rms_norm_template,
|
|
5767
5122
|
lrn_template,
|
|
5768
5123
|
lstm_template,
|
|
5124
|
+
adagrad_template,
|
|
5769
5125
|
softmax_template,
|
|
5770
5126
|
logsoftmax_template,
|
|
5771
5127
|
hardmax_template,
|
|
@@ -5798,6 +5154,7 @@ class CEmitter:
|
|
|
5798
5154
|
shape_template,
|
|
5799
5155
|
size_template,
|
|
5800
5156
|
nonzero_template,
|
|
5157
|
+
nonmax_suppression_template,
|
|
5801
5158
|
expand_template,
|
|
5802
5159
|
cumsum_template,
|
|
5803
5160
|
range_template,
|
|
@@ -5819,6 +5176,11 @@ class CEmitter:
|
|
|
5819
5176
|
return f"{node_comment}\n{_format_c_indentation(rendered)}"
|
|
5820
5177
|
|
|
5821
5178
|
if isinstance(op, BinaryOp):
|
|
5179
|
+
input0_shape = self._ctx_shape(op.input0)
|
|
5180
|
+
input1_shape = self._ctx_shape(op.input1)
|
|
5181
|
+
output_shape = self._ctx_shape(op.output)
|
|
5182
|
+
input_dtype = self._ctx_dtype(op.input0)
|
|
5183
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
5822
5184
|
params = self._shared_param_map(
|
|
5823
5185
|
[
|
|
5824
5186
|
("input0", op.input0),
|
|
@@ -5832,11 +5194,11 @@ class CEmitter:
|
|
|
5832
5194
|
and op.function not in COMPARE_FUNCTIONS
|
|
5833
5195
|
):
|
|
5834
5196
|
scalar_operator = self._scalar_function_name(
|
|
5835
|
-
op.function,
|
|
5197
|
+
op.function, input_dtype, scalar_registry
|
|
5836
5198
|
)
|
|
5837
5199
|
op_spec = binary_op_symbol(
|
|
5838
5200
|
op.function,
|
|
5839
|
-
dtype=
|
|
5201
|
+
dtype=input_dtype,
|
|
5840
5202
|
validate_attrs=False,
|
|
5841
5203
|
)
|
|
5842
5204
|
if op_spec is None:
|
|
@@ -5844,17 +5206,19 @@ class CEmitter:
|
|
|
5844
5206
|
f"Unsupported binary operator for rendering: {op.function.value}"
|
|
5845
5207
|
)
|
|
5846
5208
|
output_dim_names = _dim_names_for(op.output)
|
|
5847
|
-
shape = CEmitter._shape_dim_exprs(
|
|
5848
|
-
loop_vars = CEmitter._loop_vars(
|
|
5849
|
-
output_suffix = self._param_array_suffix(
|
|
5209
|
+
shape = CEmitter._shape_dim_exprs(output_shape, output_dim_names)
|
|
5210
|
+
loop_vars = CEmitter._loop_vars(output_shape)
|
|
5211
|
+
output_suffix = self._param_array_suffix(
|
|
5212
|
+
output_shape, output_dim_names
|
|
5213
|
+
)
|
|
5850
5214
|
input0_suffix = self._param_array_suffix(
|
|
5851
|
-
|
|
5215
|
+
input0_shape, _dim_names_for(op.input0)
|
|
5852
5216
|
)
|
|
5853
5217
|
input1_suffix = self._param_array_suffix(
|
|
5854
|
-
|
|
5218
|
+
input1_shape, _dim_names_for(op.input1)
|
|
5855
5219
|
)
|
|
5856
|
-
input_c_type =
|
|
5857
|
-
output_c_type =
|
|
5220
|
+
input_c_type = input_dtype.c_type
|
|
5221
|
+
output_c_type = output_dtype.c_type
|
|
5858
5222
|
param_decls = self._build_param_decls(
|
|
5859
5223
|
[
|
|
5860
5224
|
(params["input0"], input_c_type, input0_suffix, True),
|
|
@@ -5877,14 +5241,14 @@ class CEmitter:
|
|
|
5877
5241
|
}
|
|
5878
5242
|
left_expr = CEmitter._broadcast_index_expr(
|
|
5879
5243
|
params["input0"],
|
|
5880
|
-
|
|
5881
|
-
|
|
5244
|
+
input0_shape,
|
|
5245
|
+
output_shape,
|
|
5882
5246
|
loop_vars,
|
|
5883
5247
|
)
|
|
5884
5248
|
right_expr = CEmitter._broadcast_index_expr(
|
|
5885
5249
|
params["input1"],
|
|
5886
|
-
|
|
5887
|
-
|
|
5250
|
+
input1_shape,
|
|
5251
|
+
output_shape,
|
|
5888
5252
|
loop_vars,
|
|
5889
5253
|
)
|
|
5890
5254
|
operator_expr = None
|
|
@@ -5910,6 +5274,9 @@ class CEmitter:
|
|
|
5910
5274
|
).rstrip()
|
|
5911
5275
|
return with_node_comment(rendered)
|
|
5912
5276
|
if isinstance(op, MultiInputBinaryOp):
|
|
5277
|
+
output_shape = self._ctx_shape(op.output)
|
|
5278
|
+
input_dtype = self._ctx_dtype(op.inputs[0])
|
|
5279
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
5913
5280
|
params = self._shared_param_map(
|
|
5914
5281
|
[
|
|
5915
5282
|
*( (f"input{idx}", name) for idx, name in enumerate(op.inputs) ),
|
|
@@ -5923,11 +5290,11 @@ class CEmitter:
|
|
|
5923
5290
|
and op.function != ScalarFunction.MEAN
|
|
5924
5291
|
):
|
|
5925
5292
|
scalar_operator = self._scalar_function_name(
|
|
5926
|
-
op.function,
|
|
5293
|
+
op.function, input_dtype, scalar_registry
|
|
5927
5294
|
)
|
|
5928
5295
|
op_spec = binary_op_symbol(
|
|
5929
5296
|
op.function,
|
|
5930
|
-
dtype=
|
|
5297
|
+
dtype=input_dtype,
|
|
5931
5298
|
validate_attrs=False,
|
|
5932
5299
|
)
|
|
5933
5300
|
if op_spec is None:
|
|
@@ -5936,11 +5303,13 @@ class CEmitter:
|
|
|
5936
5303
|
f"{op.function.value}"
|
|
5937
5304
|
)
|
|
5938
5305
|
output_dim_names = _dim_names_for(op.output)
|
|
5939
|
-
shape = CEmitter._shape_dim_exprs(
|
|
5940
|
-
loop_vars = CEmitter._loop_vars(
|
|
5941
|
-
array_suffix = self._param_array_suffix(
|
|
5942
|
-
|
|
5943
|
-
|
|
5306
|
+
shape = CEmitter._shape_dim_exprs(output_shape, output_dim_names)
|
|
5307
|
+
loop_vars = CEmitter._loop_vars(output_shape)
|
|
5308
|
+
array_suffix = self._param_array_suffix(
|
|
5309
|
+
output_shape, output_dim_names
|
|
5310
|
+
)
|
|
5311
|
+
input_c_type = input_dtype.c_type
|
|
5312
|
+
output_c_type = output_dtype.c_type
|
|
5944
5313
|
input_names = [
|
|
5945
5314
|
params[f"input{idx}"] for idx in range(len(op.inputs))
|
|
5946
5315
|
]
|
|
@@ -5999,6 +5368,11 @@ class CEmitter:
|
|
|
5999
5368
|
).rstrip()
|
|
6000
5369
|
return with_node_comment(rendered)
|
|
6001
5370
|
if isinstance(op, WhereOp):
|
|
5371
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
5372
|
+
condition_shape = self._ctx_shape(op.condition)
|
|
5373
|
+
x_shape = self._ctx_shape(op.input_x)
|
|
5374
|
+
y_shape = self._ctx_shape(op.input_y)
|
|
5375
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
6002
5376
|
params = self._shared_param_map(
|
|
6003
5377
|
[
|
|
6004
5378
|
("condition", op.condition),
|
|
@@ -6009,32 +5383,32 @@ class CEmitter:
|
|
|
6009
5383
|
)
|
|
6010
5384
|
output_dim_names = _dim_names_for(op.output)
|
|
6011
5385
|
output_shape = CEmitter._shape_dim_exprs(
|
|
6012
|
-
|
|
5386
|
+
output_shape_raw, output_dim_names
|
|
6013
5387
|
)
|
|
6014
|
-
loop_vars = CEmitter._loop_vars(
|
|
5388
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
6015
5389
|
output_array_suffix = self._param_array_suffix(
|
|
6016
|
-
|
|
5390
|
+
output_shape_raw, output_dim_names
|
|
6017
5391
|
)
|
|
6018
5392
|
condition_array_suffix = self._param_array_suffix(
|
|
6019
|
-
|
|
5393
|
+
condition_shape, _dim_names_for(op.condition)
|
|
6020
5394
|
)
|
|
6021
5395
|
x_array_suffix = self._param_array_suffix(
|
|
6022
|
-
|
|
5396
|
+
x_shape, _dim_names_for(op.input_x)
|
|
6023
5397
|
)
|
|
6024
5398
|
y_array_suffix = self._param_array_suffix(
|
|
6025
|
-
|
|
5399
|
+
y_shape, _dim_names_for(op.input_y)
|
|
6026
5400
|
)
|
|
6027
5401
|
condition_expr = CEmitter._broadcast_index_expr(
|
|
6028
5402
|
params["condition"],
|
|
6029
|
-
|
|
6030
|
-
|
|
5403
|
+
condition_shape,
|
|
5404
|
+
output_shape_raw,
|
|
6031
5405
|
loop_vars,
|
|
6032
5406
|
)
|
|
6033
5407
|
x_expr = CEmitter._broadcast_index_expr(
|
|
6034
|
-
params["input_x"],
|
|
5408
|
+
params["input_x"], x_shape, output_shape_raw, loop_vars
|
|
6035
5409
|
)
|
|
6036
5410
|
y_expr = CEmitter._broadcast_index_expr(
|
|
6037
|
-
params["input_y"],
|
|
5411
|
+
params["input_y"], y_shape, output_shape_raw, loop_vars
|
|
6038
5412
|
)
|
|
6039
5413
|
output_expr = f"{params['output']}" + "".join(
|
|
6040
5414
|
f"[{var}]" for var in loop_vars
|
|
@@ -6047,11 +5421,11 @@ class CEmitter:
|
|
|
6047
5421
|
condition_array_suffix,
|
|
6048
5422
|
True,
|
|
6049
5423
|
),
|
|
6050
|
-
(params["input_x"],
|
|
6051
|
-
(params["input_y"],
|
|
5424
|
+
(params["input_x"], output_dtype.c_type, x_array_suffix, True),
|
|
5425
|
+
(params["input_y"], output_dtype.c_type, y_array_suffix, True),
|
|
6052
5426
|
(
|
|
6053
5427
|
params["output"],
|
|
6054
|
-
|
|
5428
|
+
output_dtype.c_type,
|
|
6055
5429
|
output_array_suffix,
|
|
6056
5430
|
False,
|
|
6057
5431
|
),
|
|
@@ -6074,8 +5448,8 @@ class CEmitter:
|
|
|
6074
5448
|
x_expr=x_expr,
|
|
6075
5449
|
y_expr=y_expr,
|
|
6076
5450
|
output_expr=output_expr,
|
|
6077
|
-
input_c_type=
|
|
6078
|
-
output_c_type=
|
|
5451
|
+
input_c_type=output_dtype.c_type,
|
|
5452
|
+
output_c_type=output_dtype.c_type,
|
|
6079
5453
|
condition_c_type=ScalarType.BOOL.c_type,
|
|
6080
5454
|
dim_args=dim_args,
|
|
6081
5455
|
params=param_decls,
|
|
@@ -6363,6 +5737,17 @@ class CEmitter:
|
|
|
6363
5737
|
).rstrip()
|
|
6364
5738
|
return with_node_comment(rendered)
|
|
6365
5739
|
if isinstance(op, AttentionOp):
|
|
5740
|
+
if scalar_registry is None:
|
|
5741
|
+
raise CodegenError(
|
|
5742
|
+
"Scalar function registry is required for Attention codegen."
|
|
5743
|
+
)
|
|
5744
|
+
max_fn = self._scalar_function_name(
|
|
5745
|
+
ScalarFunction.MAXIMUM, op.dtype, scalar_registry
|
|
5746
|
+
)
|
|
5747
|
+
if max_fn is None:
|
|
5748
|
+
raise CodegenError(
|
|
5749
|
+
"Failed to resolve scalar maximum function for Attention."
|
|
5750
|
+
)
|
|
6366
5751
|
params = self._shared_param_map(
|
|
6367
5752
|
[
|
|
6368
5753
|
("input_q", op.input_q),
|
|
@@ -6543,6 +5928,7 @@ class CEmitter:
|
|
|
6543
5928
|
scale_literal=CEmitter._format_floating(op.scale, op.dtype),
|
|
6544
5929
|
softcap_literal=CEmitter._format_floating(op.softcap, op.dtype),
|
|
6545
5930
|
one_literal=CEmitter._format_literal(op.dtype, 1),
|
|
5931
|
+
max_fn=max_fn,
|
|
6546
5932
|
exp_fn=CEmitter._math_fn(op.dtype, "expf", "exp"),
|
|
6547
5933
|
tanh_fn=CEmitter._math_fn(op.dtype, "tanhf", "tanh"),
|
|
6548
5934
|
is_causal=int(op.is_causal),
|
|
@@ -6580,9 +5966,74 @@ class CEmitter:
|
|
|
6580
5966
|
input_past_value_suffix=input_past_value_suffix,
|
|
6581
5967
|
input_nonpad_suffix=input_nonpad_suffix,
|
|
6582
5968
|
output_suffix=output_suffix,
|
|
6583
|
-
output_present_key_suffix=output_present_key_suffix,
|
|
6584
|
-
output_present_value_suffix=output_present_value_suffix,
|
|
6585
|
-
output_qk_matmul_suffix=output_qk_matmul_suffix,
|
|
5969
|
+
output_present_key_suffix=output_present_key_suffix,
|
|
5970
|
+
output_present_value_suffix=output_present_value_suffix,
|
|
5971
|
+
output_qk_matmul_suffix=output_qk_matmul_suffix,
|
|
5972
|
+
).rstrip()
|
|
5973
|
+
return with_node_comment(rendered)
|
|
5974
|
+
if isinstance(op, RotaryEmbeddingOp):
|
|
5975
|
+
params = self._shared_param_map(
|
|
5976
|
+
[
|
|
5977
|
+
("input0", op.input0),
|
|
5978
|
+
("cos_cache", op.cos_cache),
|
|
5979
|
+
("sin_cache", op.sin_cache),
|
|
5980
|
+
("position_ids", op.position_ids),
|
|
5981
|
+
("output", op.output),
|
|
5982
|
+
]
|
|
5983
|
+
)
|
|
5984
|
+
input_suffix = self._param_array_suffix(
|
|
5985
|
+
op.input_shape, _dim_names_for(op.input0)
|
|
5986
|
+
)
|
|
5987
|
+
cos_suffix = self._param_array_suffix(op.cos_shape)
|
|
5988
|
+
sin_suffix = self._param_array_suffix(op.sin_shape)
|
|
5989
|
+
position_suffix = (
|
|
5990
|
+
self._param_array_suffix(op.position_ids_shape)
|
|
5991
|
+
if op.position_ids_shape is not None
|
|
5992
|
+
else ""
|
|
5993
|
+
)
|
|
5994
|
+
output_suffix = self._param_array_suffix(
|
|
5995
|
+
op.input_shape, _dim_names_for(op.output)
|
|
5996
|
+
)
|
|
5997
|
+
param_decls = self._build_param_decls(
|
|
5998
|
+
[
|
|
5999
|
+
(params["input0"], c_type, input_suffix, True),
|
|
6000
|
+
(params["cos_cache"], c_type, cos_suffix, True),
|
|
6001
|
+
(params["sin_cache"], c_type, sin_suffix, True),
|
|
6002
|
+
(
|
|
6003
|
+
params["position_ids"],
|
|
6004
|
+
op.position_ids_dtype.c_type,
|
|
6005
|
+
position_suffix,
|
|
6006
|
+
True,
|
|
6007
|
+
)
|
|
6008
|
+
if params["position_ids"]
|
|
6009
|
+
else (None, "", "", True),
|
|
6010
|
+
(params["output"], c_type, output_suffix, False),
|
|
6011
|
+
]
|
|
6012
|
+
)
|
|
6013
|
+
rendered = rotary_embedding_template.render(
|
|
6014
|
+
model_name=model.name,
|
|
6015
|
+
op_name=op_name,
|
|
6016
|
+
input0=params["input0"],
|
|
6017
|
+
cos_cache=params["cos_cache"],
|
|
6018
|
+
sin_cache=params["sin_cache"],
|
|
6019
|
+
position_ids=params["position_ids"],
|
|
6020
|
+
output=params["output"],
|
|
6021
|
+
params=param_decls,
|
|
6022
|
+
c_type=c_type,
|
|
6023
|
+
input_suffix=input_suffix,
|
|
6024
|
+
cos_suffix=cos_suffix,
|
|
6025
|
+
sin_suffix=sin_suffix,
|
|
6026
|
+
position_suffix=position_suffix,
|
|
6027
|
+
output_suffix=output_suffix,
|
|
6028
|
+
batch=op.batch,
|
|
6029
|
+
seq_len=op.seq_len,
|
|
6030
|
+
num_heads=op.num_heads,
|
|
6031
|
+
head_size=op.head_size,
|
|
6032
|
+
rotary_dim=op.rotary_dim,
|
|
6033
|
+
rotary_dim_half=op.rotary_dim_half,
|
|
6034
|
+
input_rank=op.input_rank,
|
|
6035
|
+
interleaved=int(op.interleaved),
|
|
6036
|
+
has_position_ids=int(op.position_ids is not None),
|
|
6586
6037
|
).rstrip()
|
|
6587
6038
|
return with_node_comment(rendered)
|
|
6588
6039
|
if isinstance(op, ConvOp):
|
|
@@ -7432,15 +6883,142 @@ class CEmitter:
|
|
|
7432
6883
|
activation_functions=activation_functions,
|
|
7433
6884
|
).rstrip()
|
|
7434
6885
|
return with_node_comment(rendered)
|
|
6886
|
+
if isinstance(op, AdagradOp):
|
|
6887
|
+
params = self._shared_param_map(
|
|
6888
|
+
[
|
|
6889
|
+
("rate", op.rate),
|
|
6890
|
+
("timestep", op.timestep),
|
|
6891
|
+
*(
|
|
6892
|
+
(f"input{idx}", name)
|
|
6893
|
+
for idx, name in enumerate(op.inputs)
|
|
6894
|
+
),
|
|
6895
|
+
*(
|
|
6896
|
+
(f"grad{idx}", name)
|
|
6897
|
+
for idx, name in enumerate(op.gradients)
|
|
6898
|
+
),
|
|
6899
|
+
*(
|
|
6900
|
+
(f"acc{idx}", name)
|
|
6901
|
+
for idx, name in enumerate(op.accumulators)
|
|
6902
|
+
),
|
|
6903
|
+
*(
|
|
6904
|
+
(f"output{idx}", name)
|
|
6905
|
+
for idx, name in enumerate(op.outputs)
|
|
6906
|
+
),
|
|
6907
|
+
*(
|
|
6908
|
+
(f"acc_output{idx}", name)
|
|
6909
|
+
for idx, name in enumerate(op.accumulator_outputs)
|
|
6910
|
+
),
|
|
6911
|
+
]
|
|
6912
|
+
)
|
|
6913
|
+
rate_suffix = self._param_array_suffix(
|
|
6914
|
+
op.rate_shape, _dim_names_for(op.rate)
|
|
6915
|
+
)
|
|
6916
|
+
timestep_suffix = self._param_array_suffix(
|
|
6917
|
+
op.timestep_shape, _dim_names_for(op.timestep)
|
|
6918
|
+
)
|
|
6919
|
+
param_specs = [
|
|
6920
|
+
(params["rate"], op.rate_dtype.c_type, rate_suffix, True),
|
|
6921
|
+
(
|
|
6922
|
+
params["timestep"],
|
|
6923
|
+
op.timestep_dtype.c_type,
|
|
6924
|
+
timestep_suffix,
|
|
6925
|
+
True,
|
|
6926
|
+
),
|
|
6927
|
+
]
|
|
6928
|
+
tensor_specs = []
|
|
6929
|
+
for idx, shape in enumerate(op.output_shapes):
|
|
6930
|
+
input_suffix = self._param_array_suffix(
|
|
6931
|
+
op.tensor_shapes[idx], _dim_names_for(op.inputs[idx])
|
|
6932
|
+
)
|
|
6933
|
+
grad_suffix = self._param_array_suffix(
|
|
6934
|
+
op.tensor_shapes[idx], _dim_names_for(op.gradients[idx])
|
|
6935
|
+
)
|
|
6936
|
+
acc_suffix = self._param_array_suffix(
|
|
6937
|
+
op.tensor_shapes[idx], _dim_names_for(op.accumulators[idx])
|
|
6938
|
+
)
|
|
6939
|
+
output_suffix = self._param_array_suffix(
|
|
6940
|
+
op.output_shapes[idx], _dim_names_for(op.outputs[idx])
|
|
6941
|
+
)
|
|
6942
|
+
acc_output_suffix = self._param_array_suffix(
|
|
6943
|
+
op.output_shapes[idx],
|
|
6944
|
+
_dim_names_for(op.accumulator_outputs[idx]),
|
|
6945
|
+
)
|
|
6946
|
+
param_specs.extend(
|
|
6947
|
+
[
|
|
6948
|
+
(params[f"input{idx}"], c_type, input_suffix, True),
|
|
6949
|
+
(params[f"grad{idx}"], c_type, grad_suffix, True),
|
|
6950
|
+
(params[f"acc{idx}"], c_type, acc_suffix, True),
|
|
6951
|
+
(params[f"output{idx}"], c_type, output_suffix, False),
|
|
6952
|
+
(
|
|
6953
|
+
params[f"acc_output{idx}"],
|
|
6954
|
+
c_type,
|
|
6955
|
+
acc_output_suffix,
|
|
6956
|
+
False,
|
|
6957
|
+
),
|
|
6958
|
+
]
|
|
6959
|
+
)
|
|
6960
|
+
output_dim_names = _dim_names_for(op.outputs[idx])
|
|
6961
|
+
shape_exprs = CEmitter._shape_dim_exprs(
|
|
6962
|
+
shape, output_dim_names
|
|
6963
|
+
)
|
|
6964
|
+
loop_vars = CEmitter._loop_vars(shape)
|
|
6965
|
+
index_suffix = "".join(f"[{var}]" for var in loop_vars)
|
|
6966
|
+
tensor_specs.append(
|
|
6967
|
+
{
|
|
6968
|
+
"shape": shape_exprs,
|
|
6969
|
+
"loop_vars": loop_vars,
|
|
6970
|
+
"input_expr": f"{params[f'input{idx}']}{index_suffix}",
|
|
6971
|
+
"grad_expr": f"{params[f'grad{idx}']}{index_suffix}",
|
|
6972
|
+
"acc_expr": f"{params[f'acc{idx}']}{index_suffix}",
|
|
6973
|
+
"output_expr": f"{params[f'output{idx}']}{index_suffix}",
|
|
6974
|
+
"acc_output_expr": f"{params[f'acc_output{idx}']}{index_suffix}",
|
|
6975
|
+
}
|
|
6976
|
+
)
|
|
6977
|
+
param_decls = self._build_param_decls(param_specs)
|
|
6978
|
+
rendered = adagrad_template.render(
|
|
6979
|
+
model_name=model.name,
|
|
6980
|
+
op_name=op_name,
|
|
6981
|
+
rate=params["rate"],
|
|
6982
|
+
timestep=params["timestep"],
|
|
6983
|
+
params=param_decls,
|
|
6984
|
+
c_type=c_type,
|
|
6985
|
+
one_literal=CEmitter._format_literal(op.dtype, 1),
|
|
6986
|
+
decay_factor_literal=CEmitter._format_floating(
|
|
6987
|
+
op.decay_factor, op.dtype
|
|
6988
|
+
),
|
|
6989
|
+
norm_coefficient_literal=CEmitter._format_floating(
|
|
6990
|
+
op.norm_coefficient, op.dtype
|
|
6991
|
+
),
|
|
6992
|
+
epsilon_literal=CEmitter._format_floating(op.epsilon, op.dtype),
|
|
6993
|
+
sqrt_fn=CEmitter._math_fn(op.dtype, "sqrtf", "sqrt"),
|
|
6994
|
+
tensors=tensor_specs,
|
|
6995
|
+
).rstrip()
|
|
6996
|
+
return with_node_comment(rendered)
|
|
7435
6997
|
if isinstance(op, SoftmaxOp):
|
|
6998
|
+
if scalar_registry is None:
|
|
6999
|
+
raise CodegenError(
|
|
7000
|
+
"Scalar function registry is required for Softmax rendering."
|
|
7001
|
+
)
|
|
7002
|
+
output_shape = self._ctx_shape(op.output)
|
|
7003
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
7004
|
+
outer = self._derived(op, "outer")
|
|
7005
|
+
axis_size = self._derived(op, "axis_size")
|
|
7006
|
+
inner = self._derived(op, "inner")
|
|
7007
|
+
max_fn = self._scalar_function_name(
|
|
7008
|
+
ScalarFunction.MAXIMUM, output_dtype, scalar_registry
|
|
7009
|
+
)
|
|
7010
|
+
if max_fn is None:
|
|
7011
|
+
raise CodegenError(
|
|
7012
|
+
"Failed to resolve scalar maximum function for Softmax."
|
|
7013
|
+
)
|
|
7436
7014
|
params = self._shared_param_map(
|
|
7437
7015
|
[("input0", op.input0), ("output", op.output)]
|
|
7438
7016
|
)
|
|
7439
|
-
array_suffix = self._param_array_suffix(
|
|
7017
|
+
array_suffix = self._param_array_suffix(output_shape)
|
|
7440
7018
|
param_decls = self._build_param_decls(
|
|
7441
7019
|
[
|
|
7442
|
-
(params["input0"], c_type, array_suffix, True),
|
|
7443
|
-
(params["output"], c_type, array_suffix, False),
|
|
7020
|
+
(params["input0"], output_dtype.c_type, array_suffix, True),
|
|
7021
|
+
(params["output"], output_dtype.c_type, array_suffix, False),
|
|
7444
7022
|
]
|
|
7445
7023
|
)
|
|
7446
7024
|
rendered = softmax_template.render(
|
|
@@ -7449,23 +7027,40 @@ class CEmitter:
|
|
|
7449
7027
|
input0=params["input0"],
|
|
7450
7028
|
output=params["output"],
|
|
7451
7029
|
params=param_decls,
|
|
7452
|
-
c_type=c_type,
|
|
7030
|
+
c_type=output_dtype.c_type,
|
|
7453
7031
|
array_suffix=array_suffix,
|
|
7454
|
-
outer=
|
|
7455
|
-
axis_size=
|
|
7456
|
-
inner=
|
|
7457
|
-
|
|
7032
|
+
outer=outer,
|
|
7033
|
+
axis_size=axis_size,
|
|
7034
|
+
inner=inner,
|
|
7035
|
+
max_fn=max_fn,
|
|
7036
|
+
exp_fn=CEmitter._math_fn(output_dtype, "expf", "exp"),
|
|
7458
7037
|
).rstrip()
|
|
7459
7038
|
return with_node_comment(rendered)
|
|
7460
7039
|
if isinstance(op, LogSoftmaxOp):
|
|
7040
|
+
if scalar_registry is None:
|
|
7041
|
+
raise CodegenError(
|
|
7042
|
+
"Scalar function registry is required for LogSoftmax rendering."
|
|
7043
|
+
)
|
|
7044
|
+
output_shape = self._ctx_shape(op.output)
|
|
7045
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
7046
|
+
outer = self._derived(op, "outer")
|
|
7047
|
+
axis_size = self._derived(op, "axis_size")
|
|
7048
|
+
inner = self._derived(op, "inner")
|
|
7049
|
+
max_fn = self._scalar_function_name(
|
|
7050
|
+
ScalarFunction.MAXIMUM, output_dtype, scalar_registry
|
|
7051
|
+
)
|
|
7052
|
+
if max_fn is None:
|
|
7053
|
+
raise CodegenError(
|
|
7054
|
+
"Failed to resolve scalar maximum function for LogSoftmax."
|
|
7055
|
+
)
|
|
7461
7056
|
params = self._shared_param_map(
|
|
7462
7057
|
[("input0", op.input0), ("output", op.output)]
|
|
7463
7058
|
)
|
|
7464
|
-
array_suffix = self._param_array_suffix(
|
|
7059
|
+
array_suffix = self._param_array_suffix(output_shape)
|
|
7465
7060
|
param_decls = self._build_param_decls(
|
|
7466
7061
|
[
|
|
7467
|
-
(params["input0"], c_type, array_suffix, True),
|
|
7468
|
-
(params["output"], c_type, array_suffix, False),
|
|
7062
|
+
(params["input0"], output_dtype.c_type, array_suffix, True),
|
|
7063
|
+
(params["output"], output_dtype.c_type, array_suffix, False),
|
|
7469
7064
|
]
|
|
7470
7065
|
)
|
|
7471
7066
|
rendered = logsoftmax_template.render(
|
|
@@ -7474,24 +7069,41 @@ class CEmitter:
|
|
|
7474
7069
|
input0=params["input0"],
|
|
7475
7070
|
output=params["output"],
|
|
7476
7071
|
params=param_decls,
|
|
7477
|
-
c_type=c_type,
|
|
7072
|
+
c_type=output_dtype.c_type,
|
|
7478
7073
|
array_suffix=array_suffix,
|
|
7479
|
-
outer=
|
|
7480
|
-
axis_size=
|
|
7481
|
-
inner=
|
|
7482
|
-
|
|
7483
|
-
|
|
7074
|
+
outer=outer,
|
|
7075
|
+
axis_size=axis_size,
|
|
7076
|
+
inner=inner,
|
|
7077
|
+
max_fn=max_fn,
|
|
7078
|
+
exp_fn=CEmitter._math_fn(output_dtype, "expf", "exp"),
|
|
7079
|
+
log_fn=CEmitter._math_fn(output_dtype, "logf", "log"),
|
|
7484
7080
|
).rstrip()
|
|
7485
7081
|
return with_node_comment(rendered)
|
|
7486
7082
|
if isinstance(op, HardmaxOp):
|
|
7083
|
+
if scalar_registry is None:
|
|
7084
|
+
raise CodegenError(
|
|
7085
|
+
"Scalar function registry is required for Hardmax rendering."
|
|
7086
|
+
)
|
|
7087
|
+
output_shape = self._ctx_shape(op.output)
|
|
7088
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
7089
|
+
outer = self._derived(op, "outer")
|
|
7090
|
+
axis_size = self._derived(op, "axis_size")
|
|
7091
|
+
inner = self._derived(op, "inner")
|
|
7092
|
+
max_fn = self._scalar_function_name(
|
|
7093
|
+
ScalarFunction.MAXIMUM, output_dtype, scalar_registry
|
|
7094
|
+
)
|
|
7095
|
+
if max_fn is None:
|
|
7096
|
+
raise CodegenError(
|
|
7097
|
+
"Failed to resolve scalar maximum function for Hardmax."
|
|
7098
|
+
)
|
|
7487
7099
|
params = self._shared_param_map(
|
|
7488
7100
|
[("input0", op.input0), ("output", op.output)]
|
|
7489
7101
|
)
|
|
7490
|
-
array_suffix = self._param_array_suffix(
|
|
7102
|
+
array_suffix = self._param_array_suffix(output_shape)
|
|
7491
7103
|
param_decls = self._build_param_decls(
|
|
7492
7104
|
[
|
|
7493
|
-
(params["input0"], c_type, array_suffix, True),
|
|
7494
|
-
(params["output"], c_type, array_suffix, False),
|
|
7105
|
+
(params["input0"], output_dtype.c_type, array_suffix, True),
|
|
7106
|
+
(params["output"], output_dtype.c_type, array_suffix, False),
|
|
7495
7107
|
]
|
|
7496
7108
|
)
|
|
7497
7109
|
rendered = hardmax_template.render(
|
|
@@ -7500,13 +7112,14 @@ class CEmitter:
|
|
|
7500
7112
|
input0=params["input0"],
|
|
7501
7113
|
output=params["output"],
|
|
7502
7114
|
params=param_decls,
|
|
7503
|
-
c_type=c_type,
|
|
7115
|
+
c_type=output_dtype.c_type,
|
|
7504
7116
|
array_suffix=array_suffix,
|
|
7505
|
-
outer=
|
|
7506
|
-
axis_size=
|
|
7507
|
-
inner=
|
|
7117
|
+
outer=outer,
|
|
7118
|
+
axis_size=axis_size,
|
|
7119
|
+
inner=inner,
|
|
7508
7120
|
zero_literal=zero_literal,
|
|
7509
|
-
one_literal=CEmitter._format_literal(
|
|
7121
|
+
one_literal=CEmitter._format_literal(output_dtype, 1),
|
|
7122
|
+
max_fn=max_fn,
|
|
7510
7123
|
).rstrip()
|
|
7511
7124
|
return with_node_comment(rendered)
|
|
7512
7125
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
@@ -7576,6 +7189,17 @@ class CEmitter:
|
|
|
7576
7189
|
if op.dtype in {ScalarType.F16, ScalarType.F32}
|
|
7577
7190
|
else op.dtype
|
|
7578
7191
|
)
|
|
7192
|
+
if scalar_registry is None:
|
|
7193
|
+
raise CodegenError(
|
|
7194
|
+
"Scalar function registry is required for SoftmaxCrossEntropyLoss."
|
|
7195
|
+
)
|
|
7196
|
+
max_fn = self._scalar_function_name(
|
|
7197
|
+
ScalarFunction.MAXIMUM, acc_dtype, scalar_registry
|
|
7198
|
+
)
|
|
7199
|
+
if max_fn is None:
|
|
7200
|
+
raise CodegenError(
|
|
7201
|
+
"Failed to resolve scalar maximum function for SoftmaxCrossEntropyLoss."
|
|
7202
|
+
)
|
|
7579
7203
|
acc_type = acc_dtype.c_type
|
|
7580
7204
|
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
7581
7205
|
acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
|
|
@@ -7652,9 +7276,21 @@ class CEmitter:
|
|
|
7652
7276
|
acc_one_literal=acc_one_literal,
|
|
7653
7277
|
acc_exp_fn=acc_exp_fn,
|
|
7654
7278
|
acc_log_fn=acc_log_fn,
|
|
7279
|
+
max_fn=max_fn,
|
|
7655
7280
|
).rstrip()
|
|
7656
7281
|
return with_node_comment(rendered)
|
|
7657
7282
|
if isinstance(op, MaxPoolOp):
|
|
7283
|
+
if scalar_registry is None:
|
|
7284
|
+
raise CodegenError(
|
|
7285
|
+
"Scalar function registry is required for MaxPool rendering."
|
|
7286
|
+
)
|
|
7287
|
+
max_fn = self._scalar_function_name(
|
|
7288
|
+
ScalarFunction.MAXIMUM, op.dtype, scalar_registry
|
|
7289
|
+
)
|
|
7290
|
+
if max_fn is None:
|
|
7291
|
+
raise CodegenError(
|
|
7292
|
+
"Failed to resolve scalar maximum function for MaxPool."
|
|
7293
|
+
)
|
|
7658
7294
|
params = self._shared_param_map(
|
|
7659
7295
|
[
|
|
7660
7296
|
("input0", op.input0),
|
|
@@ -7699,6 +7335,7 @@ class CEmitter:
|
|
|
7699
7335
|
output_suffix=output_suffix,
|
|
7700
7336
|
indices_suffix=indices_suffix,
|
|
7701
7337
|
indices_c_type=indices_c_type,
|
|
7338
|
+
max_fn=max_fn,
|
|
7702
7339
|
batch=op.batch,
|
|
7703
7340
|
channels=op.channels,
|
|
7704
7341
|
spatial_rank=op.spatial_rank,
|
|
@@ -8032,21 +7669,133 @@ class CEmitter:
|
|
|
8032
7669
|
reduction=op.reduction,
|
|
8033
7670
|
).rstrip()
|
|
8034
7671
|
return with_node_comment(rendered)
|
|
7672
|
+
if isinstance(op, TensorScatterOp):
|
|
7673
|
+
param_pairs = [
|
|
7674
|
+
("past_cache", op.past_cache),
|
|
7675
|
+
("update", op.update),
|
|
7676
|
+
("output", op.output),
|
|
7677
|
+
]
|
|
7678
|
+
if op.write_indices is not None:
|
|
7679
|
+
param_pairs.insert(2, ("write_indices", op.write_indices))
|
|
7680
|
+
params = self._shared_param_map(param_pairs)
|
|
7681
|
+
output_dim_names = _dim_names_for(op.output)
|
|
7682
|
+
update_dim_names = _dim_names_for(op.update)
|
|
7683
|
+
past_dim_names = _dim_names_for(op.past_cache)
|
|
7684
|
+
write_indices_dim_names = (
|
|
7685
|
+
_dim_names_for(op.write_indices) if op.write_indices else None
|
|
7686
|
+
)
|
|
7687
|
+
output_shape = CEmitter._shape_dim_exprs(
|
|
7688
|
+
op.output_shape, output_dim_names
|
|
7689
|
+
)
|
|
7690
|
+
update_shape = CEmitter._shape_dim_exprs(
|
|
7691
|
+
op.update_shape, update_dim_names
|
|
7692
|
+
)
|
|
7693
|
+
prefix_shape = output_shape[: op.axis]
|
|
7694
|
+
prefix_loop_vars = (
|
|
7695
|
+
CEmitter._loop_vars(op.output_shape[: op.axis])
|
|
7696
|
+
if op.output_shape[: op.axis]
|
|
7697
|
+
else ()
|
|
7698
|
+
)
|
|
7699
|
+
tail_shape = output_shape[op.axis + 1 :]
|
|
7700
|
+
tail_loop_vars = (
|
|
7701
|
+
tuple(
|
|
7702
|
+
f"t{index}"
|
|
7703
|
+
for index in range(len(op.output_shape[op.axis + 1 :]))
|
|
7704
|
+
)
|
|
7705
|
+
if op.output_shape[op.axis + 1 :]
|
|
7706
|
+
else ()
|
|
7707
|
+
)
|
|
7708
|
+
output_loop_vars = CEmitter._loop_vars(op.output_shape)
|
|
7709
|
+
sequence_loop_var = "seq"
|
|
7710
|
+
cache_index_var = "cache_index"
|
|
7711
|
+
write_index_var = "write_index"
|
|
7712
|
+
index_vars = (*prefix_loop_vars, cache_index_var, *tail_loop_vars)
|
|
7713
|
+
output_index_expr = f"{params['output']}" + "".join(
|
|
7714
|
+
f"[{var}]" for var in index_vars
|
|
7715
|
+
)
|
|
7716
|
+
update_index_vars = (
|
|
7717
|
+
*prefix_loop_vars,
|
|
7718
|
+
sequence_loop_var,
|
|
7719
|
+
*tail_loop_vars,
|
|
7720
|
+
)
|
|
7721
|
+
update_index_expr = f"{params['update']}" + "".join(
|
|
7722
|
+
f"[{var}]" for var in update_index_vars
|
|
7723
|
+
)
|
|
7724
|
+
past_suffix = self._param_array_suffix(
|
|
7725
|
+
op.past_cache_shape, past_dim_names
|
|
7726
|
+
)
|
|
7727
|
+
update_suffix = self._param_array_suffix(
|
|
7728
|
+
op.update_shape, update_dim_names
|
|
7729
|
+
)
|
|
7730
|
+
output_suffix = self._param_array_suffix(
|
|
7731
|
+
op.output_shape, output_dim_names
|
|
7732
|
+
)
|
|
7733
|
+
param_decls = [
|
|
7734
|
+
(params["past_cache"], c_type, past_suffix, True),
|
|
7735
|
+
(params["update"], c_type, update_suffix, True),
|
|
7736
|
+
]
|
|
7737
|
+
if op.write_indices is not None and op.write_indices_dtype is not None:
|
|
7738
|
+
write_indices_suffix = self._param_array_suffix(
|
|
7739
|
+
op.write_indices_shape or (), write_indices_dim_names
|
|
7740
|
+
)
|
|
7741
|
+
param_decls.append(
|
|
7742
|
+
(
|
|
7743
|
+
params["write_indices"],
|
|
7744
|
+
op.write_indices_dtype.c_type,
|
|
7745
|
+
write_indices_suffix,
|
|
7746
|
+
True,
|
|
7747
|
+
)
|
|
7748
|
+
)
|
|
7749
|
+
param_decls.append((params["output"], c_type, output_suffix, False))
|
|
7750
|
+
param_decls_rendered = self._build_param_decls(param_decls)
|
|
7751
|
+
rendered = tensor_scatter_template.render(
|
|
7752
|
+
model_name=model.name,
|
|
7753
|
+
op_name=op_name,
|
|
7754
|
+
past_cache=params["past_cache"],
|
|
7755
|
+
update=params["update"],
|
|
7756
|
+
write_indices=(
|
|
7757
|
+
params.get("write_indices") if op.write_indices else None
|
|
7758
|
+
),
|
|
7759
|
+
output=params["output"],
|
|
7760
|
+
params=param_decls_rendered,
|
|
7761
|
+
c_type=c_type,
|
|
7762
|
+
output_shape=output_shape,
|
|
7763
|
+
output_loop_vars=output_loop_vars,
|
|
7764
|
+
prefix_shape=prefix_shape,
|
|
7765
|
+
prefix_loop_vars=prefix_loop_vars,
|
|
7766
|
+
sequence_dim=update_shape[op.axis],
|
|
7767
|
+
sequence_loop_var=sequence_loop_var,
|
|
7768
|
+
tail_shape=tail_shape,
|
|
7769
|
+
tail_loop_vars=tail_loop_vars,
|
|
7770
|
+
output_index_expr=output_index_expr,
|
|
7771
|
+
update_index_expr=update_index_expr,
|
|
7772
|
+
max_sequence_length=output_shape[op.axis],
|
|
7773
|
+
write_indices_present=op.write_indices is not None,
|
|
7774
|
+
batch_index_var=prefix_loop_vars[0]
|
|
7775
|
+
if prefix_loop_vars
|
|
7776
|
+
else "0",
|
|
7777
|
+
write_index_var=write_index_var,
|
|
7778
|
+
cache_index_var=cache_index_var,
|
|
7779
|
+
circular=op.mode == "circular",
|
|
7780
|
+
).rstrip()
|
|
7781
|
+
return with_node_comment(rendered)
|
|
8035
7782
|
if isinstance(op, TransposeOp):
|
|
7783
|
+
input_shape = self._ctx_shape(op.input0)
|
|
7784
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
8036
7785
|
params = self._shared_param_map(
|
|
8037
7786
|
[("input0", op.input0), ("output", op.output)]
|
|
8038
7787
|
)
|
|
8039
|
-
output_shape = CEmitter._codegen_shape(
|
|
7788
|
+
output_shape = CEmitter._codegen_shape(output_shape_raw)
|
|
8040
7789
|
loop_vars = CEmitter._loop_vars(output_shape)
|
|
8041
7790
|
output_suffix = self._param_array_suffix(output_shape)
|
|
8042
|
-
input_suffix = self._param_array_suffix(
|
|
7791
|
+
input_suffix = self._param_array_suffix(input_shape)
|
|
8043
7792
|
param_decls = self._build_param_decls(
|
|
8044
7793
|
[
|
|
8045
7794
|
(params["input0"], c_type, input_suffix, True),
|
|
8046
7795
|
(params["output"], c_type, output_suffix, False),
|
|
8047
7796
|
]
|
|
8048
7797
|
)
|
|
8049
|
-
if not
|
|
7798
|
+
if not input_shape:
|
|
8050
7799
|
input_indices = [loop_vars[0]]
|
|
8051
7800
|
else:
|
|
8052
7801
|
input_indices = [None] * len(op.perm)
|
|
@@ -8067,19 +7816,21 @@ class CEmitter:
|
|
|
8067
7816
|
).rstrip()
|
|
8068
7817
|
return with_node_comment(rendered)
|
|
8069
7818
|
if isinstance(op, ReshapeOp):
|
|
7819
|
+
input_shape = self._ctx_shape(op.input0)
|
|
7820
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
8070
7821
|
params = self._shared_param_map(
|
|
8071
7822
|
[("input0", op.input0), ("output", op.output)]
|
|
8072
7823
|
)
|
|
8073
|
-
input_suffix = self._param_array_suffix(
|
|
8074
|
-
output_shape = CEmitter._codegen_shape(
|
|
8075
|
-
output_suffix = self._param_array_suffix(
|
|
7824
|
+
input_suffix = self._param_array_suffix(input_shape)
|
|
7825
|
+
output_shape = CEmitter._codegen_shape(output_shape_raw)
|
|
7826
|
+
output_suffix = self._param_array_suffix(output_shape_raw)
|
|
8076
7827
|
param_decls = self._build_param_decls(
|
|
8077
7828
|
[
|
|
8078
7829
|
(params["input0"], c_type, input_suffix, True),
|
|
8079
7830
|
(params["output"], c_type, output_suffix, False),
|
|
8080
7831
|
]
|
|
8081
7832
|
)
|
|
8082
|
-
loop_vars = CEmitter._loop_vars(
|
|
7833
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
8083
7834
|
rendered = reshape_template.render(
|
|
8084
7835
|
model_name=model.name,
|
|
8085
7836
|
op_name=op_name,
|
|
@@ -8089,20 +7840,27 @@ class CEmitter:
|
|
|
8089
7840
|
c_type=c_type,
|
|
8090
7841
|
input_suffix=input_suffix,
|
|
8091
7842
|
output_suffix=output_suffix,
|
|
8092
|
-
element_count=CEmitter._element_count(
|
|
7843
|
+
element_count=CEmitter._element_count(output_shape_raw),
|
|
8093
7844
|
output_shape=output_shape,
|
|
8094
7845
|
loop_vars=loop_vars,
|
|
8095
7846
|
).rstrip()
|
|
8096
7847
|
return with_node_comment(rendered)
|
|
8097
7848
|
if isinstance(op, IdentityOp):
|
|
7849
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
8098
7850
|
params = self._shared_param_map(
|
|
8099
7851
|
[("input0", op.input0), ("output", op.output)]
|
|
8100
7852
|
)
|
|
8101
7853
|
output_dim_names = _dim_names_for(op.output)
|
|
8102
|
-
shape = CEmitter._shape_dim_exprs(
|
|
8103
|
-
|
|
8104
|
-
|
|
8105
|
-
|
|
7854
|
+
shape = CEmitter._shape_dim_exprs(
|
|
7855
|
+
output_shape_raw, output_dim_names
|
|
7856
|
+
)
|
|
7857
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
7858
|
+
output_suffix = self._param_array_suffix(
|
|
7859
|
+
output_shape_raw, output_dim_names
|
|
7860
|
+
)
|
|
7861
|
+
input_suffix = self._param_array_suffix(
|
|
7862
|
+
output_shape_raw, _dim_names_for(op.input0)
|
|
7863
|
+
)
|
|
8106
7864
|
param_decls = self._build_param_decls(
|
|
8107
7865
|
[
|
|
8108
7866
|
(params["input0"], c_type, input_suffix, True),
|
|
@@ -8704,39 +8462,41 @@ class CEmitter:
|
|
|
8704
8462
|
).rstrip()
|
|
8705
8463
|
return with_node_comment(rendered)
|
|
8706
8464
|
if isinstance(op, ReduceOp) and op.axes_input is None:
|
|
8465
|
+
input_shape = self._ctx_shape(op.input0)
|
|
8466
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
8467
|
+
axes = self._derived(op, "axes")
|
|
8468
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
8707
8469
|
params = self._shared_param_map(
|
|
8708
8470
|
[("input0", op.input0), ("output", op.output)]
|
|
8709
8471
|
)
|
|
8710
|
-
output_shape = CEmitter._codegen_shape(
|
|
8472
|
+
output_shape = CEmitter._codegen_shape(output_shape_raw)
|
|
8711
8473
|
output_loop_vars = CEmitter._loop_vars(output_shape)
|
|
8712
|
-
if not
|
|
8474
|
+
if not input_shape:
|
|
8713
8475
|
reduce_loop_vars = ("r0",)
|
|
8714
8476
|
reduce_dims = (1,)
|
|
8715
8477
|
else:
|
|
8716
|
-
reduce_loop_vars = tuple(
|
|
8717
|
-
|
|
8718
|
-
|
|
8719
|
-
reduce_dims = tuple(op.input_shape[axis] for axis in op.axes)
|
|
8720
|
-
if not op.input_shape:
|
|
8478
|
+
reduce_loop_vars = tuple(f"r{idx}" for idx in range(len(axes)))
|
|
8479
|
+
reduce_dims = tuple(input_shape[axis] for axis in axes)
|
|
8480
|
+
if not input_shape:
|
|
8721
8481
|
input_indices = [reduce_loop_vars[0]]
|
|
8722
8482
|
elif op.keepdims:
|
|
8723
8483
|
input_indices = [
|
|
8724
|
-
reduce_loop_vars[
|
|
8725
|
-
if axis in
|
|
8484
|
+
reduce_loop_vars[axes.index(axis)]
|
|
8485
|
+
if axis in axes
|
|
8726
8486
|
else output_loop_vars[axis]
|
|
8727
|
-
for axis in range(len(
|
|
8487
|
+
for axis in range(len(input_shape))
|
|
8728
8488
|
]
|
|
8729
8489
|
else:
|
|
8730
8490
|
kept_axes = [
|
|
8731
8491
|
axis
|
|
8732
|
-
for axis in range(len(
|
|
8733
|
-
if axis not in
|
|
8492
|
+
for axis in range(len(input_shape))
|
|
8493
|
+
if axis not in axes
|
|
8734
8494
|
]
|
|
8735
8495
|
input_indices = [
|
|
8736
|
-
reduce_loop_vars[
|
|
8737
|
-
if axis in
|
|
8496
|
+
reduce_loop_vars[axes.index(axis)]
|
|
8497
|
+
if axis in axes
|
|
8738
8498
|
else output_loop_vars[kept_axes.index(axis)]
|
|
8739
|
-
for axis in range(len(
|
|
8499
|
+
for axis in range(len(input_shape))
|
|
8740
8500
|
]
|
|
8741
8501
|
input_index_expr = "".join(f"[{var}]" for var in input_indices)
|
|
8742
8502
|
output_index_expr = "".join(
|
|
@@ -8748,16 +8508,16 @@ class CEmitter:
|
|
|
8748
8508
|
final_expr = "acc"
|
|
8749
8509
|
use_kahan = False
|
|
8750
8510
|
kahan_value_expr = None
|
|
8751
|
-
fabs_fn = CEmitter._math_fn(
|
|
8752
|
-
exp_fn = CEmitter._math_fn(
|
|
8753
|
-
log_fn = CEmitter._math_fn(
|
|
8754
|
-
sqrt_fn = CEmitter._math_fn(
|
|
8511
|
+
fabs_fn = CEmitter._math_fn(output_dtype, "fabsf", "fabs")
|
|
8512
|
+
exp_fn = CEmitter._math_fn(output_dtype, "expf", "exp")
|
|
8513
|
+
log_fn = CEmitter._math_fn(output_dtype, "logf", "log")
|
|
8514
|
+
sqrt_fn = CEmitter._math_fn(output_dtype, "sqrtf", "sqrt")
|
|
8755
8515
|
if op.reduce_kind == "sum":
|
|
8756
8516
|
init_literal = zero_literal
|
|
8757
8517
|
update_expr = f"acc += {value_expr};"
|
|
8758
8518
|
elif op.reduce_kind == "mean":
|
|
8759
8519
|
count_literal = CEmitter._format_literal(
|
|
8760
|
-
|
|
8520
|
+
output_dtype, op.reduce_count
|
|
8761
8521
|
)
|
|
8762
8522
|
init_literal = zero_literal
|
|
8763
8523
|
update_expr = f"acc += {value_expr};"
|
|
@@ -8769,7 +8529,7 @@ class CEmitter:
|
|
|
8769
8529
|
init_literal = max_literal
|
|
8770
8530
|
update_expr = f"if ({value_expr} < acc) acc = {value_expr};"
|
|
8771
8531
|
elif op.reduce_kind == "prod":
|
|
8772
|
-
init_literal = CEmitter._format_literal(
|
|
8532
|
+
init_literal = CEmitter._format_literal(output_dtype, 1)
|
|
8773
8533
|
update_expr = f"acc *= {value_expr};"
|
|
8774
8534
|
elif op.reduce_kind == "l1":
|
|
8775
8535
|
init_literal = zero_literal
|
|
@@ -8793,7 +8553,7 @@ class CEmitter:
|
|
|
8793
8553
|
raise CodegenError(
|
|
8794
8554
|
f"Unsupported reduce kind {op.reduce_kind}"
|
|
8795
8555
|
)
|
|
8796
|
-
if
|
|
8556
|
+
if output_dtype in {ScalarType.F16, ScalarType.F32} and op.reduce_kind in {
|
|
8797
8557
|
"sum",
|
|
8798
8558
|
"mean",
|
|
8799
8559
|
"logsum",
|
|
@@ -8811,8 +8571,8 @@ class CEmitter:
|
|
|
8811
8571
|
kahan_value_expr = f"{value_expr} * {value_expr}"
|
|
8812
8572
|
else:
|
|
8813
8573
|
kahan_value_expr = value_expr
|
|
8814
|
-
input_suffix = self._param_array_suffix(
|
|
8815
|
-
output_suffix = self._param_array_suffix(
|
|
8574
|
+
input_suffix = self._param_array_suffix(input_shape)
|
|
8575
|
+
output_suffix = self._param_array_suffix(output_shape_raw)
|
|
8816
8576
|
param_decls = self._build_param_decls(
|
|
8817
8577
|
[
|
|
8818
8578
|
(params["input0"], c_type, input_suffix, True),
|
|
@@ -8842,33 +8602,40 @@ class CEmitter:
|
|
|
8842
8602
|
).rstrip()
|
|
8843
8603
|
return with_node_comment(rendered)
|
|
8844
8604
|
if isinstance(op, ArgReduceOp):
|
|
8605
|
+
input_shape = self._ctx_shape(op.input0)
|
|
8606
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
8607
|
+
axis = self._derived(op, "axis")
|
|
8608
|
+
input_dtype = self._ctx_dtype(op.input0)
|
|
8609
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
8845
8610
|
params = self._shared_param_map(
|
|
8846
8611
|
[("input0", op.input0), ("output", op.output)]
|
|
8847
8612
|
)
|
|
8848
|
-
output_shape = CEmitter._codegen_shape(
|
|
8613
|
+
output_shape = CEmitter._codegen_shape(output_shape_raw)
|
|
8849
8614
|
output_loop_vars = CEmitter._loop_vars(output_shape)
|
|
8850
8615
|
reduce_var = "r0"
|
|
8851
|
-
reduce_dim =
|
|
8616
|
+
reduce_dim = input_shape[axis]
|
|
8852
8617
|
if op.keepdims:
|
|
8853
8618
|
input_indices = [
|
|
8854
|
-
reduce_var
|
|
8855
|
-
|
|
8619
|
+
reduce_var
|
|
8620
|
+
if axis_index == axis
|
|
8621
|
+
else output_loop_vars[axis_index]
|
|
8622
|
+
for axis_index in range(len(input_shape))
|
|
8856
8623
|
]
|
|
8857
8624
|
else:
|
|
8858
8625
|
kept_axes = [
|
|
8859
|
-
|
|
8860
|
-
for
|
|
8861
|
-
if
|
|
8626
|
+
axis_index
|
|
8627
|
+
for axis_index in range(len(input_shape))
|
|
8628
|
+
if axis_index != axis
|
|
8862
8629
|
]
|
|
8863
8630
|
input_indices = [
|
|
8864
8631
|
reduce_var
|
|
8865
|
-
if
|
|
8866
|
-
else output_loop_vars[kept_axes.index(
|
|
8867
|
-
for
|
|
8632
|
+
if axis_index == axis
|
|
8633
|
+
else output_loop_vars[kept_axes.index(axis_index)]
|
|
8634
|
+
for axis_index in range(len(input_shape))
|
|
8868
8635
|
]
|
|
8869
8636
|
init_indices = [
|
|
8870
|
-
"0" if
|
|
8871
|
-
for
|
|
8637
|
+
"0" if axis_index == axis else input_indices[axis_index]
|
|
8638
|
+
for axis_index in range(len(input_shape))
|
|
8872
8639
|
]
|
|
8873
8640
|
input_index_expr = "".join(f"[{var}]" for var in input_indices)
|
|
8874
8641
|
init_index_expr = "".join(f"[{var}]" for var in init_indices)
|
|
@@ -8883,12 +8650,12 @@ class CEmitter:
|
|
|
8883
8650
|
raise CodegenError(
|
|
8884
8651
|
f"Unsupported arg reduce kind {op.reduce_kind}"
|
|
8885
8652
|
)
|
|
8886
|
-
input_suffix = self._param_array_suffix(
|
|
8887
|
-
output_suffix = self._param_array_suffix(
|
|
8653
|
+
input_suffix = self._param_array_suffix(input_shape)
|
|
8654
|
+
output_suffix = self._param_array_suffix(output_shape_raw)
|
|
8888
8655
|
param_decls = self._build_param_decls(
|
|
8889
8656
|
[
|
|
8890
|
-
(params["input0"],
|
|
8891
|
-
(params["output"],
|
|
8657
|
+
(params["input0"], input_dtype.c_type, input_suffix, True),
|
|
8658
|
+
(params["output"], output_dtype.c_type, output_suffix, False),
|
|
8892
8659
|
]
|
|
8893
8660
|
)
|
|
8894
8661
|
rendered = arg_reduce_template.render(
|
|
@@ -8897,8 +8664,8 @@ class CEmitter:
|
|
|
8897
8664
|
input0=params["input0"],
|
|
8898
8665
|
output=params["output"],
|
|
8899
8666
|
params=param_decls,
|
|
8900
|
-
input_c_type=
|
|
8901
|
-
output_c_type=
|
|
8667
|
+
input_c_type=input_dtype.c_type,
|
|
8668
|
+
output_c_type=output_dtype.c_type,
|
|
8902
8669
|
input_suffix=input_suffix,
|
|
8903
8670
|
output_suffix=output_suffix,
|
|
8904
8671
|
output_shape=output_shape,
|
|
@@ -8913,6 +8680,11 @@ class CEmitter:
|
|
|
8913
8680
|
).rstrip()
|
|
8914
8681
|
return with_node_comment(rendered)
|
|
8915
8682
|
if isinstance(op, TopKOp):
|
|
8683
|
+
input_shape = self._ctx_shape(op.input0)
|
|
8684
|
+
output_shape_raw = self._ctx_shape(op.output_values)
|
|
8685
|
+
input_dtype = self._ctx_dtype(op.input0)
|
|
8686
|
+
output_values_dtype = self._ctx_dtype(op.output_values)
|
|
8687
|
+
output_indices_dtype = self._ctx_dtype(op.output_indices)
|
|
8916
8688
|
params = self._shared_param_map(
|
|
8917
8689
|
[
|
|
8918
8690
|
("input0", op.input0),
|
|
@@ -8920,7 +8692,7 @@ class CEmitter:
|
|
|
8920
8692
|
("output_indices", op.output_indices),
|
|
8921
8693
|
]
|
|
8922
8694
|
)
|
|
8923
|
-
output_shape = CEmitter._codegen_shape(
|
|
8695
|
+
output_shape = CEmitter._codegen_shape(output_shape_raw)
|
|
8924
8696
|
outer_shape = tuple(
|
|
8925
8697
|
dim for axis, dim in enumerate(output_shape) if axis != op.axis
|
|
8926
8698
|
)
|
|
@@ -8930,7 +8702,7 @@ class CEmitter:
|
|
|
8930
8702
|
input_indices: list[str] = []
|
|
8931
8703
|
output_indices: list[str] = []
|
|
8932
8704
|
outer_index = 0
|
|
8933
|
-
for axis in range(len(
|
|
8705
|
+
for axis in range(len(input_shape)):
|
|
8934
8706
|
if axis == op.axis:
|
|
8935
8707
|
input_indices.append(reduce_var)
|
|
8936
8708
|
output_indices.append(k_var)
|
|
@@ -8945,20 +8717,20 @@ class CEmitter:
|
|
|
8945
8717
|
if op.largest
|
|
8946
8718
|
else "(a < b) || ((a == b) && (ai < bi))"
|
|
8947
8719
|
)
|
|
8948
|
-
input_suffix = self._param_array_suffix(
|
|
8949
|
-
output_suffix = self._param_array_suffix(
|
|
8720
|
+
input_suffix = self._param_array_suffix(input_shape)
|
|
8721
|
+
output_suffix = self._param_array_suffix(output_shape_raw)
|
|
8950
8722
|
param_decls = self._build_param_decls(
|
|
8951
8723
|
[
|
|
8952
|
-
(params["input0"],
|
|
8724
|
+
(params["input0"], input_dtype.c_type, input_suffix, True),
|
|
8953
8725
|
(
|
|
8954
8726
|
params["output_values"],
|
|
8955
|
-
|
|
8727
|
+
output_values_dtype.c_type,
|
|
8956
8728
|
output_suffix,
|
|
8957
8729
|
False,
|
|
8958
8730
|
),
|
|
8959
8731
|
(
|
|
8960
8732
|
params["output_indices"],
|
|
8961
|
-
|
|
8733
|
+
output_indices_dtype.c_type,
|
|
8962
8734
|
output_suffix,
|
|
8963
8735
|
False,
|
|
8964
8736
|
),
|
|
@@ -9216,27 +8988,150 @@ class CEmitter:
|
|
|
9216
8988
|
)
|
|
9217
8989
|
param_decls = self._build_param_decls(
|
|
9218
8990
|
[
|
|
9219
|
-
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
9220
|
-
(params["output"], c_type, output_suffix, False),
|
|
8991
|
+
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
8992
|
+
(params["output"], c_type, output_suffix, False),
|
|
8993
|
+
]
|
|
8994
|
+
)
|
|
8995
|
+
input_expr = f"{params['input0']}" + "".join(
|
|
8996
|
+
f"[{var}]" for var in loop_vars
|
|
8997
|
+
)
|
|
8998
|
+
rendered = nonzero_template.render(
|
|
8999
|
+
model_name=model.name,
|
|
9000
|
+
op_name=op_name,
|
|
9001
|
+
input0=params["input0"],
|
|
9002
|
+
output=params["output"],
|
|
9003
|
+
params=param_decls,
|
|
9004
|
+
input_c_type=op.input_dtype.c_type,
|
|
9005
|
+
output_c_type=c_type,
|
|
9006
|
+
input_suffix=input_suffix,
|
|
9007
|
+
output_suffix=output_suffix,
|
|
9008
|
+
input_shape=input_shape,
|
|
9009
|
+
loop_vars=loop_vars,
|
|
9010
|
+
input_expr=input_expr,
|
|
9011
|
+
zero_literal=op.input_dtype.zero_literal,
|
|
9012
|
+
).rstrip()
|
|
9013
|
+
return with_node_comment(rendered)
|
|
9014
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
9015
|
+
if scalar_registry is None:
|
|
9016
|
+
raise CodegenError(
|
|
9017
|
+
"Scalar function registry is required for NonMaxSuppression."
|
|
9018
|
+
)
|
|
9019
|
+
min_fn = self._scalar_function_name(
|
|
9020
|
+
ScalarFunction.MINIMUM, op.boxes_dtype, scalar_registry
|
|
9021
|
+
)
|
|
9022
|
+
max_fn = self._scalar_function_name(
|
|
9023
|
+
ScalarFunction.MAXIMUM, op.boxes_dtype, scalar_registry
|
|
9024
|
+
)
|
|
9025
|
+
if min_fn is None or max_fn is None:
|
|
9026
|
+
raise CodegenError(
|
|
9027
|
+
"Failed to resolve scalar min/max functions for NonMaxSuppression."
|
|
9028
|
+
)
|
|
9029
|
+
params = self._shared_param_map(
|
|
9030
|
+
[
|
|
9031
|
+
("boxes", op.boxes),
|
|
9032
|
+
("scores", op.scores),
|
|
9033
|
+
("max_output_boxes_per_class", op.max_output_boxes_per_class),
|
|
9034
|
+
("iou_threshold", op.iou_threshold),
|
|
9035
|
+
("score_threshold", op.score_threshold),
|
|
9036
|
+
("output", op.output),
|
|
9037
|
+
]
|
|
9038
|
+
)
|
|
9039
|
+
boxes_suffix = self._param_array_suffix(
|
|
9040
|
+
op.boxes_shape, _dim_names_for(op.boxes)
|
|
9041
|
+
)
|
|
9042
|
+
scores_suffix = self._param_array_suffix(
|
|
9043
|
+
op.scores_shape, _dim_names_for(op.scores)
|
|
9044
|
+
)
|
|
9045
|
+
output_suffix = self._param_array_suffix(
|
|
9046
|
+
op.output_shape, _dim_names_for(op.output)
|
|
9047
|
+
)
|
|
9048
|
+
max_output_suffix = (
|
|
9049
|
+
self._param_array_suffix(
|
|
9050
|
+
op.max_output_shape,
|
|
9051
|
+
_dim_names_for(op.max_output_boxes_per_class or ""),
|
|
9052
|
+
)
|
|
9053
|
+
if op.max_output_shape is not None
|
|
9054
|
+
else ""
|
|
9055
|
+
)
|
|
9056
|
+
iou_threshold_suffix = (
|
|
9057
|
+
self._param_array_suffix(
|
|
9058
|
+
op.iou_threshold_shape,
|
|
9059
|
+
_dim_names_for(op.iou_threshold or ""),
|
|
9060
|
+
)
|
|
9061
|
+
if op.iou_threshold_shape is not None
|
|
9062
|
+
else ""
|
|
9063
|
+
)
|
|
9064
|
+
score_threshold_suffix = (
|
|
9065
|
+
self._param_array_suffix(
|
|
9066
|
+
op.score_threshold_shape,
|
|
9067
|
+
_dim_names_for(op.score_threshold or ""),
|
|
9068
|
+
)
|
|
9069
|
+
if op.score_threshold_shape is not None
|
|
9070
|
+
else ""
|
|
9071
|
+
)
|
|
9072
|
+
param_decls = self._build_param_decls(
|
|
9073
|
+
[
|
|
9074
|
+
(params["boxes"], op.boxes_dtype.c_type, boxes_suffix, True),
|
|
9075
|
+
(params["scores"], op.boxes_dtype.c_type, scores_suffix, True),
|
|
9076
|
+
(
|
|
9077
|
+
params["max_output_boxes_per_class"],
|
|
9078
|
+
op.max_output_dtype.c_type if op.max_output_dtype else "",
|
|
9079
|
+
max_output_suffix,
|
|
9080
|
+
True,
|
|
9081
|
+
)
|
|
9082
|
+
if params["max_output_boxes_per_class"]
|
|
9083
|
+
else (None, "", "", True),
|
|
9084
|
+
(
|
|
9085
|
+
params["iou_threshold"],
|
|
9086
|
+
(
|
|
9087
|
+
op.iou_threshold_dtype.c_type
|
|
9088
|
+
if op.iou_threshold_dtype
|
|
9089
|
+
else ""
|
|
9090
|
+
),
|
|
9091
|
+
iou_threshold_suffix,
|
|
9092
|
+
True,
|
|
9093
|
+
)
|
|
9094
|
+
if params["iou_threshold"]
|
|
9095
|
+
else (None, "", "", True),
|
|
9096
|
+
(
|
|
9097
|
+
params["score_threshold"],
|
|
9098
|
+
(
|
|
9099
|
+
op.score_threshold_dtype.c_type
|
|
9100
|
+
if op.score_threshold_dtype
|
|
9101
|
+
else ""
|
|
9102
|
+
),
|
|
9103
|
+
score_threshold_suffix,
|
|
9104
|
+
True,
|
|
9105
|
+
)
|
|
9106
|
+
if params["score_threshold"]
|
|
9107
|
+
else (None, "", "", True),
|
|
9108
|
+
(params["output"], op.output_dtype.c_type, output_suffix, False),
|
|
9221
9109
|
]
|
|
9222
9110
|
)
|
|
9223
|
-
|
|
9224
|
-
f"[{var}]" for var in loop_vars
|
|
9225
|
-
)
|
|
9226
|
-
rendered = nonzero_template.render(
|
|
9111
|
+
rendered = nonmax_suppression_template.render(
|
|
9227
9112
|
model_name=model.name,
|
|
9228
9113
|
op_name=op_name,
|
|
9229
|
-
|
|
9114
|
+
boxes=params["boxes"],
|
|
9115
|
+
scores=params["scores"],
|
|
9116
|
+
max_output_boxes_per_class=params["max_output_boxes_per_class"],
|
|
9117
|
+
iou_threshold=params["iou_threshold"],
|
|
9118
|
+
score_threshold=params["score_threshold"],
|
|
9230
9119
|
output=params["output"],
|
|
9231
9120
|
params=param_decls,
|
|
9232
|
-
input_c_type=op.
|
|
9233
|
-
output_c_type=c_type,
|
|
9234
|
-
|
|
9235
|
-
|
|
9236
|
-
|
|
9237
|
-
|
|
9238
|
-
|
|
9239
|
-
|
|
9121
|
+
input_c_type=op.boxes_dtype.c_type,
|
|
9122
|
+
output_c_type=op.output_dtype.c_type,
|
|
9123
|
+
compute_type=op.boxes_dtype.c_type,
|
|
9124
|
+
output_capacity=op.output_shape[0],
|
|
9125
|
+
num_batches=op.boxes_shape[0],
|
|
9126
|
+
num_boxes=op.boxes_shape[1],
|
|
9127
|
+
num_classes=op.scores_shape[1],
|
|
9128
|
+
center_point_box=op.center_point_box,
|
|
9129
|
+
min_fn=min_fn,
|
|
9130
|
+
max_fn=max_fn,
|
|
9131
|
+
iou_threshold_default=op.boxes_dtype.zero_literal,
|
|
9132
|
+
score_threshold_default=op.boxes_dtype.zero_literal,
|
|
9133
|
+
score_threshold_enabled=op.score_threshold is not None,
|
|
9134
|
+
dim_args=dim_args,
|
|
9240
9135
|
).rstrip()
|
|
9241
9136
|
return with_node_comment(rendered)
|
|
9242
9137
|
if isinstance(op, ExpandOp):
|
|
@@ -9476,17 +9371,24 @@ class CEmitter:
|
|
|
9476
9371
|
).rstrip()
|
|
9477
9372
|
return with_node_comment(rendered)
|
|
9478
9373
|
if isinstance(op, CastOp):
|
|
9374
|
+
input_dtype = self._ctx_dtype(op.input0)
|
|
9375
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
9376
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
9479
9377
|
params = self._shared_param_map(
|
|
9480
9378
|
[("input0", op.input0), ("output", op.output)]
|
|
9481
9379
|
)
|
|
9482
9380
|
output_dim_names = _dim_names_for(op.output)
|
|
9483
|
-
shape = CEmitter._shape_dim_exprs(
|
|
9484
|
-
|
|
9485
|
-
|
|
9381
|
+
shape = CEmitter._shape_dim_exprs(
|
|
9382
|
+
output_shape_raw, output_dim_names
|
|
9383
|
+
)
|
|
9384
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
9385
|
+
array_suffix = self._param_array_suffix(
|
|
9386
|
+
output_shape_raw, output_dim_names
|
|
9387
|
+
)
|
|
9486
9388
|
param_decls = self._build_param_decls(
|
|
9487
9389
|
[
|
|
9488
|
-
(params["input0"],
|
|
9489
|
-
(params["output"],
|
|
9390
|
+
(params["input0"], input_dtype.c_type, array_suffix, True),
|
|
9391
|
+
(params["output"], output_dtype.c_type, array_suffix, False),
|
|
9490
9392
|
]
|
|
9491
9393
|
)
|
|
9492
9394
|
rendered = cast_template.render(
|
|
@@ -9495,8 +9397,8 @@ class CEmitter:
|
|
|
9495
9397
|
input0=params["input0"],
|
|
9496
9398
|
output=params["output"],
|
|
9497
9399
|
params=param_decls,
|
|
9498
|
-
input_c_type=
|
|
9499
|
-
output_c_type=
|
|
9400
|
+
input_c_type=input_dtype.c_type,
|
|
9401
|
+
output_c_type=output_dtype.c_type,
|
|
9500
9402
|
array_suffix=array_suffix,
|
|
9501
9403
|
shape=shape,
|
|
9502
9404
|
loop_vars=loop_vars,
|
|
@@ -9504,6 +9406,10 @@ class CEmitter:
|
|
|
9504
9406
|
).rstrip()
|
|
9505
9407
|
return with_node_comment(rendered)
|
|
9506
9408
|
if isinstance(op, QuantizeLinearOp):
|
|
9409
|
+
if scalar_registry is None:
|
|
9410
|
+
raise CodegenError(
|
|
9411
|
+
"Scalar function registry is required for QuantizeLinear."
|
|
9412
|
+
)
|
|
9507
9413
|
params = self._shared_param_map(
|
|
9508
9414
|
[
|
|
9509
9415
|
("input0", op.input0),
|
|
@@ -9545,6 +9451,21 @@ class CEmitter:
|
|
|
9545
9451
|
]
|
|
9546
9452
|
)
|
|
9547
9453
|
compute_type = "double" if op.input_dtype == ScalarType.F64 else "float"
|
|
9454
|
+
compute_dtype = (
|
|
9455
|
+
ScalarType.F64
|
|
9456
|
+
if compute_type == "double"
|
|
9457
|
+
else ScalarType.F32
|
|
9458
|
+
)
|
|
9459
|
+
max_fn = self._scalar_function_name(
|
|
9460
|
+
ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
|
|
9461
|
+
)
|
|
9462
|
+
min_fn = self._scalar_function_name(
|
|
9463
|
+
ScalarFunction.MINIMUM, compute_dtype, scalar_registry
|
|
9464
|
+
)
|
|
9465
|
+
if max_fn is None or min_fn is None:
|
|
9466
|
+
raise CodegenError(
|
|
9467
|
+
"Failed to resolve scalar min/max functions for QuantizeLinear."
|
|
9468
|
+
)
|
|
9548
9469
|
round_fn = CEmitter._math_fn(
|
|
9549
9470
|
op.input_dtype, "nearbyintf", "nearbyint"
|
|
9550
9471
|
)
|
|
@@ -9580,10 +9501,221 @@ class CEmitter:
|
|
|
9580
9501
|
round_fn=round_fn,
|
|
9581
9502
|
min_literal=op.dtype.min_literal,
|
|
9582
9503
|
max_literal=op.dtype.max_literal,
|
|
9504
|
+
min_fn=min_fn,
|
|
9505
|
+
max_fn=max_fn,
|
|
9506
|
+
dim_args=dim_args,
|
|
9507
|
+
).rstrip()
|
|
9508
|
+
return with_node_comment(rendered)
|
|
9509
|
+
if isinstance(op, QLinearMatMulOp):
|
|
9510
|
+
if scalar_registry is None:
|
|
9511
|
+
raise CodegenError(
|
|
9512
|
+
"Scalar function registry is required for QLinearMatMul."
|
|
9513
|
+
)
|
|
9514
|
+
params = self._shared_param_map(
|
|
9515
|
+
[
|
|
9516
|
+
("input0", op.input0),
|
|
9517
|
+
("input0_scale", op.input0_scale),
|
|
9518
|
+
("input0_zero_point", op.input0_zero_point),
|
|
9519
|
+
("input1", op.input1),
|
|
9520
|
+
("input1_scale", op.input1_scale),
|
|
9521
|
+
("input1_zero_point", op.input1_zero_point),
|
|
9522
|
+
("output_scale", op.output_scale),
|
|
9523
|
+
("output_zero_point", op.output_zero_point),
|
|
9524
|
+
("output", op.output),
|
|
9525
|
+
]
|
|
9526
|
+
)
|
|
9527
|
+
output_shape = CEmitter._codegen_shape(op.output_shape)
|
|
9528
|
+
output_loop_vars = CEmitter._loop_vars(output_shape)
|
|
9529
|
+
output_index_expr = f"{params['output']}" + "".join(
|
|
9530
|
+
f"[{var}]" for var in output_loop_vars
|
|
9531
|
+
)
|
|
9532
|
+
batch_rank = len(op.batch_shape)
|
|
9533
|
+
batch_vars = output_loop_vars[:batch_rank]
|
|
9534
|
+
if op.left_vector and op.right_vector:
|
|
9535
|
+
row_var = None
|
|
9536
|
+
col_var = None
|
|
9537
|
+
elif op.left_vector:
|
|
9538
|
+
row_var = None
|
|
9539
|
+
col_var = output_loop_vars[-1]
|
|
9540
|
+
elif op.right_vector:
|
|
9541
|
+
row_var = output_loop_vars[-1]
|
|
9542
|
+
col_var = None
|
|
9543
|
+
else:
|
|
9544
|
+
row_var = output_loop_vars[-2]
|
|
9545
|
+
col_var = output_loop_vars[-1]
|
|
9546
|
+
input0_index_expr, input1_index_expr = CEmitter._matmul_index_exprs(
|
|
9547
|
+
op,
|
|
9548
|
+
batch_vars,
|
|
9549
|
+
row_var,
|
|
9550
|
+
col_var,
|
|
9551
|
+
batch_rank,
|
|
9552
|
+
input0=params["input0"],
|
|
9553
|
+
input1=params["input1"],
|
|
9554
|
+
)
|
|
9555
|
+
input0_suffix = self._param_array_suffix(op.input0_shape)
|
|
9556
|
+
input1_suffix = self._param_array_suffix(op.input1_shape)
|
|
9557
|
+
input0_scale_suffix = self._param_array_suffix(
|
|
9558
|
+
op.input0_scale_shape
|
|
9559
|
+
)
|
|
9560
|
+
input1_scale_suffix = self._param_array_suffix(
|
|
9561
|
+
op.input1_scale_shape
|
|
9562
|
+
)
|
|
9563
|
+
output_scale_suffix = self._param_array_suffix(
|
|
9564
|
+
op.output_scale_shape
|
|
9565
|
+
)
|
|
9566
|
+
input0_zero_suffix = self._param_array_suffix(op.input0_zero_shape)
|
|
9567
|
+
input1_zero_suffix = self._param_array_suffix(op.input1_zero_shape)
|
|
9568
|
+
output_zero_suffix = self._param_array_suffix(op.output_zero_shape)
|
|
9569
|
+
output_suffix = self._param_array_suffix(op.output_shape)
|
|
9570
|
+
param_decls = self._build_param_decls(
|
|
9571
|
+
[
|
|
9572
|
+
(
|
|
9573
|
+
params["input0"],
|
|
9574
|
+
op.input0_dtype.c_type,
|
|
9575
|
+
input0_suffix,
|
|
9576
|
+
True,
|
|
9577
|
+
),
|
|
9578
|
+
(
|
|
9579
|
+
params["input0_scale"],
|
|
9580
|
+
op.input0_scale_dtype.c_type,
|
|
9581
|
+
input0_scale_suffix,
|
|
9582
|
+
True,
|
|
9583
|
+
),
|
|
9584
|
+
(
|
|
9585
|
+
params["input0_zero_point"],
|
|
9586
|
+
op.input0_dtype.c_type,
|
|
9587
|
+
input0_zero_suffix,
|
|
9588
|
+
True,
|
|
9589
|
+
),
|
|
9590
|
+
(
|
|
9591
|
+
params["input1"],
|
|
9592
|
+
op.input1_dtype.c_type,
|
|
9593
|
+
input1_suffix,
|
|
9594
|
+
True,
|
|
9595
|
+
),
|
|
9596
|
+
(
|
|
9597
|
+
params["input1_scale"],
|
|
9598
|
+
op.input1_scale_dtype.c_type,
|
|
9599
|
+
input1_scale_suffix,
|
|
9600
|
+
True,
|
|
9601
|
+
),
|
|
9602
|
+
(
|
|
9603
|
+
params["input1_zero_point"],
|
|
9604
|
+
op.input1_dtype.c_type,
|
|
9605
|
+
input1_zero_suffix,
|
|
9606
|
+
True,
|
|
9607
|
+
),
|
|
9608
|
+
(
|
|
9609
|
+
params["output_scale"],
|
|
9610
|
+
op.output_scale_dtype.c_type,
|
|
9611
|
+
output_scale_suffix,
|
|
9612
|
+
True,
|
|
9613
|
+
),
|
|
9614
|
+
(
|
|
9615
|
+
params["output_zero_point"],
|
|
9616
|
+
op.dtype.c_type,
|
|
9617
|
+
output_zero_suffix,
|
|
9618
|
+
True,
|
|
9619
|
+
),
|
|
9620
|
+
(
|
|
9621
|
+
params["output"],
|
|
9622
|
+
op.dtype.c_type,
|
|
9623
|
+
output_suffix,
|
|
9624
|
+
False,
|
|
9625
|
+
),
|
|
9626
|
+
]
|
|
9627
|
+
)
|
|
9628
|
+
compute_dtype = (
|
|
9629
|
+
ScalarType.F64
|
|
9630
|
+
if ScalarType.F64
|
|
9631
|
+
in {
|
|
9632
|
+
op.input0_scale_dtype,
|
|
9633
|
+
op.input1_scale_dtype,
|
|
9634
|
+
op.output_scale_dtype,
|
|
9635
|
+
}
|
|
9636
|
+
else ScalarType.F32
|
|
9637
|
+
)
|
|
9638
|
+
compute_type = (
|
|
9639
|
+
"double" if compute_dtype == ScalarType.F64 else "float"
|
|
9640
|
+
)
|
|
9641
|
+
max_fn = self._scalar_function_name(
|
|
9642
|
+
ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
|
|
9643
|
+
)
|
|
9644
|
+
min_fn = self._scalar_function_name(
|
|
9645
|
+
ScalarFunction.MINIMUM, compute_dtype, scalar_registry
|
|
9646
|
+
)
|
|
9647
|
+
if max_fn is None or min_fn is None:
|
|
9648
|
+
raise CodegenError(
|
|
9649
|
+
"Failed to resolve scalar min/max functions for QLinearMatMul."
|
|
9650
|
+
)
|
|
9651
|
+
round_fn = CEmitter._math_fn(
|
|
9652
|
+
compute_dtype, "nearbyintf", "nearbyint"
|
|
9653
|
+
)
|
|
9654
|
+
scale_index = "0"
|
|
9655
|
+
rendered = qlinear_matmul_template.render(
|
|
9656
|
+
model_name=model.name,
|
|
9657
|
+
op_name=op_name,
|
|
9658
|
+
input0=params["input0"],
|
|
9659
|
+
input1=params["input1"],
|
|
9660
|
+
input0_scale=params["input0_scale"],
|
|
9661
|
+
input0_zero_point=params["input0_zero_point"],
|
|
9662
|
+
input1_scale=params["input1_scale"],
|
|
9663
|
+
input1_zero_point=params["input1_zero_point"],
|
|
9664
|
+
output_scale=params["output_scale"],
|
|
9665
|
+
output_zero_point=params["output_zero_point"],
|
|
9666
|
+
output=params["output"],
|
|
9667
|
+
params=param_decls,
|
|
9668
|
+
compute_type=compute_type,
|
|
9669
|
+
output_c_type=op.dtype.c_type,
|
|
9670
|
+
input0_index_expr=input0_index_expr,
|
|
9671
|
+
input1_index_expr=input1_index_expr,
|
|
9672
|
+
input0_scale_expr=f"{params['input0_scale']}[{scale_index}]",
|
|
9673
|
+
input1_scale_expr=f"{params['input1_scale']}[{scale_index}]",
|
|
9674
|
+
output_scale_expr=f"{params['output_scale']}[{scale_index}]",
|
|
9675
|
+
input0_zero_expr=f"{params['input0_zero_point']}[{scale_index}]",
|
|
9676
|
+
input1_zero_expr=f"{params['input1_zero_point']}[{scale_index}]",
|
|
9677
|
+
output_zero_expr=f"{params['output_zero_point']}[{scale_index}]",
|
|
9678
|
+
output_loop_vars=output_loop_vars,
|
|
9679
|
+
output_loop_bounds=output_shape,
|
|
9680
|
+
output_index_expr=output_index_expr,
|
|
9681
|
+
k=op.k,
|
|
9682
|
+
round_fn=round_fn,
|
|
9683
|
+
min_literal=op.dtype.min_literal,
|
|
9684
|
+
max_literal=op.dtype.max_literal,
|
|
9685
|
+
min_fn=min_fn,
|
|
9686
|
+
max_fn=max_fn,
|
|
9583
9687
|
dim_args=dim_args,
|
|
9584
9688
|
).rstrip()
|
|
9585
9689
|
return with_node_comment(rendered)
|
|
9586
9690
|
if isinstance(op, ClipOp):
|
|
9691
|
+
if scalar_registry is None:
|
|
9692
|
+
raise CodegenError(
|
|
9693
|
+
"Scalar function registry is required for Clip rendering."
|
|
9694
|
+
)
|
|
9695
|
+
input_shape = self._ctx_shape(op.input0)
|
|
9696
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
9697
|
+
input_dtype = self._ctx_dtype(op.input0)
|
|
9698
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
9699
|
+
min_shape = (
|
|
9700
|
+
self._ctx_shape(op.input_min)
|
|
9701
|
+
if op.input_min is not None
|
|
9702
|
+
else None
|
|
9703
|
+
)
|
|
9704
|
+
max_shape = (
|
|
9705
|
+
self._ctx_shape(op.input_max)
|
|
9706
|
+
if op.input_max is not None
|
|
9707
|
+
else None
|
|
9708
|
+
)
|
|
9709
|
+
min_fn = self._scalar_function_name(
|
|
9710
|
+
ScalarFunction.MINIMUM, input_dtype, scalar_registry
|
|
9711
|
+
)
|
|
9712
|
+
max_fn = self._scalar_function_name(
|
|
9713
|
+
ScalarFunction.MAXIMUM, input_dtype, scalar_registry
|
|
9714
|
+
)
|
|
9715
|
+
if min_fn is None or max_fn is None:
|
|
9716
|
+
raise CodegenError(
|
|
9717
|
+
"Failed to resolve scalar min/max functions for Clip."
|
|
9718
|
+
)
|
|
9587
9719
|
params = self._shared_param_map(
|
|
9588
9720
|
[
|
|
9589
9721
|
("input0", op.input0),
|
|
@@ -9594,61 +9726,61 @@ class CEmitter:
|
|
|
9594
9726
|
)
|
|
9595
9727
|
output_dim_names = _dim_names_for(op.output)
|
|
9596
9728
|
output_shape = CEmitter._shape_dim_exprs(
|
|
9597
|
-
|
|
9729
|
+
output_shape_raw, output_dim_names
|
|
9598
9730
|
)
|
|
9599
|
-
loop_vars = CEmitter._loop_vars(
|
|
9731
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
9600
9732
|
input_expr = CEmitter._broadcast_index_expr(
|
|
9601
9733
|
params["input0"],
|
|
9602
|
-
|
|
9603
|
-
|
|
9734
|
+
input_shape,
|
|
9735
|
+
output_shape_raw,
|
|
9604
9736
|
loop_vars,
|
|
9605
9737
|
)
|
|
9606
9738
|
min_expr = (
|
|
9607
9739
|
CEmitter._broadcast_index_expr(
|
|
9608
9740
|
params["input_min"],
|
|
9609
|
-
|
|
9610
|
-
|
|
9741
|
+
min_shape,
|
|
9742
|
+
output_shape_raw,
|
|
9611
9743
|
loop_vars,
|
|
9612
9744
|
)
|
|
9613
9745
|
if op.input_min is not None
|
|
9614
|
-
else
|
|
9746
|
+
else output_dtype.min_literal
|
|
9615
9747
|
)
|
|
9616
9748
|
max_expr = (
|
|
9617
9749
|
CEmitter._broadcast_index_expr(
|
|
9618
9750
|
params["input_max"],
|
|
9619
|
-
|
|
9620
|
-
|
|
9751
|
+
max_shape,
|
|
9752
|
+
output_shape_raw,
|
|
9621
9753
|
loop_vars,
|
|
9622
9754
|
)
|
|
9623
9755
|
if op.input_max is not None
|
|
9624
|
-
else
|
|
9756
|
+
else output_dtype.max_literal
|
|
9625
9757
|
)
|
|
9626
9758
|
input_suffix = self._param_array_suffix(
|
|
9627
|
-
|
|
9759
|
+
input_shape, _dim_names_for(op.input0)
|
|
9628
9760
|
)
|
|
9629
9761
|
min_suffix = (
|
|
9630
9762
|
self._param_array_suffix(
|
|
9631
|
-
|
|
9763
|
+
min_shape, _dim_names_for(op.input_min)
|
|
9632
9764
|
)
|
|
9633
|
-
if
|
|
9765
|
+
if min_shape is not None
|
|
9634
9766
|
else ""
|
|
9635
9767
|
)
|
|
9636
9768
|
max_suffix = (
|
|
9637
9769
|
self._param_array_suffix(
|
|
9638
|
-
|
|
9770
|
+
max_shape, _dim_names_for(op.input_max)
|
|
9639
9771
|
)
|
|
9640
|
-
if
|
|
9772
|
+
if max_shape is not None
|
|
9641
9773
|
else ""
|
|
9642
9774
|
)
|
|
9643
9775
|
output_suffix = self._param_array_suffix(
|
|
9644
|
-
|
|
9776
|
+
output_shape_raw, output_dim_names
|
|
9645
9777
|
)
|
|
9646
9778
|
param_decls = self._build_param_decls(
|
|
9647
9779
|
[
|
|
9648
|
-
(params["input0"],
|
|
9780
|
+
(params["input0"], input_dtype.c_type, input_suffix, True),
|
|
9649
9781
|
(
|
|
9650
9782
|
params["input_min"],
|
|
9651
|
-
|
|
9783
|
+
input_dtype.c_type,
|
|
9652
9784
|
min_suffix,
|
|
9653
9785
|
True,
|
|
9654
9786
|
)
|
|
@@ -9656,13 +9788,13 @@ class CEmitter:
|
|
|
9656
9788
|
else (None, "", "", True),
|
|
9657
9789
|
(
|
|
9658
9790
|
params["input_max"],
|
|
9659
|
-
|
|
9791
|
+
input_dtype.c_type,
|
|
9660
9792
|
max_suffix,
|
|
9661
9793
|
True,
|
|
9662
9794
|
)
|
|
9663
9795
|
if params["input_max"]
|
|
9664
9796
|
else (None, "", "", True),
|
|
9665
|
-
(params["output"],
|
|
9797
|
+
(params["output"], output_dtype.c_type, output_suffix, False),
|
|
9666
9798
|
]
|
|
9667
9799
|
)
|
|
9668
9800
|
rendered = clip_template.render(
|
|
@@ -9673,8 +9805,8 @@ class CEmitter:
|
|
|
9673
9805
|
input_max=params["input_max"],
|
|
9674
9806
|
output=params["output"],
|
|
9675
9807
|
params=param_decls,
|
|
9676
|
-
input_c_type=
|
|
9677
|
-
output_c_type=
|
|
9808
|
+
input_c_type=input_dtype.c_type,
|
|
9809
|
+
output_c_type=output_dtype.c_type,
|
|
9678
9810
|
input_suffix=input_suffix,
|
|
9679
9811
|
min_suffix=min_suffix,
|
|
9680
9812
|
max_suffix=max_suffix,
|
|
@@ -9684,30 +9816,51 @@ class CEmitter:
|
|
|
9684
9816
|
input_expr=input_expr,
|
|
9685
9817
|
min_expr=min_expr,
|
|
9686
9818
|
max_expr=max_expr,
|
|
9819
|
+
min_fn=min_fn,
|
|
9820
|
+
max_fn=max_fn,
|
|
9687
9821
|
dim_args=dim_args,
|
|
9688
9822
|
).rstrip()
|
|
9689
9823
|
return with_node_comment(rendered)
|
|
9690
9824
|
if isinstance(op, UnaryOp):
|
|
9825
|
+
input_dtype = self._ctx_dtype(op.input0)
|
|
9826
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
9827
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
9691
9828
|
params = self._shared_param_map(
|
|
9692
9829
|
[("input0", op.input0), ("output", op.output)]
|
|
9693
9830
|
)
|
|
9694
9831
|
scalar_operator = None
|
|
9695
9832
|
if scalar_registry is not None:
|
|
9696
9833
|
scalar_operator = self._scalar_function_name(
|
|
9697
|
-
op.function,
|
|
9834
|
+
op.function, input_dtype, scalar_registry, params=op.params
|
|
9698
9835
|
)
|
|
9699
9836
|
output_dim_names = _dim_names_for(op.output)
|
|
9700
|
-
shape = CEmitter._shape_dim_exprs(
|
|
9701
|
-
|
|
9702
|
-
|
|
9837
|
+
shape = CEmitter._shape_dim_exprs(
|
|
9838
|
+
output_shape_raw, output_dim_names
|
|
9839
|
+
)
|
|
9840
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
9841
|
+
array_suffix = self._param_array_suffix(
|
|
9842
|
+
output_shape_raw, output_dim_names
|
|
9843
|
+
)
|
|
9703
9844
|
param_decls = self._build_param_decls(
|
|
9704
9845
|
[
|
|
9705
|
-
(params["input0"],
|
|
9706
|
-
(params["output"],
|
|
9846
|
+
(params["input0"], input_dtype.c_type, array_suffix, True),
|
|
9847
|
+
(params["output"], output_dtype.c_type, array_suffix, False),
|
|
9707
9848
|
]
|
|
9708
9849
|
)
|
|
9709
|
-
operator_symbol = unary_op_symbol(op.function, dtype=
|
|
9710
|
-
if op.function
|
|
9850
|
+
operator_symbol = unary_op_symbol(op.function, dtype=output_dtype)
|
|
9851
|
+
if op.function == ScalarFunction.ISINF and len(op.params) == 2:
|
|
9852
|
+
detect_negative, detect_positive = op.params
|
|
9853
|
+
detect_negative = int(detect_negative)
|
|
9854
|
+
detect_positive = int(detect_positive)
|
|
9855
|
+
if detect_negative and detect_positive:
|
|
9856
|
+
operator_symbol = "isinf"
|
|
9857
|
+
elif detect_negative:
|
|
9858
|
+
operator_symbol = "isneginf"
|
|
9859
|
+
elif detect_positive:
|
|
9860
|
+
operator_symbol = "isposinf"
|
|
9861
|
+
else:
|
|
9862
|
+
operator_symbol = "zero"
|
|
9863
|
+
elif op.function in {ScalarFunction.ISINF, ScalarFunction.ISNAN}:
|
|
9711
9864
|
operator_symbol = (
|
|
9712
9865
|
"isinf" if op.function == ScalarFunction.ISINF else "isnan"
|
|
9713
9866
|
)
|
|
@@ -9722,8 +9875,8 @@ class CEmitter:
|
|
|
9722
9875
|
"array_suffix": array_suffix,
|
|
9723
9876
|
"shape": shape,
|
|
9724
9877
|
"loop_vars": loop_vars,
|
|
9725
|
-
"input_c_type":
|
|
9726
|
-
"output_c_type":
|
|
9878
|
+
"input_c_type": input_dtype.c_type,
|
|
9879
|
+
"output_c_type": output_dtype.c_type,
|
|
9727
9880
|
"zero_literal": zero_literal,
|
|
9728
9881
|
"dim_args": dim_args,
|
|
9729
9882
|
"params": param_decls,
|
|
@@ -9774,6 +9927,7 @@ class CEmitter:
|
|
|
9774
9927
|
| GatherOp
|
|
9775
9928
|
| GatherNDOp
|
|
9776
9929
|
| ScatterNDOp
|
|
9930
|
+
| TensorScatterOp
|
|
9777
9931
|
| TransposeOp
|
|
9778
9932
|
| ReshapeOp
|
|
9779
9933
|
| IdentityOp
|
|
@@ -9803,8 +9957,8 @@ class CEmitter:
|
|
|
9803
9957
|
return op.output_values
|
|
9804
9958
|
return op.output
|
|
9805
9959
|
|
|
9806
|
-
@staticmethod
|
|
9807
9960
|
def _op_inputs(
|
|
9961
|
+
self,
|
|
9808
9962
|
op: BinaryOp
|
|
9809
9963
|
| MultiInputBinaryOp
|
|
9810
9964
|
| WhereOp
|
|
@@ -9840,6 +9994,7 @@ class CEmitter:
|
|
|
9840
9994
|
| GatherOp
|
|
9841
9995
|
| GatherNDOp
|
|
9842
9996
|
| ScatterNDOp
|
|
9997
|
+
| TensorScatterOp
|
|
9843
9998
|
| TransposeOp
|
|
9844
9999
|
| ReshapeOp
|
|
9845
10000
|
| IdentityOp
|
|
@@ -9865,18 +10020,24 @@ class CEmitter:
|
|
|
9865
10020
|
) -> tuple[tuple[str, tuple[int, ...]], ...]:
|
|
9866
10021
|
if isinstance(op, BinaryOp):
|
|
9867
10022
|
return (
|
|
9868
|
-
(op.input0, op.
|
|
9869
|
-
(op.input1, op.
|
|
10023
|
+
(op.input0, self._ctx_shape(op.input0)),
|
|
10024
|
+
(op.input1, self._ctx_shape(op.input1)),
|
|
9870
10025
|
)
|
|
9871
10026
|
if isinstance(op, MultiInputBinaryOp):
|
|
9872
|
-
return tuple((name,
|
|
10027
|
+
return tuple((name, self._ctx_shape(name)) for name in op.inputs)
|
|
10028
|
+
if isinstance(op, WhereOp):
|
|
10029
|
+
return (
|
|
10030
|
+
(op.condition, self._ctx_shape(op.condition)),
|
|
10031
|
+
(op.input_x, self._ctx_shape(op.input_x)),
|
|
10032
|
+
(op.input_y, self._ctx_shape(op.input_y)),
|
|
10033
|
+
)
|
|
9873
10034
|
if isinstance(op, EinsumOp):
|
|
9874
10035
|
return tuple(
|
|
9875
10036
|
(name, shape)
|
|
9876
10037
|
for name, shape in zip(op.inputs, op.input_shapes)
|
|
9877
10038
|
)
|
|
9878
10039
|
if isinstance(op, UnaryOp):
|
|
9879
|
-
return ((op.input0, op.
|
|
10040
|
+
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
9880
10041
|
if isinstance(op, LpNormalizationOp):
|
|
9881
10042
|
return ((op.input0, op.shape),)
|
|
9882
10043
|
if isinstance(op, InstanceNormalizationOp):
|
|
@@ -9901,32 +10062,57 @@ class CEmitter:
|
|
|
9901
10062
|
if isinstance(op, RMSNormalizationOp):
|
|
9902
10063
|
return ((op.input0, op.shape), (op.scale, op.scale_shape))
|
|
9903
10064
|
if isinstance(op, ClipOp):
|
|
9904
|
-
inputs = [(op.input0, op.
|
|
9905
|
-
if op.input_min is not None
|
|
9906
|
-
inputs.append((op.input_min, op.
|
|
9907
|
-
if op.input_max is not None
|
|
9908
|
-
inputs.append((op.input_max, op.
|
|
10065
|
+
inputs = [(op.input0, self._ctx_shape(op.input0))]
|
|
10066
|
+
if op.input_min is not None:
|
|
10067
|
+
inputs.append((op.input_min, self._ctx_shape(op.input_min)))
|
|
10068
|
+
if op.input_max is not None:
|
|
10069
|
+
inputs.append((op.input_max, self._ctx_shape(op.input_max)))
|
|
9909
10070
|
return tuple(inputs)
|
|
9910
10071
|
if isinstance(op, CastOp):
|
|
9911
|
-
return ((op.input0, op.
|
|
10072
|
+
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
9912
10073
|
if isinstance(op, NonZeroOp):
|
|
9913
10074
|
return ((op.input0, op.input_shape),)
|
|
10075
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
10076
|
+
inputs = [
|
|
10077
|
+
(op.boxes, op.boxes_shape),
|
|
10078
|
+
(op.scores, op.scores_shape),
|
|
10079
|
+
]
|
|
10080
|
+
if (
|
|
10081
|
+
op.max_output_boxes_per_class is not None
|
|
10082
|
+
and op.max_output_shape is not None
|
|
10083
|
+
):
|
|
10084
|
+
inputs.append(
|
|
10085
|
+
(op.max_output_boxes_per_class, op.max_output_shape)
|
|
10086
|
+
)
|
|
10087
|
+
if (
|
|
10088
|
+
op.iou_threshold is not None
|
|
10089
|
+
and op.iou_threshold_shape is not None
|
|
10090
|
+
):
|
|
10091
|
+
inputs.append((op.iou_threshold, op.iou_threshold_shape))
|
|
10092
|
+
if (
|
|
10093
|
+
op.score_threshold is not None
|
|
10094
|
+
and op.score_threshold_shape is not None
|
|
10095
|
+
):
|
|
10096
|
+
inputs.append(
|
|
10097
|
+
(op.score_threshold, op.score_threshold_shape)
|
|
10098
|
+
)
|
|
10099
|
+
return tuple(inputs)
|
|
9914
10100
|
if isinstance(op, QuantizeLinearOp):
|
|
9915
10101
|
scale_shape = (
|
|
9916
10102
|
()
|
|
9917
10103
|
if op.axis is None
|
|
9918
|
-
else (op.
|
|
10104
|
+
else (self._ctx_shape(op.input0)[op.axis],)
|
|
9919
10105
|
)
|
|
9920
|
-
inputs = [(op.input0, op.
|
|
10106
|
+
inputs = [(op.input0, self._ctx_shape(op.input0)), (op.scale, scale_shape)]
|
|
9921
10107
|
if op.zero_point is not None:
|
|
9922
10108
|
inputs.append((op.zero_point, scale_shape))
|
|
9923
10109
|
return tuple(inputs)
|
|
9924
10110
|
if isinstance(op, IdentityOp):
|
|
9925
|
-
return ((op.input0, op.
|
|
10111
|
+
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
9926
10112
|
if isinstance(op, EyeLikeOp):
|
|
9927
10113
|
return ((op.input0, op.output_shape),)
|
|
9928
10114
|
if isinstance(op, TriluOp):
|
|
9929
|
-
inputs = [(op.input0, op.
|
|
10115
|
+
inputs = [(op.input0, self._ctx_shape(op.input0))]
|
|
9930
10116
|
if op.k_input is not None and op.k_input_shape is not None:
|
|
9931
10117
|
inputs.append((op.k_input, op.k_input_shape))
|
|
9932
10118
|
return tuple(inputs)
|
|
@@ -9943,6 +10129,14 @@ class CEmitter:
|
|
|
9943
10129
|
return tuple(inputs)
|
|
9944
10130
|
if isinstance(op, ScatterNDOp):
|
|
9945
10131
|
return ((op.data, op.data_shape),)
|
|
10132
|
+
if isinstance(op, TensorScatterOp):
|
|
10133
|
+
inputs = [
|
|
10134
|
+
(op.past_cache, op.past_cache_shape),
|
|
10135
|
+
(op.update, op.update_shape),
|
|
10136
|
+
]
|
|
10137
|
+
if op.write_indices is not None and op.write_indices_shape is not None:
|
|
10138
|
+
inputs.append((op.write_indices, op.write_indices_shape))
|
|
10139
|
+
return tuple(inputs)
|
|
9946
10140
|
if isinstance(op, CumSumOp):
|
|
9947
10141
|
return ((op.input0, op.input_shape),)
|
|
9948
10142
|
if isinstance(op, RangeOp):
|
|
@@ -9956,7 +10150,9 @@ class CEmitter:
|
|
|
9956
10150
|
if isinstance(op, SplitOp):
|
|
9957
10151
|
return ((op.input0, op.input_shape),)
|
|
9958
10152
|
if isinstance(op, TopKOp):
|
|
9959
|
-
return ((op.input0, op.
|
|
10153
|
+
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
10154
|
+
if isinstance(op, (TransposeOp, ReshapeOp, ReduceOp, ArgReduceOp)):
|
|
10155
|
+
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
9960
10156
|
return ()
|
|
9961
10157
|
|
|
9962
10158
|
def _propagate_tensor_dim_names(
|
|
@@ -10014,6 +10210,7 @@ class CEmitter:
|
|
|
10014
10210
|
| ShapeOp
|
|
10015
10211
|
| SizeOp
|
|
10016
10212
|
| NonZeroOp
|
|
10213
|
+
| NonMaxSuppressionOp
|
|
10017
10214
|
| ExpandOp
|
|
10018
10215
|
| RangeOp
|
|
10019
10216
|
| OneHotOp
|
|
@@ -10031,8 +10228,8 @@ class CEmitter:
|
|
|
10031
10228
|
tensor_dim_names[output_name] = dict(dim_names)
|
|
10032
10229
|
break
|
|
10033
10230
|
|
|
10034
|
-
@staticmethod
|
|
10035
10231
|
def _op_outputs(
|
|
10232
|
+
self,
|
|
10036
10233
|
op: BinaryOp
|
|
10037
10234
|
| MultiInputBinaryOp
|
|
10038
10235
|
| WhereOp
|
|
@@ -10068,6 +10265,7 @@ class CEmitter:
|
|
|
10068
10265
|
| GatherOp
|
|
10069
10266
|
| GatherNDOp
|
|
10070
10267
|
| ScatterNDOp
|
|
10268
|
+
| TensorScatterOp
|
|
10071
10269
|
| TransposeOp
|
|
10072
10270
|
| ReshapeOp
|
|
10073
10271
|
| IdentityOp
|
|
@@ -10086,14 +10284,40 @@ class CEmitter:
|
|
|
10086
10284
|
| ShapeOp
|
|
10087
10285
|
| SizeOp
|
|
10088
10286
|
| NonZeroOp
|
|
10287
|
+
| NonMaxSuppressionOp
|
|
10089
10288
|
| ExpandOp
|
|
10090
10289
|
| RangeOp
|
|
10091
10290
|
| OneHotOp
|
|
10092
10291
|
| SplitOp,
|
|
10093
|
-
) -> tuple[tuple[str, tuple[int, ...],
|
|
10292
|
+
) -> tuple[tuple[str, tuple[int, ...], ScalarType], ...]:
|
|
10293
|
+
if isinstance(
|
|
10294
|
+
op,
|
|
10295
|
+
(
|
|
10296
|
+
BinaryOp,
|
|
10297
|
+
MultiInputBinaryOp,
|
|
10298
|
+
WhereOp,
|
|
10299
|
+
UnaryOp,
|
|
10300
|
+
ClipOp,
|
|
10301
|
+
CastOp,
|
|
10302
|
+
TransposeOp,
|
|
10303
|
+
ReshapeOp,
|
|
10304
|
+
IdentityOp,
|
|
10305
|
+
SoftmaxOp,
|
|
10306
|
+
LogSoftmaxOp,
|
|
10307
|
+
HardmaxOp,
|
|
10308
|
+
ReduceOp,
|
|
10309
|
+
),
|
|
10310
|
+
):
|
|
10311
|
+
return (
|
|
10312
|
+
(
|
|
10313
|
+
op.output,
|
|
10314
|
+
self._op_output_shape(op),
|
|
10315
|
+
self._ctx_dtype(op.output),
|
|
10316
|
+
),
|
|
10317
|
+
)
|
|
10094
10318
|
if isinstance(op, AttentionOp):
|
|
10095
|
-
outputs: list[tuple[str, tuple[int, ...],
|
|
10096
|
-
(op.output,
|
|
10319
|
+
outputs: list[tuple[str, tuple[int, ...], ScalarType]] = [
|
|
10320
|
+
(op.output, self._op_output_shape(op), op.dtype)
|
|
10097
10321
|
]
|
|
10098
10322
|
if op.output_present_key is not None:
|
|
10099
10323
|
outputs.append(
|
|
@@ -10121,7 +10345,7 @@ class CEmitter:
|
|
|
10121
10345
|
)
|
|
10122
10346
|
return tuple(outputs)
|
|
10123
10347
|
if isinstance(op, LstmOp):
|
|
10124
|
-
outputs: list[tuple[str, tuple[int, ...],
|
|
10348
|
+
outputs: list[tuple[str, tuple[int, ...], ScalarType]] = []
|
|
10125
10349
|
if op.output_y is not None:
|
|
10126
10350
|
if op.layout == 0:
|
|
10127
10351
|
y_shape = (
|
|
@@ -10155,13 +10379,25 @@ class CEmitter:
|
|
|
10155
10379
|
)
|
|
10156
10380
|
)
|
|
10157
10381
|
return tuple(outputs)
|
|
10382
|
+
if isinstance(op, AdagradOp):
|
|
10383
|
+
outputs = [
|
|
10384
|
+
(name, shape, op.dtype)
|
|
10385
|
+
for name, shape in zip(op.outputs, op.output_shapes)
|
|
10386
|
+
]
|
|
10387
|
+
outputs.extend(
|
|
10388
|
+
(name, shape, op.dtype)
|
|
10389
|
+
for name, shape in zip(
|
|
10390
|
+
op.accumulator_outputs, op.output_shapes
|
|
10391
|
+
)
|
|
10392
|
+
)
|
|
10393
|
+
return tuple(outputs)
|
|
10158
10394
|
if isinstance(op, SoftmaxCrossEntropyLossOp):
|
|
10159
10395
|
outputs = [(op.output, op.output_shape, op.dtype)]
|
|
10160
10396
|
if op.log_prob is not None and op.log_prob_shape is not None:
|
|
10161
10397
|
outputs.append((op.log_prob, op.log_prob_shape, op.dtype))
|
|
10162
10398
|
return tuple(outputs)
|
|
10163
10399
|
if isinstance(op, LayerNormalizationOp):
|
|
10164
|
-
outputs: list[tuple[str, tuple[int, ...],
|
|
10400
|
+
outputs: list[tuple[str, tuple[int, ...], ScalarType]] = [
|
|
10165
10401
|
(op.output, op.shape, op.dtype)
|
|
10166
10402
|
]
|
|
10167
10403
|
if op.mean_output is not None:
|
|
@@ -10172,10 +10408,10 @@ class CEmitter:
|
|
|
10172
10408
|
outputs.append((op.invstd_output, invstd_shape, op.dtype))
|
|
10173
10409
|
return tuple(outputs)
|
|
10174
10410
|
if isinstance(op, MaxPoolOp):
|
|
10175
|
-
outputs = [(op.output,
|
|
10411
|
+
outputs = [(op.output, self._op_output_shape(op), op.dtype)]
|
|
10176
10412
|
if op.indices is not None and op.indices_dtype is not None:
|
|
10177
10413
|
outputs.append(
|
|
10178
|
-
(op.indices,
|
|
10414
|
+
(op.indices, self._op_output_shape(op), op.indices_dtype)
|
|
10179
10415
|
)
|
|
10180
10416
|
return tuple(outputs)
|
|
10181
10417
|
if isinstance(op, SplitOp):
|
|
@@ -10184,30 +10420,40 @@ class CEmitter:
|
|
|
10184
10420
|
for name, shape in zip(op.outputs, op.output_shapes)
|
|
10185
10421
|
)
|
|
10186
10422
|
if isinstance(op, ArgReduceOp):
|
|
10187
|
-
return (
|
|
10423
|
+
return (
|
|
10424
|
+
(
|
|
10425
|
+
op.output,
|
|
10426
|
+
self._op_output_shape(op),
|
|
10427
|
+
self._ctx_dtype(op.output),
|
|
10428
|
+
),
|
|
10429
|
+
)
|
|
10188
10430
|
if isinstance(op, TopKOp):
|
|
10189
10431
|
return (
|
|
10190
10432
|
(
|
|
10191
10433
|
op.output_values,
|
|
10192
|
-
|
|
10193
|
-
op.
|
|
10434
|
+
self._op_output_shape(op),
|
|
10435
|
+
self._ctx_dtype(op.output_values),
|
|
10194
10436
|
),
|
|
10195
10437
|
(
|
|
10196
10438
|
op.output_indices,
|
|
10197
|
-
|
|
10198
|
-
op.
|
|
10439
|
+
self._op_output_shape(op),
|
|
10440
|
+
self._ctx_dtype(op.output_indices),
|
|
10199
10441
|
),
|
|
10200
10442
|
)
|
|
10201
|
-
|
|
10443
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
10444
|
+
return ((op.output, op.output_shape, op.output_dtype),)
|
|
10445
|
+
return ((op.output, self._op_output_shape(op), op.dtype),)
|
|
10202
10446
|
|
|
10203
|
-
@staticmethod
|
|
10204
10447
|
def _op_output_shape(
|
|
10448
|
+
self,
|
|
10205
10449
|
op: BinaryOp
|
|
10206
10450
|
| MultiInputBinaryOp
|
|
10207
10451
|
| WhereOp
|
|
10208
10452
|
| UnaryOp
|
|
10209
10453
|
| ClipOp
|
|
10210
10454
|
| CastOp
|
|
10455
|
+
| QuantizeLinearOp
|
|
10456
|
+
| QLinearMatMulOp
|
|
10211
10457
|
| MatMulOp
|
|
10212
10458
|
| EinsumOp
|
|
10213
10459
|
| GemmOp
|
|
@@ -10249,6 +10495,7 @@ class CEmitter:
|
|
|
10249
10495
|
| ShapeOp
|
|
10250
10496
|
| SizeOp
|
|
10251
10497
|
| NonZeroOp
|
|
10498
|
+
| NonMaxSuppressionOp
|
|
10252
10499
|
| ExpandOp
|
|
10253
10500
|
| CumSumOp
|
|
10254
10501
|
| RangeOp
|
|
@@ -10257,19 +10504,21 @@ class CEmitter:
|
|
|
10257
10504
|
| PadOp,
|
|
10258
10505
|
) -> tuple[int, ...]:
|
|
10259
10506
|
if isinstance(op, BinaryOp):
|
|
10260
|
-
return op.
|
|
10507
|
+
return self._ctx_shape(op.output)
|
|
10261
10508
|
if isinstance(op, MultiInputBinaryOp):
|
|
10262
|
-
return op.
|
|
10509
|
+
return self._ctx_shape(op.output)
|
|
10263
10510
|
if isinstance(op, WhereOp):
|
|
10264
|
-
return op.
|
|
10511
|
+
return self._ctx_shape(op.output)
|
|
10265
10512
|
if isinstance(op, UnaryOp):
|
|
10266
|
-
return op.
|
|
10513
|
+
return self._ctx_shape(op.output)
|
|
10267
10514
|
if isinstance(op, ClipOp):
|
|
10268
|
-
return op.
|
|
10515
|
+
return self._ctx_shape(op.output)
|
|
10269
10516
|
if isinstance(op, QuantizeLinearOp):
|
|
10270
10517
|
return op.input_shape
|
|
10271
10518
|
if isinstance(op, CastOp):
|
|
10272
|
-
return op.
|
|
10519
|
+
return self._ctx_shape(op.output)
|
|
10520
|
+
if isinstance(op, QLinearMatMulOp):
|
|
10521
|
+
return op.output_shape
|
|
10273
10522
|
if isinstance(op, MatMulOp):
|
|
10274
10523
|
return op.output_shape
|
|
10275
10524
|
if isinstance(op, EinsumOp):
|
|
@@ -10301,11 +10550,11 @@ class CEmitter:
|
|
|
10301
10550
|
if isinstance(op, LrnOp):
|
|
10302
10551
|
return op.shape
|
|
10303
10552
|
if isinstance(op, SoftmaxOp):
|
|
10304
|
-
return op.
|
|
10553
|
+
return self._ctx_shape(op.output)
|
|
10305
10554
|
if isinstance(op, LogSoftmaxOp):
|
|
10306
|
-
return op.
|
|
10555
|
+
return self._ctx_shape(op.output)
|
|
10307
10556
|
if isinstance(op, HardmaxOp):
|
|
10308
|
-
return op.
|
|
10557
|
+
return self._ctx_shape(op.output)
|
|
10309
10558
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
10310
10559
|
return op.output_shape
|
|
10311
10560
|
if isinstance(op, SoftmaxCrossEntropyLossOp):
|
|
@@ -10322,12 +10571,14 @@ class CEmitter:
|
|
|
10322
10571
|
return op.output_shape
|
|
10323
10572
|
if isinstance(op, ScatterNDOp):
|
|
10324
10573
|
return op.output_shape
|
|
10325
|
-
if isinstance(op,
|
|
10574
|
+
if isinstance(op, TensorScatterOp):
|
|
10326
10575
|
return op.output_shape
|
|
10576
|
+
if isinstance(op, TransposeOp):
|
|
10577
|
+
return self._ctx_shape(op.output)
|
|
10327
10578
|
if isinstance(op, ReshapeOp):
|
|
10328
|
-
return op.
|
|
10579
|
+
return self._ctx_shape(op.output)
|
|
10329
10580
|
if isinstance(op, IdentityOp):
|
|
10330
|
-
return op.
|
|
10581
|
+
return self._ctx_shape(op.output)
|
|
10331
10582
|
if isinstance(op, EyeLikeOp):
|
|
10332
10583
|
return op.output_shape
|
|
10333
10584
|
if isinstance(op, TriluOp):
|
|
@@ -10347,11 +10598,11 @@ class CEmitter:
|
|
|
10347
10598
|
if isinstance(op, GridSampleOp):
|
|
10348
10599
|
return op.output_shape
|
|
10349
10600
|
if isinstance(op, ReduceOp):
|
|
10350
|
-
return op.
|
|
10601
|
+
return self._ctx_shape(op.output)
|
|
10351
10602
|
if isinstance(op, ArgReduceOp):
|
|
10352
|
-
return op.
|
|
10603
|
+
return self._ctx_shape(op.output)
|
|
10353
10604
|
if isinstance(op, TopKOp):
|
|
10354
|
-
return op.
|
|
10605
|
+
return self._ctx_shape(op.output_values)
|
|
10355
10606
|
if isinstance(op, ConstantOfShapeOp):
|
|
10356
10607
|
return op.shape
|
|
10357
10608
|
if isinstance(op, ShapeOp):
|
|
@@ -10360,6 +10611,8 @@ class CEmitter:
|
|
|
10360
10611
|
return op.output_shape
|
|
10361
10612
|
if isinstance(op, NonZeroOp):
|
|
10362
10613
|
return op.output_shape
|
|
10614
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
10615
|
+
return op.output_shape
|
|
10363
10616
|
if isinstance(op, ExpandOp):
|
|
10364
10617
|
return op.output_shape
|
|
10365
10618
|
if isinstance(op, CumSumOp):
|
|
@@ -10372,8 +10625,8 @@ class CEmitter:
|
|
|
10372
10625
|
return (op.batch, op.q_seq, op.q_heads * op.v_head_size)
|
|
10373
10626
|
return (op.batch, op.q_heads, op.q_seq, op.v_head_size)
|
|
10374
10627
|
|
|
10375
|
-
@staticmethod
|
|
10376
10628
|
def _op_output_dtype(
|
|
10629
|
+
self,
|
|
10377
10630
|
op: BinaryOp
|
|
10378
10631
|
| MultiInputBinaryOp
|
|
10379
10632
|
| WhereOp
|
|
@@ -10399,6 +10652,7 @@ class CEmitter:
|
|
|
10399
10652
|
| SoftmaxOp
|
|
10400
10653
|
| LogSoftmaxOp
|
|
10401
10654
|
| HardmaxOp
|
|
10655
|
+
| AdagradOp
|
|
10402
10656
|
| NegativeLogLikelihoodLossOp
|
|
10403
10657
|
| SoftmaxCrossEntropyLossOp
|
|
10404
10658
|
| MaxPoolOp
|
|
@@ -10420,6 +10674,7 @@ class CEmitter:
|
|
|
10420
10674
|
| ShapeOp
|
|
10421
10675
|
| SizeOp
|
|
10422
10676
|
| NonZeroOp
|
|
10677
|
+
| NonMaxSuppressionOp
|
|
10423
10678
|
| ExpandOp
|
|
10424
10679
|
| CumSumOp
|
|
10425
10680
|
| RangeOp
|
|
@@ -10428,9 +10683,30 @@ class CEmitter:
|
|
|
10428
10683
|
| PadOp,
|
|
10429
10684
|
) -> ScalarType:
|
|
10430
10685
|
if isinstance(op, ArgReduceOp):
|
|
10431
|
-
return op.
|
|
10686
|
+
return self._ctx_dtype(op.output)
|
|
10432
10687
|
if isinstance(op, TopKOp):
|
|
10433
|
-
return op.
|
|
10688
|
+
return self._ctx_dtype(op.output_values)
|
|
10689
|
+
if isinstance(op, NonMaxSuppressionOp):
|
|
10690
|
+
return op.output_dtype
|
|
10691
|
+
if isinstance(
|
|
10692
|
+
op,
|
|
10693
|
+
(
|
|
10694
|
+
BinaryOp,
|
|
10695
|
+
MultiInputBinaryOp,
|
|
10696
|
+
WhereOp,
|
|
10697
|
+
UnaryOp,
|
|
10698
|
+
ClipOp,
|
|
10699
|
+
CastOp,
|
|
10700
|
+
SoftmaxOp,
|
|
10701
|
+
LogSoftmaxOp,
|
|
10702
|
+
HardmaxOp,
|
|
10703
|
+
TransposeOp,
|
|
10704
|
+
ReshapeOp,
|
|
10705
|
+
IdentityOp,
|
|
10706
|
+
ReduceOp,
|
|
10707
|
+
),
|
|
10708
|
+
):
|
|
10709
|
+
return self._ctx_dtype(op.output)
|
|
10434
10710
|
return op.dtype
|
|
10435
10711
|
|
|
10436
10712
|
@staticmethod
|
|
@@ -10815,7 +11091,7 @@ class CEmitter:
|
|
|
10815
11091
|
self, constants: tuple[ConstTensor, ...]
|
|
10816
11092
|
) -> tuple[tuple[ConstTensor, ...], tuple[ConstTensor, ...]]:
|
|
10817
11093
|
if self._large_weight_threshold <= 0:
|
|
10818
|
-
return ()
|
|
11094
|
+
return constants, ()
|
|
10819
11095
|
inline: list[ConstTensor] = []
|
|
10820
11096
|
large: list[ConstTensor] = []
|
|
10821
11097
|
for const in constants:
|