emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -9,7 +9,7 @@ import re
|
|
|
9
9
|
import struct
|
|
10
10
|
from typing import Mapping, Sequence
|
|
11
11
|
|
|
12
|
-
from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
|
|
12
|
+
from jinja2 import Environment, FileSystemLoader, PackageLoader, Template, select_autoescape
|
|
13
13
|
import numpy as np
|
|
14
14
|
|
|
15
15
|
from ..errors import CodegenError
|
|
@@ -37,15 +37,18 @@ from ..ir.ops import (
|
|
|
37
37
|
AttentionOp,
|
|
38
38
|
AveragePoolOp,
|
|
39
39
|
BatchNormOp,
|
|
40
|
+
BernoulliOp,
|
|
40
41
|
BinaryOp,
|
|
41
42
|
CastOp,
|
|
42
43
|
ClipOp,
|
|
43
44
|
ConcatOp,
|
|
44
45
|
ConstantOfShapeOp,
|
|
45
46
|
ConvOp,
|
|
47
|
+
ConvIntegerOp,
|
|
46
48
|
ConvTransposeOp,
|
|
47
49
|
CumSumOp,
|
|
48
50
|
DepthToSpaceOp,
|
|
51
|
+
DequantizeLinearOp,
|
|
49
52
|
EinsumKind,
|
|
50
53
|
EinsumOp,
|
|
51
54
|
ExpandOp,
|
|
@@ -54,8 +57,10 @@ from ..ir.ops import (
|
|
|
54
57
|
GatherNDOp,
|
|
55
58
|
GatherOp,
|
|
56
59
|
GemmOp,
|
|
60
|
+
GruOp,
|
|
57
61
|
GridSampleOp,
|
|
58
62
|
GroupNormalizationOp,
|
|
63
|
+
HammingWindowOp,
|
|
59
64
|
HardmaxOp,
|
|
60
65
|
IdentityOp,
|
|
61
66
|
InstanceNormalizationOp,
|
|
@@ -73,8 +78,11 @@ from ..ir.ops import (
|
|
|
73
78
|
NonMaxSuppressionOp,
|
|
74
79
|
NonZeroOp,
|
|
75
80
|
OneHotOp,
|
|
81
|
+
OptionalHasElementOp,
|
|
76
82
|
PadOp,
|
|
77
83
|
QuantizeLinearOp,
|
|
84
|
+
PowOp,
|
|
85
|
+
QLinearMulOp,
|
|
78
86
|
QLinearMatMulOp,
|
|
79
87
|
RangeOp,
|
|
80
88
|
ReduceOp,
|
|
@@ -91,6 +99,7 @@ from ..ir.ops import (
|
|
|
91
99
|
SpaceToDepthOp,
|
|
92
100
|
SplitOp,
|
|
93
101
|
TensorScatterOp,
|
|
102
|
+
TfIdfVectorizerOp,
|
|
94
103
|
TileOp,
|
|
95
104
|
TopKOp,
|
|
96
105
|
TransposeOp,
|
|
@@ -278,9 +287,11 @@ class ModelHeader:
|
|
|
278
287
|
class LoweredModel:
|
|
279
288
|
name: str
|
|
280
289
|
input_names: tuple[str, ...]
|
|
290
|
+
input_optional_names: tuple[str | None, ...]
|
|
281
291
|
input_shapes: tuple[tuple[int, ...], ...]
|
|
282
292
|
input_dtypes: tuple[ScalarType, ...]
|
|
283
293
|
output_names: tuple[str, ...]
|
|
294
|
+
output_optional_names: tuple[str | None, ...]
|
|
284
295
|
output_shapes: tuple[tuple[int, ...], ...]
|
|
285
296
|
output_dtypes: tuple[ScalarType, ...]
|
|
286
297
|
constants: tuple[ConstTensor, ...]
|
|
@@ -304,20 +315,37 @@ class _EmitState:
|
|
|
304
315
|
class CEmitter:
|
|
305
316
|
def __init__(
|
|
306
317
|
self,
|
|
307
|
-
template_dir: Path,
|
|
318
|
+
template_dir: Path | None,
|
|
308
319
|
*,
|
|
309
320
|
restrict_arrays: bool = True,
|
|
321
|
+
fp32_accumulation_strategy: str = "fp64",
|
|
322
|
+
fp16_accumulation_strategy: str = "fp32",
|
|
310
323
|
truncate_weights_after: int | None = None,
|
|
311
324
|
large_temp_threshold_bytes: int = 1024,
|
|
312
325
|
large_weight_threshold: int = 1024,
|
|
313
326
|
) -> None:
|
|
327
|
+
loader = (
|
|
328
|
+
FileSystemLoader(str(template_dir))
|
|
329
|
+
if template_dir is not None
|
|
330
|
+
else PackageLoader("emx_onnx_cgen", "templates")
|
|
331
|
+
)
|
|
314
332
|
self._env = Environment(
|
|
315
|
-
loader=
|
|
333
|
+
loader=loader,
|
|
316
334
|
autoescape=select_autoescape(enabled_extensions=()),
|
|
317
335
|
trim_blocks=True,
|
|
318
336
|
lstrip_blocks=True,
|
|
319
337
|
)
|
|
320
338
|
self._restrict_arrays = restrict_arrays
|
|
339
|
+
if fp32_accumulation_strategy not in {"simple", "fp64"}:
|
|
340
|
+
raise CodegenError(
|
|
341
|
+
"fp32_accumulation_strategy must be 'simple' or 'fp64'"
|
|
342
|
+
)
|
|
343
|
+
self._fp32_accumulation_strategy = fp32_accumulation_strategy
|
|
344
|
+
if fp16_accumulation_strategy not in {"simple", "fp32"}:
|
|
345
|
+
raise CodegenError(
|
|
346
|
+
"fp16_accumulation_strategy must be 'simple' or 'fp32'"
|
|
347
|
+
)
|
|
348
|
+
self._fp16_accumulation_strategy = fp16_accumulation_strategy
|
|
321
349
|
if truncate_weights_after is not None and truncate_weights_after < 1:
|
|
322
350
|
raise CodegenError("truncate_weights_after must be >= 1")
|
|
323
351
|
self._truncate_weights_after = truncate_weights_after
|
|
@@ -390,6 +418,21 @@ class CEmitter:
|
|
|
390
418
|
mapped[key] = unique
|
|
391
419
|
return mapped
|
|
392
420
|
|
|
421
|
+
def _accumulation_dtype(self, dtype: ScalarType) -> ScalarType:
|
|
422
|
+
if dtype == ScalarType.F32:
|
|
423
|
+
return (
|
|
424
|
+
ScalarType.F32
|
|
425
|
+
if self._fp32_accumulation_strategy == "simple"
|
|
426
|
+
else ScalarType.F64
|
|
427
|
+
)
|
|
428
|
+
if dtype == ScalarType.F16:
|
|
429
|
+
return (
|
|
430
|
+
ScalarType.F16
|
|
431
|
+
if self._fp16_accumulation_strategy == "simple"
|
|
432
|
+
else ScalarType.F32
|
|
433
|
+
)
|
|
434
|
+
return dtype
|
|
435
|
+
|
|
393
436
|
def _ctx_name(self, name: str) -> str:
|
|
394
437
|
if self._emit_state is None:
|
|
395
438
|
raise CodegenError("Emitter state not initialized")
|
|
@@ -410,6 +453,12 @@ class CEmitter:
|
|
|
410
453
|
raise CodegenError("Emitter state not initialized")
|
|
411
454
|
return self._emit_state.op_context.require_derived(op, key)
|
|
412
455
|
|
|
456
|
+
def _maybe_derived(self, op: OpBase, key: str) -> object | None:
|
|
457
|
+
if self._emit_state is None:
|
|
458
|
+
raise CodegenError("Emitter state not initialized")
|
|
459
|
+
value = self._emit_state.op_context.get_derived(op, key, None)
|
|
460
|
+
return value
|
|
461
|
+
|
|
413
462
|
@staticmethod
|
|
414
463
|
def _build_param_decls(
|
|
415
464
|
specs: Sequence[tuple[str | None, str, str, bool]]
|
|
@@ -447,6 +496,8 @@ class CEmitter:
|
|
|
447
496
|
| ClipOp
|
|
448
497
|
| CastOp
|
|
449
498
|
| QuantizeLinearOp
|
|
499
|
+
| DequantizeLinearOp
|
|
500
|
+
| QLinearMulOp
|
|
450
501
|
| QLinearMatMulOp
|
|
451
502
|
| MatMulOp
|
|
452
503
|
| EinsumOp
|
|
@@ -454,6 +505,7 @@ class CEmitter:
|
|
|
454
505
|
| AttentionOp
|
|
455
506
|
| RotaryEmbeddingOp
|
|
456
507
|
| ConvOp
|
|
508
|
+
| ConvIntegerOp
|
|
457
509
|
| AveragePoolOp
|
|
458
510
|
| BatchNormOp
|
|
459
511
|
| LpNormalizationOp
|
|
@@ -463,6 +515,7 @@ class CEmitter:
|
|
|
463
515
|
| MeanVarianceNormalizationOp
|
|
464
516
|
| RMSNormalizationOp
|
|
465
517
|
| LrnOp
|
|
518
|
+
| GruOp
|
|
466
519
|
| LstmOp
|
|
467
520
|
| AdagradOp
|
|
468
521
|
| SoftmaxOp
|
|
@@ -480,6 +533,7 @@ class CEmitter:
|
|
|
480
533
|
| TransposeOp
|
|
481
534
|
| ReshapeOp
|
|
482
535
|
| IdentityOp
|
|
536
|
+
| BernoulliOp
|
|
483
537
|
| EyeLikeOp
|
|
484
538
|
| TriluOp
|
|
485
539
|
| TileOp
|
|
@@ -495,11 +549,13 @@ class CEmitter:
|
|
|
495
549
|
| ConstantOfShapeOp
|
|
496
550
|
| ShapeOp
|
|
497
551
|
| SizeOp
|
|
552
|
+
| OptionalHasElementOp
|
|
498
553
|
| NonZeroOp
|
|
499
554
|
| NonMaxSuppressionOp
|
|
500
555
|
| ExpandOp
|
|
501
556
|
| CumSumOp
|
|
502
557
|
| RangeOp
|
|
558
|
+
| HammingWindowOp
|
|
503
559
|
| OneHotOp
|
|
504
560
|
| SplitOp,
|
|
505
561
|
) -> tuple[str, ...]:
|
|
@@ -527,6 +583,24 @@ class CEmitter:
|
|
|
527
583
|
names.append(op.zero_point)
|
|
528
584
|
names.append(op.output)
|
|
529
585
|
return tuple(names)
|
|
586
|
+
if isinstance(op, DequantizeLinearOp):
|
|
587
|
+
names = [op.input0, op.scale]
|
|
588
|
+
if op.zero_point is not None:
|
|
589
|
+
names.append(op.zero_point)
|
|
590
|
+
names.append(op.output)
|
|
591
|
+
return tuple(names)
|
|
592
|
+
if isinstance(op, QLinearMulOp):
|
|
593
|
+
return (
|
|
594
|
+
op.input0,
|
|
595
|
+
op.input0_scale,
|
|
596
|
+
op.input0_zero_point,
|
|
597
|
+
op.input1,
|
|
598
|
+
op.input1_scale,
|
|
599
|
+
op.input1_zero_point,
|
|
600
|
+
op.output_scale,
|
|
601
|
+
op.output_zero_point,
|
|
602
|
+
op.output,
|
|
603
|
+
)
|
|
530
604
|
if isinstance(op, QLinearMatMulOp):
|
|
531
605
|
return (
|
|
532
606
|
op.input0,
|
|
@@ -579,6 +653,14 @@ class CEmitter:
|
|
|
579
653
|
names.append(op.bias)
|
|
580
654
|
names.append(op.output)
|
|
581
655
|
return tuple(names)
|
|
656
|
+
if isinstance(op, ConvIntegerOp):
|
|
657
|
+
names = [op.input0, op.weights]
|
|
658
|
+
if op.x_zero_point is not None:
|
|
659
|
+
names.append(op.x_zero_point)
|
|
660
|
+
if op.w_zero_point is not None:
|
|
661
|
+
names.append(op.w_zero_point)
|
|
662
|
+
names.append(op.output)
|
|
663
|
+
return tuple(names)
|
|
582
664
|
if isinstance(op, ConvTransposeOp):
|
|
583
665
|
names = [op.input0, op.weights]
|
|
584
666
|
if op.bias is not None:
|
|
@@ -611,6 +693,19 @@ class CEmitter:
|
|
|
611
693
|
return (op.input0, op.output)
|
|
612
694
|
if isinstance(op, RMSNormalizationOp):
|
|
613
695
|
return (op.input0, op.scale, op.output)
|
|
696
|
+
if isinstance(op, GruOp):
|
|
697
|
+
names = [op.input_x, op.input_w, op.input_r]
|
|
698
|
+
if op.input_b is not None:
|
|
699
|
+
names.append(op.input_b)
|
|
700
|
+
if op.input_sequence_lens is not None:
|
|
701
|
+
names.append(op.input_sequence_lens)
|
|
702
|
+
if op.input_initial_h is not None:
|
|
703
|
+
names.append(op.input_initial_h)
|
|
704
|
+
if op.output_y is not None:
|
|
705
|
+
names.append(op.output_y)
|
|
706
|
+
if op.output_y_h is not None:
|
|
707
|
+
names.append(op.output_y_h)
|
|
708
|
+
return tuple(names)
|
|
614
709
|
if isinstance(op, LstmOp):
|
|
615
710
|
names = [op.input_x, op.input_w, op.input_r]
|
|
616
711
|
if op.input_b is not None:
|
|
@@ -705,14 +800,20 @@ class CEmitter:
|
|
|
705
800
|
return tuple(names)
|
|
706
801
|
if isinstance(op, RangeOp):
|
|
707
802
|
return (op.start, op.limit, op.delta, op.output)
|
|
803
|
+
if isinstance(op, HammingWindowOp):
|
|
804
|
+
return (op.size, op.output)
|
|
708
805
|
if isinstance(op, OneHotOp):
|
|
709
806
|
return (op.indices, op.depth, op.values, op.output)
|
|
807
|
+
if isinstance(op, TfIdfVectorizerOp):
|
|
808
|
+
return (op.input0, op.output)
|
|
710
809
|
if isinstance(op, SplitOp):
|
|
711
810
|
return (op.input0, *op.outputs)
|
|
712
811
|
if isinstance(op, ReshapeOp):
|
|
713
812
|
return (op.input0, op.output)
|
|
714
813
|
if isinstance(op, IdentityOp):
|
|
715
814
|
return (op.input0, op.output)
|
|
815
|
+
if isinstance(op, BernoulliOp):
|
|
816
|
+
return (op.input0, op.output)
|
|
716
817
|
if isinstance(op, EyeLikeOp):
|
|
717
818
|
return (op.input0, op.output)
|
|
718
819
|
if isinstance(op, TriluOp):
|
|
@@ -761,7 +862,12 @@ class CEmitter:
|
|
|
761
862
|
if isinstance(op, GridSampleOp):
|
|
762
863
|
return (op.input0, op.grid, op.output)
|
|
763
864
|
if isinstance(op, TopKOp):
|
|
764
|
-
return (
|
|
865
|
+
return (
|
|
866
|
+
op.input0,
|
|
867
|
+
op.k_input,
|
|
868
|
+
op.output_values,
|
|
869
|
+
op.output_indices,
|
|
870
|
+
)
|
|
765
871
|
if isinstance(op, ReduceOp):
|
|
766
872
|
names = [op.input0]
|
|
767
873
|
if op.axes_input is not None:
|
|
@@ -777,6 +883,12 @@ class CEmitter:
|
|
|
777
883
|
names = [model.name]
|
|
778
884
|
names.extend(model.input_names)
|
|
779
885
|
names.extend(model.output_names)
|
|
886
|
+
names.extend(
|
|
887
|
+
name for name in model.input_optional_names if name is not None
|
|
888
|
+
)
|
|
889
|
+
names.extend(
|
|
890
|
+
name for name in model.output_optional_names if name is not None
|
|
891
|
+
)
|
|
780
892
|
for op in model.ops:
|
|
781
893
|
names.extend(
|
|
782
894
|
name for name in self._op_names(op) if name not in constant_names
|
|
@@ -809,12 +921,15 @@ class CEmitter:
|
|
|
809
921
|
def _map_op_names(
|
|
810
922
|
self,
|
|
811
923
|
op: BinaryOp
|
|
924
|
+
| PowOp
|
|
812
925
|
| MultiInputBinaryOp
|
|
813
926
|
| WhereOp
|
|
814
927
|
| UnaryOp
|
|
815
928
|
| ClipOp
|
|
816
929
|
| CastOp
|
|
817
930
|
| QuantizeLinearOp
|
|
931
|
+
| DequantizeLinearOp
|
|
932
|
+
| QLinearMulOp
|
|
818
933
|
| QLinearMatMulOp
|
|
819
934
|
| MatMulOp
|
|
820
935
|
| EinsumOp
|
|
@@ -822,6 +937,7 @@ class CEmitter:
|
|
|
822
937
|
| AttentionOp
|
|
823
938
|
| RotaryEmbeddingOp
|
|
824
939
|
| ConvOp
|
|
940
|
+
| ConvIntegerOp
|
|
825
941
|
| ConvTransposeOp
|
|
826
942
|
| AveragePoolOp
|
|
827
943
|
| LpPoolOp
|
|
@@ -833,6 +949,7 @@ class CEmitter:
|
|
|
833
949
|
| MeanVarianceNormalizationOp
|
|
834
950
|
| RMSNormalizationOp
|
|
835
951
|
| LrnOp
|
|
952
|
+
| GruOp
|
|
836
953
|
| LstmOp
|
|
837
954
|
| AdagradOp
|
|
838
955
|
| SoftmaxOp
|
|
@@ -865,22 +982,28 @@ class CEmitter:
|
|
|
865
982
|
| ConstantOfShapeOp
|
|
866
983
|
| ShapeOp
|
|
867
984
|
| SizeOp
|
|
985
|
+
| OptionalHasElementOp
|
|
868
986
|
| NonZeroOp
|
|
869
987
|
| NonMaxSuppressionOp
|
|
870
988
|
| ExpandOp
|
|
871
989
|
| CumSumOp
|
|
872
990
|
| RangeOp
|
|
991
|
+
| HammingWindowOp
|
|
873
992
|
| OneHotOp
|
|
993
|
+
| TfIdfVectorizerOp
|
|
874
994
|
| SplitOp,
|
|
875
995
|
name_map: dict[str, str],
|
|
876
996
|
) -> (
|
|
877
997
|
BinaryOp
|
|
998
|
+
| PowOp
|
|
878
999
|
| MultiInputBinaryOp
|
|
879
1000
|
| WhereOp
|
|
880
1001
|
| UnaryOp
|
|
881
1002
|
| ClipOp
|
|
882
1003
|
| CastOp
|
|
883
1004
|
| QuantizeLinearOp
|
|
1005
|
+
| DequantizeLinearOp
|
|
1006
|
+
| QLinearMulOp
|
|
884
1007
|
| QLinearMatMulOp
|
|
885
1008
|
| MatMulOp
|
|
886
1009
|
| EinsumOp
|
|
@@ -888,6 +1011,7 @@ class CEmitter:
|
|
|
888
1011
|
| AttentionOp
|
|
889
1012
|
| RotaryEmbeddingOp
|
|
890
1013
|
| ConvOp
|
|
1014
|
+
| ConvIntegerOp
|
|
891
1015
|
| ConvTransposeOp
|
|
892
1016
|
| AveragePoolOp
|
|
893
1017
|
| LpPoolOp
|
|
@@ -899,6 +1023,7 @@ class CEmitter:
|
|
|
899
1023
|
| MeanVarianceNormalizationOp
|
|
900
1024
|
| RMSNormalizationOp
|
|
901
1025
|
| LrnOp
|
|
1026
|
+
| GruOp
|
|
902
1027
|
| LstmOp
|
|
903
1028
|
| AdagradOp
|
|
904
1029
|
| SoftmaxOp
|
|
@@ -931,14 +1056,25 @@ class CEmitter:
|
|
|
931
1056
|
| ConstantOfShapeOp
|
|
932
1057
|
| ShapeOp
|
|
933
1058
|
| SizeOp
|
|
1059
|
+
| OptionalHasElementOp
|
|
934
1060
|
| NonZeroOp
|
|
935
1061
|
| NonMaxSuppressionOp
|
|
936
1062
|
| ExpandOp
|
|
937
1063
|
| CumSumOp
|
|
938
1064
|
| RangeOp
|
|
1065
|
+
| HammingWindowOp
|
|
939
1066
|
| OneHotOp
|
|
940
1067
|
| SplitOp
|
|
1068
|
+
| TfIdfVectorizerOp
|
|
941
1069
|
):
|
|
1070
|
+
if isinstance(op, PowOp):
|
|
1071
|
+
return PowOp(
|
|
1072
|
+
input0=name_map.get(op.input0, op.input0),
|
|
1073
|
+
input1=name_map.get(op.input1, op.input1),
|
|
1074
|
+
output=name_map.get(op.output, op.output),
|
|
1075
|
+
function=op.function,
|
|
1076
|
+
operator_kind=op.operator_kind,
|
|
1077
|
+
)
|
|
942
1078
|
if isinstance(op, BinaryOp):
|
|
943
1079
|
return BinaryOp(
|
|
944
1080
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -946,11 +1082,6 @@ class CEmitter:
|
|
|
946
1082
|
output=name_map.get(op.output, op.output),
|
|
947
1083
|
function=op.function,
|
|
948
1084
|
operator_kind=op.operator_kind,
|
|
949
|
-
input0_shape=op.input0_shape,
|
|
950
|
-
input1_shape=op.input1_shape,
|
|
951
|
-
shape=op.shape,
|
|
952
|
-
dtype=op.dtype,
|
|
953
|
-
input_dtype=op.input_dtype,
|
|
954
1085
|
)
|
|
955
1086
|
if isinstance(op, MultiInputBinaryOp):
|
|
956
1087
|
return MultiInputBinaryOp(
|
|
@@ -968,20 +1099,12 @@ class CEmitter:
|
|
|
968
1099
|
input_x=name_map.get(op.input_x, op.input_x),
|
|
969
1100
|
input_y=name_map.get(op.input_y, op.input_y),
|
|
970
1101
|
output=name_map.get(op.output, op.output),
|
|
971
|
-
condition_shape=op.condition_shape,
|
|
972
|
-
x_shape=op.x_shape,
|
|
973
|
-
y_shape=op.y_shape,
|
|
974
|
-
output_shape=op.output_shape,
|
|
975
|
-
dtype=op.dtype,
|
|
976
1102
|
)
|
|
977
1103
|
if isinstance(op, UnaryOp):
|
|
978
1104
|
return UnaryOp(
|
|
979
1105
|
input0=name_map.get(op.input0, op.input0),
|
|
980
1106
|
output=name_map.get(op.output, op.output),
|
|
981
1107
|
function=op.function,
|
|
982
|
-
shape=op.shape,
|
|
983
|
-
dtype=op.dtype,
|
|
984
|
-
input_dtype=op.input_dtype,
|
|
985
1108
|
params=op.params,
|
|
986
1109
|
)
|
|
987
1110
|
if isinstance(op, ClipOp):
|
|
@@ -990,11 +1113,8 @@ class CEmitter:
|
|
|
990
1113
|
input_min=self._map_optional_name(name_map, op.input_min),
|
|
991
1114
|
input_max=self._map_optional_name(name_map, op.input_max),
|
|
992
1115
|
output=name_map.get(op.output, op.output),
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
max_shape=op.max_shape,
|
|
996
|
-
output_shape=op.output_shape,
|
|
997
|
-
dtype=op.dtype,
|
|
1116
|
+
min_value=op.min_value,
|
|
1117
|
+
max_value=op.max_value,
|
|
998
1118
|
)
|
|
999
1119
|
if isinstance(op, CastOp):
|
|
1000
1120
|
return CastOp(
|
|
@@ -1016,8 +1136,21 @@ class CEmitter:
|
|
|
1016
1136
|
input_dtype=op.input_dtype,
|
|
1017
1137
|
scale_dtype=op.scale_dtype,
|
|
1018
1138
|
)
|
|
1019
|
-
if isinstance(op,
|
|
1020
|
-
return
|
|
1139
|
+
if isinstance(op, DequantizeLinearOp):
|
|
1140
|
+
return DequantizeLinearOp(
|
|
1141
|
+
input0=name_map.get(op.input0, op.input0),
|
|
1142
|
+
scale=name_map.get(op.scale, op.scale),
|
|
1143
|
+
zero_point=self._map_optional_name(name_map, op.zero_point),
|
|
1144
|
+
output=name_map.get(op.output, op.output),
|
|
1145
|
+
input_shape=op.input_shape,
|
|
1146
|
+
axis=op.axis,
|
|
1147
|
+
block_size=op.block_size,
|
|
1148
|
+
dtype=op.dtype,
|
|
1149
|
+
input_dtype=op.input_dtype,
|
|
1150
|
+
scale_dtype=op.scale_dtype,
|
|
1151
|
+
)
|
|
1152
|
+
if isinstance(op, QLinearMulOp):
|
|
1153
|
+
return QLinearMulOp(
|
|
1021
1154
|
input0=name_map.get(op.input0, op.input0),
|
|
1022
1155
|
input0_scale=name_map.get(op.input0_scale, op.input0_scale),
|
|
1023
1156
|
input0_zero_point=name_map.get(
|
|
@@ -1036,14 +1169,6 @@ class CEmitter:
|
|
|
1036
1169
|
input0_shape=op.input0_shape,
|
|
1037
1170
|
input1_shape=op.input1_shape,
|
|
1038
1171
|
output_shape=op.output_shape,
|
|
1039
|
-
batch_shape=op.batch_shape,
|
|
1040
|
-
input0_batch_shape=op.input0_batch_shape,
|
|
1041
|
-
input1_batch_shape=op.input1_batch_shape,
|
|
1042
|
-
m=op.m,
|
|
1043
|
-
n=op.n,
|
|
1044
|
-
k=op.k,
|
|
1045
|
-
left_vector=op.left_vector,
|
|
1046
|
-
right_vector=op.right_vector,
|
|
1047
1172
|
input0_dtype=op.input0_dtype,
|
|
1048
1173
|
input1_dtype=op.input1_dtype,
|
|
1049
1174
|
dtype=op.dtype,
|
|
@@ -1057,10 +1182,22 @@ class CEmitter:
|
|
|
1057
1182
|
input1_zero_shape=op.input1_zero_shape,
|
|
1058
1183
|
output_zero_shape=op.output_zero_shape,
|
|
1059
1184
|
)
|
|
1060
|
-
if isinstance(op,
|
|
1061
|
-
return
|
|
1185
|
+
if isinstance(op, QLinearMatMulOp):
|
|
1186
|
+
return QLinearMatMulOp(
|
|
1062
1187
|
input0=name_map.get(op.input0, op.input0),
|
|
1188
|
+
input0_scale=name_map.get(op.input0_scale, op.input0_scale),
|
|
1189
|
+
input0_zero_point=name_map.get(
|
|
1190
|
+
op.input0_zero_point, op.input0_zero_point
|
|
1191
|
+
),
|
|
1063
1192
|
input1=name_map.get(op.input1, op.input1),
|
|
1193
|
+
input1_scale=name_map.get(op.input1_scale, op.input1_scale),
|
|
1194
|
+
input1_zero_point=name_map.get(
|
|
1195
|
+
op.input1_zero_point, op.input1_zero_point
|
|
1196
|
+
),
|
|
1197
|
+
output_scale=name_map.get(op.output_scale, op.output_scale),
|
|
1198
|
+
output_zero_point=name_map.get(
|
|
1199
|
+
op.output_zero_point, op.output_zero_point
|
|
1200
|
+
),
|
|
1064
1201
|
output=name_map.get(op.output, op.output),
|
|
1065
1202
|
input0_shape=op.input0_shape,
|
|
1066
1203
|
input1_shape=op.input1_shape,
|
|
@@ -1073,7 +1210,24 @@ class CEmitter:
|
|
|
1073
1210
|
k=op.k,
|
|
1074
1211
|
left_vector=op.left_vector,
|
|
1075
1212
|
right_vector=op.right_vector,
|
|
1213
|
+
input0_dtype=op.input0_dtype,
|
|
1214
|
+
input1_dtype=op.input1_dtype,
|
|
1076
1215
|
dtype=op.dtype,
|
|
1216
|
+
input0_scale_dtype=op.input0_scale_dtype,
|
|
1217
|
+
input1_scale_dtype=op.input1_scale_dtype,
|
|
1218
|
+
output_scale_dtype=op.output_scale_dtype,
|
|
1219
|
+
input0_scale_shape=op.input0_scale_shape,
|
|
1220
|
+
input1_scale_shape=op.input1_scale_shape,
|
|
1221
|
+
output_scale_shape=op.output_scale_shape,
|
|
1222
|
+
input0_zero_shape=op.input0_zero_shape,
|
|
1223
|
+
input1_zero_shape=op.input1_zero_shape,
|
|
1224
|
+
output_zero_shape=op.output_zero_shape,
|
|
1225
|
+
)
|
|
1226
|
+
if isinstance(op, MatMulOp):
|
|
1227
|
+
return MatMulOp(
|
|
1228
|
+
input0=name_map.get(op.input0, op.input0),
|
|
1229
|
+
input1=name_map.get(op.input1, op.input1),
|
|
1230
|
+
output=name_map.get(op.output, op.output),
|
|
1077
1231
|
)
|
|
1078
1232
|
if isinstance(op, EinsumOp):
|
|
1079
1233
|
return EinsumOp(
|
|
@@ -1091,15 +1245,10 @@ class CEmitter:
|
|
|
1091
1245
|
input_b=name_map.get(op.input_b, op.input_b),
|
|
1092
1246
|
input_c=self._map_optional_name(name_map, op.input_c),
|
|
1093
1247
|
output=name_map.get(op.output, op.output),
|
|
1094
|
-
m=op.m,
|
|
1095
|
-
n=op.n,
|
|
1096
|
-
k=op.k,
|
|
1097
1248
|
trans_a=op.trans_a,
|
|
1098
1249
|
trans_b=op.trans_b,
|
|
1099
1250
|
alpha=op.alpha,
|
|
1100
1251
|
beta=op.beta,
|
|
1101
|
-
c_shape=op.c_shape,
|
|
1102
|
-
dtype=op.dtype,
|
|
1103
1252
|
)
|
|
1104
1253
|
if isinstance(op, AttentionOp):
|
|
1105
1254
|
return AttentionOp(
|
|
@@ -1202,6 +1351,35 @@ class CEmitter:
|
|
|
1202
1351
|
group=op.group,
|
|
1203
1352
|
dtype=op.dtype,
|
|
1204
1353
|
)
|
|
1354
|
+
if isinstance(op, ConvIntegerOp):
|
|
1355
|
+
return ConvIntegerOp(
|
|
1356
|
+
input0=name_map.get(op.input0, op.input0),
|
|
1357
|
+
weights=name_map.get(op.weights, op.weights),
|
|
1358
|
+
x_zero_point=self._map_optional_name(
|
|
1359
|
+
name_map, op.x_zero_point
|
|
1360
|
+
),
|
|
1361
|
+
w_zero_point=self._map_optional_name(
|
|
1362
|
+
name_map, op.w_zero_point
|
|
1363
|
+
),
|
|
1364
|
+
output=name_map.get(op.output, op.output),
|
|
1365
|
+
batch=op.batch,
|
|
1366
|
+
in_channels=op.in_channels,
|
|
1367
|
+
out_channels=op.out_channels,
|
|
1368
|
+
spatial_rank=op.spatial_rank,
|
|
1369
|
+
in_spatial=op.in_spatial,
|
|
1370
|
+
out_spatial=op.out_spatial,
|
|
1371
|
+
kernel_shape=op.kernel_shape,
|
|
1372
|
+
strides=op.strides,
|
|
1373
|
+
pads=op.pads,
|
|
1374
|
+
dilations=op.dilations,
|
|
1375
|
+
group=op.group,
|
|
1376
|
+
input_dtype=op.input_dtype,
|
|
1377
|
+
weight_dtype=op.weight_dtype,
|
|
1378
|
+
dtype=op.dtype,
|
|
1379
|
+
x_zero_point_shape=op.x_zero_point_shape,
|
|
1380
|
+
w_zero_point_shape=op.w_zero_point_shape,
|
|
1381
|
+
w_zero_point_per_channel=op.w_zero_point_per_channel,
|
|
1382
|
+
)
|
|
1205
1383
|
if isinstance(op, ConvTransposeOp):
|
|
1206
1384
|
return ConvTransposeOp(
|
|
1207
1385
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1228,16 +1406,26 @@ class CEmitter:
|
|
|
1228
1406
|
output=name_map.get(op.output, op.output),
|
|
1229
1407
|
batch=op.batch,
|
|
1230
1408
|
channels=op.channels,
|
|
1409
|
+
spatial_rank=op.spatial_rank,
|
|
1410
|
+
in_d=op.in_d,
|
|
1231
1411
|
in_h=op.in_h,
|
|
1232
1412
|
in_w=op.in_w,
|
|
1413
|
+
out_d=op.out_d,
|
|
1233
1414
|
out_h=op.out_h,
|
|
1234
1415
|
out_w=op.out_w,
|
|
1416
|
+
kernel_d=op.kernel_d,
|
|
1235
1417
|
kernel_h=op.kernel_h,
|
|
1236
1418
|
kernel_w=op.kernel_w,
|
|
1419
|
+
dilation_d=op.dilation_d,
|
|
1420
|
+
dilation_h=op.dilation_h,
|
|
1421
|
+
dilation_w=op.dilation_w,
|
|
1422
|
+
stride_d=op.stride_d,
|
|
1237
1423
|
stride_h=op.stride_h,
|
|
1238
1424
|
stride_w=op.stride_w,
|
|
1425
|
+
pad_front=op.pad_front,
|
|
1239
1426
|
pad_top=op.pad_top,
|
|
1240
1427
|
pad_left=op.pad_left,
|
|
1428
|
+
pad_back=op.pad_back,
|
|
1241
1429
|
pad_bottom=op.pad_bottom,
|
|
1242
1430
|
pad_right=op.pad_right,
|
|
1243
1431
|
count_include_pad=op.count_include_pad,
|
|
@@ -1255,6 +1443,8 @@ class CEmitter:
|
|
|
1255
1443
|
out_w=op.out_w,
|
|
1256
1444
|
kernel_h=op.kernel_h,
|
|
1257
1445
|
kernel_w=op.kernel_w,
|
|
1446
|
+
dilation_h=op.dilation_h,
|
|
1447
|
+
dilation_w=op.dilation_w,
|
|
1258
1448
|
stride_h=op.stride_h,
|
|
1259
1449
|
stride_w=op.stride_w,
|
|
1260
1450
|
pad_top=op.pad_top,
|
|
@@ -1371,6 +1561,35 @@ class CEmitter:
|
|
|
1371
1561
|
bias=op.bias,
|
|
1372
1562
|
dtype=op.dtype,
|
|
1373
1563
|
)
|
|
1564
|
+
if isinstance(op, GruOp):
|
|
1565
|
+
return GruOp(
|
|
1566
|
+
input_x=name_map.get(op.input_x, op.input_x),
|
|
1567
|
+
input_w=name_map.get(op.input_w, op.input_w),
|
|
1568
|
+
input_r=name_map.get(op.input_r, op.input_r),
|
|
1569
|
+
input_b=self._map_optional_name(name_map, op.input_b),
|
|
1570
|
+
input_sequence_lens=self._map_optional_name(
|
|
1571
|
+
name_map, op.input_sequence_lens
|
|
1572
|
+
),
|
|
1573
|
+
input_initial_h=self._map_optional_name(
|
|
1574
|
+
name_map, op.input_initial_h
|
|
1575
|
+
),
|
|
1576
|
+
output_y=self._map_optional_name(name_map, op.output_y),
|
|
1577
|
+
output_y_h=self._map_optional_name(name_map, op.output_y_h),
|
|
1578
|
+
seq_length=op.seq_length,
|
|
1579
|
+
batch_size=op.batch_size,
|
|
1580
|
+
input_size=op.input_size,
|
|
1581
|
+
hidden_size=op.hidden_size,
|
|
1582
|
+
num_directions=op.num_directions,
|
|
1583
|
+
direction=op.direction,
|
|
1584
|
+
layout=op.layout,
|
|
1585
|
+
linear_before_reset=op.linear_before_reset,
|
|
1586
|
+
clip=op.clip,
|
|
1587
|
+
activation_kinds=op.activation_kinds,
|
|
1588
|
+
activation_alphas=op.activation_alphas,
|
|
1589
|
+
activation_betas=op.activation_betas,
|
|
1590
|
+
dtype=op.dtype,
|
|
1591
|
+
sequence_lens_dtype=op.sequence_lens_dtype,
|
|
1592
|
+
)
|
|
1374
1593
|
if isinstance(op, LstmOp):
|
|
1375
1594
|
return LstmOp(
|
|
1376
1595
|
input_x=name_map.get(op.input_x, op.input_x),
|
|
@@ -1436,34 +1655,19 @@ class CEmitter:
|
|
|
1436
1655
|
return SoftmaxOp(
|
|
1437
1656
|
input0=name_map.get(op.input0, op.input0),
|
|
1438
1657
|
output=name_map.get(op.output, op.output),
|
|
1439
|
-
outer=op.outer,
|
|
1440
|
-
axis_size=op.axis_size,
|
|
1441
|
-
inner=op.inner,
|
|
1442
1658
|
axis=op.axis,
|
|
1443
|
-
shape=op.shape,
|
|
1444
|
-
dtype=op.dtype,
|
|
1445
1659
|
)
|
|
1446
1660
|
if isinstance(op, LogSoftmaxOp):
|
|
1447
1661
|
return LogSoftmaxOp(
|
|
1448
1662
|
input0=name_map.get(op.input0, op.input0),
|
|
1449
1663
|
output=name_map.get(op.output, op.output),
|
|
1450
|
-
outer=op.outer,
|
|
1451
|
-
axis_size=op.axis_size,
|
|
1452
|
-
inner=op.inner,
|
|
1453
1664
|
axis=op.axis,
|
|
1454
|
-
shape=op.shape,
|
|
1455
|
-
dtype=op.dtype,
|
|
1456
1665
|
)
|
|
1457
1666
|
if isinstance(op, HardmaxOp):
|
|
1458
1667
|
return HardmaxOp(
|
|
1459
1668
|
input0=name_map.get(op.input0, op.input0),
|
|
1460
1669
|
output=name_map.get(op.output, op.output),
|
|
1461
|
-
outer=op.outer,
|
|
1462
|
-
axis_size=op.axis_size,
|
|
1463
|
-
inner=op.inner,
|
|
1464
1670
|
axis=op.axis,
|
|
1465
|
-
shape=op.shape,
|
|
1466
|
-
dtype=op.dtype,
|
|
1467
1671
|
)
|
|
1468
1672
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
1469
1673
|
return NegativeLogLikelihoodLossOp(
|
|
@@ -1624,9 +1828,6 @@ class CEmitter:
|
|
|
1624
1828
|
return IdentityOp(
|
|
1625
1829
|
input0=name_map.get(op.input0, op.input0),
|
|
1626
1830
|
output=name_map.get(op.output, op.output),
|
|
1627
|
-
shape=op.shape,
|
|
1628
|
-
dtype=op.dtype,
|
|
1629
|
-
input_dtype=op.input_dtype,
|
|
1630
1831
|
)
|
|
1631
1832
|
if isinstance(op, EyeLikeOp):
|
|
1632
1833
|
return EyeLikeOp(
|
|
@@ -1782,45 +1983,32 @@ class CEmitter:
|
|
|
1782
1983
|
return ReduceOp(
|
|
1783
1984
|
input0=name_map.get(op.input0, op.input0),
|
|
1784
1985
|
output=name_map.get(op.output, op.output),
|
|
1785
|
-
input_shape=op.input_shape,
|
|
1786
|
-
output_shape=op.output_shape,
|
|
1787
1986
|
axes=op.axes,
|
|
1788
1987
|
axes_input=self._map_optional_name(name_map, op.axes_input),
|
|
1789
|
-
axes_input_shape=op.axes_input_shape,
|
|
1790
|
-
axes_input_dtype=op.axes_input_dtype,
|
|
1791
1988
|
keepdims=op.keepdims,
|
|
1792
1989
|
noop_with_empty_axes=op.noop_with_empty_axes,
|
|
1793
1990
|
reduce_kind=op.reduce_kind,
|
|
1794
1991
|
reduce_count=op.reduce_count,
|
|
1795
|
-
dtype=op.dtype,
|
|
1796
1992
|
)
|
|
1797
1993
|
if isinstance(op, ArgReduceOp):
|
|
1798
1994
|
return ArgReduceOp(
|
|
1799
1995
|
input0=name_map.get(op.input0, op.input0),
|
|
1800
1996
|
output=name_map.get(op.output, op.output),
|
|
1801
|
-
input_shape=op.input_shape,
|
|
1802
|
-
output_shape=op.output_shape,
|
|
1803
1997
|
axis=op.axis,
|
|
1804
1998
|
keepdims=op.keepdims,
|
|
1805
1999
|
select_last_index=op.select_last_index,
|
|
1806
2000
|
reduce_kind=op.reduce_kind,
|
|
1807
|
-
input_dtype=op.input_dtype,
|
|
1808
|
-
output_dtype=op.output_dtype,
|
|
1809
2001
|
)
|
|
1810
2002
|
if isinstance(op, TopKOp):
|
|
1811
2003
|
return TopKOp(
|
|
1812
2004
|
input0=name_map.get(op.input0, op.input0),
|
|
2005
|
+
k_input=name_map.get(op.k_input, op.k_input),
|
|
1813
2006
|
output_values=name_map.get(op.output_values, op.output_values),
|
|
1814
2007
|
output_indices=name_map.get(op.output_indices, op.output_indices),
|
|
1815
|
-
input_shape=op.input_shape,
|
|
1816
|
-
output_shape=op.output_shape,
|
|
1817
2008
|
axis=op.axis,
|
|
1818
2009
|
k=op.k,
|
|
1819
2010
|
largest=op.largest,
|
|
1820
2011
|
sorted=op.sorted,
|
|
1821
|
-
input_dtype=op.input_dtype,
|
|
1822
|
-
output_values_dtype=op.output_values_dtype,
|
|
1823
|
-
output_indices_dtype=op.output_indices_dtype,
|
|
1824
2012
|
)
|
|
1825
2013
|
if isinstance(op, ConstantOfShapeOp):
|
|
1826
2014
|
return ConstantOfShapeOp(
|
|
@@ -1852,6 +2040,11 @@ class CEmitter:
|
|
|
1852
2040
|
dtype=op.dtype,
|
|
1853
2041
|
input_dtype=op.input_dtype,
|
|
1854
2042
|
)
|
|
2043
|
+
if isinstance(op, OptionalHasElementOp):
|
|
2044
|
+
return OptionalHasElementOp(
|
|
2045
|
+
input0=name_map.get(op.input0, op.input0),
|
|
2046
|
+
output=name_map.get(op.output, op.output),
|
|
2047
|
+
)
|
|
1855
2048
|
if isinstance(op, NonZeroOp):
|
|
1856
2049
|
return NonZeroOp(
|
|
1857
2050
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1918,6 +2111,25 @@ class CEmitter:
|
|
|
1918
2111
|
dtype=op.dtype,
|
|
1919
2112
|
input_dtype=op.input_dtype,
|
|
1920
2113
|
)
|
|
2114
|
+
if isinstance(op, HammingWindowOp):
|
|
2115
|
+
return HammingWindowOp(
|
|
2116
|
+
size=name_map.get(op.size, op.size),
|
|
2117
|
+
output=name_map.get(op.output, op.output),
|
|
2118
|
+
output_shape=op.output_shape,
|
|
2119
|
+
periodic=op.periodic,
|
|
2120
|
+
dtype=op.dtype,
|
|
2121
|
+
input_dtype=op.input_dtype,
|
|
2122
|
+
)
|
|
2123
|
+
if isinstance(op, BernoulliOp):
|
|
2124
|
+
return BernoulliOp(
|
|
2125
|
+
input0=name_map.get(op.input0, op.input0),
|
|
2126
|
+
output=name_map.get(op.output, op.output),
|
|
2127
|
+
input_shape=op.input_shape,
|
|
2128
|
+
output_shape=op.output_shape,
|
|
2129
|
+
input_dtype=op.input_dtype,
|
|
2130
|
+
dtype=op.dtype,
|
|
2131
|
+
seed=op.seed,
|
|
2132
|
+
)
|
|
1921
2133
|
if isinstance(op, OneHotOp):
|
|
1922
2134
|
return OneHotOp(
|
|
1923
2135
|
indices=name_map.get(op.indices, op.indices),
|
|
@@ -1933,6 +2145,23 @@ class CEmitter:
|
|
|
1933
2145
|
indices_dtype=op.indices_dtype,
|
|
1934
2146
|
depth_dtype=op.depth_dtype,
|
|
1935
2147
|
)
|
|
2148
|
+
if isinstance(op, TfIdfVectorizerOp):
|
|
2149
|
+
return TfIdfVectorizerOp(
|
|
2150
|
+
input0=name_map.get(op.input0, op.input0),
|
|
2151
|
+
output=name_map.get(op.output, op.output),
|
|
2152
|
+
input_shape=op.input_shape,
|
|
2153
|
+
output_shape=op.output_shape,
|
|
2154
|
+
input_dtype=op.input_dtype,
|
|
2155
|
+
output_dtype=op.output_dtype,
|
|
2156
|
+
min_gram_length=op.min_gram_length,
|
|
2157
|
+
max_gram_length=op.max_gram_length,
|
|
2158
|
+
max_skip_count=op.max_skip_count,
|
|
2159
|
+
mode=op.mode,
|
|
2160
|
+
ngram_counts=op.ngram_counts,
|
|
2161
|
+
ngram_indexes=op.ngram_indexes,
|
|
2162
|
+
pool_int64s=op.pool_int64s,
|
|
2163
|
+
weights=op.weights,
|
|
2164
|
+
)
|
|
1936
2165
|
if isinstance(op, SplitOp):
|
|
1937
2166
|
return SplitOp(
|
|
1938
2167
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1973,11 +2202,19 @@ class CEmitter:
|
|
|
1973
2202
|
input_names=tuple(
|
|
1974
2203
|
name_map.get(name, name) for name in model.input_names
|
|
1975
2204
|
),
|
|
2205
|
+
input_optional_names=tuple(
|
|
2206
|
+
name_map.get(name, name) if name is not None else None
|
|
2207
|
+
for name in model.input_optional_names
|
|
2208
|
+
),
|
|
1976
2209
|
input_shapes=model.input_shapes,
|
|
1977
2210
|
input_dtypes=model.input_dtypes,
|
|
1978
2211
|
output_names=tuple(
|
|
1979
2212
|
name_map.get(name, name) for name in model.output_names
|
|
1980
2213
|
),
|
|
2214
|
+
output_optional_names=tuple(
|
|
2215
|
+
name_map.get(name, name) if name is not None else None
|
|
2216
|
+
for name in model.output_optional_names
|
|
2217
|
+
),
|
|
1981
2218
|
output_shapes=model.output_shapes,
|
|
1982
2219
|
output_dtypes=model.output_dtypes,
|
|
1983
2220
|
constants=constants,
|
|
@@ -2024,6 +2261,18 @@ class CEmitter:
|
|
|
2024
2261
|
for name, values in testbench_inputs.items()
|
|
2025
2262
|
}
|
|
2026
2263
|
|
|
2264
|
+
@staticmethod
|
|
2265
|
+
def _sanitize_testbench_optional_inputs(
|
|
2266
|
+
testbench_optional_inputs: Mapping[str, bool] | None,
|
|
2267
|
+
name_map: Mapping[str, str],
|
|
2268
|
+
) -> Mapping[str, bool] | None:
|
|
2269
|
+
if not testbench_optional_inputs:
|
|
2270
|
+
return None
|
|
2271
|
+
return {
|
|
2272
|
+
name_map.get(name, name): value
|
|
2273
|
+
for name, value in testbench_optional_inputs.items()
|
|
2274
|
+
}
|
|
2275
|
+
|
|
2027
2276
|
def _load_templates(self, emit_testbench: bool) -> dict[str, Template]:
|
|
2028
2277
|
try:
|
|
2029
2278
|
templates = {
|
|
@@ -2038,6 +2287,10 @@ class CEmitter:
|
|
|
2038
2287
|
"quantize_linear": self._env.get_template(
|
|
2039
2288
|
"quantize_linear_op.c.j2"
|
|
2040
2289
|
),
|
|
2290
|
+
"dequantize_linear": self._env.get_template(
|
|
2291
|
+
"dequantize_linear_op.c.j2"
|
|
2292
|
+
),
|
|
2293
|
+
"qlinear_mul": self._env.get_template("qlinear_mul_op.c.j2"),
|
|
2041
2294
|
"qlinear_matmul": self._env.get_template(
|
|
2042
2295
|
"qlinear_matmul_op.c.j2"
|
|
2043
2296
|
),
|
|
@@ -2049,6 +2302,7 @@ class CEmitter:
|
|
|
2049
2302
|
"rotary_embedding_op.c.j2"
|
|
2050
2303
|
),
|
|
2051
2304
|
"conv": self._env.get_template("conv_op.c.j2"),
|
|
2305
|
+
"conv_integer": self._env.get_template("conv_integer_op.c.j2"),
|
|
2052
2306
|
"conv_transpose": self._env.get_template(
|
|
2053
2307
|
"conv_transpose_op.c.j2"
|
|
2054
2308
|
),
|
|
@@ -2070,6 +2324,7 @@ class CEmitter:
|
|
|
2070
2324
|
),
|
|
2071
2325
|
"rms_norm": self._env.get_template("rms_normalization_op.c.j2"),
|
|
2072
2326
|
"lrn": self._env.get_template("lrn_op.c.j2"),
|
|
2327
|
+
"gru": self._env.get_template("gru_op.c.j2"),
|
|
2073
2328
|
"lstm": self._env.get_template("lstm_op.c.j2"),
|
|
2074
2329
|
"adagrad": self._env.get_template("adagrad_op.c.j2"),
|
|
2075
2330
|
"softmax": self._env.get_template("softmax_op.c.j2"),
|
|
@@ -2093,6 +2348,7 @@ class CEmitter:
|
|
|
2093
2348
|
"transpose": self._env.get_template("transpose_op.c.j2"),
|
|
2094
2349
|
"reshape": self._env.get_template("reshape_op.c.j2"),
|
|
2095
2350
|
"identity": self._env.get_template("identity_op.c.j2"),
|
|
2351
|
+
"bernoulli": self._env.get_template("bernoulli_op.c.j2"),
|
|
2096
2352
|
"eye_like": self._env.get_template("eye_like_op.c.j2"),
|
|
2097
2353
|
"trilu": self._env.get_template("trilu_op.c.j2"),
|
|
2098
2354
|
"tile": self._env.get_template("tile_op.c.j2"),
|
|
@@ -2116,6 +2372,9 @@ class CEmitter:
|
|
|
2116
2372
|
),
|
|
2117
2373
|
"shape": self._env.get_template("shape_op.c.j2"),
|
|
2118
2374
|
"size": self._env.get_template("size_op.c.j2"),
|
|
2375
|
+
"optional_has_element": self._env.get_template(
|
|
2376
|
+
"optional_has_element_op.c.j2"
|
|
2377
|
+
),
|
|
2119
2378
|
"nonzero": self._env.get_template("nonzero_op.c.j2"),
|
|
2120
2379
|
"nonmax_suppression": self._env.get_template(
|
|
2121
2380
|
"nonmax_suppression_op.c.j2"
|
|
@@ -2123,7 +2382,13 @@ class CEmitter:
|
|
|
2123
2382
|
"expand": self._env.get_template("expand_op.c.j2"),
|
|
2124
2383
|
"cumsum": self._env.get_template("cumsum_op.c.j2"),
|
|
2125
2384
|
"range": self._env.get_template("range_op.c.j2"),
|
|
2385
|
+
"hamming_window": self._env.get_template(
|
|
2386
|
+
"hamming_window_op.c.j2"
|
|
2387
|
+
),
|
|
2126
2388
|
"one_hot": self._env.get_template("one_hot_op.c.j2"),
|
|
2389
|
+
"tfidf_vectorizer": self._env.get_template(
|
|
2390
|
+
"tfidf_vectorizer_op.c.j2"
|
|
2391
|
+
),
|
|
2127
2392
|
"split": self._env.get_template("split_op.c.j2"),
|
|
2128
2393
|
}
|
|
2129
2394
|
if emit_testbench:
|
|
@@ -2138,6 +2403,7 @@ class CEmitter:
|
|
|
2138
2403
|
*,
|
|
2139
2404
|
emit_testbench: bool = False,
|
|
2140
2405
|
testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None = None,
|
|
2406
|
+
testbench_optional_inputs: Mapping[str, bool] | None = None,
|
|
2141
2407
|
variable_dim_inputs: Mapping[int, Mapping[int, str]] | None = None,
|
|
2142
2408
|
variable_dim_outputs: Mapping[int, Mapping[int, str]] | None = None,
|
|
2143
2409
|
) -> str:
|
|
@@ -2147,6 +2413,9 @@ class CEmitter:
|
|
|
2147
2413
|
testbench_inputs = self._sanitize_testbench_inputs(
|
|
2148
2414
|
testbench_inputs, name_map
|
|
2149
2415
|
)
|
|
2416
|
+
testbench_optional_inputs = self._sanitize_testbench_optional_inputs(
|
|
2417
|
+
testbench_optional_inputs, name_map
|
|
2418
|
+
)
|
|
2150
2419
|
inline_constants, large_constants = self._partition_constants(
|
|
2151
2420
|
model.constants
|
|
2152
2421
|
)
|
|
@@ -2184,6 +2453,8 @@ class CEmitter:
|
|
|
2184
2453
|
model.name,
|
|
2185
2454
|
*model.input_names,
|
|
2186
2455
|
*model.output_names,
|
|
2456
|
+
*(name for name in model.input_optional_names if name is not None),
|
|
2457
|
+
*(name for name in model.output_optional_names if name is not None),
|
|
2187
2458
|
*(const.name for const in model.constants),
|
|
2188
2459
|
}
|
|
2189
2460
|
temp_buffers = self._temp_buffers(model, reserved_names=reserved_names)
|
|
@@ -2235,16 +2506,27 @@ class CEmitter:
|
|
|
2235
2506
|
*includes,
|
|
2236
2507
|
"",
|
|
2237
2508
|
self._emit_index_type_define(),
|
|
2509
|
+
self._emit_unused_define(),
|
|
2238
2510
|
]
|
|
2239
2511
|
if scalar_preamble:
|
|
2240
2512
|
sections.extend(("", *scalar_preamble))
|
|
2241
2513
|
sections.append("")
|
|
2242
|
-
constants_section = self.
|
|
2514
|
+
constants_section = self._emit_constant_declarations(inline_constants)
|
|
2243
2515
|
if constants_section:
|
|
2244
2516
|
sections.extend((constants_section.rstrip(), ""))
|
|
2245
|
-
|
|
2517
|
+
storage_declarations = self._emit_constant_storage_declarations(
|
|
2246
2518
|
large_constants
|
|
2247
2519
|
)
|
|
2520
|
+
if storage_declarations:
|
|
2521
|
+
sections.extend((storage_declarations.rstrip(), ""))
|
|
2522
|
+
constants_section = self._emit_constant_definitions(
|
|
2523
|
+
inline_constants, storage_prefix="const"
|
|
2524
|
+
)
|
|
2525
|
+
if constants_section:
|
|
2526
|
+
sections.extend((constants_section.rstrip(), ""))
|
|
2527
|
+
large_constants_section = self._emit_constant_storage_definitions(
|
|
2528
|
+
large_constants, storage_prefix=""
|
|
2529
|
+
)
|
|
2248
2530
|
if large_constants_section:
|
|
2249
2531
|
sections.extend((large_constants_section.rstrip(), ""))
|
|
2250
2532
|
if scalar_functions:
|
|
@@ -2267,6 +2549,7 @@ class CEmitter:
|
|
|
2267
2549
|
model,
|
|
2268
2550
|
testbench_template,
|
|
2269
2551
|
testbench_inputs=testbench_inputs,
|
|
2552
|
+
testbench_optional_inputs=testbench_optional_inputs,
|
|
2270
2553
|
dim_order=dim_order,
|
|
2271
2554
|
dim_values=dim_values,
|
|
2272
2555
|
weight_data_filename=self._weight_data_filename(model),
|
|
@@ -2285,6 +2568,7 @@ class CEmitter:
|
|
|
2285
2568
|
*,
|
|
2286
2569
|
emit_testbench: bool = False,
|
|
2287
2570
|
testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None = None,
|
|
2571
|
+
testbench_optional_inputs: Mapping[str, bool] | None = None,
|
|
2288
2572
|
variable_dim_inputs: Mapping[int, Mapping[int, str]] | None = None,
|
|
2289
2573
|
variable_dim_outputs: Mapping[int, Mapping[int, str]] | None = None,
|
|
2290
2574
|
) -> tuple[str, str]:
|
|
@@ -2294,6 +2578,9 @@ class CEmitter:
|
|
|
2294
2578
|
testbench_inputs = self._sanitize_testbench_inputs(
|
|
2295
2579
|
testbench_inputs, name_map
|
|
2296
2580
|
)
|
|
2581
|
+
testbench_optional_inputs = self._sanitize_testbench_optional_inputs(
|
|
2582
|
+
testbench_optional_inputs, name_map
|
|
2583
|
+
)
|
|
2297
2584
|
inline_constants, large_constants = self._partition_constants(
|
|
2298
2585
|
model.constants
|
|
2299
2586
|
)
|
|
@@ -2331,6 +2618,8 @@ class CEmitter:
|
|
|
2331
2618
|
model.name,
|
|
2332
2619
|
*model.input_names,
|
|
2333
2620
|
*model.output_names,
|
|
2621
|
+
*(name for name in model.input_optional_names if name is not None),
|
|
2622
|
+
*(name for name in model.output_optional_names if name is not None),
|
|
2334
2623
|
*(const.name for const in model.constants),
|
|
2335
2624
|
}
|
|
2336
2625
|
temp_buffers = self._temp_buffers(model, reserved_names=reserved_names)
|
|
@@ -2389,9 +2678,14 @@ class CEmitter:
|
|
|
2389
2678
|
constants_section = self._emit_constant_declarations(inline_constants)
|
|
2390
2679
|
if constants_section:
|
|
2391
2680
|
sections.extend((constants_section.rstrip(), ""))
|
|
2392
|
-
|
|
2681
|
+
storage_declarations = self._emit_constant_storage_declarations(
|
|
2393
2682
|
large_constants
|
|
2394
2683
|
)
|
|
2684
|
+
if storage_declarations:
|
|
2685
|
+
sections.extend((storage_declarations.rstrip(), ""))
|
|
2686
|
+
large_constants_section = self._emit_constant_storage_definitions(
|
|
2687
|
+
large_constants, storage_prefix=""
|
|
2688
|
+
)
|
|
2395
2689
|
if large_constants_section:
|
|
2396
2690
|
sections.extend((large_constants_section.rstrip(), ""))
|
|
2397
2691
|
if scalar_functions:
|
|
@@ -2414,6 +2708,7 @@ class CEmitter:
|
|
|
2414
2708
|
model,
|
|
2415
2709
|
testbench_template,
|
|
2416
2710
|
testbench_inputs=testbench_inputs,
|
|
2711
|
+
testbench_optional_inputs=testbench_optional_inputs,
|
|
2417
2712
|
dim_order=dim_order,
|
|
2418
2713
|
dim_values=dim_values,
|
|
2419
2714
|
weight_data_filename=self._weight_data_filename(model),
|
|
@@ -2656,7 +2951,7 @@ class CEmitter:
|
|
|
2656
2951
|
except ScalarFunctionError:
|
|
2657
2952
|
return None
|
|
2658
2953
|
|
|
2659
|
-
def
|
|
2954
|
+
def _rnn_activation_function_name(
|
|
2660
2955
|
self,
|
|
2661
2956
|
kind: int,
|
|
2662
2957
|
alpha: float,
|
|
@@ -2667,7 +2962,7 @@ class CEmitter:
|
|
|
2667
2962
|
spec = _LSTM_ACTIVATION_SPECS.get(kind)
|
|
2668
2963
|
if spec is None:
|
|
2669
2964
|
raise CodegenError(
|
|
2670
|
-
f"Unsupported
|
|
2965
|
+
f"Unsupported RNN activation kind for codegen: {kind}"
|
|
2671
2966
|
)
|
|
2672
2967
|
function, param_count = spec
|
|
2673
2968
|
if param_count == 0:
|
|
@@ -2681,7 +2976,7 @@ class CEmitter:
|
|
|
2681
2976
|
)
|
|
2682
2977
|
if name is None:
|
|
2683
2978
|
raise CodegenError(
|
|
2684
|
-
f"Failed to resolve scalar function for
|
|
2979
|
+
f"Failed to resolve scalar function for RNN activation kind {kind}"
|
|
2685
2980
|
)
|
|
2686
2981
|
return name
|
|
2687
2982
|
|
|
@@ -2695,12 +2990,15 @@ class CEmitter:
|
|
|
2695
2990
|
| ClipOp
|
|
2696
2991
|
| CastOp
|
|
2697
2992
|
| QuantizeLinearOp
|
|
2993
|
+
| DequantizeLinearOp
|
|
2994
|
+
| QLinearMulOp
|
|
2698
2995
|
| QLinearMatMulOp
|
|
2699
2996
|
| MatMulOp
|
|
2700
2997
|
| EinsumOp
|
|
2701
2998
|
| GemmOp
|
|
2702
2999
|
| AttentionOp
|
|
2703
3000
|
| ConvOp
|
|
3001
|
+
| ConvIntegerOp
|
|
2704
3002
|
| ConvTransposeOp
|
|
2705
3003
|
| AveragePoolOp
|
|
2706
3004
|
| LpPoolOp
|
|
@@ -2712,6 +3010,7 @@ class CEmitter:
|
|
|
2712
3010
|
| MeanVarianceNormalizationOp
|
|
2713
3011
|
| RMSNormalizationOp
|
|
2714
3012
|
| LrnOp
|
|
3013
|
+
| GruOp
|
|
2715
3014
|
| LstmOp
|
|
2716
3015
|
| AdagradOp
|
|
2717
3016
|
| SoftmaxOp
|
|
@@ -2743,11 +3042,13 @@ class CEmitter:
|
|
|
2743
3042
|
| ConstantOfShapeOp
|
|
2744
3043
|
| ShapeOp
|
|
2745
3044
|
| SizeOp
|
|
3045
|
+
| OptionalHasElementOp
|
|
2746
3046
|
| NonZeroOp
|
|
2747
3047
|
| NonMaxSuppressionOp
|
|
2748
3048
|
| ExpandOp
|
|
2749
3049
|
| CumSumOp
|
|
2750
3050
|
| RangeOp
|
|
3051
|
+
| HammingWindowOp
|
|
2751
3052
|
| OneHotOp
|
|
2752
3053
|
| SplitOp
|
|
2753
3054
|
],
|
|
@@ -2787,6 +3088,8 @@ class CEmitter:
|
|
|
2787
3088
|
return model.op_context.dtype(op.data)
|
|
2788
3089
|
if isinstance(op, ExpandOp):
|
|
2789
3090
|
return model.op_context.dtype(op.input0)
|
|
3091
|
+
if hasattr(op, "output") and isinstance(op.output, str):
|
|
3092
|
+
return model.op_context.dtype(op.output)
|
|
2790
3093
|
return op.dtype
|
|
2791
3094
|
|
|
2792
3095
|
model_dtypes.update(
|
|
@@ -2798,7 +3101,10 @@ class CEmitter:
|
|
|
2798
3101
|
dtype
|
|
2799
3102
|
for op in resolved_ops
|
|
2800
3103
|
if isinstance(op, ArgReduceOp)
|
|
2801
|
-
for dtype in (
|
|
3104
|
+
for dtype in (
|
|
3105
|
+
model.op_context.dtype(op.input0),
|
|
3106
|
+
model.op_context.dtype(op.output),
|
|
3107
|
+
)
|
|
2802
3108
|
}
|
|
2803
3109
|
model_dtypes.update(arg_reduce_dtypes)
|
|
2804
3110
|
topk_dtypes = {
|
|
@@ -2806,9 +3112,9 @@ class CEmitter:
|
|
|
2806
3112
|
for op in resolved_ops
|
|
2807
3113
|
if isinstance(op, TopKOp)
|
|
2808
3114
|
for dtype in (
|
|
2809
|
-
op.
|
|
2810
|
-
op.
|
|
2811
|
-
op.
|
|
3115
|
+
model.op_context.dtype(op.input0),
|
|
3116
|
+
model.op_context.dtype(op.output_values),
|
|
3117
|
+
model.op_context.dtype(op.output_indices),
|
|
2812
3118
|
)
|
|
2813
3119
|
}
|
|
2814
3120
|
model_dtypes.update(topk_dtypes)
|
|
@@ -2867,15 +3173,18 @@ class CEmitter:
|
|
|
2867
3173
|
includes.add("#include <stdbool.h>")
|
|
2868
3174
|
if any(
|
|
2869
3175
|
isinstance(op, UnaryOp)
|
|
2870
|
-
and unary_op_symbol(
|
|
3176
|
+
and unary_op_symbol(
|
|
3177
|
+
op.function, dtype=model.op_context.dtype(op.output)
|
|
3178
|
+
)
|
|
3179
|
+
in {"llabs", "abs"}
|
|
2871
3180
|
for op in resolved_ops
|
|
2872
3181
|
):
|
|
2873
3182
|
includes.add("#include <stdlib.h>")
|
|
2874
3183
|
if any(isinstance(op, PadOp) for op in resolved_ops):
|
|
2875
3184
|
includes.add("#include <stddef.h>")
|
|
2876
|
-
if CEmitter._needs_math(resolved_ops):
|
|
3185
|
+
if CEmitter._needs_math(resolved_ops, model.op_context):
|
|
2877
3186
|
includes.add("#include <math.h>")
|
|
2878
|
-
if CEmitter._needs_limits(resolved_ops):
|
|
3187
|
+
if CEmitter._needs_limits(resolved_ops, model.op_context):
|
|
2879
3188
|
includes.add("#include <limits.h>")
|
|
2880
3189
|
if any(
|
|
2881
3190
|
isinstance(op, (ConcatOp, ReshapeOp, SplitOp, IdentityOp))
|
|
@@ -2905,6 +3214,20 @@ class CEmitter:
|
|
|
2905
3214
|
)
|
|
2906
3215
|
)
|
|
2907
3216
|
|
|
3217
|
+
@staticmethod
|
|
3218
|
+
def _emit_unused_define() -> str:
|
|
3219
|
+
return "\n".join(
|
|
3220
|
+
(
|
|
3221
|
+
"#ifndef EMX_UNUSED",
|
|
3222
|
+
"#if defined(__GNUC__) || defined(__clang__)",
|
|
3223
|
+
"#define EMX_UNUSED __attribute__((unused))",
|
|
3224
|
+
"#else",
|
|
3225
|
+
"#define EMX_UNUSED",
|
|
3226
|
+
"#endif",
|
|
3227
|
+
"#endif",
|
|
3228
|
+
)
|
|
3229
|
+
)
|
|
3230
|
+
|
|
2908
3231
|
@staticmethod
|
|
2909
3232
|
def _needs_stdint(
|
|
2910
3233
|
model_dtypes: set[ScalarType],
|
|
@@ -2940,12 +3263,15 @@ class CEmitter:
|
|
|
2940
3263
|
| ClipOp
|
|
2941
3264
|
| CastOp
|
|
2942
3265
|
| QuantizeLinearOp
|
|
3266
|
+
| DequantizeLinearOp
|
|
3267
|
+
| QLinearMulOp
|
|
2943
3268
|
| QLinearMatMulOp
|
|
2944
3269
|
| MatMulOp
|
|
2945
3270
|
| EinsumOp
|
|
2946
3271
|
| GemmOp
|
|
2947
3272
|
| AttentionOp
|
|
2948
3273
|
| ConvOp
|
|
3274
|
+
| ConvIntegerOp
|
|
2949
3275
|
| ConvTransposeOp
|
|
2950
3276
|
| AveragePoolOp
|
|
2951
3277
|
| LpPoolOp
|
|
@@ -2957,6 +3283,7 @@ class CEmitter:
|
|
|
2957
3283
|
| MeanVarianceNormalizationOp
|
|
2958
3284
|
| RMSNormalizationOp
|
|
2959
3285
|
| LrnOp
|
|
3286
|
+
| GruOp
|
|
2960
3287
|
| LstmOp
|
|
2961
3288
|
| AdagradOp
|
|
2962
3289
|
| SoftmaxOp
|
|
@@ -2988,14 +3315,17 @@ class CEmitter:
|
|
|
2988
3315
|
| ConstantOfShapeOp
|
|
2989
3316
|
| ShapeOp
|
|
2990
3317
|
| SizeOp
|
|
3318
|
+
| OptionalHasElementOp
|
|
2991
3319
|
| NonZeroOp
|
|
2992
3320
|
| NonMaxSuppressionOp
|
|
2993
3321
|
| ExpandOp
|
|
2994
3322
|
| CumSumOp
|
|
2995
3323
|
| RangeOp
|
|
3324
|
+
| HammingWindowOp
|
|
2996
3325
|
| OneHotOp
|
|
2997
3326
|
| SplitOp
|
|
2998
3327
|
],
|
|
3328
|
+
op_context: OpContext,
|
|
2999
3329
|
) -> bool:
|
|
3000
3330
|
math_ops = {
|
|
3001
3331
|
"atanhf",
|
|
@@ -3014,13 +3344,18 @@ class CEmitter:
|
|
|
3014
3344
|
|
|
3015
3345
|
def is_binary_math_op(op: BinaryOp) -> bool:
|
|
3016
3346
|
op_spec = binary_op_symbol(
|
|
3017
|
-
op.function,
|
|
3347
|
+
op.function,
|
|
3348
|
+
dtype=op_context.dtype(op.input0),
|
|
3349
|
+
validate_attrs=False,
|
|
3018
3350
|
)
|
|
3019
3351
|
return op_spec is not None and op_spec.operator in binary_math_ops
|
|
3020
3352
|
|
|
3021
3353
|
if any(
|
|
3022
3354
|
isinstance(op, UnaryOp)
|
|
3023
|
-
and unary_op_symbol(
|
|
3355
|
+
and unary_op_symbol(
|
|
3356
|
+
op.function, dtype=op_context.dtype(op.output)
|
|
3357
|
+
)
|
|
3358
|
+
in math_ops
|
|
3024
3359
|
for op in resolved_ops
|
|
3025
3360
|
):
|
|
3026
3361
|
return True
|
|
@@ -3038,7 +3373,7 @@ class CEmitter:
|
|
|
3038
3373
|
return True
|
|
3039
3374
|
if any(
|
|
3040
3375
|
isinstance(op, ClipOp)
|
|
3041
|
-
and
|
|
3376
|
+
and op_context.dtype(op.output).is_float
|
|
3042
3377
|
and (op.input_min is None or op.input_max is None)
|
|
3043
3378
|
for op in resolved_ops
|
|
3044
3379
|
):
|
|
@@ -3061,6 +3396,7 @@ class CEmitter:
|
|
|
3061
3396
|
MeanVarianceNormalizationOp,
|
|
3062
3397
|
RMSNormalizationOp,
|
|
3063
3398
|
LrnOp,
|
|
3399
|
+
GruOp,
|
|
3064
3400
|
LstmOp,
|
|
3065
3401
|
AdagradOp,
|
|
3066
3402
|
SoftmaxOp,
|
|
@@ -3082,7 +3418,7 @@ class CEmitter:
|
|
|
3082
3418
|
if any(
|
|
3083
3419
|
isinstance(op, ReduceOp)
|
|
3084
3420
|
and op.reduce_kind in {"min", "max"}
|
|
3085
|
-
and
|
|
3421
|
+
and op_context.dtype(op.output).is_float
|
|
3086
3422
|
for op in resolved_ops
|
|
3087
3423
|
):
|
|
3088
3424
|
return True
|
|
@@ -3092,10 +3428,20 @@ class CEmitter:
|
|
|
3092
3428
|
):
|
|
3093
3429
|
return True
|
|
3094
3430
|
if any(
|
|
3095
|
-
isinstance(
|
|
3431
|
+
isinstance(
|
|
3432
|
+
op,
|
|
3433
|
+
(
|
|
3434
|
+
LpPoolOp,
|
|
3435
|
+
QuantizeLinearOp,
|
|
3436
|
+
QLinearMulOp,
|
|
3437
|
+
QLinearMatMulOp,
|
|
3438
|
+
),
|
|
3439
|
+
)
|
|
3096
3440
|
for op in resolved_ops
|
|
3097
3441
|
):
|
|
3098
3442
|
return True
|
|
3443
|
+
if any(isinstance(op, HammingWindowOp) for op in resolved_ops):
|
|
3444
|
+
return True
|
|
3099
3445
|
return False
|
|
3100
3446
|
|
|
3101
3447
|
@staticmethod
|
|
@@ -3106,12 +3452,15 @@ class CEmitter:
|
|
|
3106
3452
|
| ClipOp
|
|
3107
3453
|
| CastOp
|
|
3108
3454
|
| QuantizeLinearOp
|
|
3455
|
+
| DequantizeLinearOp
|
|
3456
|
+
| QLinearMulOp
|
|
3109
3457
|
| QLinearMatMulOp
|
|
3110
3458
|
| MatMulOp
|
|
3111
3459
|
| EinsumOp
|
|
3112
3460
|
| GemmOp
|
|
3113
3461
|
| AttentionOp
|
|
3114
3462
|
| ConvOp
|
|
3463
|
+
| ConvIntegerOp
|
|
3115
3464
|
| ConvTransposeOp
|
|
3116
3465
|
| AveragePoolOp
|
|
3117
3466
|
| LpPoolOp
|
|
@@ -3123,6 +3472,7 @@ class CEmitter:
|
|
|
3123
3472
|
| MeanVarianceNormalizationOp
|
|
3124
3473
|
| RMSNormalizationOp
|
|
3125
3474
|
| LrnOp
|
|
3475
|
+
| GruOp
|
|
3126
3476
|
| LstmOp
|
|
3127
3477
|
| SoftmaxOp
|
|
3128
3478
|
| LogSoftmaxOp
|
|
@@ -3151,19 +3501,23 @@ class CEmitter:
|
|
|
3151
3501
|
| ConstantOfShapeOp
|
|
3152
3502
|
| ShapeOp
|
|
3153
3503
|
| SizeOp
|
|
3504
|
+
| OptionalHasElementOp
|
|
3154
3505
|
| NonZeroOp
|
|
3155
3506
|
| NonMaxSuppressionOp
|
|
3156
3507
|
| ExpandOp
|
|
3157
3508
|
| CumSumOp
|
|
3158
3509
|
| RangeOp
|
|
3510
|
+
| HammingWindowOp
|
|
3159
3511
|
| OneHotOp
|
|
3160
3512
|
| SplitOp
|
|
3161
3513
|
],
|
|
3514
|
+
op_context: OpContext,
|
|
3162
3515
|
) -> bool:
|
|
3163
3516
|
if any(
|
|
3164
3517
|
isinstance(op, ReduceOp)
|
|
3165
3518
|
and op.reduce_kind in {"min", "max"}
|
|
3166
|
-
and op.
|
|
3519
|
+
and op_context.dtype(op.output)
|
|
3520
|
+
in {
|
|
3167
3521
|
ScalarType.I64,
|
|
3168
3522
|
ScalarType.I32,
|
|
3169
3523
|
ScalarType.I16,
|
|
@@ -3174,7 +3528,7 @@ class CEmitter:
|
|
|
3174
3528
|
return True
|
|
3175
3529
|
if any(
|
|
3176
3530
|
isinstance(op, ClipOp)
|
|
3177
|
-
and
|
|
3531
|
+
and op_context.dtype(op.output).is_integer
|
|
3178
3532
|
and (op.input_min is None or op.input_max is None)
|
|
3179
3533
|
for op in resolved_ops
|
|
3180
3534
|
):
|
|
@@ -3187,7 +3541,7 @@ class CEmitter:
|
|
|
3187
3541
|
):
|
|
3188
3542
|
return True
|
|
3189
3543
|
if any(
|
|
3190
|
-
isinstance(op, (QuantizeLinearOp, QLinearMatMulOp))
|
|
3544
|
+
isinstance(op, (QuantizeLinearOp, QLinearMulOp, QLinearMatMulOp))
|
|
3191
3545
|
and op.dtype.is_integer
|
|
3192
3546
|
for op in resolved_ops
|
|
3193
3547
|
):
|
|
@@ -3206,12 +3560,15 @@ class CEmitter:
|
|
|
3206
3560
|
| ClipOp
|
|
3207
3561
|
| CastOp
|
|
3208
3562
|
| QuantizeLinearOp
|
|
3563
|
+
| DequantizeLinearOp
|
|
3564
|
+
| QLinearMulOp
|
|
3209
3565
|
| QLinearMatMulOp
|
|
3210
3566
|
| MatMulOp
|
|
3211
3567
|
| EinsumOp
|
|
3212
3568
|
| GemmOp
|
|
3213
3569
|
| AttentionOp
|
|
3214
3570
|
| ConvOp
|
|
3571
|
+
| ConvIntegerOp
|
|
3215
3572
|
| ConvTransposeOp
|
|
3216
3573
|
| AveragePoolOp
|
|
3217
3574
|
| LpPoolOp
|
|
@@ -3223,6 +3580,7 @@ class CEmitter:
|
|
|
3223
3580
|
| MeanVarianceNormalizationOp
|
|
3224
3581
|
| RMSNormalizationOp
|
|
3225
3582
|
| LrnOp
|
|
3583
|
+
| GruOp
|
|
3226
3584
|
| LstmOp
|
|
3227
3585
|
| SoftmaxOp
|
|
3228
3586
|
| LogSoftmaxOp
|
|
@@ -3256,6 +3614,7 @@ class CEmitter:
|
|
|
3256
3614
|
| ExpandOp
|
|
3257
3615
|
| CumSumOp
|
|
3258
3616
|
| RangeOp
|
|
3617
|
+
| HammingWindowOp
|
|
3259
3618
|
| OneHotOp
|
|
3260
3619
|
| SplitOp
|
|
3261
3620
|
],
|
|
@@ -3266,6 +3625,7 @@ class CEmitter:
|
|
|
3266
3625
|
output_dim_names: Mapping[int, Mapping[int, str]],
|
|
3267
3626
|
) -> str:
|
|
3268
3627
|
params = []
|
|
3628
|
+
optional_flags = self._optional_input_flag_map(model)
|
|
3269
3629
|
if dim_order:
|
|
3270
3630
|
params.extend(self._format_dim_args(dim_order))
|
|
3271
3631
|
for index, (name, shape, dtype) in enumerate(
|
|
@@ -3273,14 +3633,17 @@ class CEmitter:
|
|
|
3273
3633
|
):
|
|
3274
3634
|
params.append(
|
|
3275
3635
|
f"const {dtype.c_type} {name}"
|
|
3276
|
-
f"{self._param_array_suffix(shape, input_dim_names.get(index))}"
|
|
3636
|
+
f"{self._param_array_suffix(shape, input_dim_names.get(index), use_restrict=True)}"
|
|
3277
3637
|
)
|
|
3638
|
+
optional_flag = optional_flags.get(name)
|
|
3639
|
+
if optional_flag is not None:
|
|
3640
|
+
params.append(f"_Bool {optional_flag}")
|
|
3278
3641
|
for index, (name, shape, dtype) in enumerate(
|
|
3279
3642
|
zip(model.output_names, model.output_shapes, model.output_dtypes)
|
|
3280
3643
|
):
|
|
3281
3644
|
params.append(
|
|
3282
3645
|
f"{dtype.c_type} {name}"
|
|
3283
|
-
f"{self._param_array_suffix(shape, output_dim_names.get(index))}"
|
|
3646
|
+
f"{self._param_array_suffix(shape, output_dim_names.get(index), use_restrict=True)}"
|
|
3284
3647
|
)
|
|
3285
3648
|
signature = ", ".join(params)
|
|
3286
3649
|
lines = [f"void {model.name}({signature}) {{"]
|
|
@@ -3297,7 +3660,7 @@ class CEmitter:
|
|
|
3297
3660
|
)
|
|
3298
3661
|
for index, op in enumerate(resolved_ops):
|
|
3299
3662
|
op_name = self._op_function_name(model, index)
|
|
3300
|
-
call = self._build_op_call(op, dim_order)
|
|
3663
|
+
call = self._build_op_call(op, dim_order, optional_flags)
|
|
3301
3664
|
lines.append(f" {op_name}({call});")
|
|
3302
3665
|
lines.append("}")
|
|
3303
3666
|
return "\n".join(lines)
|
|
@@ -3309,14 +3672,16 @@ class CEmitter:
|
|
|
3309
3672
|
element_count *= dim
|
|
3310
3673
|
return element_count * temp.dtype.np_dtype.itemsize
|
|
3311
3674
|
|
|
3312
|
-
@staticmethod
|
|
3313
3675
|
def _build_op_call(
|
|
3676
|
+
self,
|
|
3314
3677
|
op: BinaryOp
|
|
3315
3678
|
| WhereOp
|
|
3316
3679
|
| UnaryOp
|
|
3317
3680
|
| ClipOp
|
|
3318
3681
|
| CastOp
|
|
3319
3682
|
| QuantizeLinearOp
|
|
3683
|
+
| DequantizeLinearOp
|
|
3684
|
+
| QLinearMulOp
|
|
3320
3685
|
| QLinearMatMulOp
|
|
3321
3686
|
| MatMulOp
|
|
3322
3687
|
| EinsumOp
|
|
@@ -3324,6 +3689,7 @@ class CEmitter:
|
|
|
3324
3689
|
| AttentionOp
|
|
3325
3690
|
| RotaryEmbeddingOp
|
|
3326
3691
|
| ConvOp
|
|
3692
|
+
| ConvIntegerOp
|
|
3327
3693
|
| ConvTransposeOp
|
|
3328
3694
|
| AveragePoolOp
|
|
3329
3695
|
| LpPoolOp
|
|
@@ -3335,6 +3701,7 @@ class CEmitter:
|
|
|
3335
3701
|
| MeanVarianceNormalizationOp
|
|
3336
3702
|
| RMSNormalizationOp
|
|
3337
3703
|
| LrnOp
|
|
3704
|
+
| GruOp
|
|
3338
3705
|
| LstmOp
|
|
3339
3706
|
| AdagradOp
|
|
3340
3707
|
| SoftmaxOp
|
|
@@ -3367,16 +3734,21 @@ class CEmitter:
|
|
|
3367
3734
|
| ConstantOfShapeOp
|
|
3368
3735
|
| ShapeOp
|
|
3369
3736
|
| SizeOp
|
|
3737
|
+
| OptionalHasElementOp
|
|
3370
3738
|
| NonZeroOp
|
|
3371
3739
|
| NonMaxSuppressionOp
|
|
3372
3740
|
| ExpandOp
|
|
3373
3741
|
| CumSumOp
|
|
3374
3742
|
| RangeOp
|
|
3743
|
+
| HammingWindowOp
|
|
3375
3744
|
| OneHotOp
|
|
3376
|
-
| SplitOp
|
|
3745
|
+
| SplitOp
|
|
3746
|
+
| OptionalHasElementOp,
|
|
3377
3747
|
dim_order: Sequence[str],
|
|
3748
|
+
optional_flags: Mapping[str, str] | None = None,
|
|
3378
3749
|
) -> str:
|
|
3379
3750
|
args: list[str] = []
|
|
3751
|
+
optional_flags = optional_flags or {}
|
|
3380
3752
|
if dim_order:
|
|
3381
3753
|
args.extend(dim_order)
|
|
3382
3754
|
if isinstance(op, BinaryOp):
|
|
@@ -3388,6 +3760,21 @@ class CEmitter:
|
|
|
3388
3760
|
if isinstance(op, WhereOp):
|
|
3389
3761
|
args.extend([op.condition, op.input_x, op.input_y, op.output])
|
|
3390
3762
|
return ", ".join(args)
|
|
3763
|
+
if isinstance(op, QLinearMulOp):
|
|
3764
|
+
args.extend(
|
|
3765
|
+
[
|
|
3766
|
+
op.input0,
|
|
3767
|
+
op.input0_scale,
|
|
3768
|
+
op.input0_zero_point,
|
|
3769
|
+
op.input1,
|
|
3770
|
+
op.input1_scale,
|
|
3771
|
+
op.input1_zero_point,
|
|
3772
|
+
op.output_scale,
|
|
3773
|
+
op.output_zero_point,
|
|
3774
|
+
op.output,
|
|
3775
|
+
]
|
|
3776
|
+
)
|
|
3777
|
+
return ", ".join(args)
|
|
3391
3778
|
if isinstance(op, QLinearMatMulOp):
|
|
3392
3779
|
args.extend(
|
|
3393
3780
|
[
|
|
@@ -3434,6 +3821,13 @@ class CEmitter:
|
|
|
3434
3821
|
call_parts.append(op.output)
|
|
3435
3822
|
args.extend(call_parts)
|
|
3436
3823
|
return ", ".join(args)
|
|
3824
|
+
if isinstance(op, DequantizeLinearOp):
|
|
3825
|
+
call_parts = [op.input0, op.scale]
|
|
3826
|
+
if op.zero_point is not None:
|
|
3827
|
+
call_parts.append(op.zero_point)
|
|
3828
|
+
call_parts.append(op.output)
|
|
3829
|
+
args.extend(call_parts)
|
|
3830
|
+
return ", ".join(args)
|
|
3437
3831
|
if isinstance(op, AttentionOp):
|
|
3438
3832
|
call_parts = [op.input_q, op.input_k, op.input_v]
|
|
3439
3833
|
if op.input_attn_mask is not None:
|
|
@@ -3453,12 +3847,27 @@ class CEmitter:
|
|
|
3453
3847
|
call_parts.append(op.output_qk_matmul)
|
|
3454
3848
|
args.extend(call_parts)
|
|
3455
3849
|
return ", ".join(args)
|
|
3850
|
+
if isinstance(op, RotaryEmbeddingOp):
|
|
3851
|
+
call_parts = [op.input0, op.cos_cache, op.sin_cache]
|
|
3852
|
+
if op.position_ids is not None:
|
|
3853
|
+
call_parts.append(op.position_ids)
|
|
3854
|
+
call_parts.append(op.output)
|
|
3855
|
+
args.extend(call_parts)
|
|
3856
|
+
return ", ".join(args)
|
|
3456
3857
|
if isinstance(op, ConvOp):
|
|
3457
3858
|
if op.bias is None:
|
|
3458
3859
|
args.extend([op.input0, op.weights, op.output])
|
|
3459
3860
|
return ", ".join(args)
|
|
3460
3861
|
args.extend([op.input0, op.weights, op.bias, op.output])
|
|
3461
3862
|
return ", ".join(args)
|
|
3863
|
+
if isinstance(op, ConvIntegerOp):
|
|
3864
|
+
args.extend([op.input0, op.weights])
|
|
3865
|
+
if op.x_zero_point is not None:
|
|
3866
|
+
args.append(op.x_zero_point)
|
|
3867
|
+
if op.w_zero_point is not None:
|
|
3868
|
+
args.append(op.w_zero_point)
|
|
3869
|
+
args.append(op.output)
|
|
3870
|
+
return ", ".join(args)
|
|
3462
3871
|
if isinstance(op, ConvTransposeOp):
|
|
3463
3872
|
if op.bias is None:
|
|
3464
3873
|
args.extend([op.input0, op.weights, op.output])
|
|
@@ -3502,6 +3911,20 @@ class CEmitter:
|
|
|
3502
3911
|
if isinstance(op, RMSNormalizationOp):
|
|
3503
3912
|
args.extend([op.input0, op.scale, op.output])
|
|
3504
3913
|
return ", ".join(args)
|
|
3914
|
+
if isinstance(op, GruOp):
|
|
3915
|
+
call_parts = [op.input_x, op.input_w, op.input_r]
|
|
3916
|
+
if op.input_b is not None:
|
|
3917
|
+
call_parts.append(op.input_b)
|
|
3918
|
+
if op.input_sequence_lens is not None:
|
|
3919
|
+
call_parts.append(op.input_sequence_lens)
|
|
3920
|
+
if op.input_initial_h is not None:
|
|
3921
|
+
call_parts.append(op.input_initial_h)
|
|
3922
|
+
if op.output_y is not None:
|
|
3923
|
+
call_parts.append(op.output_y)
|
|
3924
|
+
if op.output_y_h is not None:
|
|
3925
|
+
call_parts.append(op.output_y_h)
|
|
3926
|
+
args.extend(call_parts)
|
|
3927
|
+
return ", ".join(args)
|
|
3505
3928
|
if isinstance(op, LstmOp):
|
|
3506
3929
|
call_parts = [op.input_x, op.input_w, op.input_r]
|
|
3507
3930
|
if op.input_b is not None:
|
|
@@ -3523,17 +3946,18 @@ class CEmitter:
|
|
|
3523
3946
|
args.extend(call_parts)
|
|
3524
3947
|
return ", ".join(args)
|
|
3525
3948
|
if isinstance(op, AdagradOp):
|
|
3526
|
-
args.
|
|
3527
|
-
|
|
3528
|
-
|
|
3529
|
-
|
|
3530
|
-
|
|
3531
|
-
|
|
3532
|
-
|
|
3533
|
-
|
|
3534
|
-
|
|
3535
|
-
|
|
3536
|
-
|
|
3949
|
+
args.append(op.rate)
|
|
3950
|
+
args.append(op.timestep)
|
|
3951
|
+
for index in range(len(op.inputs)):
|
|
3952
|
+
args.extend(
|
|
3953
|
+
[
|
|
3954
|
+
op.inputs[index],
|
|
3955
|
+
op.gradients[index],
|
|
3956
|
+
op.accumulators[index],
|
|
3957
|
+
op.outputs[index],
|
|
3958
|
+
op.accumulator_outputs[index],
|
|
3959
|
+
]
|
|
3960
|
+
)
|
|
3537
3961
|
return ", ".join(args)
|
|
3538
3962
|
if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
|
|
3539
3963
|
args.extend([op.input0, op.output])
|
|
@@ -3590,6 +4014,14 @@ class CEmitter:
|
|
|
3590
4014
|
if isinstance(op, SizeOp):
|
|
3591
4015
|
args.extend([op.input0, op.output])
|
|
3592
4016
|
return ", ".join(args)
|
|
4017
|
+
if isinstance(op, OptionalHasElementOp):
|
|
4018
|
+
input_flag = optional_flags.get(op.input0)
|
|
4019
|
+
if input_flag is None:
|
|
4020
|
+
raise CodegenError(
|
|
4021
|
+
"OptionalHasElement expects an optional input flag."
|
|
4022
|
+
)
|
|
4023
|
+
args.extend([op.input0, input_flag, op.output])
|
|
4024
|
+
return ", ".join(args)
|
|
3593
4025
|
if isinstance(op, NonZeroOp):
|
|
3594
4026
|
args.extend([op.input0, op.output])
|
|
3595
4027
|
return ", ".join(args)
|
|
@@ -3622,6 +4054,9 @@ class CEmitter:
|
|
|
3622
4054
|
if isinstance(op, RangeOp):
|
|
3623
4055
|
args.extend([op.start, op.limit, op.delta, op.output])
|
|
3624
4056
|
return ", ".join(args)
|
|
4057
|
+
if isinstance(op, HammingWindowOp):
|
|
4058
|
+
args.extend([op.size, op.output])
|
|
4059
|
+
return ", ".join(args)
|
|
3625
4060
|
if isinstance(op, OneHotOp):
|
|
3626
4061
|
args.extend([op.indices, op.depth, op.values, op.output])
|
|
3627
4062
|
return ", ".join(args)
|
|
@@ -3732,12 +4167,15 @@ class CEmitter:
|
|
|
3732
4167
|
@staticmethod
|
|
3733
4168
|
def _resolve_op(
|
|
3734
4169
|
op: BinaryOp
|
|
4170
|
+
| PowOp
|
|
3735
4171
|
| MultiInputBinaryOp
|
|
3736
4172
|
| WhereOp
|
|
3737
4173
|
| UnaryOp
|
|
3738
4174
|
| ClipOp
|
|
3739
4175
|
| CastOp
|
|
3740
4176
|
| QuantizeLinearOp
|
|
4177
|
+
| DequantizeLinearOp
|
|
4178
|
+
| QLinearMulOp
|
|
3741
4179
|
| QLinearMatMulOp
|
|
3742
4180
|
| MatMulOp
|
|
3743
4181
|
| EinsumOp
|
|
@@ -3745,6 +4183,7 @@ class CEmitter:
|
|
|
3745
4183
|
| AttentionOp
|
|
3746
4184
|
| RotaryEmbeddingOp
|
|
3747
4185
|
| ConvOp
|
|
4186
|
+
| ConvIntegerOp
|
|
3748
4187
|
| ConvTransposeOp
|
|
3749
4188
|
| AveragePoolOp
|
|
3750
4189
|
| LpPoolOp
|
|
@@ -3756,6 +4195,7 @@ class CEmitter:
|
|
|
3756
4195
|
| MeanVarianceNormalizationOp
|
|
3757
4196
|
| RMSNormalizationOp
|
|
3758
4197
|
| LrnOp
|
|
4198
|
+
| GruOp
|
|
3759
4199
|
| LstmOp
|
|
3760
4200
|
| AdagradOp
|
|
3761
4201
|
| SoftmaxOp
|
|
@@ -3792,23 +4232,29 @@ class CEmitter:
|
|
|
3792
4232
|
| ExpandOp
|
|
3793
4233
|
| CumSumOp
|
|
3794
4234
|
| RangeOp
|
|
4235
|
+
| HammingWindowOp
|
|
3795
4236
|
| OneHotOp
|
|
4237
|
+
| TfIdfVectorizerOp
|
|
3796
4238
|
| SplitOp,
|
|
3797
4239
|
temp_map: dict[str, str],
|
|
3798
4240
|
) -> (
|
|
3799
4241
|
BinaryOp
|
|
4242
|
+
| PowOp
|
|
3800
4243
|
| MultiInputBinaryOp
|
|
3801
4244
|
| WhereOp
|
|
3802
4245
|
| UnaryOp
|
|
3803
4246
|
| ClipOp
|
|
3804
4247
|
| CastOp
|
|
3805
4248
|
| QuantizeLinearOp
|
|
4249
|
+
| DequantizeLinearOp
|
|
4250
|
+
| QLinearMulOp
|
|
3806
4251
|
| QLinearMatMulOp
|
|
3807
4252
|
| MatMulOp
|
|
3808
4253
|
| EinsumOp
|
|
3809
4254
|
| GemmOp
|
|
3810
4255
|
| AttentionOp
|
|
3811
4256
|
| ConvOp
|
|
4257
|
+
| ConvIntegerOp
|
|
3812
4258
|
| ConvTransposeOp
|
|
3813
4259
|
| AveragePoolOp
|
|
3814
4260
|
| LpPoolOp
|
|
@@ -3820,6 +4266,7 @@ class CEmitter:
|
|
|
3820
4266
|
| MeanVarianceNormalizationOp
|
|
3821
4267
|
| RMSNormalizationOp
|
|
3822
4268
|
| LrnOp
|
|
4269
|
+
| GruOp
|
|
3823
4270
|
| LstmOp
|
|
3824
4271
|
| AdagradOp
|
|
3825
4272
|
| SoftmaxOp
|
|
@@ -3856,9 +4303,19 @@ class CEmitter:
|
|
|
3856
4303
|
| ExpandOp
|
|
3857
4304
|
| CumSumOp
|
|
3858
4305
|
| RangeOp
|
|
4306
|
+
| HammingWindowOp
|
|
3859
4307
|
| OneHotOp
|
|
3860
4308
|
| SplitOp
|
|
4309
|
+
| TfIdfVectorizerOp
|
|
3861
4310
|
):
|
|
4311
|
+
if isinstance(op, PowOp):
|
|
4312
|
+
return PowOp(
|
|
4313
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
4314
|
+
input1=temp_map.get(op.input1, op.input1),
|
|
4315
|
+
output=temp_map.get(op.output, op.output),
|
|
4316
|
+
function=op.function,
|
|
4317
|
+
operator_kind=op.operator_kind,
|
|
4318
|
+
)
|
|
3862
4319
|
if isinstance(op, BinaryOp):
|
|
3863
4320
|
return BinaryOp(
|
|
3864
4321
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -3866,11 +4323,6 @@ class CEmitter:
|
|
|
3866
4323
|
output=temp_map.get(op.output, op.output),
|
|
3867
4324
|
function=op.function,
|
|
3868
4325
|
operator_kind=op.operator_kind,
|
|
3869
|
-
input0_shape=op.input0_shape,
|
|
3870
|
-
input1_shape=op.input1_shape,
|
|
3871
|
-
shape=op.shape,
|
|
3872
|
-
dtype=op.dtype,
|
|
3873
|
-
input_dtype=op.input_dtype,
|
|
3874
4326
|
)
|
|
3875
4327
|
if isinstance(op, MultiInputBinaryOp):
|
|
3876
4328
|
return MultiInputBinaryOp(
|
|
@@ -3888,20 +4340,12 @@ class CEmitter:
|
|
|
3888
4340
|
input_x=temp_map.get(op.input_x, op.input_x),
|
|
3889
4341
|
input_y=temp_map.get(op.input_y, op.input_y),
|
|
3890
4342
|
output=temp_map.get(op.output, op.output),
|
|
3891
|
-
condition_shape=op.condition_shape,
|
|
3892
|
-
x_shape=op.x_shape,
|
|
3893
|
-
y_shape=op.y_shape,
|
|
3894
|
-
output_shape=op.output_shape,
|
|
3895
|
-
dtype=op.dtype,
|
|
3896
4343
|
)
|
|
3897
4344
|
if isinstance(op, UnaryOp):
|
|
3898
4345
|
return UnaryOp(
|
|
3899
4346
|
input0=temp_map.get(op.input0, op.input0),
|
|
3900
4347
|
output=temp_map.get(op.output, op.output),
|
|
3901
4348
|
function=op.function,
|
|
3902
|
-
shape=op.shape,
|
|
3903
|
-
dtype=op.dtype,
|
|
3904
|
-
input_dtype=op.input_dtype,
|
|
3905
4349
|
params=op.params,
|
|
3906
4350
|
)
|
|
3907
4351
|
if isinstance(op, ClipOp):
|
|
@@ -3914,29 +4358,14 @@ class CEmitter:
|
|
|
3914
4358
|
if op.input_max is not None
|
|
3915
4359
|
else None,
|
|
3916
4360
|
output=temp_map.get(op.output, op.output),
|
|
3917
|
-
|
|
3918
|
-
|
|
3919
|
-
max_shape=op.max_shape,
|
|
3920
|
-
output_shape=op.output_shape,
|
|
3921
|
-
dtype=op.dtype,
|
|
4361
|
+
min_value=op.min_value,
|
|
4362
|
+
max_value=op.max_value,
|
|
3922
4363
|
)
|
|
3923
4364
|
if isinstance(op, MatMulOp):
|
|
3924
4365
|
return MatMulOp(
|
|
3925
4366
|
input0=temp_map.get(op.input0, op.input0),
|
|
3926
4367
|
input1=temp_map.get(op.input1, op.input1),
|
|
3927
4368
|
output=temp_map.get(op.output, op.output),
|
|
3928
|
-
input0_shape=op.input0_shape,
|
|
3929
|
-
input1_shape=op.input1_shape,
|
|
3930
|
-
output_shape=op.output_shape,
|
|
3931
|
-
batch_shape=op.batch_shape,
|
|
3932
|
-
input0_batch_shape=op.input0_batch_shape,
|
|
3933
|
-
input1_batch_shape=op.input1_batch_shape,
|
|
3934
|
-
m=op.m,
|
|
3935
|
-
n=op.n,
|
|
3936
|
-
k=op.k,
|
|
3937
|
-
left_vector=op.left_vector,
|
|
3938
|
-
right_vector=op.right_vector,
|
|
3939
|
-
dtype=op.dtype,
|
|
3940
4369
|
)
|
|
3941
4370
|
if isinstance(op, EinsumOp):
|
|
3942
4371
|
return EinsumOp(
|
|
@@ -3972,6 +4401,56 @@ class CEmitter:
|
|
|
3972
4401
|
input_dtype=op.input_dtype,
|
|
3973
4402
|
scale_dtype=op.scale_dtype,
|
|
3974
4403
|
)
|
|
4404
|
+
if isinstance(op, DequantizeLinearOp):
|
|
4405
|
+
return DequantizeLinearOp(
|
|
4406
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
4407
|
+
scale=temp_map.get(op.scale, op.scale),
|
|
4408
|
+
zero_point=(
|
|
4409
|
+
temp_map.get(op.zero_point, op.zero_point)
|
|
4410
|
+
if op.zero_point is not None
|
|
4411
|
+
else None
|
|
4412
|
+
),
|
|
4413
|
+
output=temp_map.get(op.output, op.output),
|
|
4414
|
+
input_shape=op.input_shape,
|
|
4415
|
+
axis=op.axis,
|
|
4416
|
+
block_size=op.block_size,
|
|
4417
|
+
dtype=op.dtype,
|
|
4418
|
+
input_dtype=op.input_dtype,
|
|
4419
|
+
scale_dtype=op.scale_dtype,
|
|
4420
|
+
)
|
|
4421
|
+
if isinstance(op, QLinearMulOp):
|
|
4422
|
+
return QLinearMulOp(
|
|
4423
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
4424
|
+
input0_scale=temp_map.get(op.input0_scale, op.input0_scale),
|
|
4425
|
+
input0_zero_point=temp_map.get(
|
|
4426
|
+
op.input0_zero_point, op.input0_zero_point
|
|
4427
|
+
),
|
|
4428
|
+
input1=temp_map.get(op.input1, op.input1),
|
|
4429
|
+
input1_scale=temp_map.get(op.input1_scale, op.input1_scale),
|
|
4430
|
+
input1_zero_point=temp_map.get(
|
|
4431
|
+
op.input1_zero_point, op.input1_zero_point
|
|
4432
|
+
),
|
|
4433
|
+
output_scale=temp_map.get(op.output_scale, op.output_scale),
|
|
4434
|
+
output_zero_point=temp_map.get(
|
|
4435
|
+
op.output_zero_point, op.output_zero_point
|
|
4436
|
+
),
|
|
4437
|
+
output=temp_map.get(op.output, op.output),
|
|
4438
|
+
input0_shape=op.input0_shape,
|
|
4439
|
+
input1_shape=op.input1_shape,
|
|
4440
|
+
output_shape=op.output_shape,
|
|
4441
|
+
input0_dtype=op.input0_dtype,
|
|
4442
|
+
input1_dtype=op.input1_dtype,
|
|
4443
|
+
dtype=op.dtype,
|
|
4444
|
+
input0_scale_dtype=op.input0_scale_dtype,
|
|
4445
|
+
input1_scale_dtype=op.input1_scale_dtype,
|
|
4446
|
+
output_scale_dtype=op.output_scale_dtype,
|
|
4447
|
+
input0_scale_shape=op.input0_scale_shape,
|
|
4448
|
+
input1_scale_shape=op.input1_scale_shape,
|
|
4449
|
+
output_scale_shape=op.output_scale_shape,
|
|
4450
|
+
input0_zero_shape=op.input0_zero_shape,
|
|
4451
|
+
input1_zero_shape=op.input1_zero_shape,
|
|
4452
|
+
output_zero_shape=op.output_zero_shape,
|
|
4453
|
+
)
|
|
3975
4454
|
if isinstance(op, QLinearMatMulOp):
|
|
3976
4455
|
return QLinearMatMulOp(
|
|
3977
4456
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4023,15 +4502,10 @@ class CEmitter:
|
|
|
4023
4502
|
else None
|
|
4024
4503
|
),
|
|
4025
4504
|
output=temp_map.get(op.output, op.output),
|
|
4026
|
-
m=op.m,
|
|
4027
|
-
n=op.n,
|
|
4028
|
-
k=op.k,
|
|
4029
4505
|
trans_a=op.trans_a,
|
|
4030
4506
|
trans_b=op.trans_b,
|
|
4031
4507
|
alpha=op.alpha,
|
|
4032
4508
|
beta=op.beta,
|
|
4033
|
-
c_shape=op.c_shape,
|
|
4034
|
-
dtype=op.dtype,
|
|
4035
4509
|
)
|
|
4036
4510
|
if isinstance(op, AttentionOp):
|
|
4037
4511
|
return AttentionOp(
|
|
@@ -4133,6 +4607,51 @@ class CEmitter:
|
|
|
4133
4607
|
input_rank=op.input_rank,
|
|
4134
4608
|
interleaved=op.interleaved,
|
|
4135
4609
|
)
|
|
4610
|
+
if isinstance(op, GruOp):
|
|
4611
|
+
return GruOp(
|
|
4612
|
+
input_x=temp_map.get(op.input_x, op.input_x),
|
|
4613
|
+
input_w=temp_map.get(op.input_w, op.input_w),
|
|
4614
|
+
input_r=temp_map.get(op.input_r, op.input_r),
|
|
4615
|
+
input_b=(
|
|
4616
|
+
temp_map.get(op.input_b, op.input_b)
|
|
4617
|
+
if op.input_b is not None
|
|
4618
|
+
else None
|
|
4619
|
+
),
|
|
4620
|
+
input_sequence_lens=(
|
|
4621
|
+
temp_map.get(op.input_sequence_lens, op.input_sequence_lens)
|
|
4622
|
+
if op.input_sequence_lens is not None
|
|
4623
|
+
else None
|
|
4624
|
+
),
|
|
4625
|
+
input_initial_h=(
|
|
4626
|
+
temp_map.get(op.input_initial_h, op.input_initial_h)
|
|
4627
|
+
if op.input_initial_h is not None
|
|
4628
|
+
else None
|
|
4629
|
+
),
|
|
4630
|
+
output_y=(
|
|
4631
|
+
temp_map.get(op.output_y, op.output_y)
|
|
4632
|
+
if op.output_y is not None
|
|
4633
|
+
else None
|
|
4634
|
+
),
|
|
4635
|
+
output_y_h=(
|
|
4636
|
+
temp_map.get(op.output_y_h, op.output_y_h)
|
|
4637
|
+
if op.output_y_h is not None
|
|
4638
|
+
else None
|
|
4639
|
+
),
|
|
4640
|
+
seq_length=op.seq_length,
|
|
4641
|
+
batch_size=op.batch_size,
|
|
4642
|
+
input_size=op.input_size,
|
|
4643
|
+
hidden_size=op.hidden_size,
|
|
4644
|
+
num_directions=op.num_directions,
|
|
4645
|
+
direction=op.direction,
|
|
4646
|
+
layout=op.layout,
|
|
4647
|
+
linear_before_reset=op.linear_before_reset,
|
|
4648
|
+
clip=op.clip,
|
|
4649
|
+
activation_kinds=op.activation_kinds,
|
|
4650
|
+
activation_alphas=op.activation_alphas,
|
|
4651
|
+
activation_betas=op.activation_betas,
|
|
4652
|
+
dtype=op.dtype,
|
|
4653
|
+
sequence_lens_dtype=op.sequence_lens_dtype,
|
|
4654
|
+
)
|
|
4136
4655
|
if isinstance(op, LstmOp):
|
|
4137
4656
|
return LstmOp(
|
|
4138
4657
|
input_x=temp_map.get(op.input_x, op.input_x),
|
|
@@ -4239,6 +4758,35 @@ class CEmitter:
|
|
|
4239
4758
|
group=op.group,
|
|
4240
4759
|
dtype=op.dtype,
|
|
4241
4760
|
)
|
|
4761
|
+
if isinstance(op, ConvIntegerOp):
|
|
4762
|
+
return ConvIntegerOp(
|
|
4763
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
4764
|
+
weights=temp_map.get(op.weights, op.weights),
|
|
4765
|
+
x_zero_point=temp_map.get(op.x_zero_point, op.x_zero_point)
|
|
4766
|
+
if op.x_zero_point
|
|
4767
|
+
else None,
|
|
4768
|
+
w_zero_point=temp_map.get(op.w_zero_point, op.w_zero_point)
|
|
4769
|
+
if op.w_zero_point
|
|
4770
|
+
else None,
|
|
4771
|
+
output=temp_map.get(op.output, op.output),
|
|
4772
|
+
batch=op.batch,
|
|
4773
|
+
in_channels=op.in_channels,
|
|
4774
|
+
out_channels=op.out_channels,
|
|
4775
|
+
spatial_rank=op.spatial_rank,
|
|
4776
|
+
in_spatial=op.in_spatial,
|
|
4777
|
+
out_spatial=op.out_spatial,
|
|
4778
|
+
kernel_shape=op.kernel_shape,
|
|
4779
|
+
strides=op.strides,
|
|
4780
|
+
pads=op.pads,
|
|
4781
|
+
dilations=op.dilations,
|
|
4782
|
+
group=op.group,
|
|
4783
|
+
input_dtype=op.input_dtype,
|
|
4784
|
+
weight_dtype=op.weight_dtype,
|
|
4785
|
+
dtype=op.dtype,
|
|
4786
|
+
x_zero_point_shape=op.x_zero_point_shape,
|
|
4787
|
+
w_zero_point_shape=op.w_zero_point_shape,
|
|
4788
|
+
w_zero_point_per_channel=op.w_zero_point_per_channel,
|
|
4789
|
+
)
|
|
4242
4790
|
if isinstance(op, ConvTransposeOp):
|
|
4243
4791
|
return ConvTransposeOp(
|
|
4244
4792
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4265,16 +4813,26 @@ class CEmitter:
|
|
|
4265
4813
|
output=temp_map.get(op.output, op.output),
|
|
4266
4814
|
batch=op.batch,
|
|
4267
4815
|
channels=op.channels,
|
|
4816
|
+
spatial_rank=op.spatial_rank,
|
|
4817
|
+
in_d=op.in_d,
|
|
4268
4818
|
in_h=op.in_h,
|
|
4269
4819
|
in_w=op.in_w,
|
|
4820
|
+
out_d=op.out_d,
|
|
4270
4821
|
out_h=op.out_h,
|
|
4271
4822
|
out_w=op.out_w,
|
|
4823
|
+
kernel_d=op.kernel_d,
|
|
4272
4824
|
kernel_h=op.kernel_h,
|
|
4273
4825
|
kernel_w=op.kernel_w,
|
|
4826
|
+
dilation_d=op.dilation_d,
|
|
4827
|
+
dilation_h=op.dilation_h,
|
|
4828
|
+
dilation_w=op.dilation_w,
|
|
4829
|
+
stride_d=op.stride_d,
|
|
4274
4830
|
stride_h=op.stride_h,
|
|
4275
4831
|
stride_w=op.stride_w,
|
|
4832
|
+
pad_front=op.pad_front,
|
|
4276
4833
|
pad_top=op.pad_top,
|
|
4277
4834
|
pad_left=op.pad_left,
|
|
4835
|
+
pad_back=op.pad_back,
|
|
4278
4836
|
pad_bottom=op.pad_bottom,
|
|
4279
4837
|
pad_right=op.pad_right,
|
|
4280
4838
|
count_include_pad=op.count_include_pad,
|
|
@@ -4292,6 +4850,8 @@ class CEmitter:
|
|
|
4292
4850
|
out_w=op.out_w,
|
|
4293
4851
|
kernel_h=op.kernel_h,
|
|
4294
4852
|
kernel_w=op.kernel_w,
|
|
4853
|
+
dilation_h=op.dilation_h,
|
|
4854
|
+
dilation_w=op.dilation_w,
|
|
4295
4855
|
stride_h=op.stride_h,
|
|
4296
4856
|
stride_w=op.stride_w,
|
|
4297
4857
|
pad_top=op.pad_top,
|
|
@@ -4420,34 +4980,19 @@ class CEmitter:
|
|
|
4420
4980
|
return SoftmaxOp(
|
|
4421
4981
|
input0=temp_map.get(op.input0, op.input0),
|
|
4422
4982
|
output=temp_map.get(op.output, op.output),
|
|
4423
|
-
outer=op.outer,
|
|
4424
|
-
axis_size=op.axis_size,
|
|
4425
|
-
inner=op.inner,
|
|
4426
4983
|
axis=op.axis,
|
|
4427
|
-
shape=op.shape,
|
|
4428
|
-
dtype=op.dtype,
|
|
4429
4984
|
)
|
|
4430
4985
|
if isinstance(op, LogSoftmaxOp):
|
|
4431
4986
|
return LogSoftmaxOp(
|
|
4432
4987
|
input0=temp_map.get(op.input0, op.input0),
|
|
4433
4988
|
output=temp_map.get(op.output, op.output),
|
|
4434
|
-
outer=op.outer,
|
|
4435
|
-
axis_size=op.axis_size,
|
|
4436
|
-
inner=op.inner,
|
|
4437
4989
|
axis=op.axis,
|
|
4438
|
-
shape=op.shape,
|
|
4439
|
-
dtype=op.dtype,
|
|
4440
4990
|
)
|
|
4441
4991
|
if isinstance(op, HardmaxOp):
|
|
4442
4992
|
return HardmaxOp(
|
|
4443
4993
|
input0=temp_map.get(op.input0, op.input0),
|
|
4444
4994
|
output=temp_map.get(op.output, op.output),
|
|
4445
|
-
outer=op.outer,
|
|
4446
|
-
axis_size=op.axis_size,
|
|
4447
|
-
inner=op.inner,
|
|
4448
4995
|
axis=op.axis,
|
|
4449
|
-
shape=op.shape,
|
|
4450
|
-
dtype=op.dtype,
|
|
4451
4996
|
)
|
|
4452
4997
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
4453
4998
|
return NegativeLogLikelihoodLossOp(
|
|
@@ -4629,6 +5174,11 @@ class CEmitter:
|
|
|
4629
5174
|
dtype=op.dtype,
|
|
4630
5175
|
input_dtype=op.input_dtype,
|
|
4631
5176
|
)
|
|
5177
|
+
if isinstance(op, OptionalHasElementOp):
|
|
5178
|
+
return OptionalHasElementOp(
|
|
5179
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
5180
|
+
output=temp_map.get(op.output, op.output),
|
|
5181
|
+
)
|
|
4632
5182
|
if isinstance(op, NonZeroOp):
|
|
4633
5183
|
return NonZeroOp(
|
|
4634
5184
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4695,6 +5245,15 @@ class CEmitter:
|
|
|
4695
5245
|
dtype=op.dtype,
|
|
4696
5246
|
input_dtype=op.input_dtype,
|
|
4697
5247
|
)
|
|
5248
|
+
if isinstance(op, HammingWindowOp):
|
|
5249
|
+
return HammingWindowOp(
|
|
5250
|
+
size=temp_map.get(op.size, op.size),
|
|
5251
|
+
output=temp_map.get(op.output, op.output),
|
|
5252
|
+
output_shape=op.output_shape,
|
|
5253
|
+
periodic=op.periodic,
|
|
5254
|
+
dtype=op.dtype,
|
|
5255
|
+
input_dtype=op.input_dtype,
|
|
5256
|
+
)
|
|
4698
5257
|
if isinstance(op, OneHotOp):
|
|
4699
5258
|
return OneHotOp(
|
|
4700
5259
|
indices=temp_map.get(op.indices, op.indices),
|
|
@@ -4710,6 +5269,23 @@ class CEmitter:
|
|
|
4710
5269
|
indices_dtype=op.indices_dtype,
|
|
4711
5270
|
depth_dtype=op.depth_dtype,
|
|
4712
5271
|
)
|
|
5272
|
+
if isinstance(op, TfIdfVectorizerOp):
|
|
5273
|
+
return TfIdfVectorizerOp(
|
|
5274
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
5275
|
+
output=temp_map.get(op.output, op.output),
|
|
5276
|
+
input_shape=op.input_shape,
|
|
5277
|
+
output_shape=op.output_shape,
|
|
5278
|
+
input_dtype=op.input_dtype,
|
|
5279
|
+
output_dtype=op.output_dtype,
|
|
5280
|
+
min_gram_length=op.min_gram_length,
|
|
5281
|
+
max_gram_length=op.max_gram_length,
|
|
5282
|
+
max_skip_count=op.max_skip_count,
|
|
5283
|
+
mode=op.mode,
|
|
5284
|
+
ngram_counts=op.ngram_counts,
|
|
5285
|
+
ngram_indexes=op.ngram_indexes,
|
|
5286
|
+
pool_int64s=op.pool_int64s,
|
|
5287
|
+
weights=op.weights,
|
|
5288
|
+
)
|
|
4713
5289
|
if isinstance(op, SplitOp):
|
|
4714
5290
|
return SplitOp(
|
|
4715
5291
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4746,9 +5322,6 @@ class CEmitter:
|
|
|
4746
5322
|
return IdentityOp(
|
|
4747
5323
|
input0=temp_map.get(op.input0, op.input0),
|
|
4748
5324
|
output=temp_map.get(op.output, op.output),
|
|
4749
|
-
shape=op.shape,
|
|
4750
|
-
dtype=op.dtype,
|
|
4751
|
-
input_dtype=op.input_dtype,
|
|
4752
5325
|
)
|
|
4753
5326
|
if isinstance(op, EyeLikeOp):
|
|
4754
5327
|
return EyeLikeOp(
|
|
@@ -4934,54 +5507,50 @@ class CEmitter:
|
|
|
4934
5507
|
return ReduceOp(
|
|
4935
5508
|
input0=temp_map.get(op.input0, op.input0),
|
|
4936
5509
|
output=temp_map.get(op.output, op.output),
|
|
4937
|
-
input_shape=op.input_shape,
|
|
4938
|
-
output_shape=op.output_shape,
|
|
4939
5510
|
axes=op.axes,
|
|
4940
5511
|
axes_input=temp_map.get(op.axes_input, op.axes_input)
|
|
4941
5512
|
if op.axes_input
|
|
4942
5513
|
else None,
|
|
4943
|
-
axes_input_shape=op.axes_input_shape,
|
|
4944
|
-
axes_input_dtype=op.axes_input_dtype,
|
|
4945
5514
|
keepdims=op.keepdims,
|
|
4946
5515
|
noop_with_empty_axes=op.noop_with_empty_axes,
|
|
4947
5516
|
reduce_kind=op.reduce_kind,
|
|
4948
5517
|
reduce_count=op.reduce_count,
|
|
4949
|
-
dtype=op.dtype,
|
|
4950
5518
|
)
|
|
4951
5519
|
if isinstance(op, ArgReduceOp):
|
|
4952
5520
|
return ArgReduceOp(
|
|
4953
5521
|
input0=temp_map.get(op.input0, op.input0),
|
|
4954
5522
|
output=temp_map.get(op.output, op.output),
|
|
4955
|
-
input_shape=op.input_shape,
|
|
4956
|
-
output_shape=op.output_shape,
|
|
4957
5523
|
axis=op.axis,
|
|
4958
5524
|
keepdims=op.keepdims,
|
|
4959
5525
|
select_last_index=op.select_last_index,
|
|
4960
5526
|
reduce_kind=op.reduce_kind,
|
|
4961
|
-
input_dtype=op.input_dtype,
|
|
4962
|
-
output_dtype=op.output_dtype,
|
|
4963
5527
|
)
|
|
4964
5528
|
if isinstance(op, TopKOp):
|
|
4965
5529
|
return TopKOp(
|
|
4966
5530
|
input0=temp_map.get(op.input0, op.input0),
|
|
5531
|
+
k_input=temp_map.get(op.k_input, op.k_input),
|
|
4967
5532
|
output_values=temp_map.get(op.output_values, op.output_values),
|
|
4968
5533
|
output_indices=temp_map.get(op.output_indices, op.output_indices),
|
|
4969
|
-
input_shape=op.input_shape,
|
|
4970
|
-
output_shape=op.output_shape,
|
|
4971
5534
|
axis=op.axis,
|
|
4972
5535
|
k=op.k,
|
|
4973
5536
|
largest=op.largest,
|
|
4974
5537
|
sorted=op.sorted,
|
|
5538
|
+
)
|
|
5539
|
+
if isinstance(op, BernoulliOp):
|
|
5540
|
+
return BernoulliOp(
|
|
5541
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
5542
|
+
output=temp_map.get(op.output, op.output),
|
|
5543
|
+
input_shape=op.input_shape,
|
|
5544
|
+
output_shape=op.output_shape,
|
|
4975
5545
|
input_dtype=op.input_dtype,
|
|
4976
|
-
|
|
4977
|
-
|
|
5546
|
+
dtype=op.dtype,
|
|
5547
|
+
seed=op.seed,
|
|
4978
5548
|
)
|
|
4979
5549
|
return UnaryOp(
|
|
4980
5550
|
input0=temp_map.get(op.input0, op.input0),
|
|
4981
5551
|
output=temp_map.get(op.output, op.output),
|
|
4982
5552
|
function=op.function,
|
|
4983
|
-
|
|
4984
|
-
dtype=op.dtype,
|
|
5553
|
+
params=op.params,
|
|
4985
5554
|
)
|
|
4986
5555
|
|
|
4987
5556
|
def render_op(self, op: OpBase, ctx: EmitContext) -> str:
|
|
@@ -5007,6 +5576,8 @@ class CEmitter:
|
|
|
5007
5576
|
clip_template=templates["clip"],
|
|
5008
5577
|
cast_template=templates["cast"],
|
|
5009
5578
|
quantize_linear_template=templates["quantize_linear"],
|
|
5579
|
+
dequantize_linear_template=templates["dequantize_linear"],
|
|
5580
|
+
qlinear_mul_template=templates["qlinear_mul"],
|
|
5010
5581
|
qlinear_matmul_template=templates["qlinear_matmul"],
|
|
5011
5582
|
matmul_template=templates["matmul"],
|
|
5012
5583
|
einsum_template=templates["einsum"],
|
|
@@ -5014,6 +5585,7 @@ class CEmitter:
|
|
|
5014
5585
|
attention_template=templates["attention"],
|
|
5015
5586
|
rotary_embedding_template=templates["rotary_embedding"],
|
|
5016
5587
|
conv_template=templates["conv"],
|
|
5588
|
+
conv_integer_template=templates["conv_integer"],
|
|
5017
5589
|
conv_transpose_template=templates["conv_transpose"],
|
|
5018
5590
|
avg_pool_template=templates["avg_pool"],
|
|
5019
5591
|
lp_pool_template=templates["lp_pool"],
|
|
@@ -5025,6 +5597,7 @@ class CEmitter:
|
|
|
5025
5597
|
mean_variance_norm_template=templates["mean_variance_norm"],
|
|
5026
5598
|
rms_norm_template=templates["rms_norm"],
|
|
5027
5599
|
lrn_template=templates["lrn"],
|
|
5600
|
+
gru_template=templates["gru"],
|
|
5028
5601
|
lstm_template=templates["lstm"],
|
|
5029
5602
|
adagrad_template=templates["adagrad"],
|
|
5030
5603
|
softmax_template=templates["softmax"],
|
|
@@ -5043,6 +5616,7 @@ class CEmitter:
|
|
|
5043
5616
|
transpose_template=templates["transpose"],
|
|
5044
5617
|
reshape_template=templates["reshape"],
|
|
5045
5618
|
identity_template=templates["identity"],
|
|
5619
|
+
bernoulli_template=templates["bernoulli"],
|
|
5046
5620
|
eye_like_template=templates["eye_like"],
|
|
5047
5621
|
trilu_template=templates["trilu"],
|
|
5048
5622
|
tile_template=templates["tile"],
|
|
@@ -5060,12 +5634,15 @@ class CEmitter:
|
|
|
5060
5634
|
constant_of_shape_template=templates["constant_of_shape"],
|
|
5061
5635
|
shape_template=templates["shape"],
|
|
5062
5636
|
size_template=templates["size"],
|
|
5637
|
+
optional_has_element_template=templates["optional_has_element"],
|
|
5063
5638
|
nonzero_template=templates["nonzero"],
|
|
5064
5639
|
nonmax_suppression_template=templates["nonmax_suppression"],
|
|
5065
5640
|
expand_template=templates["expand"],
|
|
5066
5641
|
cumsum_template=templates["cumsum"],
|
|
5067
5642
|
range_template=templates["range"],
|
|
5643
|
+
hamming_window_template=templates["hamming_window"],
|
|
5068
5644
|
one_hot_template=templates["one_hot"],
|
|
5645
|
+
tfidf_vectorizer_template=templates["tfidf_vectorizer"],
|
|
5069
5646
|
split_template=templates["split"],
|
|
5070
5647
|
scalar_registry=state.scalar_registry,
|
|
5071
5648
|
dim_args=state.dim_args,
|
|
@@ -5091,6 +5668,8 @@ class CEmitter:
|
|
|
5091
5668
|
clip_template,
|
|
5092
5669
|
cast_template,
|
|
5093
5670
|
quantize_linear_template,
|
|
5671
|
+
dequantize_linear_template,
|
|
5672
|
+
qlinear_mul_template,
|
|
5094
5673
|
qlinear_matmul_template,
|
|
5095
5674
|
matmul_template,
|
|
5096
5675
|
einsum_template,
|
|
@@ -5098,6 +5677,7 @@ class CEmitter:
|
|
|
5098
5677
|
attention_template,
|
|
5099
5678
|
rotary_embedding_template,
|
|
5100
5679
|
conv_template,
|
|
5680
|
+
conv_integer_template,
|
|
5101
5681
|
conv_transpose_template,
|
|
5102
5682
|
avg_pool_template,
|
|
5103
5683
|
lp_pool_template,
|
|
@@ -5109,6 +5689,7 @@ class CEmitter:
|
|
|
5109
5689
|
mean_variance_norm_template,
|
|
5110
5690
|
rms_norm_template,
|
|
5111
5691
|
lrn_template,
|
|
5692
|
+
gru_template,
|
|
5112
5693
|
lstm_template,
|
|
5113
5694
|
adagrad_template,
|
|
5114
5695
|
softmax_template,
|
|
@@ -5125,6 +5706,7 @@ class CEmitter:
|
|
|
5125
5706
|
transpose_template,
|
|
5126
5707
|
reshape_template,
|
|
5127
5708
|
identity_template,
|
|
5709
|
+
bernoulli_template,
|
|
5128
5710
|
eye_like_template,
|
|
5129
5711
|
trilu_template,
|
|
5130
5712
|
tile_template,
|
|
@@ -5142,12 +5724,15 @@ class CEmitter:
|
|
|
5142
5724
|
constant_of_shape_template,
|
|
5143
5725
|
shape_template,
|
|
5144
5726
|
size_template,
|
|
5727
|
+
optional_has_element_template,
|
|
5145
5728
|
nonzero_template,
|
|
5146
5729
|
nonmax_suppression_template,
|
|
5147
5730
|
expand_template,
|
|
5148
5731
|
cumsum_template,
|
|
5149
5732
|
range_template,
|
|
5733
|
+
hamming_window_template,
|
|
5150
5734
|
one_hot_template,
|
|
5735
|
+
tfidf_vectorizer_template,
|
|
5151
5736
|
split_template,
|
|
5152
5737
|
scalar_registry: ScalarFunctionRegistry | None = None,
|
|
5153
5738
|
dim_args: str = "",
|
|
@@ -5169,6 +5754,11 @@ class CEmitter:
|
|
|
5169
5754
|
input1_shape = self._ctx_shape(op.input1)
|
|
5170
5755
|
output_shape = self._ctx_shape(op.output)
|
|
5171
5756
|
input_dtype = self._ctx_dtype(op.input0)
|
|
5757
|
+
input1_dtype = (
|
|
5758
|
+
self._ctx_dtype(op.input1)
|
|
5759
|
+
if isinstance(op, PowOp)
|
|
5760
|
+
else input_dtype
|
|
5761
|
+
)
|
|
5172
5762
|
output_dtype = self._ctx_dtype(op.output)
|
|
5173
5763
|
params = self._shared_param_map(
|
|
5174
5764
|
[
|
|
@@ -5207,11 +5797,12 @@ class CEmitter:
|
|
|
5207
5797
|
input1_shape, _dim_names_for(op.input1)
|
|
5208
5798
|
)
|
|
5209
5799
|
input_c_type = input_dtype.c_type
|
|
5800
|
+
input1_c_type = input1_dtype.c_type
|
|
5210
5801
|
output_c_type = output_dtype.c_type
|
|
5211
5802
|
param_decls = self._build_param_decls(
|
|
5212
5803
|
[
|
|
5213
5804
|
(params["input0"], input_c_type, input0_suffix, True),
|
|
5214
|
-
(params["input1"],
|
|
5805
|
+
(params["input1"], input1_c_type, input1_suffix, True),
|
|
5215
5806
|
(params["output"], output_c_type, output_suffix, False),
|
|
5216
5807
|
]
|
|
5217
5808
|
)
|
|
@@ -5234,12 +5825,20 @@ class CEmitter:
|
|
|
5234
5825
|
output_shape,
|
|
5235
5826
|
loop_vars,
|
|
5236
5827
|
)
|
|
5237
|
-
|
|
5238
|
-
|
|
5239
|
-
|
|
5240
|
-
|
|
5241
|
-
|
|
5242
|
-
|
|
5828
|
+
prelu_axis = None
|
|
5829
|
+
if op.function == ScalarFunction.PRELU:
|
|
5830
|
+
derived_axis = self._maybe_derived(op, "prelu_slope_axis")
|
|
5831
|
+
if isinstance(derived_axis, int):
|
|
5832
|
+
prelu_axis = derived_axis
|
|
5833
|
+
if prelu_axis is None:
|
|
5834
|
+
right_expr = CEmitter._broadcast_index_expr(
|
|
5835
|
+
params["input1"],
|
|
5836
|
+
input1_shape,
|
|
5837
|
+
output_shape,
|
|
5838
|
+
loop_vars,
|
|
5839
|
+
)
|
|
5840
|
+
else:
|
|
5841
|
+
right_expr = f"{params['input1']}[{loop_vars[prelu_axis]}]"
|
|
5243
5842
|
operator_expr = None
|
|
5244
5843
|
operator = op_spec.operator
|
|
5245
5844
|
operator_kind = op.operator_kind
|
|
@@ -5263,7 +5862,7 @@ class CEmitter:
|
|
|
5263
5862
|
).rstrip()
|
|
5264
5863
|
return with_node_comment(rendered)
|
|
5265
5864
|
if isinstance(op, MultiInputBinaryOp):
|
|
5266
|
-
|
|
5865
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
5267
5866
|
input_dtype = self._ctx_dtype(op.inputs[0])
|
|
5268
5867
|
output_dtype = self._ctx_dtype(op.output)
|
|
5269
5868
|
params = self._shared_param_map(
|
|
@@ -5292,27 +5891,47 @@ class CEmitter:
|
|
|
5292
5891
|
f"{op.function.value}"
|
|
5293
5892
|
)
|
|
5294
5893
|
output_dim_names = _dim_names_for(op.output)
|
|
5295
|
-
shape = CEmitter._shape_dim_exprs(
|
|
5296
|
-
|
|
5297
|
-
|
|
5298
|
-
|
|
5894
|
+
shape = CEmitter._shape_dim_exprs(
|
|
5895
|
+
output_shape_raw, output_dim_names
|
|
5896
|
+
)
|
|
5897
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
5898
|
+
output_array_suffix = self._param_array_suffix(
|
|
5899
|
+
output_shape_raw, output_dim_names
|
|
5299
5900
|
)
|
|
5300
5901
|
input_c_type = input_dtype.c_type
|
|
5301
5902
|
output_c_type = output_dtype.c_type
|
|
5302
5903
|
input_names = [
|
|
5303
5904
|
params[f"input{idx}"] for idx in range(len(op.inputs))
|
|
5304
5905
|
]
|
|
5906
|
+
input_shapes = [self._ctx_shape(name) for name in op.inputs]
|
|
5907
|
+
input_dim_names = [
|
|
5908
|
+
_dim_names_for(name) for name in op.inputs
|
|
5909
|
+
]
|
|
5910
|
+
input_array_suffixes = [
|
|
5911
|
+
self._param_array_suffix(shape, dim_names)
|
|
5912
|
+
for shape, dim_names in zip(input_shapes, input_dim_names)
|
|
5913
|
+
]
|
|
5305
5914
|
param_decls = self._build_param_decls(
|
|
5306
5915
|
[
|
|
5307
|
-
*(
|
|
5308
|
-
|
|
5916
|
+
*(
|
|
5917
|
+
(name, input_c_type, array_suffix, True)
|
|
5918
|
+
for name, array_suffix in zip(
|
|
5919
|
+
input_names, input_array_suffixes
|
|
5920
|
+
)
|
|
5921
|
+
),
|
|
5922
|
+
(
|
|
5923
|
+
params["output"],
|
|
5924
|
+
output_c_type,
|
|
5925
|
+
output_array_suffix,
|
|
5926
|
+
False,
|
|
5927
|
+
),
|
|
5309
5928
|
]
|
|
5310
5929
|
)
|
|
5311
5930
|
common = {
|
|
5312
5931
|
"model_name": model.name,
|
|
5313
5932
|
"op_name": op_name,
|
|
5314
5933
|
"element_count": CEmitter._element_count_expr(shape),
|
|
5315
|
-
"array_suffix":
|
|
5934
|
+
"array_suffix": output_array_suffix,
|
|
5316
5935
|
"shape": shape,
|
|
5317
5936
|
"loop_vars": loop_vars,
|
|
5318
5937
|
"input_c_type": input_c_type,
|
|
@@ -5322,8 +5941,10 @@ class CEmitter:
|
|
|
5322
5941
|
"params": param_decls,
|
|
5323
5942
|
}
|
|
5324
5943
|
input_exprs = [
|
|
5325
|
-
|
|
5326
|
-
|
|
5944
|
+
CEmitter._broadcast_index_expr(
|
|
5945
|
+
name, shape, output_shape_raw, loop_vars
|
|
5946
|
+
)
|
|
5947
|
+
for name, shape in zip(input_names, input_shapes)
|
|
5327
5948
|
]
|
|
5328
5949
|
output_expr = f"{params['output']}" + "".join(
|
|
5329
5950
|
f"[{var}]" for var in loop_vars
|
|
@@ -5452,37 +6073,51 @@ class CEmitter:
|
|
|
5452
6073
|
("output", op.output),
|
|
5453
6074
|
]
|
|
5454
6075
|
)
|
|
5455
|
-
output_shape = CEmitter._codegen_shape(op.
|
|
6076
|
+
output_shape = CEmitter._codegen_shape(self._ctx_shape(op.output))
|
|
5456
6077
|
output_loop_vars = CEmitter._loop_vars(output_shape)
|
|
5457
6078
|
output_index_expr = f"{params['output']}" + "".join(
|
|
5458
6079
|
f"[{var}]" for var in output_loop_vars
|
|
5459
6080
|
)
|
|
5460
|
-
|
|
6081
|
+
batch_shape = self._derived(op, "batch_shape")
|
|
6082
|
+
batch_rank = len(batch_shape)
|
|
5461
6083
|
batch_vars = output_loop_vars[:batch_rank]
|
|
5462
|
-
|
|
6084
|
+
left_vector = bool(self._derived(op, "left_vector"))
|
|
6085
|
+
right_vector = bool(self._derived(op, "right_vector"))
|
|
6086
|
+
if left_vector and right_vector:
|
|
5463
6087
|
row_var = None
|
|
5464
6088
|
col_var = None
|
|
5465
|
-
elif
|
|
6089
|
+
elif left_vector:
|
|
5466
6090
|
row_var = None
|
|
5467
6091
|
col_var = output_loop_vars[-1]
|
|
5468
|
-
elif
|
|
6092
|
+
elif right_vector:
|
|
5469
6093
|
row_var = output_loop_vars[-1]
|
|
5470
6094
|
col_var = None
|
|
5471
6095
|
else:
|
|
5472
6096
|
row_var = output_loop_vars[-2]
|
|
5473
6097
|
col_var = output_loop_vars[-1]
|
|
6098
|
+
input0_shape = self._ctx_shape(op.input0)
|
|
6099
|
+
input1_shape = self._ctx_shape(op.input1)
|
|
6100
|
+
input0_batch_shape = self._derived(op, "input0_batch_shape")
|
|
6101
|
+
input1_batch_shape = self._derived(op, "input1_batch_shape")
|
|
5474
6102
|
input0_index_expr, input1_index_expr = CEmitter._matmul_index_exprs(
|
|
5475
|
-
op,
|
|
5476
6103
|
batch_vars,
|
|
5477
6104
|
row_var,
|
|
5478
6105
|
col_var,
|
|
5479
6106
|
batch_rank,
|
|
5480
6107
|
input0=params["input0"],
|
|
5481
6108
|
input1=params["input1"],
|
|
5482
|
-
|
|
5483
|
-
|
|
5484
|
-
|
|
5485
|
-
|
|
6109
|
+
left_vector=left_vector,
|
|
6110
|
+
right_vector=right_vector,
|
|
6111
|
+
input0_shape=input0_shape,
|
|
6112
|
+
input1_shape=input1_shape,
|
|
6113
|
+
input0_batch_shape=input0_batch_shape,
|
|
6114
|
+
input1_batch_shape=input1_batch_shape,
|
|
6115
|
+
)
|
|
6116
|
+
input0_suffix = self._param_array_suffix(input0_shape)
|
|
6117
|
+
input1_suffix = self._param_array_suffix(input1_shape)
|
|
6118
|
+
output_suffix = self._param_array_suffix(self._ctx_shape(op.output))
|
|
6119
|
+
acc_dtype = self._accumulation_dtype(self._ctx_dtype(op.output))
|
|
6120
|
+
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
5486
6121
|
param_decls = self._build_param_decls(
|
|
5487
6122
|
[
|
|
5488
6123
|
(params["input0"], c_type, input0_suffix, True),
|
|
@@ -5490,6 +6125,9 @@ class CEmitter:
|
|
|
5490
6125
|
(params["output"], c_type, output_suffix, False),
|
|
5491
6126
|
]
|
|
5492
6127
|
)
|
|
6128
|
+
m = int(self._derived(op, "m"))
|
|
6129
|
+
n = int(self._derived(op, "n"))
|
|
6130
|
+
k = int(self._derived(op, "k"))
|
|
5493
6131
|
rendered = matmul_template.render(
|
|
5494
6132
|
model_name=model.name,
|
|
5495
6133
|
op_name=op_name,
|
|
@@ -5498,8 +6136,8 @@ class CEmitter:
|
|
|
5498
6136
|
output=params["output"],
|
|
5499
6137
|
params=param_decls,
|
|
5500
6138
|
c_type=c_type,
|
|
5501
|
-
acc_type=c_type,
|
|
5502
|
-
zero_literal=
|
|
6139
|
+
acc_type=acc_dtype.c_type,
|
|
6140
|
+
zero_literal=acc_zero_literal,
|
|
5503
6141
|
input0_suffix=input0_suffix,
|
|
5504
6142
|
input1_suffix=input1_suffix,
|
|
5505
6143
|
output_suffix=output_suffix,
|
|
@@ -5508,9 +6146,9 @@ class CEmitter:
|
|
|
5508
6146
|
output_index_expr=output_index_expr,
|
|
5509
6147
|
input0_index_expr=input0_index_expr,
|
|
5510
6148
|
input1_index_expr=input1_index_expr,
|
|
5511
|
-
m=
|
|
5512
|
-
n=
|
|
5513
|
-
k=
|
|
6149
|
+
m=m,
|
|
6150
|
+
n=n,
|
|
6151
|
+
k=k,
|
|
5514
6152
|
).rstrip()
|
|
5515
6153
|
return with_node_comment(rendered)
|
|
5516
6154
|
if isinstance(op, EinsumOp):
|
|
@@ -5561,6 +6199,8 @@ class CEmitter:
|
|
|
5561
6199
|
),
|
|
5562
6200
|
]
|
|
5563
6201
|
)
|
|
6202
|
+
acc_dtype = self._accumulation_dtype(self._ctx_dtype(op.output))
|
|
6203
|
+
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
5564
6204
|
input_loop_vars: tuple[str, ...] = ()
|
|
5565
6205
|
input_loop_bounds: tuple[str | int, ...] = ()
|
|
5566
6206
|
reduce_loop_var = "k"
|
|
@@ -5633,8 +6273,8 @@ class CEmitter:
|
|
|
5633
6273
|
output_loop_vars=output_loop_vars,
|
|
5634
6274
|
output_loop_bounds=output_shape,
|
|
5635
6275
|
output_expr=output_expr,
|
|
5636
|
-
acc_type=
|
|
5637
|
-
zero_literal=
|
|
6276
|
+
acc_type=acc_dtype.c_type,
|
|
6277
|
+
zero_literal=acc_zero_literal,
|
|
5638
6278
|
input_loop_vars=input_loop_vars,
|
|
5639
6279
|
input_loop_bounds=input_loop_bounds,
|
|
5640
6280
|
reduce_loop_var=reduce_loop_var,
|
|
@@ -5653,14 +6293,20 @@ class CEmitter:
|
|
|
5653
6293
|
("output", op.output),
|
|
5654
6294
|
]
|
|
5655
6295
|
)
|
|
5656
|
-
|
|
5657
|
-
|
|
6296
|
+
m = int(self._derived(op, "m"))
|
|
6297
|
+
n = int(self._derived(op, "n"))
|
|
6298
|
+
k = int(self._derived(op, "k"))
|
|
6299
|
+
trans_a = bool(self._derived(op, "trans_a"))
|
|
6300
|
+
trans_b = bool(self._derived(op, "trans_b"))
|
|
6301
|
+
c_shape = self._derived(op, "c_shape")
|
|
6302
|
+
input_a_shape = (k, m) if trans_a else (m, k)
|
|
6303
|
+
input_b_shape = (n, k) if trans_b else (k, n)
|
|
5658
6304
|
input_a_suffix = self._param_array_suffix(input_a_shape)
|
|
5659
6305
|
input_b_suffix = self._param_array_suffix(input_b_shape)
|
|
5660
|
-
output_suffix = self._param_array_suffix((
|
|
6306
|
+
output_suffix = self._param_array_suffix((m, n))
|
|
5661
6307
|
c_suffix = (
|
|
5662
|
-
self._param_array_suffix(
|
|
5663
|
-
if
|
|
6308
|
+
self._param_array_suffix(c_shape)
|
|
6309
|
+
if c_shape is not None
|
|
5664
6310
|
else ""
|
|
5665
6311
|
)
|
|
5666
6312
|
param_decls = self._build_param_decls(
|
|
@@ -5678,24 +6324,31 @@ class CEmitter:
|
|
|
5678
6324
|
(params["output"], c_type, output_suffix, False),
|
|
5679
6325
|
]
|
|
5680
6326
|
)
|
|
5681
|
-
|
|
5682
|
-
|
|
5683
|
-
|
|
6327
|
+
dtype = self._ctx_dtype(op.output)
|
|
6328
|
+
alpha_literal = CEmitter._format_literal(
|
|
6329
|
+
dtype, self._derived(op, "alpha")
|
|
6330
|
+
)
|
|
6331
|
+
beta_literal = CEmitter._format_literal(
|
|
6332
|
+
dtype, self._derived(op, "beta")
|
|
6333
|
+
)
|
|
6334
|
+
acc_dtype = self._accumulation_dtype(dtype)
|
|
6335
|
+
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
6336
|
+
if c_shape is None:
|
|
5684
6337
|
c_rank = 0
|
|
5685
6338
|
c_dim0 = 0
|
|
5686
6339
|
c_dim1 = 0
|
|
5687
|
-
elif len(
|
|
6340
|
+
elif len(c_shape) == 0:
|
|
5688
6341
|
c_rank = 0
|
|
5689
6342
|
c_dim0 = 0
|
|
5690
6343
|
c_dim1 = 0
|
|
5691
|
-
elif len(
|
|
6344
|
+
elif len(c_shape) == 1:
|
|
5692
6345
|
c_rank = 1
|
|
5693
6346
|
c_dim0 = 1
|
|
5694
|
-
c_dim1 =
|
|
6347
|
+
c_dim1 = c_shape[0]
|
|
5695
6348
|
else:
|
|
5696
6349
|
c_rank = 2
|
|
5697
|
-
c_dim0 =
|
|
5698
|
-
c_dim1 =
|
|
6350
|
+
c_dim0 = c_shape[0]
|
|
6351
|
+
c_dim1 = c_shape[1]
|
|
5699
6352
|
rendered = gemm_template.render(
|
|
5700
6353
|
model_name=model.name,
|
|
5701
6354
|
op_name=op_name,
|
|
@@ -5704,21 +6357,21 @@ class CEmitter:
|
|
|
5704
6357
|
input_c=params["input_c"],
|
|
5705
6358
|
output=params["output"],
|
|
5706
6359
|
params=param_decls,
|
|
5707
|
-
c_type=c_type,
|
|
5708
|
-
acc_type=c_type,
|
|
5709
|
-
zero_literal=
|
|
6360
|
+
c_type=dtype.c_type,
|
|
6361
|
+
acc_type=acc_dtype.c_type,
|
|
6362
|
+
zero_literal=acc_zero_literal,
|
|
5710
6363
|
alpha_literal=alpha_literal,
|
|
5711
6364
|
beta_literal=beta_literal,
|
|
5712
|
-
trans_a=int(
|
|
5713
|
-
trans_b=int(
|
|
5714
|
-
m=
|
|
5715
|
-
n=
|
|
5716
|
-
k=
|
|
6365
|
+
trans_a=int(trans_a),
|
|
6366
|
+
trans_b=int(trans_b),
|
|
6367
|
+
m=m,
|
|
6368
|
+
n=n,
|
|
6369
|
+
k=k,
|
|
5717
6370
|
input_a_suffix=input_a_suffix,
|
|
5718
6371
|
input_b_suffix=input_b_suffix,
|
|
5719
6372
|
output_suffix=output_suffix,
|
|
5720
6373
|
c_suffix=(
|
|
5721
|
-
c_suffix if
|
|
6374
|
+
c_suffix if c_shape is not None else None
|
|
5722
6375
|
),
|
|
5723
6376
|
c_rank=c_rank,
|
|
5724
6377
|
c_dim0=c_dim0,
|
|
@@ -6034,6 +6687,9 @@ class CEmitter:
|
|
|
6034
6687
|
("output", op.output),
|
|
6035
6688
|
]
|
|
6036
6689
|
)
|
|
6690
|
+
acc_dtype = self._accumulation_dtype(op.dtype)
|
|
6691
|
+
acc_type = acc_dtype.c_type
|
|
6692
|
+
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
6037
6693
|
input_shape = (op.batch, op.in_channels, *op.in_spatial)
|
|
6038
6694
|
weight_shape = (
|
|
6039
6695
|
op.out_channels,
|
|
@@ -6077,6 +6733,8 @@ class CEmitter:
|
|
|
6077
6733
|
output=params["output"],
|
|
6078
6734
|
params=param_decls,
|
|
6079
6735
|
c_type=c_type,
|
|
6736
|
+
acc_type=acc_type,
|
|
6737
|
+
acc_zero_literal=acc_zero_literal,
|
|
6080
6738
|
zero_literal=zero_literal,
|
|
6081
6739
|
input_suffix=input_suffix,
|
|
6082
6740
|
weight_suffix=weight_suffix,
|
|
@@ -6100,6 +6758,129 @@ class CEmitter:
|
|
|
6100
6758
|
in_indices=in_indices,
|
|
6101
6759
|
).rstrip()
|
|
6102
6760
|
return with_node_comment(rendered)
|
|
6761
|
+
if isinstance(op, ConvIntegerOp):
|
|
6762
|
+
params = self._shared_param_map(
|
|
6763
|
+
[
|
|
6764
|
+
("input0", op.input0),
|
|
6765
|
+
("weights", op.weights),
|
|
6766
|
+
("x_zero_point", op.x_zero_point),
|
|
6767
|
+
("w_zero_point", op.w_zero_point),
|
|
6768
|
+
("output", op.output),
|
|
6769
|
+
]
|
|
6770
|
+
)
|
|
6771
|
+
acc_dtype = op.dtype
|
|
6772
|
+
acc_type = acc_dtype.c_type
|
|
6773
|
+
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
6774
|
+
input_shape = (op.batch, op.in_channels, *op.in_spatial)
|
|
6775
|
+
weight_shape = (
|
|
6776
|
+
op.out_channels,
|
|
6777
|
+
op.in_channels // op.group,
|
|
6778
|
+
*op.kernel_shape,
|
|
6779
|
+
)
|
|
6780
|
+
output_shape = (op.batch, op.out_channels, *op.out_spatial)
|
|
6781
|
+
out_indices = tuple(f"od{dim}" for dim in range(op.spatial_rank))
|
|
6782
|
+
kernel_indices = tuple(
|
|
6783
|
+
f"kd{dim}" for dim in range(op.spatial_rank)
|
|
6784
|
+
)
|
|
6785
|
+
in_indices = tuple(f"id{dim}" for dim in range(op.spatial_rank))
|
|
6786
|
+
pad_begin = op.pads[: op.spatial_rank]
|
|
6787
|
+
group_in_channels = op.in_channels // op.group
|
|
6788
|
+
group_out_channels = op.out_channels // op.group
|
|
6789
|
+
input_suffix = self._param_array_suffix(input_shape)
|
|
6790
|
+
weight_suffix = self._param_array_suffix(weight_shape)
|
|
6791
|
+
output_suffix = self._param_array_suffix(output_shape)
|
|
6792
|
+
x_zero_suffix = (
|
|
6793
|
+
self._param_array_suffix(op.x_zero_point_shape)
|
|
6794
|
+
if op.x_zero_point_shape is not None
|
|
6795
|
+
else ""
|
|
6796
|
+
)
|
|
6797
|
+
w_zero_suffix = (
|
|
6798
|
+
self._param_array_suffix(op.w_zero_point_shape)
|
|
6799
|
+
if op.w_zero_point_shape is not None
|
|
6800
|
+
else ""
|
|
6801
|
+
)
|
|
6802
|
+
param_decls = self._build_param_decls(
|
|
6803
|
+
[
|
|
6804
|
+
(
|
|
6805
|
+
params["input0"],
|
|
6806
|
+
op.input_dtype.c_type,
|
|
6807
|
+
input_suffix,
|
|
6808
|
+
True,
|
|
6809
|
+
),
|
|
6810
|
+
(
|
|
6811
|
+
params["weights"],
|
|
6812
|
+
op.weight_dtype.c_type,
|
|
6813
|
+
weight_suffix,
|
|
6814
|
+
True,
|
|
6815
|
+
),
|
|
6816
|
+
(
|
|
6817
|
+
params["x_zero_point"],
|
|
6818
|
+
op.input_dtype.c_type,
|
|
6819
|
+
x_zero_suffix,
|
|
6820
|
+
True,
|
|
6821
|
+
)
|
|
6822
|
+
if params["x_zero_point"]
|
|
6823
|
+
else (None, "", "", True),
|
|
6824
|
+
(
|
|
6825
|
+
params["w_zero_point"],
|
|
6826
|
+
op.weight_dtype.c_type,
|
|
6827
|
+
w_zero_suffix,
|
|
6828
|
+
True,
|
|
6829
|
+
)
|
|
6830
|
+
if params["w_zero_point"]
|
|
6831
|
+
else (None, "", "", True),
|
|
6832
|
+
(params["output"], c_type, output_suffix, False),
|
|
6833
|
+
]
|
|
6834
|
+
)
|
|
6835
|
+
x_zero_expr = (
|
|
6836
|
+
f"{params['x_zero_point']}[0]"
|
|
6837
|
+
if params["x_zero_point"]
|
|
6838
|
+
else "0"
|
|
6839
|
+
)
|
|
6840
|
+
if params["w_zero_point"]:
|
|
6841
|
+
if op.w_zero_point_per_channel:
|
|
6842
|
+
w_zero_expr = f"{params['w_zero_point']}[oc_global]"
|
|
6843
|
+
else:
|
|
6844
|
+
w_zero_expr = f"{params['w_zero_point']}[0]"
|
|
6845
|
+
else:
|
|
6846
|
+
w_zero_expr = "0"
|
|
6847
|
+
rendered = conv_integer_template.render(
|
|
6848
|
+
model_name=model.name,
|
|
6849
|
+
op_name=op_name,
|
|
6850
|
+
input0=params["input0"],
|
|
6851
|
+
weights=params["weights"],
|
|
6852
|
+
x_zero_point=params["x_zero_point"],
|
|
6853
|
+
w_zero_point=params["w_zero_point"],
|
|
6854
|
+
output=params["output"],
|
|
6855
|
+
params=param_decls,
|
|
6856
|
+
c_type=c_type,
|
|
6857
|
+
acc_type=acc_type,
|
|
6858
|
+
acc_zero_literal=acc_zero_literal,
|
|
6859
|
+
input_suffix=input_suffix,
|
|
6860
|
+
weight_suffix=weight_suffix,
|
|
6861
|
+
x_zero_suffix=x_zero_suffix,
|
|
6862
|
+
w_zero_suffix=w_zero_suffix,
|
|
6863
|
+
output_suffix=output_suffix,
|
|
6864
|
+
batch=op.batch,
|
|
6865
|
+
in_channels=op.in_channels,
|
|
6866
|
+
out_channels=op.out_channels,
|
|
6867
|
+
spatial_rank=op.spatial_rank,
|
|
6868
|
+
in_spatial=op.in_spatial,
|
|
6869
|
+
out_spatial=op.out_spatial,
|
|
6870
|
+
kernel_shape=op.kernel_shape,
|
|
6871
|
+
strides=op.strides,
|
|
6872
|
+
pads_begin=pad_begin,
|
|
6873
|
+
dilations=op.dilations,
|
|
6874
|
+
group=op.group,
|
|
6875
|
+
group_in_channels=group_in_channels,
|
|
6876
|
+
group_out_channels=group_out_channels,
|
|
6877
|
+
out_indices=out_indices,
|
|
6878
|
+
kernel_indices=kernel_indices,
|
|
6879
|
+
in_indices=in_indices,
|
|
6880
|
+
x_zero_expr=x_zero_expr,
|
|
6881
|
+
w_zero_expr=w_zero_expr,
|
|
6882
|
+
).rstrip()
|
|
6883
|
+
return with_node_comment(rendered)
|
|
6103
6884
|
if isinstance(op, ConvTransposeOp):
|
|
6104
6885
|
params = self._shared_param_map(
|
|
6105
6886
|
[
|
|
@@ -6179,8 +6960,27 @@ class CEmitter:
|
|
|
6179
6960
|
params = self._shared_param_map(
|
|
6180
6961
|
[("input0", op.input0), ("output", op.output)]
|
|
6181
6962
|
)
|
|
6182
|
-
|
|
6183
|
-
|
|
6963
|
+
if op.spatial_rank == 3:
|
|
6964
|
+
input_shape = (
|
|
6965
|
+
op.batch,
|
|
6966
|
+
op.channels,
|
|
6967
|
+
op.in_d,
|
|
6968
|
+
op.in_h,
|
|
6969
|
+
op.in_w,
|
|
6970
|
+
)
|
|
6971
|
+
output_shape = (
|
|
6972
|
+
op.batch,
|
|
6973
|
+
op.channels,
|
|
6974
|
+
op.out_d,
|
|
6975
|
+
op.out_h,
|
|
6976
|
+
op.out_w,
|
|
6977
|
+
)
|
|
6978
|
+
elif op.spatial_rank == 1:
|
|
6979
|
+
input_shape = (op.batch, op.channels, op.in_w)
|
|
6980
|
+
output_shape = (op.batch, op.channels, op.out_w)
|
|
6981
|
+
else:
|
|
6982
|
+
input_shape = (op.batch, op.channels, op.in_h, op.in_w)
|
|
6983
|
+
output_shape = (op.batch, op.channels, op.out_h, op.out_w)
|
|
6184
6984
|
input_suffix = self._param_array_suffix(input_shape)
|
|
6185
6985
|
output_suffix = self._param_array_suffix(output_shape)
|
|
6186
6986
|
param_decls = self._build_param_decls(
|
|
@@ -6201,16 +7001,26 @@ class CEmitter:
|
|
|
6201
7001
|
output_suffix=output_suffix,
|
|
6202
7002
|
batch=op.batch,
|
|
6203
7003
|
channels=op.channels,
|
|
7004
|
+
spatial_rank=op.spatial_rank,
|
|
7005
|
+
in_d=op.in_d,
|
|
6204
7006
|
in_h=op.in_h,
|
|
6205
7007
|
in_w=op.in_w,
|
|
7008
|
+
out_d=op.out_d,
|
|
6206
7009
|
out_h=op.out_h,
|
|
6207
7010
|
out_w=op.out_w,
|
|
7011
|
+
kernel_d=op.kernel_d,
|
|
6208
7012
|
kernel_h=op.kernel_h,
|
|
6209
7013
|
kernel_w=op.kernel_w,
|
|
7014
|
+
dilation_d=op.dilation_d,
|
|
7015
|
+
dilation_h=op.dilation_h,
|
|
7016
|
+
dilation_w=op.dilation_w,
|
|
7017
|
+
stride_d=op.stride_d,
|
|
6210
7018
|
stride_h=op.stride_h,
|
|
6211
7019
|
stride_w=op.stride_w,
|
|
7020
|
+
pad_front=op.pad_front,
|
|
6212
7021
|
pad_top=op.pad_top,
|
|
6213
7022
|
pad_left=op.pad_left,
|
|
7023
|
+
pad_back=op.pad_back,
|
|
6214
7024
|
pad_bottom=op.pad_bottom,
|
|
6215
7025
|
pad_right=op.pad_right,
|
|
6216
7026
|
count_include_pad=int(op.count_include_pad),
|
|
@@ -6247,6 +7057,8 @@ class CEmitter:
|
|
|
6247
7057
|
out_w=op.out_w,
|
|
6248
7058
|
kernel_h=op.kernel_h,
|
|
6249
7059
|
kernel_w=op.kernel_w,
|
|
7060
|
+
dilation_h=op.dilation_h,
|
|
7061
|
+
dilation_w=op.dilation_w,
|
|
6250
7062
|
stride_h=op.stride_h,
|
|
6251
7063
|
stride_w=op.stride_w,
|
|
6252
7064
|
pad_top=op.pad_top,
|
|
@@ -6431,11 +7243,7 @@ class CEmitter:
|
|
|
6431
7243
|
).rstrip()
|
|
6432
7244
|
return with_node_comment(rendered)
|
|
6433
7245
|
if isinstance(op, LayerNormalizationOp):
|
|
6434
|
-
acc_dtype = (
|
|
6435
|
-
ScalarType.F32
|
|
6436
|
-
if op.dtype in {ScalarType.F16, ScalarType.F32}
|
|
6437
|
-
else op.dtype
|
|
6438
|
-
)
|
|
7246
|
+
acc_dtype = self._accumulation_dtype(op.dtype)
|
|
6439
7247
|
acc_type = acc_dtype.c_type
|
|
6440
7248
|
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
6441
7249
|
acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
|
|
@@ -6443,7 +7251,7 @@ class CEmitter:
|
|
|
6443
7251
|
op.epsilon, acc_dtype
|
|
6444
7252
|
)
|
|
6445
7253
|
acc_sqrt_fn = CEmitter._math_fn(acc_dtype, "sqrtf", "sqrt")
|
|
6446
|
-
use_kahan =
|
|
7254
|
+
use_kahan = False
|
|
6447
7255
|
params = self._shared_param_map(
|
|
6448
7256
|
[
|
|
6449
7257
|
("input0", op.input0),
|
|
@@ -6678,7 +7486,7 @@ class CEmitter:
|
|
|
6678
7486
|
pow_fn=CEmitter._math_fn(op.dtype, "powf", "pow"),
|
|
6679
7487
|
).rstrip()
|
|
6680
7488
|
return with_node_comment(rendered)
|
|
6681
|
-
if isinstance(op,
|
|
7489
|
+
if isinstance(op, GruOp):
|
|
6682
7490
|
params = self._shared_param_map(
|
|
6683
7491
|
[
|
|
6684
7492
|
("input_x", op.input_x),
|
|
@@ -6687,11 +7495,8 @@ class CEmitter:
|
|
|
6687
7495
|
("input_b", op.input_b),
|
|
6688
7496
|
("input_sequence_lens", op.input_sequence_lens),
|
|
6689
7497
|
("input_initial_h", op.input_initial_h),
|
|
6690
|
-
("input_initial_c", op.input_initial_c),
|
|
6691
|
-
("input_p", op.input_p),
|
|
6692
7498
|
("output_y", op.output_y),
|
|
6693
7499
|
("output_y_h", op.output_y_h),
|
|
6694
|
-
("output_y_c", op.output_y_c),
|
|
6695
7500
|
]
|
|
6696
7501
|
)
|
|
6697
7502
|
input_x_shape = (
|
|
@@ -6699,14 +7504,16 @@ class CEmitter:
|
|
|
6699
7504
|
if op.layout == 0
|
|
6700
7505
|
else (op.batch_size, op.seq_length, op.input_size)
|
|
6701
7506
|
)
|
|
6702
|
-
w_shape = (op.num_directions,
|
|
6703
|
-
r_shape = (op.num_directions,
|
|
7507
|
+
w_shape = (op.num_directions, 3 * op.hidden_size, op.input_size)
|
|
7508
|
+
r_shape = (op.num_directions, 3 * op.hidden_size, op.hidden_size)
|
|
6704
7509
|
b_shape = (
|
|
6705
|
-
(op.num_directions,
|
|
7510
|
+
(op.num_directions, 6 * op.hidden_size)
|
|
6706
7511
|
if op.input_b is not None
|
|
6707
7512
|
else None
|
|
6708
7513
|
)
|
|
6709
|
-
seq_shape = (
|
|
7514
|
+
seq_shape = (
|
|
7515
|
+
(op.batch_size,) if op.input_sequence_lens is not None else None
|
|
7516
|
+
)
|
|
6710
7517
|
state_shape = (
|
|
6711
7518
|
(op.num_directions, op.batch_size, op.hidden_size)
|
|
6712
7519
|
if op.layout == 0
|
|
@@ -6717,16 +7524,6 @@ class CEmitter:
|
|
|
6717
7524
|
if op.input_initial_h is not None or op.output_y_h is not None
|
|
6718
7525
|
else None
|
|
6719
7526
|
)
|
|
6720
|
-
c_shape = (
|
|
6721
|
-
state_shape
|
|
6722
|
-
if op.input_initial_c is not None or op.output_y_c is not None
|
|
6723
|
-
else None
|
|
6724
|
-
)
|
|
6725
|
-
p_shape = (
|
|
6726
|
-
(op.num_directions, 3 * op.hidden_size)
|
|
6727
|
-
if op.input_p is not None
|
|
6728
|
-
else None
|
|
6729
|
-
)
|
|
6730
7527
|
y_shape = (
|
|
6731
7528
|
(op.seq_length, op.num_directions, op.batch_size, op.hidden_size)
|
|
6732
7529
|
if op.layout == 0
|
|
@@ -6776,22 +7573,6 @@ class CEmitter:
|
|
|
6776
7573
|
)
|
|
6777
7574
|
if params["input_initial_h"]
|
|
6778
7575
|
else (None, "", "", True),
|
|
6779
|
-
(
|
|
6780
|
-
params["input_initial_c"],
|
|
6781
|
-
c_type,
|
|
6782
|
-
self._param_array_suffix(c_shape),
|
|
6783
|
-
True,
|
|
6784
|
-
)
|
|
6785
|
-
if params["input_initial_c"]
|
|
6786
|
-
else (None, "", "", True),
|
|
6787
|
-
(
|
|
6788
|
-
params["input_p"],
|
|
6789
|
-
c_type,
|
|
6790
|
-
self._param_array_suffix(p_shape),
|
|
6791
|
-
True,
|
|
6792
|
-
)
|
|
6793
|
-
if params["input_p"]
|
|
6794
|
-
else (None, "", "", True),
|
|
6795
7576
|
(
|
|
6796
7577
|
params["output_y"],
|
|
6797
7578
|
c_type,
|
|
@@ -6808,22 +7589,14 @@ class CEmitter:
|
|
|
6808
7589
|
)
|
|
6809
7590
|
if params["output_y_h"]
|
|
6810
7591
|
else (None, "", "", False),
|
|
6811
|
-
(
|
|
6812
|
-
params["output_y_c"],
|
|
6813
|
-
c_type,
|
|
6814
|
-
self._param_array_suffix(c_shape),
|
|
6815
|
-
False,
|
|
6816
|
-
)
|
|
6817
|
-
if params["output_y_c"]
|
|
6818
|
-
else (None, "", "", False),
|
|
6819
7592
|
]
|
|
6820
7593
|
)
|
|
6821
7594
|
if scalar_registry is None:
|
|
6822
7595
|
raise CodegenError(
|
|
6823
|
-
"Scalar function registry is required for
|
|
7596
|
+
"Scalar function registry is required for GRU codegen."
|
|
6824
7597
|
)
|
|
6825
7598
|
activation_functions = tuple(
|
|
6826
|
-
self.
|
|
7599
|
+
self._rnn_activation_function_name(
|
|
6827
7600
|
kind,
|
|
6828
7601
|
alpha,
|
|
6829
7602
|
beta,
|
|
@@ -6836,7 +7609,7 @@ class CEmitter:
|
|
|
6836
7609
|
op.activation_betas,
|
|
6837
7610
|
)
|
|
6838
7611
|
)
|
|
6839
|
-
rendered =
|
|
7612
|
+
rendered = gru_template.render(
|
|
6840
7613
|
model_name=model.name,
|
|
6841
7614
|
op_name=op_name,
|
|
6842
7615
|
input_x=params["input_x"],
|
|
@@ -6845,11 +7618,8 @@ class CEmitter:
|
|
|
6845
7618
|
input_b=params["input_b"],
|
|
6846
7619
|
input_sequence_lens=params["input_sequence_lens"],
|
|
6847
7620
|
input_initial_h=params["input_initial_h"],
|
|
6848
|
-
input_initial_c=params["input_initial_c"],
|
|
6849
|
-
input_p=params["input_p"],
|
|
6850
7621
|
output_y=params["output_y"],
|
|
6851
7622
|
output_y_h=params["output_y_h"],
|
|
6852
|
-
output_y_c=params["output_y_c"],
|
|
6853
7623
|
params=param_decls,
|
|
6854
7624
|
c_type=c_type,
|
|
6855
7625
|
seq_c_type=(op.sequence_lens_dtype or ScalarType.I64).c_type,
|
|
@@ -6868,38 +7638,232 @@ class CEmitter:
|
|
|
6868
7638
|
num_directions=op.num_directions,
|
|
6869
7639
|
layout=op.layout,
|
|
6870
7640
|
direction=op.direction,
|
|
6871
|
-
|
|
7641
|
+
linear_before_reset=op.linear_before_reset,
|
|
6872
7642
|
activation_functions=activation_functions,
|
|
6873
7643
|
).rstrip()
|
|
6874
7644
|
return with_node_comment(rendered)
|
|
6875
|
-
if isinstance(op,
|
|
7645
|
+
if isinstance(op, LstmOp):
|
|
6876
7646
|
params = self._shared_param_map(
|
|
6877
7647
|
[
|
|
6878
|
-
("
|
|
6879
|
-
("
|
|
6880
|
-
|
|
6881
|
-
|
|
6882
|
-
|
|
6883
|
-
),
|
|
6884
|
-
|
|
6885
|
-
|
|
6886
|
-
|
|
6887
|
-
),
|
|
6888
|
-
|
|
6889
|
-
|
|
6890
|
-
|
|
6891
|
-
|
|
6892
|
-
|
|
6893
|
-
|
|
6894
|
-
|
|
6895
|
-
|
|
6896
|
-
|
|
6897
|
-
|
|
6898
|
-
|
|
6899
|
-
|
|
6900
|
-
|
|
6901
|
-
|
|
6902
|
-
|
|
7648
|
+
("input_x", op.input_x),
|
|
7649
|
+
("input_w", op.input_w),
|
|
7650
|
+
("input_r", op.input_r),
|
|
7651
|
+
("input_b", op.input_b),
|
|
7652
|
+
("input_sequence_lens", op.input_sequence_lens),
|
|
7653
|
+
("input_initial_h", op.input_initial_h),
|
|
7654
|
+
("input_initial_c", op.input_initial_c),
|
|
7655
|
+
("input_p", op.input_p),
|
|
7656
|
+
("output_y", op.output_y),
|
|
7657
|
+
("output_y_h", op.output_y_h),
|
|
7658
|
+
("output_y_c", op.output_y_c),
|
|
7659
|
+
]
|
|
7660
|
+
)
|
|
7661
|
+
input_x_shape = (
|
|
7662
|
+
(op.seq_length, op.batch_size, op.input_size)
|
|
7663
|
+
if op.layout == 0
|
|
7664
|
+
else (op.batch_size, op.seq_length, op.input_size)
|
|
7665
|
+
)
|
|
7666
|
+
w_shape = (op.num_directions, 4 * op.hidden_size, op.input_size)
|
|
7667
|
+
r_shape = (op.num_directions, 4 * op.hidden_size, op.hidden_size)
|
|
7668
|
+
b_shape = (
|
|
7669
|
+
(op.num_directions, 8 * op.hidden_size)
|
|
7670
|
+
if op.input_b is not None
|
|
7671
|
+
else None
|
|
7672
|
+
)
|
|
7673
|
+
seq_shape = (op.batch_size,) if op.input_sequence_lens is not None else None
|
|
7674
|
+
state_shape = (
|
|
7675
|
+
(op.num_directions, op.batch_size, op.hidden_size)
|
|
7676
|
+
if op.layout == 0
|
|
7677
|
+
else (op.batch_size, op.num_directions, op.hidden_size)
|
|
7678
|
+
)
|
|
7679
|
+
h_shape = (
|
|
7680
|
+
state_shape
|
|
7681
|
+
if op.input_initial_h is not None or op.output_y_h is not None
|
|
7682
|
+
else None
|
|
7683
|
+
)
|
|
7684
|
+
c_shape = (
|
|
7685
|
+
state_shape
|
|
7686
|
+
if op.input_initial_c is not None or op.output_y_c is not None
|
|
7687
|
+
else None
|
|
7688
|
+
)
|
|
7689
|
+
p_shape = (
|
|
7690
|
+
(op.num_directions, 3 * op.hidden_size)
|
|
7691
|
+
if op.input_p is not None
|
|
7692
|
+
else None
|
|
7693
|
+
)
|
|
7694
|
+
y_shape = (
|
|
7695
|
+
(op.seq_length, op.num_directions, op.batch_size, op.hidden_size)
|
|
7696
|
+
if op.layout == 0
|
|
7697
|
+
else (op.batch_size, op.seq_length, op.num_directions, op.hidden_size)
|
|
7698
|
+
)
|
|
7699
|
+
param_decls = self._build_param_decls(
|
|
7700
|
+
[
|
|
7701
|
+
(
|
|
7702
|
+
params["input_x"],
|
|
7703
|
+
c_type,
|
|
7704
|
+
self._param_array_suffix(input_x_shape),
|
|
7705
|
+
True,
|
|
7706
|
+
),
|
|
7707
|
+
(
|
|
7708
|
+
params["input_w"],
|
|
7709
|
+
c_type,
|
|
7710
|
+
self._param_array_suffix(w_shape),
|
|
7711
|
+
True,
|
|
7712
|
+
),
|
|
7713
|
+
(
|
|
7714
|
+
params["input_r"],
|
|
7715
|
+
c_type,
|
|
7716
|
+
self._param_array_suffix(r_shape),
|
|
7717
|
+
True,
|
|
7718
|
+
),
|
|
7719
|
+
(
|
|
7720
|
+
params["input_b"],
|
|
7721
|
+
c_type,
|
|
7722
|
+
self._param_array_suffix(b_shape),
|
|
7723
|
+
True,
|
|
7724
|
+
)
|
|
7725
|
+
if params["input_b"]
|
|
7726
|
+
else (None, "", "", True),
|
|
7727
|
+
(
|
|
7728
|
+
params["input_sequence_lens"],
|
|
7729
|
+
(op.sequence_lens_dtype or ScalarType.I64).c_type,
|
|
7730
|
+
self._param_array_suffix(seq_shape),
|
|
7731
|
+
True,
|
|
7732
|
+
)
|
|
7733
|
+
if params["input_sequence_lens"]
|
|
7734
|
+
else (None, "", "", True),
|
|
7735
|
+
(
|
|
7736
|
+
params["input_initial_h"],
|
|
7737
|
+
c_type,
|
|
7738
|
+
self._param_array_suffix(h_shape),
|
|
7739
|
+
True,
|
|
7740
|
+
)
|
|
7741
|
+
if params["input_initial_h"]
|
|
7742
|
+
else (None, "", "", True),
|
|
7743
|
+
(
|
|
7744
|
+
params["input_initial_c"],
|
|
7745
|
+
c_type,
|
|
7746
|
+
self._param_array_suffix(c_shape),
|
|
7747
|
+
True,
|
|
7748
|
+
)
|
|
7749
|
+
if params["input_initial_c"]
|
|
7750
|
+
else (None, "", "", True),
|
|
7751
|
+
(
|
|
7752
|
+
params["input_p"],
|
|
7753
|
+
c_type,
|
|
7754
|
+
self._param_array_suffix(p_shape),
|
|
7755
|
+
True,
|
|
7756
|
+
)
|
|
7757
|
+
if params["input_p"]
|
|
7758
|
+
else (None, "", "", True),
|
|
7759
|
+
(
|
|
7760
|
+
params["output_y"],
|
|
7761
|
+
c_type,
|
|
7762
|
+
self._param_array_suffix(y_shape),
|
|
7763
|
+
False,
|
|
7764
|
+
)
|
|
7765
|
+
if params["output_y"]
|
|
7766
|
+
else (None, "", "", False),
|
|
7767
|
+
(
|
|
7768
|
+
params["output_y_h"],
|
|
7769
|
+
c_type,
|
|
7770
|
+
self._param_array_suffix(h_shape),
|
|
7771
|
+
False,
|
|
7772
|
+
)
|
|
7773
|
+
if params["output_y_h"]
|
|
7774
|
+
else (None, "", "", False),
|
|
7775
|
+
(
|
|
7776
|
+
params["output_y_c"],
|
|
7777
|
+
c_type,
|
|
7778
|
+
self._param_array_suffix(c_shape),
|
|
7779
|
+
False,
|
|
7780
|
+
)
|
|
7781
|
+
if params["output_y_c"]
|
|
7782
|
+
else (None, "", "", False),
|
|
7783
|
+
]
|
|
7784
|
+
)
|
|
7785
|
+
if scalar_registry is None:
|
|
7786
|
+
raise CodegenError(
|
|
7787
|
+
"Scalar function registry is required for LSTM codegen."
|
|
7788
|
+
)
|
|
7789
|
+
activation_functions = tuple(
|
|
7790
|
+
self._rnn_activation_function_name(
|
|
7791
|
+
kind,
|
|
7792
|
+
alpha,
|
|
7793
|
+
beta,
|
|
7794
|
+
op.dtype,
|
|
7795
|
+
scalar_registry,
|
|
7796
|
+
)
|
|
7797
|
+
for kind, alpha, beta in zip(
|
|
7798
|
+
op.activation_kinds,
|
|
7799
|
+
op.activation_alphas,
|
|
7800
|
+
op.activation_betas,
|
|
7801
|
+
)
|
|
7802
|
+
)
|
|
7803
|
+
rendered = lstm_template.render(
|
|
7804
|
+
model_name=model.name,
|
|
7805
|
+
op_name=op_name,
|
|
7806
|
+
input_x=params["input_x"],
|
|
7807
|
+
input_w=params["input_w"],
|
|
7808
|
+
input_r=params["input_r"],
|
|
7809
|
+
input_b=params["input_b"],
|
|
7810
|
+
input_sequence_lens=params["input_sequence_lens"],
|
|
7811
|
+
input_initial_h=params["input_initial_h"],
|
|
7812
|
+
input_initial_c=params["input_initial_c"],
|
|
7813
|
+
input_p=params["input_p"],
|
|
7814
|
+
output_y=params["output_y"],
|
|
7815
|
+
output_y_h=params["output_y_h"],
|
|
7816
|
+
output_y_c=params["output_y_c"],
|
|
7817
|
+
params=param_decls,
|
|
7818
|
+
c_type=c_type,
|
|
7819
|
+
seq_c_type=(op.sequence_lens_dtype or ScalarType.I64).c_type,
|
|
7820
|
+
zero_literal=zero_literal,
|
|
7821
|
+
one_literal=CEmitter._format_literal(op.dtype, 1),
|
|
7822
|
+
clip_literal=(
|
|
7823
|
+
CEmitter._format_floating(op.clip, op.dtype)
|
|
7824
|
+
if op.clip is not None
|
|
7825
|
+
else CEmitter._format_literal(op.dtype, 0)
|
|
7826
|
+
),
|
|
7827
|
+
use_clip=int(op.clip is not None and op.clip > 0),
|
|
7828
|
+
seq_length=op.seq_length,
|
|
7829
|
+
batch_size=op.batch_size,
|
|
7830
|
+
input_size=op.input_size,
|
|
7831
|
+
hidden_size=op.hidden_size,
|
|
7832
|
+
num_directions=op.num_directions,
|
|
7833
|
+
layout=op.layout,
|
|
7834
|
+
direction=op.direction,
|
|
7835
|
+
input_forget=op.input_forget,
|
|
7836
|
+
activation_functions=activation_functions,
|
|
7837
|
+
).rstrip()
|
|
7838
|
+
return with_node_comment(rendered)
|
|
7839
|
+
if isinstance(op, AdagradOp):
|
|
7840
|
+
params = self._shared_param_map(
|
|
7841
|
+
[
|
|
7842
|
+
("rate", op.rate),
|
|
7843
|
+
("timestep", op.timestep),
|
|
7844
|
+
*(
|
|
7845
|
+
(f"input{idx}", name)
|
|
7846
|
+
for idx, name in enumerate(op.inputs)
|
|
7847
|
+
),
|
|
7848
|
+
*(
|
|
7849
|
+
(f"grad{idx}", name)
|
|
7850
|
+
for idx, name in enumerate(op.gradients)
|
|
7851
|
+
),
|
|
7852
|
+
*(
|
|
7853
|
+
(f"acc{idx}", name)
|
|
7854
|
+
for idx, name in enumerate(op.accumulators)
|
|
7855
|
+
),
|
|
7856
|
+
*(
|
|
7857
|
+
(f"output{idx}", name)
|
|
7858
|
+
for idx, name in enumerate(op.outputs)
|
|
7859
|
+
),
|
|
7860
|
+
*(
|
|
7861
|
+
(f"acc_output{idx}", name)
|
|
7862
|
+
for idx, name in enumerate(op.accumulator_outputs)
|
|
7863
|
+
),
|
|
7864
|
+
]
|
|
7865
|
+
)
|
|
7866
|
+
rate_suffix = self._param_array_suffix(
|
|
6903
7867
|
op.rate_shape, _dim_names_for(op.rate)
|
|
6904
7868
|
)
|
|
6905
7869
|
timestep_suffix = self._param_array_suffix(
|
|
@@ -7112,11 +8076,7 @@ class CEmitter:
|
|
|
7112
8076
|
).rstrip()
|
|
7113
8077
|
return with_node_comment(rendered)
|
|
7114
8078
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
7115
|
-
acc_dtype = (
|
|
7116
|
-
ScalarType.F64
|
|
7117
|
-
if op.dtype in {ScalarType.F16, ScalarType.F32}
|
|
7118
|
-
else op.dtype
|
|
7119
|
-
)
|
|
8079
|
+
acc_dtype = self._accumulation_dtype(op.dtype)
|
|
7120
8080
|
acc_type = acc_dtype.c_type
|
|
7121
8081
|
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
7122
8082
|
acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
|
|
@@ -7173,11 +8133,7 @@ class CEmitter:
|
|
|
7173
8133
|
).rstrip()
|
|
7174
8134
|
return with_node_comment(rendered)
|
|
7175
8135
|
if isinstance(op, SoftmaxCrossEntropyLossOp):
|
|
7176
|
-
acc_dtype = (
|
|
7177
|
-
ScalarType.F64
|
|
7178
|
-
if op.dtype in {ScalarType.F16, ScalarType.F32}
|
|
7179
|
-
else op.dtype
|
|
7180
|
-
)
|
|
8136
|
+
acc_dtype = self._accumulation_dtype(op.dtype)
|
|
7181
8137
|
if scalar_registry is None:
|
|
7182
8138
|
raise CodegenError(
|
|
7183
8139
|
"Scalar function registry is required for SoftmaxCrossEntropyLoss."
|
|
@@ -7873,7 +8829,58 @@ class CEmitter:
|
|
|
7873
8829
|
loop_vars=loop_vars,
|
|
7874
8830
|
).rstrip()
|
|
7875
8831
|
return with_node_comment(rendered)
|
|
8832
|
+
if isinstance(op, BernoulliOp):
|
|
8833
|
+
output_dim_names = _dim_names_for(op.output)
|
|
8834
|
+
shape = CEmitter._shape_dim_exprs(op.output_shape, output_dim_names)
|
|
8835
|
+
loop_vars = CEmitter._loop_vars(op.output_shape)
|
|
8836
|
+
output_suffix = self._param_array_suffix(
|
|
8837
|
+
op.output_shape, output_dim_names
|
|
8838
|
+
)
|
|
8839
|
+
input_suffix = self._param_array_suffix(
|
|
8840
|
+
op.input_shape, _dim_names_for(op.input0)
|
|
8841
|
+
)
|
|
8842
|
+
params = self._shared_param_map(
|
|
8843
|
+
[("input0", op.input0), ("output", op.output)]
|
|
8844
|
+
)
|
|
8845
|
+
output_dtype = op.dtype
|
|
8846
|
+
param_decls = self._build_param_decls(
|
|
8847
|
+
[
|
|
8848
|
+
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
8849
|
+
(params["output"], output_dtype.c_type, output_suffix, False),
|
|
8850
|
+
]
|
|
8851
|
+
)
|
|
8852
|
+
one_literal = (
|
|
8853
|
+
"true"
|
|
8854
|
+
if output_dtype == ScalarType.BOOL
|
|
8855
|
+
else f"({output_dtype.c_type})1"
|
|
8856
|
+
)
|
|
8857
|
+
zero_literal = (
|
|
8858
|
+
"false"
|
|
8859
|
+
if output_dtype == ScalarType.BOOL
|
|
8860
|
+
else output_dtype.zero_literal
|
|
8861
|
+
)
|
|
8862
|
+
rendered = bernoulli_template.render(
|
|
8863
|
+
model_name=model.name,
|
|
8864
|
+
op_name=op_name,
|
|
8865
|
+
input0=params["input0"],
|
|
8866
|
+
output=params["output"],
|
|
8867
|
+
input_index_expr="".join(
|
|
8868
|
+
f"[{var}]" for var in loop_vars
|
|
8869
|
+
),
|
|
8870
|
+
output_index_expr="".join(
|
|
8871
|
+
f"[{var}]" for var in loop_vars
|
|
8872
|
+
),
|
|
8873
|
+
shape=shape,
|
|
8874
|
+
loop_vars=loop_vars,
|
|
8875
|
+
seed=op.seed if op.seed is not None else 0,
|
|
8876
|
+
one_literal=one_literal,
|
|
8877
|
+
zero_literal=zero_literal,
|
|
8878
|
+
dim_args=dim_args,
|
|
8879
|
+
params=param_decls,
|
|
8880
|
+
).rstrip()
|
|
8881
|
+
return with_node_comment(rendered)
|
|
7876
8882
|
if isinstance(op, EyeLikeOp):
|
|
8883
|
+
input_c_type = op.input_dtype.c_type
|
|
7877
8884
|
params = self._shared_param_map(
|
|
7878
8885
|
[("input0", op.input0), ("output", op.output)]
|
|
7879
8886
|
)
|
|
@@ -7887,7 +8894,7 @@ class CEmitter:
|
|
|
7887
8894
|
batch_size = CEmitter._element_count(batch_dims or (1,))
|
|
7888
8895
|
param_decls = self._build_param_decls(
|
|
7889
8896
|
[
|
|
7890
|
-
(params["input0"],
|
|
8897
|
+
(params["input0"], input_c_type, input_suffix, True),
|
|
7891
8898
|
(params["output"], c_type, output_suffix, False),
|
|
7892
8899
|
]
|
|
7893
8900
|
)
|
|
@@ -8499,8 +9506,6 @@ class CEmitter:
|
|
|
8499
9506
|
update_expr = None
|
|
8500
9507
|
init_literal = None
|
|
8501
9508
|
final_expr = "acc"
|
|
8502
|
-
use_kahan = False
|
|
8503
|
-
kahan_value_expr = None
|
|
8504
9509
|
fabs_fn = CEmitter._math_fn(output_dtype, "fabsf", "fabs")
|
|
8505
9510
|
exp_fn = CEmitter._math_fn(output_dtype, "expf", "exp")
|
|
8506
9511
|
log_fn = CEmitter._math_fn(output_dtype, "logf", "log")
|
|
@@ -8546,24 +9551,6 @@ class CEmitter:
|
|
|
8546
9551
|
raise CodegenError(
|
|
8547
9552
|
f"Unsupported reduce kind {op.reduce_kind}"
|
|
8548
9553
|
)
|
|
8549
|
-
if output_dtype in {ScalarType.F16, ScalarType.F32} and op.reduce_kind in {
|
|
8550
|
-
"sum",
|
|
8551
|
-
"mean",
|
|
8552
|
-
"logsum",
|
|
8553
|
-
"logsumexp",
|
|
8554
|
-
"l1",
|
|
8555
|
-
"l2",
|
|
8556
|
-
"sumsquare",
|
|
8557
|
-
}:
|
|
8558
|
-
use_kahan = True
|
|
8559
|
-
if op.reduce_kind == "logsumexp":
|
|
8560
|
-
kahan_value_expr = f"{exp_fn}({value_expr})"
|
|
8561
|
-
elif op.reduce_kind == "l1":
|
|
8562
|
-
kahan_value_expr = f"{fabs_fn}({value_expr})"
|
|
8563
|
-
elif op.reduce_kind in {"l2", "sumsquare"}:
|
|
8564
|
-
kahan_value_expr = f"{value_expr} * {value_expr}"
|
|
8565
|
-
else:
|
|
8566
|
-
kahan_value_expr = value_expr
|
|
8567
9554
|
input_suffix = self._param_array_suffix(input_shape)
|
|
8568
9555
|
output_suffix = self._param_array_suffix(output_shape_raw)
|
|
8569
9556
|
param_decls = self._build_param_decls(
|
|
@@ -8590,8 +9577,8 @@ class CEmitter:
|
|
|
8590
9577
|
zero_literal=zero_literal,
|
|
8591
9578
|
update_expr=update_expr,
|
|
8592
9579
|
final_expr=final_expr,
|
|
8593
|
-
use_kahan=
|
|
8594
|
-
kahan_value_expr=
|
|
9580
|
+
use_kahan=False,
|
|
9581
|
+
kahan_value_expr=None,
|
|
8595
9582
|
).rstrip()
|
|
8596
9583
|
return with_node_comment(rendered)
|
|
8597
9584
|
if isinstance(op, ArgReduceOp):
|
|
@@ -8736,9 +9723,9 @@ class CEmitter:
|
|
|
8736
9723
|
output_values=params["output_values"],
|
|
8737
9724
|
output_indices=params["output_indices"],
|
|
8738
9725
|
params=param_decls,
|
|
8739
|
-
input_c_type=
|
|
8740
|
-
output_values_c_type=
|
|
8741
|
-
output_indices_c_type=
|
|
9726
|
+
input_c_type=input_dtype.c_type,
|
|
9727
|
+
output_values_c_type=output_values_dtype.c_type,
|
|
9728
|
+
output_indices_c_type=output_indices_dtype.c_type,
|
|
8742
9729
|
input_suffix=input_suffix,
|
|
8743
9730
|
output_suffix=output_suffix,
|
|
8744
9731
|
output_shape=output_shape,
|
|
@@ -8746,7 +9733,7 @@ class CEmitter:
|
|
|
8746
9733
|
outer_loop_vars=outer_loop_vars,
|
|
8747
9734
|
reduce_var=reduce_var,
|
|
8748
9735
|
k_var=k_var,
|
|
8749
|
-
axis_dim=
|
|
9736
|
+
axis_dim=input_shape[op.axis],
|
|
8750
9737
|
k=op.k,
|
|
8751
9738
|
input_index_expr=input_index_expr,
|
|
8752
9739
|
output_index_expr=output_index_expr,
|
|
@@ -8762,11 +9749,15 @@ class CEmitter:
|
|
|
8762
9749
|
("output", op.output),
|
|
8763
9750
|
]
|
|
8764
9751
|
)
|
|
8765
|
-
|
|
9752
|
+
input_shape_raw = self._ctx_shape(op.input0)
|
|
9753
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
9754
|
+
output_shape = CEmitter._codegen_shape(output_shape_raw)
|
|
8766
9755
|
output_loop_vars = CEmitter._loop_vars(output_shape)
|
|
8767
|
-
input_shape = CEmitter._codegen_shape(
|
|
9756
|
+
input_shape = CEmitter._codegen_shape(input_shape_raw)
|
|
8768
9757
|
input_loop_vars = CEmitter._loop_vars(input_shape)
|
|
8769
|
-
axes_shape =
|
|
9758
|
+
axes_shape = (
|
|
9759
|
+
self._ctx_shape(op.axes_input) if op.axes_input is not None else ()
|
|
9760
|
+
)
|
|
8770
9761
|
axes_count = 1
|
|
8771
9762
|
for dim in axes_shape:
|
|
8772
9763
|
if dim == 0:
|
|
@@ -8774,8 +9765,8 @@ class CEmitter:
|
|
|
8774
9765
|
break
|
|
8775
9766
|
axes_count *= dim
|
|
8776
9767
|
axes_c_type = (
|
|
8777
|
-
op.
|
|
8778
|
-
if op.
|
|
9768
|
+
self._ctx_dtype(op.axes_input).c_type
|
|
9769
|
+
if op.axes_input is not None
|
|
8779
9770
|
else ScalarType.I64.c_type
|
|
8780
9771
|
)
|
|
8781
9772
|
input_indices = "".join(f"[{var}]" for var in input_loop_vars)
|
|
@@ -8789,10 +9780,11 @@ class CEmitter:
|
|
|
8789
9780
|
update_expr = None
|
|
8790
9781
|
init_literal = None
|
|
8791
9782
|
post_expr = None
|
|
8792
|
-
|
|
8793
|
-
|
|
8794
|
-
|
|
8795
|
-
|
|
9783
|
+
reduce_dtype = self._ctx_dtype(op.output)
|
|
9784
|
+
fabs_fn = CEmitter._math_fn(reduce_dtype, "fabsf", "fabs")
|
|
9785
|
+
exp_fn = CEmitter._math_fn(reduce_dtype, "expf", "exp")
|
|
9786
|
+
log_fn = CEmitter._math_fn(reduce_dtype, "logf", "log")
|
|
9787
|
+
sqrt_fn = CEmitter._math_fn(reduce_dtype, "sqrtf", "sqrt")
|
|
8796
9788
|
if op.reduce_kind == "sum":
|
|
8797
9789
|
init_literal = zero_literal
|
|
8798
9790
|
update_expr = f"*out_ptr += {value_expr};"
|
|
@@ -8807,7 +9799,7 @@ class CEmitter:
|
|
|
8807
9799
|
init_literal = max_literal
|
|
8808
9800
|
update_expr = f"if ({value_expr} < *out_ptr) *out_ptr = {value_expr};"
|
|
8809
9801
|
elif op.reduce_kind == "prod":
|
|
8810
|
-
init_literal = CEmitter._format_literal(
|
|
9802
|
+
init_literal = CEmitter._format_literal(reduce_dtype, 1)
|
|
8811
9803
|
update_expr = f"*out_ptr *= {value_expr};"
|
|
8812
9804
|
elif op.reduce_kind == "l1":
|
|
8813
9805
|
init_literal = zero_literal
|
|
@@ -8831,11 +9823,11 @@ class CEmitter:
|
|
|
8831
9823
|
raise CodegenError(
|
|
8832
9824
|
f"Unsupported reduce kind {op.reduce_kind}"
|
|
8833
9825
|
)
|
|
8834
|
-
input_suffix = self._param_array_suffix(
|
|
8835
|
-
output_suffix = self._param_array_suffix(
|
|
9826
|
+
input_suffix = self._param_array_suffix(input_shape_raw)
|
|
9827
|
+
output_suffix = self._param_array_suffix(output_shape_raw)
|
|
8836
9828
|
axes_suffix = (
|
|
8837
|
-
self._param_array_suffix(
|
|
8838
|
-
if
|
|
9829
|
+
self._param_array_suffix(axes_shape)
|
|
9830
|
+
if axes_shape
|
|
8839
9831
|
else ""
|
|
8840
9832
|
)
|
|
8841
9833
|
params = self._build_param_decls(
|
|
@@ -8963,6 +9955,44 @@ class CEmitter:
|
|
|
8963
9955
|
value=CEmitter._format_literal(op.dtype, op.value),
|
|
8964
9956
|
).rstrip()
|
|
8965
9957
|
return with_node_comment(rendered)
|
|
9958
|
+
if isinstance(op, OptionalHasElementOp):
|
|
9959
|
+
params = self._shared_param_map(
|
|
9960
|
+
[("input0", op.input0), ("output", op.output)]
|
|
9961
|
+
)
|
|
9962
|
+
input_shape = self._ctx_shape(op.input0)
|
|
9963
|
+
output_shape = self._ctx_shape(op.output)
|
|
9964
|
+
input_dim_names = _dim_names_for(op.input0)
|
|
9965
|
+
output_dim_names = _dim_names_for(op.output)
|
|
9966
|
+
input_suffix = self._param_array_suffix(input_shape, input_dim_names)
|
|
9967
|
+
output_suffix = self._param_array_suffix(output_shape, output_dim_names)
|
|
9968
|
+
input_dtype = self._ctx_dtype(op.input0)
|
|
9969
|
+
output_dtype = self._ctx_dtype(op.output)
|
|
9970
|
+
optional_flags = self._optional_input_flag_map(model)
|
|
9971
|
+
input_flag = optional_flags.get(op.input0)
|
|
9972
|
+
if input_flag is None:
|
|
9973
|
+
raise CodegenError(
|
|
9974
|
+
"OptionalHasElement expects an optional input flag."
|
|
9975
|
+
)
|
|
9976
|
+
param_decls = self._build_param_decls(
|
|
9977
|
+
[
|
|
9978
|
+
(params["input0"], input_dtype.c_type, input_suffix, True),
|
|
9979
|
+
(input_flag, "_Bool", "", True),
|
|
9980
|
+
(params["output"], output_dtype.c_type, output_suffix, False),
|
|
9981
|
+
]
|
|
9982
|
+
)
|
|
9983
|
+
rendered = optional_has_element_template.render(
|
|
9984
|
+
model_name=model.name,
|
|
9985
|
+
op_name=op_name,
|
|
9986
|
+
input0=params["input0"],
|
|
9987
|
+
input_present=input_flag,
|
|
9988
|
+
output=params["output"],
|
|
9989
|
+
params=param_decls,
|
|
9990
|
+
input_c_type=input_dtype.c_type,
|
|
9991
|
+
output_c_type=output_dtype.c_type,
|
|
9992
|
+
input_suffix=input_suffix,
|
|
9993
|
+
output_suffix=output_suffix,
|
|
9994
|
+
).rstrip()
|
|
9995
|
+
return with_node_comment(rendered)
|
|
8966
9996
|
if isinstance(op, NonZeroOp):
|
|
8967
9997
|
params = self._shared_param_map(
|
|
8968
9998
|
[("input0", op.input0), ("output", op.output)]
|
|
@@ -9247,6 +10277,38 @@ class CEmitter:
|
|
|
9247
10277
|
length=op.length,
|
|
9248
10278
|
).rstrip()
|
|
9249
10279
|
return with_node_comment(rendered)
|
|
10280
|
+
if isinstance(op, HammingWindowOp):
|
|
10281
|
+
params = self._shared_param_map(
|
|
10282
|
+
[
|
|
10283
|
+
("size", op.size),
|
|
10284
|
+
("output", op.output),
|
|
10285
|
+
]
|
|
10286
|
+
)
|
|
10287
|
+
scalar_suffix = self._param_array_suffix(())
|
|
10288
|
+
output_suffix = self._param_array_suffix(op.output_shape)
|
|
10289
|
+
param_decls = self._build_param_decls(
|
|
10290
|
+
[
|
|
10291
|
+
(
|
|
10292
|
+
params["size"],
|
|
10293
|
+
op.input_dtype.c_type,
|
|
10294
|
+
scalar_suffix,
|
|
10295
|
+
True,
|
|
10296
|
+
),
|
|
10297
|
+
(params["output"], c_type, output_suffix, False),
|
|
10298
|
+
]
|
|
10299
|
+
)
|
|
10300
|
+
rendered = hamming_window_template.render(
|
|
10301
|
+
model_name=model.name,
|
|
10302
|
+
op_name=op_name,
|
|
10303
|
+
size=params["size"],
|
|
10304
|
+
output=params["output"],
|
|
10305
|
+
params=param_decls,
|
|
10306
|
+
c_type=c_type,
|
|
10307
|
+
output_suffix=output_suffix,
|
|
10308
|
+
length=op.output_shape[0],
|
|
10309
|
+
periodic_literal="1" if op.periodic else "0",
|
|
10310
|
+
).rstrip()
|
|
10311
|
+
return with_node_comment(rendered)
|
|
9250
10312
|
if isinstance(op, OneHotOp):
|
|
9251
10313
|
params = self._shared_param_map(
|
|
9252
10314
|
[
|
|
@@ -9315,6 +10377,85 @@ class CEmitter:
|
|
|
9315
10377
|
c_type=c_type,
|
|
9316
10378
|
).rstrip()
|
|
9317
10379
|
return with_node_comment(rendered)
|
|
10380
|
+
if isinstance(op, TfIdfVectorizerOp):
|
|
10381
|
+
params = self._shared_param_map(
|
|
10382
|
+
[("input0", op.input0), ("output", op.output)]
|
|
10383
|
+
)
|
|
10384
|
+
input_dim_names = _dim_names_for(op.input0)
|
|
10385
|
+
output_dim_names = _dim_names_for(op.output)
|
|
10386
|
+
input_suffix = self._param_array_suffix(
|
|
10387
|
+
op.input_shape, input_dim_names
|
|
10388
|
+
)
|
|
10389
|
+
output_suffix = self._param_array_suffix(
|
|
10390
|
+
op.output_shape, output_dim_names
|
|
10391
|
+
)
|
|
10392
|
+
param_decls = self._build_param_decls(
|
|
10393
|
+
[
|
|
10394
|
+
(
|
|
10395
|
+
params["input0"],
|
|
10396
|
+
op.input_dtype.c_type,
|
|
10397
|
+
input_suffix,
|
|
10398
|
+
True,
|
|
10399
|
+
),
|
|
10400
|
+
(
|
|
10401
|
+
params["output"],
|
|
10402
|
+
op.output_dtype.c_type,
|
|
10403
|
+
output_suffix,
|
|
10404
|
+
False,
|
|
10405
|
+
),
|
|
10406
|
+
]
|
|
10407
|
+
)
|
|
10408
|
+
output_dim = op.output_shape[-1] if op.output_shape else 0
|
|
10409
|
+
mode_id = {"TF": 0, "IDF": 1, "TFIDF": 2}[op.mode]
|
|
10410
|
+
pool_values = [
|
|
10411
|
+
CEmitter._format_literal(ScalarType.I64, value)
|
|
10412
|
+
for value in op.pool_int64s
|
|
10413
|
+
]
|
|
10414
|
+
ngram_counts_values = [
|
|
10415
|
+
CEmitter._format_literal(ScalarType.I64, value)
|
|
10416
|
+
for value in op.ngram_counts
|
|
10417
|
+
]
|
|
10418
|
+
ngram_indexes_values = [
|
|
10419
|
+
CEmitter._format_literal(ScalarType.I64, value)
|
|
10420
|
+
for value in op.ngram_indexes
|
|
10421
|
+
]
|
|
10422
|
+
weights_values = (
|
|
10423
|
+
[
|
|
10424
|
+
CEmitter._format_literal(op.output_dtype, value)
|
|
10425
|
+
for value in op.weights
|
|
10426
|
+
]
|
|
10427
|
+
if op.weights is not None
|
|
10428
|
+
else None
|
|
10429
|
+
)
|
|
10430
|
+
rendered = tfidf_vectorizer_template.render(
|
|
10431
|
+
model_name=model.name,
|
|
10432
|
+
op_name=op_name,
|
|
10433
|
+
input0=params["input0"],
|
|
10434
|
+
output=params["output"],
|
|
10435
|
+
params=param_decls,
|
|
10436
|
+
input_suffix=input_suffix,
|
|
10437
|
+
output_suffix=output_suffix,
|
|
10438
|
+
input_shape=op.input_shape,
|
|
10439
|
+
output_shape=op.output_shape,
|
|
10440
|
+
input_rank=len(op.input_shape),
|
|
10441
|
+
output_dim=output_dim,
|
|
10442
|
+
min_gram_length=op.min_gram_length,
|
|
10443
|
+
max_gram_length=op.max_gram_length,
|
|
10444
|
+
max_skip_count=op.max_skip_count,
|
|
10445
|
+
mode_id=mode_id,
|
|
10446
|
+
ngram_counts_len=len(op.ngram_counts),
|
|
10447
|
+
pool_size=len(op.pool_int64s),
|
|
10448
|
+
ngram_index_len=len(op.ngram_indexes),
|
|
10449
|
+
pool_values=pool_values,
|
|
10450
|
+
ngram_counts_values=ngram_counts_values,
|
|
10451
|
+
ngram_indexes_values=ngram_indexes_values,
|
|
10452
|
+
weights_values=weights_values,
|
|
10453
|
+
zero_literal=op.output_dtype.zero_literal,
|
|
10454
|
+
one_literal=CEmitter._format_literal(op.output_dtype, 1.0),
|
|
10455
|
+
c_type=op.output_dtype.c_type,
|
|
10456
|
+
input_c_type=op.input_dtype.c_type,
|
|
10457
|
+
).rstrip()
|
|
10458
|
+
return with_node_comment(rendered)
|
|
9318
10459
|
if isinstance(op, SplitOp):
|
|
9319
10460
|
output_params = [
|
|
9320
10461
|
(f"output_{index}", name)
|
|
@@ -9429,30 +10570,302 @@ class CEmitter:
|
|
|
9429
10570
|
scale_suffix = self._param_array_suffix(
|
|
9430
10571
|
scale_shape, _dim_names_for(op.scale)
|
|
9431
10572
|
)
|
|
9432
|
-
zero_point_suffix = self._param_array_suffix(
|
|
9433
|
-
scale_shape, _dim_names_for(op.zero_point or "")
|
|
10573
|
+
zero_point_suffix = self._param_array_suffix(
|
|
10574
|
+
scale_shape, _dim_names_for(op.zero_point or "")
|
|
10575
|
+
)
|
|
10576
|
+
param_decls = self._build_param_decls(
|
|
10577
|
+
[
|
|
10578
|
+
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
10579
|
+
(params["scale"], op.scale_dtype.c_type, scale_suffix, True),
|
|
10580
|
+
(
|
|
10581
|
+
params["zero_point"],
|
|
10582
|
+
op.dtype.c_type,
|
|
10583
|
+
zero_point_suffix,
|
|
10584
|
+
True,
|
|
10585
|
+
)
|
|
10586
|
+
if params["zero_point"]
|
|
10587
|
+
else (None, "", "", True),
|
|
10588
|
+
(params["output"], op.dtype.c_type, input_suffix, False),
|
|
10589
|
+
]
|
|
10590
|
+
)
|
|
10591
|
+
compute_type = "double" if op.input_dtype == ScalarType.F64 else "float"
|
|
10592
|
+
compute_dtype = (
|
|
10593
|
+
ScalarType.F64
|
|
10594
|
+
if compute_type == "double"
|
|
10595
|
+
else ScalarType.F32
|
|
10596
|
+
)
|
|
10597
|
+
max_fn = self._scalar_function_name(
|
|
10598
|
+
ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
|
|
10599
|
+
)
|
|
10600
|
+
min_fn = self._scalar_function_name(
|
|
10601
|
+
ScalarFunction.MINIMUM, compute_dtype, scalar_registry
|
|
10602
|
+
)
|
|
10603
|
+
if max_fn is None or min_fn is None:
|
|
10604
|
+
raise CodegenError(
|
|
10605
|
+
"Failed to resolve scalar min/max functions for QuantizeLinear."
|
|
10606
|
+
)
|
|
10607
|
+
round_fn = CEmitter._math_fn(
|
|
10608
|
+
op.input_dtype, "nearbyintf", "nearbyint"
|
|
10609
|
+
)
|
|
10610
|
+
scale_index = "0" if op.axis is None else loop_vars[op.axis]
|
|
10611
|
+
input_expr = f"{params['input0']}" + "".join(
|
|
10612
|
+
f"[{var}]" for var in loop_vars
|
|
10613
|
+
)
|
|
10614
|
+
output_expr = f"{params['output']}" + "".join(
|
|
10615
|
+
f"[{var}]" for var in loop_vars
|
|
10616
|
+
)
|
|
10617
|
+
scale_expr = f"{params['scale']}[{scale_index}]"
|
|
10618
|
+
if params["zero_point"]:
|
|
10619
|
+
zero_expr = f"{params['zero_point']}[{scale_index}]"
|
|
10620
|
+
else:
|
|
10621
|
+
zero_expr = "0"
|
|
10622
|
+
rendered = quantize_linear_template.render(
|
|
10623
|
+
model_name=model.name,
|
|
10624
|
+
op_name=op_name,
|
|
10625
|
+
input0=params["input0"],
|
|
10626
|
+
scale=params["scale"],
|
|
10627
|
+
zero_point=params["zero_point"],
|
|
10628
|
+
output=params["output"],
|
|
10629
|
+
params=param_decls,
|
|
10630
|
+
compute_type=compute_type,
|
|
10631
|
+
input_c_type=op.input_dtype.c_type,
|
|
10632
|
+
output_c_type=op.dtype.c_type,
|
|
10633
|
+
shape=shape,
|
|
10634
|
+
loop_vars=loop_vars,
|
|
10635
|
+
input_expr=input_expr,
|
|
10636
|
+
scale_expr=scale_expr,
|
|
10637
|
+
zero_expr=zero_expr,
|
|
10638
|
+
output_expr=output_expr,
|
|
10639
|
+
round_fn=round_fn,
|
|
10640
|
+
min_literal=op.dtype.min_literal,
|
|
10641
|
+
max_literal=op.dtype.max_literal,
|
|
10642
|
+
min_fn=min_fn,
|
|
10643
|
+
max_fn=max_fn,
|
|
10644
|
+
dim_args=dim_args,
|
|
10645
|
+
).rstrip()
|
|
10646
|
+
return with_node_comment(rendered)
|
|
10647
|
+
if isinstance(op, DequantizeLinearOp):
|
|
10648
|
+
params = self._shared_param_map(
|
|
10649
|
+
[
|
|
10650
|
+
("input0", op.input0),
|
|
10651
|
+
("scale", op.scale),
|
|
10652
|
+
("zero_point", op.zero_point),
|
|
10653
|
+
("output", op.output),
|
|
10654
|
+
]
|
|
10655
|
+
)
|
|
10656
|
+
output_dim_names = _dim_names_for(op.output)
|
|
10657
|
+
shape = CEmitter._shape_dim_exprs(op.input_shape, output_dim_names)
|
|
10658
|
+
loop_vars = CEmitter._loop_vars(op.input_shape)
|
|
10659
|
+
input_suffix = self._param_array_suffix(
|
|
10660
|
+
op.input_shape, _dim_names_for(op.input0)
|
|
10661
|
+
)
|
|
10662
|
+
if op.axis is None:
|
|
10663
|
+
scale_shape = ()
|
|
10664
|
+
elif op.block_size:
|
|
10665
|
+
scale_shape_list = list(op.input_shape)
|
|
10666
|
+
scale_shape_list[op.axis] = (
|
|
10667
|
+
op.input_shape[op.axis] // op.block_size
|
|
10668
|
+
)
|
|
10669
|
+
scale_shape = tuple(scale_shape_list)
|
|
10670
|
+
else:
|
|
10671
|
+
scale_shape = (op.input_shape[op.axis],)
|
|
10672
|
+
scale_suffix = self._param_array_suffix(
|
|
10673
|
+
scale_shape, _dim_names_for(op.scale)
|
|
10674
|
+
)
|
|
10675
|
+
zero_point_suffix = self._param_array_suffix(
|
|
10676
|
+
scale_shape, _dim_names_for(op.zero_point or "")
|
|
10677
|
+
)
|
|
10678
|
+
param_decls = self._build_param_decls(
|
|
10679
|
+
[
|
|
10680
|
+
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
10681
|
+
(params["scale"], op.scale_dtype.c_type, scale_suffix, True),
|
|
10682
|
+
(
|
|
10683
|
+
params["zero_point"],
|
|
10684
|
+
op.input_dtype.c_type,
|
|
10685
|
+
zero_point_suffix,
|
|
10686
|
+
True,
|
|
10687
|
+
)
|
|
10688
|
+
if params["zero_point"]
|
|
10689
|
+
else (None, "", "", True),
|
|
10690
|
+
(params["output"], op.dtype.c_type, input_suffix, False),
|
|
10691
|
+
]
|
|
10692
|
+
)
|
|
10693
|
+
compute_type = "double" if op.dtype == ScalarType.F64 else "float"
|
|
10694
|
+
input_expr = f"{params['input0']}" + "".join(
|
|
10695
|
+
f"[{var}]" for var in loop_vars
|
|
10696
|
+
)
|
|
10697
|
+
output_expr = f"{params['output']}" + "".join(
|
|
10698
|
+
f"[{var}]" for var in loop_vars
|
|
10699
|
+
)
|
|
10700
|
+
if op.axis is None:
|
|
10701
|
+
scale_expr = f"{params['scale']}[0]"
|
|
10702
|
+
elif op.block_size:
|
|
10703
|
+
scale_indices = list(loop_vars)
|
|
10704
|
+
scale_indices[op.axis] = (
|
|
10705
|
+
f"({loop_vars[op.axis]}) / {op.block_size}"
|
|
10706
|
+
)
|
|
10707
|
+
scale_expr = f"{params['scale']}" + "".join(
|
|
10708
|
+
f"[{index}]" for index in scale_indices
|
|
10709
|
+
)
|
|
10710
|
+
else:
|
|
10711
|
+
scale_index = loop_vars[op.axis]
|
|
10712
|
+
scale_expr = f"{params['scale']}[{scale_index}]"
|
|
10713
|
+
if params["zero_point"]:
|
|
10714
|
+
if op.axis is None:
|
|
10715
|
+
zero_expr = f"{params['zero_point']}[0]"
|
|
10716
|
+
elif op.block_size:
|
|
10717
|
+
scale_indices = list(loop_vars)
|
|
10718
|
+
scale_indices[op.axis] = (
|
|
10719
|
+
f"({loop_vars[op.axis]}) / {op.block_size}"
|
|
10720
|
+
)
|
|
10721
|
+
zero_expr = f"{params['zero_point']}" + "".join(
|
|
10722
|
+
f"[{index}]" for index in scale_indices
|
|
10723
|
+
)
|
|
10724
|
+
else:
|
|
10725
|
+
zero_expr = f"{params['zero_point']}[{scale_index}]"
|
|
10726
|
+
else:
|
|
10727
|
+
zero_expr = "0"
|
|
10728
|
+
rendered = dequantize_linear_template.render(
|
|
10729
|
+
model_name=model.name,
|
|
10730
|
+
op_name=op_name,
|
|
10731
|
+
input0=params["input0"],
|
|
10732
|
+
scale=params["scale"],
|
|
10733
|
+
zero_point=params["zero_point"],
|
|
10734
|
+
output=params["output"],
|
|
10735
|
+
params=param_decls,
|
|
10736
|
+
compute_type=compute_type,
|
|
10737
|
+
input_c_type=op.input_dtype.c_type,
|
|
10738
|
+
output_c_type=op.dtype.c_type,
|
|
10739
|
+
shape=shape,
|
|
10740
|
+
loop_vars=loop_vars,
|
|
10741
|
+
input_expr=input_expr,
|
|
10742
|
+
scale_expr=scale_expr,
|
|
10743
|
+
zero_expr=zero_expr,
|
|
10744
|
+
output_expr=output_expr,
|
|
10745
|
+
dim_args=dim_args,
|
|
10746
|
+
).rstrip()
|
|
10747
|
+
return with_node_comment(rendered)
|
|
10748
|
+
if isinstance(op, QLinearMulOp):
|
|
10749
|
+
if scalar_registry is None:
|
|
10750
|
+
raise CodegenError(
|
|
10751
|
+
"Scalar function registry is required for QLinearMul."
|
|
10752
|
+
)
|
|
10753
|
+
params = self._shared_param_map(
|
|
10754
|
+
[
|
|
10755
|
+
("input0", op.input0),
|
|
10756
|
+
("input0_scale", op.input0_scale),
|
|
10757
|
+
("input0_zero_point", op.input0_zero_point),
|
|
10758
|
+
("input1", op.input1),
|
|
10759
|
+
("input1_scale", op.input1_scale),
|
|
10760
|
+
("input1_zero_point", op.input1_zero_point),
|
|
10761
|
+
("output_scale", op.output_scale),
|
|
10762
|
+
("output_zero_point", op.output_zero_point),
|
|
10763
|
+
("output", op.output),
|
|
10764
|
+
]
|
|
10765
|
+
)
|
|
10766
|
+
output_shape = CEmitter._codegen_shape(op.output_shape)
|
|
10767
|
+
output_loop_vars = CEmitter._loop_vars(op.output_shape)
|
|
10768
|
+
output_index_expr = f"{params['output']}" + "".join(
|
|
10769
|
+
f"[{var}]" for var in output_loop_vars
|
|
10770
|
+
)
|
|
10771
|
+
input0_index_expr = CEmitter._broadcast_index_expr(
|
|
10772
|
+
params["input0"],
|
|
10773
|
+
op.input0_shape,
|
|
10774
|
+
op.output_shape,
|
|
10775
|
+
output_loop_vars,
|
|
10776
|
+
)
|
|
10777
|
+
input1_index_expr = CEmitter._broadcast_index_expr(
|
|
10778
|
+
params["input1"],
|
|
10779
|
+
op.input1_shape,
|
|
10780
|
+
op.output_shape,
|
|
10781
|
+
output_loop_vars,
|
|
10782
|
+
)
|
|
10783
|
+
input0_suffix = self._param_array_suffix(op.input0_shape)
|
|
10784
|
+
input1_suffix = self._param_array_suffix(op.input1_shape)
|
|
10785
|
+
input0_scale_suffix = self._param_array_suffix(
|
|
10786
|
+
op.input0_scale_shape
|
|
10787
|
+
)
|
|
10788
|
+
input1_scale_suffix = self._param_array_suffix(
|
|
10789
|
+
op.input1_scale_shape
|
|
9434
10790
|
)
|
|
10791
|
+
output_scale_suffix = self._param_array_suffix(
|
|
10792
|
+
op.output_scale_shape
|
|
10793
|
+
)
|
|
10794
|
+
input0_zero_suffix = self._param_array_suffix(op.input0_zero_shape)
|
|
10795
|
+
input1_zero_suffix = self._param_array_suffix(op.input1_zero_shape)
|
|
10796
|
+
output_zero_suffix = self._param_array_suffix(op.output_zero_shape)
|
|
10797
|
+
output_suffix = self._param_array_suffix(op.output_shape)
|
|
9435
10798
|
param_decls = self._build_param_decls(
|
|
9436
10799
|
[
|
|
9437
|
-
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
9438
|
-
(params["scale"], op.scale_dtype.c_type, scale_suffix, True),
|
|
9439
10800
|
(
|
|
9440
|
-
params["
|
|
10801
|
+
params["input0"],
|
|
10802
|
+
op.input0_dtype.c_type,
|
|
10803
|
+
input0_suffix,
|
|
10804
|
+
True,
|
|
10805
|
+
),
|
|
10806
|
+
(
|
|
10807
|
+
params["input0_scale"],
|
|
10808
|
+
op.input0_scale_dtype.c_type,
|
|
10809
|
+
input0_scale_suffix,
|
|
10810
|
+
True,
|
|
10811
|
+
),
|
|
10812
|
+
(
|
|
10813
|
+
params["input0_zero_point"],
|
|
10814
|
+
op.input0_dtype.c_type,
|
|
10815
|
+
input0_zero_suffix,
|
|
10816
|
+
True,
|
|
10817
|
+
),
|
|
10818
|
+
(
|
|
10819
|
+
params["input1"],
|
|
10820
|
+
op.input1_dtype.c_type,
|
|
10821
|
+
input1_suffix,
|
|
10822
|
+
True,
|
|
10823
|
+
),
|
|
10824
|
+
(
|
|
10825
|
+
params["input1_scale"],
|
|
10826
|
+
op.input1_scale_dtype.c_type,
|
|
10827
|
+
input1_scale_suffix,
|
|
10828
|
+
True,
|
|
10829
|
+
),
|
|
10830
|
+
(
|
|
10831
|
+
params["input1_zero_point"],
|
|
10832
|
+
op.input1_dtype.c_type,
|
|
10833
|
+
input1_zero_suffix,
|
|
10834
|
+
True,
|
|
10835
|
+
),
|
|
10836
|
+
(
|
|
10837
|
+
params["output_scale"],
|
|
10838
|
+
op.output_scale_dtype.c_type,
|
|
10839
|
+
output_scale_suffix,
|
|
10840
|
+
True,
|
|
10841
|
+
),
|
|
10842
|
+
(
|
|
10843
|
+
params["output_zero_point"],
|
|
9441
10844
|
op.dtype.c_type,
|
|
9442
|
-
|
|
10845
|
+
output_zero_suffix,
|
|
9443
10846
|
True,
|
|
9444
|
-
)
|
|
9445
|
-
|
|
9446
|
-
|
|
9447
|
-
|
|
10847
|
+
),
|
|
10848
|
+
(
|
|
10849
|
+
params["output"],
|
|
10850
|
+
op.dtype.c_type,
|
|
10851
|
+
output_suffix,
|
|
10852
|
+
False,
|
|
10853
|
+
),
|
|
9448
10854
|
]
|
|
9449
10855
|
)
|
|
9450
|
-
compute_type = "double" if op.input_dtype == ScalarType.F64 else "float"
|
|
9451
10856
|
compute_dtype = (
|
|
9452
10857
|
ScalarType.F64
|
|
9453
|
-
if
|
|
10858
|
+
if ScalarType.F64
|
|
10859
|
+
in {
|
|
10860
|
+
op.input0_scale_dtype,
|
|
10861
|
+
op.input1_scale_dtype,
|
|
10862
|
+
op.output_scale_dtype,
|
|
10863
|
+
}
|
|
9454
10864
|
else ScalarType.F32
|
|
9455
10865
|
)
|
|
10866
|
+
compute_type = (
|
|
10867
|
+
"double" if compute_dtype == ScalarType.F64 else "float"
|
|
10868
|
+
)
|
|
9456
10869
|
max_fn = self._scalar_function_name(
|
|
9457
10870
|
ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
|
|
9458
10871
|
)
|
|
@@ -9461,40 +10874,38 @@ class CEmitter:
|
|
|
9461
10874
|
)
|
|
9462
10875
|
if max_fn is None or min_fn is None:
|
|
9463
10876
|
raise CodegenError(
|
|
9464
|
-
"Failed to resolve scalar min/max functions for
|
|
10877
|
+
"Failed to resolve scalar min/max functions for QLinearMul."
|
|
9465
10878
|
)
|
|
9466
10879
|
round_fn = CEmitter._math_fn(
|
|
9467
|
-
|
|
9468
|
-
)
|
|
9469
|
-
scale_index = "0" if op.axis is None else loop_vars[op.axis]
|
|
9470
|
-
input_expr = f"{params['input0']}" + "".join(
|
|
9471
|
-
f"[{var}]" for var in loop_vars
|
|
9472
|
-
)
|
|
9473
|
-
output_expr = f"{params['output']}" + "".join(
|
|
9474
|
-
f"[{var}]" for var in loop_vars
|
|
10880
|
+
compute_dtype, "nearbyintf", "nearbyint"
|
|
9475
10881
|
)
|
|
9476
|
-
|
|
9477
|
-
|
|
9478
|
-
zero_expr = f"{params['zero_point']}[{scale_index}]"
|
|
9479
|
-
else:
|
|
9480
|
-
zero_expr = "0"
|
|
9481
|
-
rendered = quantize_linear_template.render(
|
|
10882
|
+
scale_index = "0"
|
|
10883
|
+
rendered = qlinear_mul_template.render(
|
|
9482
10884
|
model_name=model.name,
|
|
9483
10885
|
op_name=op_name,
|
|
9484
10886
|
input0=params["input0"],
|
|
9485
|
-
|
|
9486
|
-
|
|
10887
|
+
input1=params["input1"],
|
|
10888
|
+
input0_scale=params["input0_scale"],
|
|
10889
|
+
input0_zero_point=params["input0_zero_point"],
|
|
10890
|
+
input1_scale=params["input1_scale"],
|
|
10891
|
+
input1_zero_point=params["input1_zero_point"],
|
|
10892
|
+
output_scale=params["output_scale"],
|
|
10893
|
+
output_zero_point=params["output_zero_point"],
|
|
9487
10894
|
output=params["output"],
|
|
9488
10895
|
params=param_decls,
|
|
9489
10896
|
compute_type=compute_type,
|
|
9490
|
-
input_c_type=op.input_dtype.c_type,
|
|
9491
10897
|
output_c_type=op.dtype.c_type,
|
|
9492
|
-
|
|
9493
|
-
|
|
9494
|
-
|
|
9495
|
-
|
|
9496
|
-
|
|
9497
|
-
|
|
10898
|
+
input0_index_expr=input0_index_expr,
|
|
10899
|
+
input1_index_expr=input1_index_expr,
|
|
10900
|
+
input0_scale_expr=f"{params['input0_scale']}[{scale_index}]",
|
|
10901
|
+
input1_scale_expr=f"{params['input1_scale']}[{scale_index}]",
|
|
10902
|
+
output_scale_expr=f"{params['output_scale']}[{scale_index}]",
|
|
10903
|
+
input0_zero_expr=f"{params['input0_zero_point']}[{scale_index}]",
|
|
10904
|
+
input1_zero_expr=f"{params['input1_zero_point']}[{scale_index}]",
|
|
10905
|
+
output_zero_expr=f"{params['output_zero_point']}[{scale_index}]",
|
|
10906
|
+
output_loop_vars=output_loop_vars,
|
|
10907
|
+
output_loop_bounds=output_shape,
|
|
10908
|
+
output_index_expr=output_index_expr,
|
|
9498
10909
|
round_fn=round_fn,
|
|
9499
10910
|
min_literal=op.dtype.min_literal,
|
|
9500
10911
|
max_literal=op.dtype.max_literal,
|
|
@@ -9504,10 +10915,6 @@ class CEmitter:
|
|
|
9504
10915
|
).rstrip()
|
|
9505
10916
|
return with_node_comment(rendered)
|
|
9506
10917
|
if isinstance(op, QLinearMatMulOp):
|
|
9507
|
-
if scalar_registry is None:
|
|
9508
|
-
raise CodegenError(
|
|
9509
|
-
"Scalar function registry is required for QLinearMatMul."
|
|
9510
|
-
)
|
|
9511
10918
|
params = self._shared_param_map(
|
|
9512
10919
|
[
|
|
9513
10920
|
("input0", op.input0),
|
|
@@ -9541,13 +10948,18 @@ class CEmitter:
|
|
|
9541
10948
|
row_var = output_loop_vars[-2]
|
|
9542
10949
|
col_var = output_loop_vars[-1]
|
|
9543
10950
|
input0_index_expr, input1_index_expr = CEmitter._matmul_index_exprs(
|
|
9544
|
-
op,
|
|
9545
10951
|
batch_vars,
|
|
9546
10952
|
row_var,
|
|
9547
10953
|
col_var,
|
|
9548
10954
|
batch_rank,
|
|
9549
10955
|
input0=params["input0"],
|
|
9550
10956
|
input1=params["input1"],
|
|
10957
|
+
left_vector=op.left_vector,
|
|
10958
|
+
right_vector=op.right_vector,
|
|
10959
|
+
input0_shape=op.input0_shape,
|
|
10960
|
+
input1_shape=op.input1_shape,
|
|
10961
|
+
input0_batch_shape=op.input0_batch_shape,
|
|
10962
|
+
input1_batch_shape=op.input1_batch_shape,
|
|
9551
10963
|
)
|
|
9552
10964
|
input0_suffix = self._param_array_suffix(op.input0_shape)
|
|
9553
10965
|
input1_suffix = self._param_array_suffix(op.input1_shape)
|
|
@@ -9622,32 +11034,28 @@ class CEmitter:
|
|
|
9622
11034
|
),
|
|
9623
11035
|
]
|
|
9624
11036
|
)
|
|
9625
|
-
|
|
9626
|
-
|
|
9627
|
-
|
|
9628
|
-
|
|
9629
|
-
|
|
9630
|
-
|
|
9631
|
-
|
|
9632
|
-
|
|
9633
|
-
|
|
9634
|
-
|
|
11037
|
+
if ScalarType.F64 in {
|
|
11038
|
+
op.input0_scale_dtype,
|
|
11039
|
+
op.input1_scale_dtype,
|
|
11040
|
+
op.output_scale_dtype,
|
|
11041
|
+
}:
|
|
11042
|
+
scale_dtype = ScalarType.F64
|
|
11043
|
+
elif ScalarType.F32 in {
|
|
11044
|
+
op.input0_scale_dtype,
|
|
11045
|
+
op.input1_scale_dtype,
|
|
11046
|
+
op.output_scale_dtype,
|
|
11047
|
+
}:
|
|
11048
|
+
scale_dtype = ScalarType.F32
|
|
11049
|
+
else:
|
|
11050
|
+
scale_dtype = ScalarType.F16
|
|
11051
|
+
compute_dtype = ScalarType.F64
|
|
9635
11052
|
compute_type = (
|
|
9636
11053
|
"double" if compute_dtype == ScalarType.F64 else "float"
|
|
9637
11054
|
)
|
|
9638
|
-
max_fn = self._scalar_function_name(
|
|
9639
|
-
ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
|
|
9640
|
-
)
|
|
9641
|
-
min_fn = self._scalar_function_name(
|
|
9642
|
-
ScalarFunction.MINIMUM, compute_dtype, scalar_registry
|
|
9643
|
-
)
|
|
9644
|
-
if max_fn is None or min_fn is None:
|
|
9645
|
-
raise CodegenError(
|
|
9646
|
-
"Failed to resolve scalar min/max functions for QLinearMatMul."
|
|
9647
|
-
)
|
|
9648
11055
|
round_fn = CEmitter._math_fn(
|
|
9649
11056
|
compute_dtype, "nearbyintf", "nearbyint"
|
|
9650
11057
|
)
|
|
11058
|
+
mod_fn = CEmitter._math_fn(compute_dtype, "fmodf", "fmod")
|
|
9651
11059
|
scale_index = "0"
|
|
9652
11060
|
rendered = qlinear_matmul_template.render(
|
|
9653
11061
|
model_name=model.name,
|
|
@@ -9662,6 +11070,8 @@ class CEmitter:
|
|
|
9662
11070
|
output_zero_point=params["output_zero_point"],
|
|
9663
11071
|
output=params["output"],
|
|
9664
11072
|
params=param_decls,
|
|
11073
|
+
scale_type=scale_dtype.c_type,
|
|
11074
|
+
scale_is_float16=scale_dtype == ScalarType.F16,
|
|
9665
11075
|
compute_type=compute_type,
|
|
9666
11076
|
output_c_type=op.dtype.c_type,
|
|
9667
11077
|
input0_index_expr=input0_index_expr,
|
|
@@ -9677,10 +11087,8 @@ class CEmitter:
|
|
|
9677
11087
|
output_index_expr=output_index_expr,
|
|
9678
11088
|
k=op.k,
|
|
9679
11089
|
round_fn=round_fn,
|
|
9680
|
-
|
|
9681
|
-
|
|
9682
|
-
min_fn=min_fn,
|
|
9683
|
-
max_fn=max_fn,
|
|
11090
|
+
mod_fn=mod_fn,
|
|
11091
|
+
output_is_signed=op.dtype.is_signed,
|
|
9684
11092
|
dim_args=dim_args,
|
|
9685
11093
|
).rstrip()
|
|
9686
11094
|
return with_node_comment(rendered)
|
|
@@ -9740,7 +11148,11 @@ class CEmitter:
|
|
|
9740
11148
|
loop_vars,
|
|
9741
11149
|
)
|
|
9742
11150
|
if op.input_min is not None
|
|
9743
|
-
else
|
|
11151
|
+
else (
|
|
11152
|
+
CEmitter._format_literal(output_dtype, op.min_value)
|
|
11153
|
+
if op.min_value is not None
|
|
11154
|
+
else output_dtype.min_literal
|
|
11155
|
+
)
|
|
9744
11156
|
)
|
|
9745
11157
|
max_expr = (
|
|
9746
11158
|
CEmitter._broadcast_index_expr(
|
|
@@ -9750,7 +11162,11 @@ class CEmitter:
|
|
|
9750
11162
|
loop_vars,
|
|
9751
11163
|
)
|
|
9752
11164
|
if op.input_max is not None
|
|
9753
|
-
else
|
|
11165
|
+
else (
|
|
11166
|
+
CEmitter._format_literal(output_dtype, op.max_value)
|
|
11167
|
+
if op.max_value is not None
|
|
11168
|
+
else output_dtype.max_literal
|
|
11169
|
+
)
|
|
9754
11170
|
)
|
|
9755
11171
|
input_suffix = self._param_array_suffix(
|
|
9756
11172
|
input_shape, _dim_names_for(op.input0)
|
|
@@ -9896,11 +11312,14 @@ class CEmitter:
|
|
|
9896
11312
|
| ClipOp
|
|
9897
11313
|
| CastOp
|
|
9898
11314
|
| QuantizeLinearOp
|
|
11315
|
+
| QLinearMulOp
|
|
11316
|
+
| QLinearMatMulOp
|
|
9899
11317
|
| MatMulOp
|
|
9900
11318
|
| EinsumOp
|
|
9901
11319
|
| GemmOp
|
|
9902
11320
|
| AttentionOp
|
|
9903
11321
|
| ConvOp
|
|
11322
|
+
| ConvIntegerOp
|
|
9904
11323
|
| ConvTransposeOp
|
|
9905
11324
|
| AveragePoolOp
|
|
9906
11325
|
| LpPoolOp
|
|
@@ -9912,6 +11331,7 @@ class CEmitter:
|
|
|
9912
11331
|
| MeanVarianceNormalizationOp
|
|
9913
11332
|
| RMSNormalizationOp
|
|
9914
11333
|
| LrnOp
|
|
11334
|
+
| GruOp
|
|
9915
11335
|
| LstmOp
|
|
9916
11336
|
| SoftmaxOp
|
|
9917
11337
|
| LogSoftmaxOp
|
|
@@ -9942,9 +11362,11 @@ class CEmitter:
|
|
|
9942
11362
|
| ConstantOfShapeOp
|
|
9943
11363
|
| ShapeOp
|
|
9944
11364
|
| SizeOp
|
|
11365
|
+
| OptionalHasElementOp
|
|
9945
11366
|
| ExpandOp
|
|
9946
11367
|
| CumSumOp
|
|
9947
11368
|
| RangeOp
|
|
11369
|
+
| HammingWindowOp
|
|
9948
11370
|
| OneHotOp
|
|
9949
11371
|
| SplitOp,
|
|
9950
11372
|
) -> str:
|
|
@@ -9963,11 +11385,13 @@ class CEmitter:
|
|
|
9963
11385
|
| ClipOp
|
|
9964
11386
|
| CastOp
|
|
9965
11387
|
| QuantizeLinearOp
|
|
11388
|
+
| DequantizeLinearOp
|
|
9966
11389
|
| MatMulOp
|
|
9967
11390
|
| EinsumOp
|
|
9968
11391
|
| GemmOp
|
|
9969
11392
|
| AttentionOp
|
|
9970
11393
|
| ConvOp
|
|
11394
|
+
| ConvIntegerOp
|
|
9971
11395
|
| ConvTransposeOp
|
|
9972
11396
|
| AveragePoolOp
|
|
9973
11397
|
| LpPoolOp
|
|
@@ -9979,6 +11403,7 @@ class CEmitter:
|
|
|
9979
11403
|
| MeanVarianceNormalizationOp
|
|
9980
11404
|
| RMSNormalizationOp
|
|
9981
11405
|
| LrnOp
|
|
11406
|
+
| GruOp
|
|
9982
11407
|
| LstmOp
|
|
9983
11408
|
| SoftmaxOp
|
|
9984
11409
|
| LogSoftmaxOp
|
|
@@ -10009,9 +11434,11 @@ class CEmitter:
|
|
|
10009
11434
|
| ConstantOfShapeOp
|
|
10010
11435
|
| ShapeOp
|
|
10011
11436
|
| SizeOp
|
|
11437
|
+
| OptionalHasElementOp
|
|
10012
11438
|
| ExpandOp
|
|
10013
11439
|
| CumSumOp
|
|
10014
11440
|
| RangeOp
|
|
11441
|
+
| HammingWindowOp
|
|
10015
11442
|
| OneHotOp
|
|
10016
11443
|
| SplitOp,
|
|
10017
11444
|
) -> tuple[tuple[str, tuple[int, ...]], ...]:
|
|
@@ -10069,6 +11496,8 @@ class CEmitter:
|
|
|
10069
11496
|
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
10070
11497
|
if isinstance(op, NonZeroOp):
|
|
10071
11498
|
return ((op.input0, op.input_shape),)
|
|
11499
|
+
if isinstance(op, OptionalHasElementOp):
|
|
11500
|
+
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
10072
11501
|
if isinstance(op, NonMaxSuppressionOp):
|
|
10073
11502
|
inputs = [
|
|
10074
11503
|
(op.boxes, op.boxes_shape),
|
|
@@ -10104,6 +11533,20 @@ class CEmitter:
|
|
|
10104
11533
|
if op.zero_point is not None:
|
|
10105
11534
|
inputs.append((op.zero_point, scale_shape))
|
|
10106
11535
|
return tuple(inputs)
|
|
11536
|
+
if isinstance(op, DequantizeLinearOp):
|
|
11537
|
+
if op.axis is None:
|
|
11538
|
+
scale_shape = ()
|
|
11539
|
+
elif op.block_size:
|
|
11540
|
+
input_shape = self._ctx_shape(op.input0)
|
|
11541
|
+
scale_shape_list = list(input_shape)
|
|
11542
|
+
scale_shape_list[op.axis] = input_shape[op.axis] // op.block_size
|
|
11543
|
+
scale_shape = tuple(scale_shape_list)
|
|
11544
|
+
else:
|
|
11545
|
+
scale_shape = (self._ctx_shape(op.input0)[op.axis],)
|
|
11546
|
+
inputs = [(op.input0, self._ctx_shape(op.input0)), (op.scale, scale_shape)]
|
|
11547
|
+
if op.zero_point is not None:
|
|
11548
|
+
inputs.append((op.zero_point, scale_shape))
|
|
11549
|
+
return tuple(inputs)
|
|
10107
11550
|
if isinstance(op, IdentityOp):
|
|
10108
11551
|
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
10109
11552
|
if isinstance(op, EyeLikeOp):
|
|
@@ -10138,6 +11581,8 @@ class CEmitter:
|
|
|
10138
11581
|
return ((op.input0, op.input_shape),)
|
|
10139
11582
|
if isinstance(op, RangeOp):
|
|
10140
11583
|
return ((op.start, ()), (op.limit, ()), (op.delta, ()))
|
|
11584
|
+
if isinstance(op, HammingWindowOp):
|
|
11585
|
+
return ((op.size, ()),)
|
|
10141
11586
|
if isinstance(op, OneHotOp):
|
|
10142
11587
|
return (
|
|
10143
11588
|
(op.indices, op.indices_shape),
|
|
@@ -10147,7 +11592,10 @@ class CEmitter:
|
|
|
10147
11592
|
if isinstance(op, SplitOp):
|
|
10148
11593
|
return ((op.input0, op.input_shape),)
|
|
10149
11594
|
if isinstance(op, TopKOp):
|
|
10150
|
-
return (
|
|
11595
|
+
return (
|
|
11596
|
+
(op.input0, self._ctx_shape(op.input0)),
|
|
11597
|
+
(op.k_input, self._ctx_shape(op.k_input)),
|
|
11598
|
+
)
|
|
10151
11599
|
if isinstance(op, (TransposeOp, ReshapeOp, ReduceOp, ArgReduceOp)):
|
|
10152
11600
|
return ((op.input0, self._ctx_shape(op.input0)),)
|
|
10153
11601
|
return ()
|
|
@@ -10162,6 +11610,7 @@ class CEmitter:
|
|
|
10162
11610
|
| ClipOp
|
|
10163
11611
|
| CastOp
|
|
10164
11612
|
| QuantizeLinearOp
|
|
11613
|
+
| DequantizeLinearOp
|
|
10165
11614
|
| MatMulOp
|
|
10166
11615
|
| EinsumOp
|
|
10167
11616
|
| GemmOp
|
|
@@ -10178,6 +11627,7 @@ class CEmitter:
|
|
|
10178
11627
|
| MeanVarianceNormalizationOp
|
|
10179
11628
|
| RMSNormalizationOp
|
|
10180
11629
|
| LrnOp
|
|
11630
|
+
| GruOp
|
|
10181
11631
|
| LstmOp
|
|
10182
11632
|
| SoftmaxOp
|
|
10183
11633
|
| LogSoftmaxOp
|
|
@@ -10210,6 +11660,7 @@ class CEmitter:
|
|
|
10210
11660
|
| NonMaxSuppressionOp
|
|
10211
11661
|
| ExpandOp
|
|
10212
11662
|
| RangeOp
|
|
11663
|
+
| HammingWindowOp
|
|
10213
11664
|
| OneHotOp
|
|
10214
11665
|
| SplitOp
|
|
10215
11666
|
],
|
|
@@ -10234,11 +11685,13 @@ class CEmitter:
|
|
|
10234
11685
|
| ClipOp
|
|
10235
11686
|
| CastOp
|
|
10236
11687
|
| QuantizeLinearOp
|
|
11688
|
+
| DequantizeLinearOp
|
|
10237
11689
|
| MatMulOp
|
|
10238
11690
|
| EinsumOp
|
|
10239
11691
|
| GemmOp
|
|
10240
11692
|
| AttentionOp
|
|
10241
11693
|
| ConvOp
|
|
11694
|
+
| ConvIntegerOp
|
|
10242
11695
|
| ConvTransposeOp
|
|
10243
11696
|
| AveragePoolOp
|
|
10244
11697
|
| LpPoolOp
|
|
@@ -10250,6 +11703,7 @@ class CEmitter:
|
|
|
10250
11703
|
| MeanVarianceNormalizationOp
|
|
10251
11704
|
| RMSNormalizationOp
|
|
10252
11705
|
| LrnOp
|
|
11706
|
+
| GruOp
|
|
10253
11707
|
| LstmOp
|
|
10254
11708
|
| SoftmaxOp
|
|
10255
11709
|
| LogSoftmaxOp
|
|
@@ -10284,9 +11738,18 @@ class CEmitter:
|
|
|
10284
11738
|
| NonMaxSuppressionOp
|
|
10285
11739
|
| ExpandOp
|
|
10286
11740
|
| RangeOp
|
|
11741
|
+
| HammingWindowOp
|
|
10287
11742
|
| OneHotOp
|
|
10288
11743
|
| SplitOp,
|
|
10289
11744
|
) -> tuple[tuple[str, tuple[int, ...], ScalarType], ...]:
|
|
11745
|
+
if isinstance(op, OptionalHasElementOp):
|
|
11746
|
+
return (
|
|
11747
|
+
(
|
|
11748
|
+
op.output,
|
|
11749
|
+
self._op_output_shape(op),
|
|
11750
|
+
self._op_output_dtype(op),
|
|
11751
|
+
),
|
|
11752
|
+
)
|
|
10290
11753
|
if isinstance(
|
|
10291
11754
|
op,
|
|
10292
11755
|
(
|
|
@@ -10341,6 +11804,39 @@ class CEmitter:
|
|
|
10341
11804
|
)
|
|
10342
11805
|
)
|
|
10343
11806
|
return tuple(outputs)
|
|
11807
|
+
if isinstance(op, GruOp):
|
|
11808
|
+
outputs: list[tuple[str, tuple[int, ...], ScalarType]] = []
|
|
11809
|
+
if op.output_y is not None:
|
|
11810
|
+
if op.layout == 0:
|
|
11811
|
+
y_shape = (
|
|
11812
|
+
op.seq_length,
|
|
11813
|
+
op.num_directions,
|
|
11814
|
+
op.batch_size,
|
|
11815
|
+
op.hidden_size,
|
|
11816
|
+
)
|
|
11817
|
+
else:
|
|
11818
|
+
y_shape = (
|
|
11819
|
+
op.batch_size,
|
|
11820
|
+
op.seq_length,
|
|
11821
|
+
op.num_directions,
|
|
11822
|
+
op.hidden_size,
|
|
11823
|
+
)
|
|
11824
|
+
outputs.append((op.output_y, y_shape, op.dtype))
|
|
11825
|
+
if op.output_y_h is not None:
|
|
11826
|
+
if op.layout == 0:
|
|
11827
|
+
state_shape = (
|
|
11828
|
+
op.num_directions,
|
|
11829
|
+
op.batch_size,
|
|
11830
|
+
op.hidden_size,
|
|
11831
|
+
)
|
|
11832
|
+
else:
|
|
11833
|
+
state_shape = (
|
|
11834
|
+
op.batch_size,
|
|
11835
|
+
op.num_directions,
|
|
11836
|
+
op.hidden_size,
|
|
11837
|
+
)
|
|
11838
|
+
outputs.append((op.output_y_h, state_shape, op.dtype))
|
|
11839
|
+
return tuple(outputs)
|
|
10344
11840
|
if isinstance(op, LstmOp):
|
|
10345
11841
|
outputs: list[tuple[str, tuple[int, ...], ScalarType]] = []
|
|
10346
11842
|
if op.output_y is not None:
|
|
@@ -10456,12 +11952,14 @@ class CEmitter:
|
|
|
10456
11952
|
| ClipOp
|
|
10457
11953
|
| CastOp
|
|
10458
11954
|
| QuantizeLinearOp
|
|
11955
|
+
| DequantizeLinearOp
|
|
10459
11956
|
| QLinearMatMulOp
|
|
10460
11957
|
| MatMulOp
|
|
10461
11958
|
| EinsumOp
|
|
10462
11959
|
| GemmOp
|
|
10463
11960
|
| AttentionOp
|
|
10464
11961
|
| ConvOp
|
|
11962
|
+
| ConvIntegerOp
|
|
10465
11963
|
| AveragePoolOp
|
|
10466
11964
|
| BatchNormOp
|
|
10467
11965
|
| LpNormalizationOp
|
|
@@ -10471,6 +11969,7 @@ class CEmitter:
|
|
|
10471
11969
|
| MeanVarianceNormalizationOp
|
|
10472
11970
|
| RMSNormalizationOp
|
|
10473
11971
|
| LrnOp
|
|
11972
|
+
| GruOp
|
|
10474
11973
|
| LstmOp
|
|
10475
11974
|
| SoftmaxOp
|
|
10476
11975
|
| LogSoftmaxOp
|
|
@@ -10485,6 +11984,7 @@ class CEmitter:
|
|
|
10485
11984
|
| TransposeOp
|
|
10486
11985
|
| ReshapeOp
|
|
10487
11986
|
| IdentityOp
|
|
11987
|
+
| BernoulliOp
|
|
10488
11988
|
| EyeLikeOp
|
|
10489
11989
|
| TriluOp
|
|
10490
11990
|
| TileOp
|
|
@@ -10502,7 +12002,10 @@ class CEmitter:
|
|
|
10502
12002
|
| ExpandOp
|
|
10503
12003
|
| CumSumOp
|
|
10504
12004
|
| RangeOp
|
|
12005
|
+
| HammingWindowOp
|
|
10505
12006
|
| OneHotOp
|
|
12007
|
+
| TfIdfVectorizerOp
|
|
12008
|
+
| RotaryEmbeddingOp
|
|
10506
12009
|
| SplitOp
|
|
10507
12010
|
| PadOp,
|
|
10508
12011
|
) -> tuple[int, ...]:
|
|
@@ -10518,21 +12021,29 @@ class CEmitter:
|
|
|
10518
12021
|
return self._ctx_shape(op.output)
|
|
10519
12022
|
if isinstance(op, QuantizeLinearOp):
|
|
10520
12023
|
return op.input_shape
|
|
12024
|
+
if isinstance(op, DequantizeLinearOp):
|
|
12025
|
+
return op.input_shape
|
|
10521
12026
|
if isinstance(op, CastOp):
|
|
10522
12027
|
return self._ctx_shape(op.output)
|
|
12028
|
+
if isinstance(op, QLinearMulOp):
|
|
12029
|
+
return op.output_shape
|
|
10523
12030
|
if isinstance(op, QLinearMatMulOp):
|
|
10524
12031
|
return op.output_shape
|
|
10525
12032
|
if isinstance(op, MatMulOp):
|
|
10526
|
-
return op.
|
|
12033
|
+
return self._ctx_shape(op.output)
|
|
10527
12034
|
if isinstance(op, EinsumOp):
|
|
10528
12035
|
return op.output_shape
|
|
10529
12036
|
if isinstance(op, GemmOp):
|
|
10530
|
-
return (op.
|
|
12037
|
+
return self._ctx_shape(op.output)
|
|
10531
12038
|
if isinstance(op, ConvOp):
|
|
10532
12039
|
return (op.batch, op.out_channels, *op.out_spatial)
|
|
12040
|
+
if isinstance(op, ConvIntegerOp):
|
|
12041
|
+
return (op.batch, op.out_channels, *op.out_spatial)
|
|
10533
12042
|
if isinstance(op, ConvTransposeOp):
|
|
10534
12043
|
return (op.batch, op.out_channels, *op.out_spatial)
|
|
10535
12044
|
if isinstance(op, AveragePoolOp):
|
|
12045
|
+
if op.spatial_rank == 3:
|
|
12046
|
+
return (op.batch, op.channels, op.out_d, op.out_h, op.out_w)
|
|
10536
12047
|
return (op.batch, op.channels, op.out_h, op.out_w)
|
|
10537
12048
|
if isinstance(op, LpPoolOp):
|
|
10538
12049
|
return (op.batch, op.channels, op.out_h, op.out_w)
|
|
@@ -10582,6 +12093,8 @@ class CEmitter:
|
|
|
10582
12093
|
return self._ctx_shape(op.output)
|
|
10583
12094
|
if isinstance(op, IdentityOp):
|
|
10584
12095
|
return self._ctx_shape(op.output)
|
|
12096
|
+
if isinstance(op, BernoulliOp):
|
|
12097
|
+
return op.output_shape
|
|
10585
12098
|
if isinstance(op, EyeLikeOp):
|
|
10586
12099
|
return op.output_shape
|
|
10587
12100
|
if isinstance(op, TriluOp):
|
|
@@ -10612,6 +12125,8 @@ class CEmitter:
|
|
|
10612
12125
|
return op.output_shape
|
|
10613
12126
|
if isinstance(op, SizeOp):
|
|
10614
12127
|
return op.output_shape
|
|
12128
|
+
if isinstance(op, OptionalHasElementOp):
|
|
12129
|
+
return self._ctx_shape(op.output)
|
|
10615
12130
|
if isinstance(op, NonZeroOp):
|
|
10616
12131
|
return op.output_shape
|
|
10617
12132
|
if isinstance(op, NonMaxSuppressionOp):
|
|
@@ -10622,8 +12137,14 @@ class CEmitter:
|
|
|
10622
12137
|
return op.input_shape
|
|
10623
12138
|
if isinstance(op, RangeOp):
|
|
10624
12139
|
return op.output_shape
|
|
12140
|
+
if isinstance(op, HammingWindowOp):
|
|
12141
|
+
return op.output_shape
|
|
10625
12142
|
if isinstance(op, OneHotOp):
|
|
10626
12143
|
return op.output_shape
|
|
12144
|
+
if isinstance(op, TfIdfVectorizerOp):
|
|
12145
|
+
return op.output_shape
|
|
12146
|
+
if isinstance(op, RotaryEmbeddingOp):
|
|
12147
|
+
return op.input_shape
|
|
10627
12148
|
if op.output_rank == 3:
|
|
10628
12149
|
return (op.batch, op.q_seq, op.q_heads * op.v_head_size)
|
|
10629
12150
|
return (op.batch, op.q_heads, op.q_seq, op.v_head_size)
|
|
@@ -10637,11 +12158,13 @@ class CEmitter:
|
|
|
10637
12158
|
| ClipOp
|
|
10638
12159
|
| CastOp
|
|
10639
12160
|
| QuantizeLinearOp
|
|
12161
|
+
| DequantizeLinearOp
|
|
10640
12162
|
| MatMulOp
|
|
10641
12163
|
| EinsumOp
|
|
10642
12164
|
| GemmOp
|
|
10643
12165
|
| AttentionOp
|
|
10644
12166
|
| ConvOp
|
|
12167
|
+
| ConvIntegerOp
|
|
10645
12168
|
| ConvTransposeOp
|
|
10646
12169
|
| AveragePoolOp
|
|
10647
12170
|
| LpPoolOp
|
|
@@ -10666,6 +12189,7 @@ class CEmitter:
|
|
|
10666
12189
|
| TransposeOp
|
|
10667
12190
|
| ReshapeOp
|
|
10668
12191
|
| IdentityOp
|
|
12192
|
+
| BernoulliOp
|
|
10669
12193
|
| EyeLikeOp
|
|
10670
12194
|
| TriluOp
|
|
10671
12195
|
| TileOp
|
|
@@ -10681,7 +12205,9 @@ class CEmitter:
|
|
|
10681
12205
|
| ExpandOp
|
|
10682
12206
|
| CumSumOp
|
|
10683
12207
|
| RangeOp
|
|
12208
|
+
| HammingWindowOp
|
|
10684
12209
|
| OneHotOp
|
|
12210
|
+
| TfIdfVectorizerOp
|
|
10685
12211
|
| SplitOp
|
|
10686
12212
|
| PadOp,
|
|
10687
12213
|
) -> ScalarType:
|
|
@@ -10689,8 +12215,12 @@ class CEmitter:
|
|
|
10689
12215
|
return self._ctx_dtype(op.output)
|
|
10690
12216
|
if isinstance(op, TopKOp):
|
|
10691
12217
|
return self._ctx_dtype(op.output_values)
|
|
12218
|
+
if isinstance(op, OptionalHasElementOp):
|
|
12219
|
+
return self._ctx_dtype(op.output)
|
|
10692
12220
|
if isinstance(op, NonMaxSuppressionOp):
|
|
10693
12221
|
return op.output_dtype
|
|
12222
|
+
if isinstance(op, TfIdfVectorizerOp):
|
|
12223
|
+
return op.output_dtype
|
|
10694
12224
|
if isinstance(
|
|
10695
12225
|
op,
|
|
10696
12226
|
(
|
|
@@ -10703,6 +12233,8 @@ class CEmitter:
|
|
|
10703
12233
|
SoftmaxOp,
|
|
10704
12234
|
LogSoftmaxOp,
|
|
10705
12235
|
HardmaxOp,
|
|
12236
|
+
MatMulOp,
|
|
12237
|
+
GemmOp,
|
|
10706
12238
|
GatherOp,
|
|
10707
12239
|
TransposeOp,
|
|
10708
12240
|
ReshapeOp,
|
|
@@ -10729,10 +12261,12 @@ class CEmitter:
|
|
|
10729
12261
|
self,
|
|
10730
12262
|
shape: tuple[int, ...],
|
|
10731
12263
|
dim_names: Mapping[int, str] | None = None,
|
|
12264
|
+
*,
|
|
12265
|
+
use_restrict: bool = False,
|
|
10732
12266
|
) -> str:
|
|
10733
12267
|
shape = CEmitter._codegen_shape(shape)
|
|
10734
12268
|
dim_names = dim_names or {}
|
|
10735
|
-
if not self._restrict_arrays:
|
|
12269
|
+
if not (self._restrict_arrays and use_restrict):
|
|
10736
12270
|
return "".join(
|
|
10737
12271
|
f"[{dim_names.get(index, dim)}]"
|
|
10738
12272
|
for index, dim in enumerate(shape)
|
|
@@ -10755,6 +12289,16 @@ class CEmitter:
|
|
|
10755
12289
|
return ""
|
|
10756
12290
|
return ", ".join(f"int {dim_name}" for dim_name in dim_order) + ", "
|
|
10757
12291
|
|
|
12292
|
+
@staticmethod
|
|
12293
|
+
def _optional_input_flag_map(model: LoweredModel) -> dict[str, str]:
|
|
12294
|
+
return {
|
|
12295
|
+
name: flag
|
|
12296
|
+
for name, flag in zip(
|
|
12297
|
+
model.input_names, model.input_optional_names
|
|
12298
|
+
)
|
|
12299
|
+
if flag is not None
|
|
12300
|
+
}
|
|
12301
|
+
|
|
10758
12302
|
def _build_variable_dim_names(
|
|
10759
12303
|
self,
|
|
10760
12304
|
model: LoweredModel,
|
|
@@ -10772,6 +12316,12 @@ class CEmitter:
|
|
|
10772
12316
|
dim_vars: dict[tuple[str, int, int], str] = {}
|
|
10773
12317
|
dim_values: dict[str, int] = {}
|
|
10774
12318
|
reserved_names = set(model.input_names) | set(model.output_names)
|
|
12319
|
+
reserved_names.update(
|
|
12320
|
+
name for name in model.input_optional_names if name is not None
|
|
12321
|
+
)
|
|
12322
|
+
reserved_names.update(
|
|
12323
|
+
name for name in model.output_optional_names if name is not None
|
|
12324
|
+
)
|
|
10775
12325
|
used_names = set(reserved_names)
|
|
10776
12326
|
dim_aliases: dict[str, str] = {}
|
|
10777
12327
|
|
|
@@ -10926,14 +12476,19 @@ class CEmitter:
|
|
|
10926
12476
|
|
|
10927
12477
|
@staticmethod
|
|
10928
12478
|
def _matmul_index_exprs(
|
|
10929
|
-
op: MatMulOp,
|
|
10930
12479
|
batch_vars: tuple[str, ...],
|
|
10931
12480
|
row_var: str | None,
|
|
10932
12481
|
col_var: str | None,
|
|
10933
12482
|
batch_rank: int,
|
|
10934
12483
|
*,
|
|
10935
|
-
input0: str
|
|
10936
|
-
input1: str
|
|
12484
|
+
input0: str,
|
|
12485
|
+
input1: str,
|
|
12486
|
+
left_vector: bool,
|
|
12487
|
+
right_vector: bool,
|
|
12488
|
+
input0_shape: tuple[int, ...],
|
|
12489
|
+
input1_shape: tuple[int, ...],
|
|
12490
|
+
input0_batch_shape: tuple[int, ...],
|
|
12491
|
+
input1_batch_shape: tuple[int, ...],
|
|
10937
12492
|
) -> tuple[str, str]:
|
|
10938
12493
|
def batch_indices(
|
|
10939
12494
|
batch_shape: tuple[int, ...], actual_rank: int
|
|
@@ -10948,28 +12503,28 @@ class CEmitter:
|
|
|
10948
12503
|
indices.append("0" if dim == 1 else var)
|
|
10949
12504
|
return indices
|
|
10950
12505
|
|
|
10951
|
-
if
|
|
12506
|
+
if left_vector:
|
|
10952
12507
|
input0_indices = ["k"]
|
|
10953
12508
|
else:
|
|
10954
|
-
input0_batch_rank = len(
|
|
12509
|
+
input0_batch_rank = len(input0_shape) - 2
|
|
10955
12510
|
input0_indices = batch_indices(
|
|
10956
|
-
|
|
12511
|
+
input0_batch_shape, input0_batch_rank
|
|
10957
12512
|
)
|
|
10958
12513
|
input0_indices.append(row_var if row_var is not None else "0")
|
|
10959
12514
|
input0_indices.append("k")
|
|
10960
|
-
if
|
|
12515
|
+
if right_vector:
|
|
10961
12516
|
input1_indices = ["k"]
|
|
10962
12517
|
else:
|
|
10963
|
-
input1_batch_rank = len(
|
|
12518
|
+
input1_batch_rank = len(input1_shape) - 2
|
|
10964
12519
|
input1_indices = batch_indices(
|
|
10965
|
-
|
|
12520
|
+
input1_batch_shape, input1_batch_rank
|
|
10966
12521
|
)
|
|
10967
12522
|
input1_indices.append("k")
|
|
10968
12523
|
input1_indices.append(col_var if col_var is not None else "0")
|
|
10969
|
-
input0_index_expr = f"{input0
|
|
12524
|
+
input0_index_expr = f"{input0}" + "".join(
|
|
10970
12525
|
f"[{index}]" for index in input0_indices
|
|
10971
12526
|
)
|
|
10972
|
-
input1_index_expr = f"{input1
|
|
12527
|
+
input1_index_expr = f"{input1}" + "".join(
|
|
10973
12528
|
f"[{index}]" for index in input1_indices
|
|
10974
12529
|
)
|
|
10975
12530
|
return input0_index_expr, input1_index_expr
|
|
@@ -10980,6 +12535,7 @@ class CEmitter:
|
|
|
10980
12535
|
testbench_template,
|
|
10981
12536
|
*,
|
|
10982
12537
|
testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None = None,
|
|
12538
|
+
testbench_optional_inputs: Mapping[str, bool] | None = None,
|
|
10983
12539
|
dim_order: Sequence[str],
|
|
10984
12540
|
dim_values: Mapping[str, int],
|
|
10985
12541
|
weight_data_filename: str,
|
|
@@ -10988,13 +12544,34 @@ class CEmitter:
|
|
|
10988
12544
|
self._element_count(shape) for shape in model.input_shapes
|
|
10989
12545
|
)
|
|
10990
12546
|
testbench_inputs = testbench_inputs or {}
|
|
12547
|
+
testbench_optional_inputs = testbench_optional_inputs or {}
|
|
12548
|
+
rng_requires_u64 = False
|
|
12549
|
+
rng_requires_float = False
|
|
12550
|
+
rng_requires_double = False
|
|
12551
|
+
rng_requires_i64 = False
|
|
10991
12552
|
inputs = []
|
|
10992
|
-
for name, shape, count, dtype in zip(
|
|
10993
|
-
model.input_names,
|
|
12553
|
+
for name, shape, count, dtype, optional_flag in zip(
|
|
12554
|
+
model.input_names,
|
|
12555
|
+
model.input_shapes,
|
|
12556
|
+
input_counts,
|
|
12557
|
+
model.input_dtypes,
|
|
12558
|
+
model.input_optional_names,
|
|
10994
12559
|
):
|
|
12560
|
+
json_name = self._ctx_name(name)
|
|
10995
12561
|
codegen_shape = self._codegen_shape(shape)
|
|
10996
12562
|
loop_shape = (1,) if not shape else shape
|
|
10997
12563
|
loop_vars = self._loop_vars(loop_shape)
|
|
12564
|
+
constant_values = testbench_inputs.get(name)
|
|
12565
|
+
if constant_values is None:
|
|
12566
|
+
rng_requires_u64 = True
|
|
12567
|
+
if dtype in {ScalarType.F16, ScalarType.F32}:
|
|
12568
|
+
rng_requires_float = True
|
|
12569
|
+
elif dtype == ScalarType.F64:
|
|
12570
|
+
rng_requires_double = True
|
|
12571
|
+
elif dtype == ScalarType.BOOL:
|
|
12572
|
+
pass
|
|
12573
|
+
else:
|
|
12574
|
+
rng_requires_i64 = True
|
|
10998
12575
|
if dtype in {ScalarType.F16, ScalarType.F32}:
|
|
10999
12576
|
random_expr = "rng_next_float()"
|
|
11000
12577
|
elif dtype == ScalarType.F64:
|
|
@@ -11003,7 +12580,6 @@ class CEmitter:
|
|
|
11003
12580
|
random_expr = "((rng_next_u64() & 1ull) != 0)"
|
|
11004
12581
|
else:
|
|
11005
12582
|
random_expr = f"({dtype.c_type})rng_next_i64()"
|
|
11006
|
-
constant_values = testbench_inputs.get(name)
|
|
11007
12583
|
constant_name = None
|
|
11008
12584
|
constant_lines = None
|
|
11009
12585
|
if constant_values is not None:
|
|
@@ -11015,6 +12591,11 @@ class CEmitter:
|
|
|
11015
12591
|
]
|
|
11016
12592
|
else:
|
|
11017
12593
|
constant_lines = [self._format_value(0, dtype)]
|
|
12594
|
+
optional_present = (
|
|
12595
|
+
testbench_optional_inputs.get(name, True)
|
|
12596
|
+
if optional_flag is not None
|
|
12597
|
+
else None
|
|
12598
|
+
)
|
|
11018
12599
|
inputs.append(
|
|
11019
12600
|
{
|
|
11020
12601
|
"name": name,
|
|
@@ -11035,12 +12616,16 @@ class CEmitter:
|
|
|
11035
12616
|
"print_cast": self._print_cast(dtype),
|
|
11036
12617
|
"constant_name": constant_name,
|
|
11037
12618
|
"constant_lines": constant_lines,
|
|
12619
|
+
"json_name": json_name,
|
|
12620
|
+
"optional_flag_name": optional_flag,
|
|
12621
|
+
"optional_present": optional_present,
|
|
11038
12622
|
}
|
|
11039
12623
|
)
|
|
11040
12624
|
outputs = []
|
|
11041
12625
|
for name, shape, dtype in zip(
|
|
11042
12626
|
model.output_names, model.output_shapes, model.output_dtypes
|
|
11043
12627
|
):
|
|
12628
|
+
json_name = self._ctx_name(name)
|
|
11044
12629
|
codegen_shape = self._codegen_shape(shape)
|
|
11045
12630
|
loop_shape = (1,) if not shape else shape
|
|
11046
12631
|
output_loop_vars = self._loop_vars(loop_shape)
|
|
@@ -11061,10 +12646,15 @@ class CEmitter:
|
|
|
11061
12646
|
"c_type": dtype.c_type,
|
|
11062
12647
|
"print_format": self._print_format(dtype),
|
|
11063
12648
|
"print_cast": self._print_cast(dtype),
|
|
12649
|
+
"json_name": json_name,
|
|
11064
12650
|
}
|
|
11065
12651
|
)
|
|
11066
12652
|
rendered = testbench_template.render(
|
|
11067
12653
|
model_name=model.name,
|
|
12654
|
+
rng_requires_u64=rng_requires_u64,
|
|
12655
|
+
rng_requires_float=rng_requires_float,
|
|
12656
|
+
rng_requires_double=rng_requires_double,
|
|
12657
|
+
rng_requires_i64=rng_requires_i64,
|
|
11068
12658
|
dim_args=[
|
|
11069
12659
|
{"name": dim_name, "value": dim_values[dim_name]}
|
|
11070
12660
|
for dim_name in dim_order
|
|
@@ -11097,13 +12687,30 @@ class CEmitter:
|
|
|
11097
12687
|
) -> tuple[tuple[ConstTensor, ...], tuple[ConstTensor, ...]]:
|
|
11098
12688
|
if self._large_weight_threshold <= 0:
|
|
11099
12689
|
return constants, ()
|
|
12690
|
+
sorted_constants = sorted(
|
|
12691
|
+
enumerate(constants),
|
|
12692
|
+
key=lambda item: (
|
|
12693
|
+
self._element_count(item[1].shape)
|
|
12694
|
+
* item[1].dtype.np_dtype.itemsize,
|
|
12695
|
+
item[0],
|
|
12696
|
+
),
|
|
12697
|
+
)
|
|
12698
|
+
inline_set: set[ConstTensor] = set()
|
|
12699
|
+
total_bytes = 0
|
|
12700
|
+
for _, const in sorted_constants:
|
|
12701
|
+
const_bytes = (
|
|
12702
|
+
self._element_count(const.shape) * const.dtype.np_dtype.itemsize
|
|
12703
|
+
)
|
|
12704
|
+
if total_bytes + const_bytes <= self._large_weight_threshold:
|
|
12705
|
+
inline_set.add(const)
|
|
12706
|
+
total_bytes += const_bytes
|
|
11100
12707
|
inline: list[ConstTensor] = []
|
|
11101
12708
|
large: list[ConstTensor] = []
|
|
11102
12709
|
for const in constants:
|
|
11103
|
-
if
|
|
11104
|
-
large.append(const)
|
|
11105
|
-
else:
|
|
12710
|
+
if const in inline_set:
|
|
11106
12711
|
inline.append(const)
|
|
12712
|
+
else:
|
|
12713
|
+
large.append(const)
|
|
11107
12714
|
return tuple(inline), tuple(large)
|
|
11108
12715
|
|
|
11109
12716
|
@staticmethod
|
|
@@ -11113,12 +12720,16 @@ class CEmitter:
|
|
|
11113
12720
|
def _emit_weight_loader(
|
|
11114
12721
|
self, model: LoweredModel, large_constants: tuple[ConstTensor, ...]
|
|
11115
12722
|
) -> str:
|
|
11116
|
-
lines = [
|
|
12723
|
+
lines = []
|
|
11117
12724
|
if not large_constants:
|
|
12725
|
+
lines.append(f"_Bool {model.name}_load(const char *path) {{")
|
|
11118
12726
|
lines.append(" (void)path;")
|
|
11119
12727
|
lines.append(" return 1;")
|
|
11120
12728
|
lines.append("}")
|
|
11121
12729
|
return _format_c_indentation("\n".join(lines))
|
|
12730
|
+
lines.append(f"static _Bool {model.name}_load_file(FILE *file);")
|
|
12731
|
+
lines.append("")
|
|
12732
|
+
lines.append(f"_Bool {model.name}_load(const char *path) {{")
|
|
11122
12733
|
lines.append(" FILE *file = fopen(path, \"rb\");")
|
|
11123
12734
|
lines.append(" if (!file) {")
|
|
11124
12735
|
lines.append(" return 0;")
|
|
@@ -11171,7 +12782,7 @@ class CEmitter:
|
|
|
11171
12782
|
for value in const.data
|
|
11172
12783
|
]
|
|
11173
12784
|
lines.append(
|
|
11174
|
-
f"{storage_prefix} {c_type} {const.name}{array_suffix} = {{"
|
|
12785
|
+
f"{storage_prefix} EMX_UNUSED {c_type} {const.name}{array_suffix} = {{"
|
|
11175
12786
|
)
|
|
11176
12787
|
if values:
|
|
11177
12788
|
if (
|
|
@@ -11199,12 +12810,23 @@ class CEmitter:
|
|
|
11199
12810
|
return ""
|
|
11200
12811
|
lines = []
|
|
11201
12812
|
for index, const in enumerate(constants, start=1):
|
|
11202
|
-
lines.append(self._emit_constant_comment(const, index))
|
|
11203
12813
|
c_type = const.dtype.c_type
|
|
11204
12814
|
array_suffix = self._array_suffix(const.shape)
|
|
11205
12815
|
lines.append(f"extern const {c_type} {const.name}{array_suffix};")
|
|
11206
12816
|
return "\n".join(lines)
|
|
11207
12817
|
|
|
12818
|
+
def _emit_constant_storage_declarations(
|
|
12819
|
+
self, constants: tuple[ConstTensor, ...]
|
|
12820
|
+
) -> str:
|
|
12821
|
+
if not constants:
|
|
12822
|
+
return ""
|
|
12823
|
+
lines = []
|
|
12824
|
+
for index, const in enumerate(constants, start=1):
|
|
12825
|
+
c_type = const.dtype.c_type
|
|
12826
|
+
array_suffix = self._array_suffix(const.shape)
|
|
12827
|
+
lines.append(f"extern {c_type} {const.name}{array_suffix};")
|
|
12828
|
+
return "\n".join(lines)
|
|
12829
|
+
|
|
11208
12830
|
def _emit_constant_storage_definitions(
|
|
11209
12831
|
self,
|
|
11210
12832
|
constants: tuple[ConstTensor, ...],
|
|
@@ -11214,11 +12836,12 @@ class CEmitter:
|
|
|
11214
12836
|
if not constants:
|
|
11215
12837
|
return ""
|
|
11216
12838
|
lines: list[str] = []
|
|
12839
|
+
prefix = f"{storage_prefix} " if storage_prefix else ""
|
|
11217
12840
|
for index, const in enumerate(constants, start=1):
|
|
11218
12841
|
lines.append(self._emit_constant_comment(const, index))
|
|
11219
12842
|
c_type = const.dtype.c_type
|
|
11220
12843
|
array_suffix = self._array_suffix(const.shape)
|
|
11221
|
-
lines.append(f"{
|
|
12844
|
+
lines.append(f"{prefix}{c_type} {const.name}{array_suffix};")
|
|
11222
12845
|
lines.append("")
|
|
11223
12846
|
if lines and not lines[-1]:
|
|
11224
12847
|
lines.pop()
|