emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +340 -59
- emx_onnx_cgen/codegen/c_emitter.py +2369 -111
- emx_onnx_cgen/compiler.py +188 -5
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/lowering/common.py +379 -2
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/gather_elements.py +1 -3
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +6 -5
- emx_onnx_cgen/lowering/logsoftmax.py +5 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/matmul.py +6 -7
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/reduce.py +5 -6
- emx_onnx_cgen/lowering/reshape.py +223 -51
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/softmax.py +5 -1
- emx_onnx_cgen/lowering/squeeze.py +5 -5
- emx_onnx_cgen/lowering/topk.py +116 -0
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +5 -5
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +460 -42
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +61 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
- shared/scalar_functions.py +49 -17
- shared/ulp.py +48 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
4
5
|
import itertools
|
|
6
|
+
import math
|
|
7
|
+
from math import prod
|
|
5
8
|
from pathlib import Path
|
|
6
9
|
import re
|
|
10
|
+
import struct
|
|
7
11
|
from typing import Mapping, Sequence
|
|
8
12
|
|
|
9
13
|
from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
|
|
14
|
+
import numpy as np
|
|
10
15
|
|
|
11
16
|
from ..errors import CodegenError
|
|
12
17
|
from ..ops import (
|
|
@@ -24,6 +29,38 @@ from shared.scalar_types import ScalarFunctionError, ScalarType
|
|
|
24
29
|
|
|
25
30
|
|
|
26
31
|
def _format_c_indentation(source: str, *, indent: str = " ") -> str:
|
|
32
|
+
def strip_string_literals(line: str) -> str:
|
|
33
|
+
sanitized: list[str] = []
|
|
34
|
+
in_string = False
|
|
35
|
+
in_char = False
|
|
36
|
+
escape = False
|
|
37
|
+
for char in line:
|
|
38
|
+
if escape:
|
|
39
|
+
escape = False
|
|
40
|
+
if not (in_string or in_char):
|
|
41
|
+
sanitized.append(char)
|
|
42
|
+
continue
|
|
43
|
+
if in_string:
|
|
44
|
+
if char == "\\":
|
|
45
|
+
escape = True
|
|
46
|
+
elif char == '"':
|
|
47
|
+
in_string = False
|
|
48
|
+
continue
|
|
49
|
+
if in_char:
|
|
50
|
+
if char == "\\":
|
|
51
|
+
escape = True
|
|
52
|
+
elif char == "'":
|
|
53
|
+
in_char = False
|
|
54
|
+
continue
|
|
55
|
+
if char == '"':
|
|
56
|
+
in_string = True
|
|
57
|
+
continue
|
|
58
|
+
if char == "'":
|
|
59
|
+
in_char = True
|
|
60
|
+
continue
|
|
61
|
+
sanitized.append(char)
|
|
62
|
+
return "".join(sanitized)
|
|
63
|
+
|
|
27
64
|
formatted_lines: list[str] = []
|
|
28
65
|
indent_level = 0
|
|
29
66
|
for line in source.splitlines():
|
|
@@ -34,8 +71,9 @@ def _format_c_indentation(source: str, *, indent: str = " ") -> str:
|
|
|
34
71
|
if stripped.startswith("}"):
|
|
35
72
|
indent_level = max(indent_level - 1, 0)
|
|
36
73
|
formatted_lines.append(f"{indent * indent_level}{stripped}")
|
|
37
|
-
|
|
38
|
-
|
|
74
|
+
sanitized = strip_string_literals(stripped)
|
|
75
|
+
open_count = sanitized.count("{")
|
|
76
|
+
close_count = sanitized.count("}")
|
|
39
77
|
if stripped.startswith("}"):
|
|
40
78
|
close_count = max(close_count - 1, 0)
|
|
41
79
|
indent_level += open_count - close_count
|
|
@@ -119,6 +157,8 @@ class BinaryOp:
|
|
|
119
157
|
output: str
|
|
120
158
|
function: ScalarFunction
|
|
121
159
|
operator_kind: OperatorKind
|
|
160
|
+
input0_shape: tuple[int, ...]
|
|
161
|
+
input1_shape: tuple[int, ...]
|
|
122
162
|
shape: tuple[int, ...]
|
|
123
163
|
dtype: ScalarType
|
|
124
164
|
input_dtype: ScalarType
|
|
@@ -211,6 +251,26 @@ class MatMulOp:
|
|
|
211
251
|
dtype: ScalarType
|
|
212
252
|
|
|
213
253
|
|
|
254
|
+
class EinsumKind(str, Enum):
|
|
255
|
+
REDUCE_ALL = "reduce_all"
|
|
256
|
+
SUM_J = "sum_j"
|
|
257
|
+
TRANSPOSE = "transpose"
|
|
258
|
+
DOT = "dot"
|
|
259
|
+
BATCH_MATMUL = "batch_matmul"
|
|
260
|
+
BATCH_DIAGONAL = "batch_diagonal"
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@dataclass(frozen=True)
|
|
264
|
+
class EinsumOp:
|
|
265
|
+
inputs: tuple[str, ...]
|
|
266
|
+
output: str
|
|
267
|
+
kind: EinsumKind
|
|
268
|
+
input_shapes: tuple[tuple[int, ...], ...]
|
|
269
|
+
output_shape: tuple[int, ...]
|
|
270
|
+
dtype: ScalarType
|
|
271
|
+
input_dtype: ScalarType
|
|
272
|
+
|
|
273
|
+
|
|
214
274
|
@dataclass(frozen=True)
|
|
215
275
|
class GemmOp:
|
|
216
276
|
input_a: str
|
|
@@ -305,6 +365,27 @@ class ConvOp:
|
|
|
305
365
|
return self.out_spatial[1]
|
|
306
366
|
|
|
307
367
|
|
|
368
|
+
@dataclass(frozen=True)
|
|
369
|
+
class ConvTransposeOp:
|
|
370
|
+
input0: str
|
|
371
|
+
weights: str
|
|
372
|
+
bias: str | None
|
|
373
|
+
output: str
|
|
374
|
+
batch: int
|
|
375
|
+
in_channels: int
|
|
376
|
+
out_channels: int
|
|
377
|
+
spatial_rank: int
|
|
378
|
+
in_spatial: tuple[int, ...]
|
|
379
|
+
out_spatial: tuple[int, ...]
|
|
380
|
+
kernel_shape: tuple[int, ...]
|
|
381
|
+
strides: tuple[int, ...]
|
|
382
|
+
pads: tuple[int, ...]
|
|
383
|
+
dilations: tuple[int, ...]
|
|
384
|
+
output_padding: tuple[int, ...]
|
|
385
|
+
group: int
|
|
386
|
+
dtype: ScalarType
|
|
387
|
+
|
|
388
|
+
|
|
308
389
|
@dataclass(frozen=True)
|
|
309
390
|
class AveragePoolOp:
|
|
310
391
|
input0: str
|
|
@@ -327,6 +408,41 @@ class AveragePoolOp:
|
|
|
327
408
|
dtype: ScalarType
|
|
328
409
|
|
|
329
410
|
|
|
411
|
+
@dataclass(frozen=True)
|
|
412
|
+
class LpPoolOp:
|
|
413
|
+
input0: str
|
|
414
|
+
output: str
|
|
415
|
+
batch: int
|
|
416
|
+
channels: int
|
|
417
|
+
in_h: int
|
|
418
|
+
in_w: int
|
|
419
|
+
out_h: int
|
|
420
|
+
out_w: int
|
|
421
|
+
kernel_h: int
|
|
422
|
+
kernel_w: int
|
|
423
|
+
stride_h: int
|
|
424
|
+
stride_w: int
|
|
425
|
+
pad_top: int
|
|
426
|
+
pad_left: int
|
|
427
|
+
pad_bottom: int
|
|
428
|
+
pad_right: int
|
|
429
|
+
p: int
|
|
430
|
+
dtype: ScalarType
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
@dataclass(frozen=True)
|
|
434
|
+
class QuantizeLinearOp:
|
|
435
|
+
input0: str
|
|
436
|
+
scale: str
|
|
437
|
+
zero_point: str | None
|
|
438
|
+
output: str
|
|
439
|
+
input_shape: tuple[int, ...]
|
|
440
|
+
axis: int | None
|
|
441
|
+
dtype: ScalarType
|
|
442
|
+
input_dtype: ScalarType
|
|
443
|
+
scale_dtype: ScalarType
|
|
444
|
+
|
|
445
|
+
|
|
330
446
|
@dataclass(frozen=True)
|
|
331
447
|
class SoftmaxOp:
|
|
332
448
|
input0: str
|
|
@@ -351,6 +467,18 @@ class LogSoftmaxOp:
|
|
|
351
467
|
dtype: ScalarType
|
|
352
468
|
|
|
353
469
|
|
|
470
|
+
@dataclass(frozen=True)
|
|
471
|
+
class HardmaxOp:
|
|
472
|
+
input0: str
|
|
473
|
+
output: str
|
|
474
|
+
outer: int
|
|
475
|
+
axis_size: int
|
|
476
|
+
inner: int
|
|
477
|
+
axis: int
|
|
478
|
+
shape: tuple[int, ...]
|
|
479
|
+
dtype: ScalarType
|
|
480
|
+
|
|
481
|
+
|
|
354
482
|
@dataclass(frozen=True)
|
|
355
483
|
class NegativeLogLikelihoodLossOp:
|
|
356
484
|
input0: str
|
|
@@ -595,6 +723,34 @@ class GatherOp:
|
|
|
595
723
|
indices_dtype: ScalarType
|
|
596
724
|
|
|
597
725
|
|
|
726
|
+
@dataclass(frozen=True)
|
|
727
|
+
class GatherNDOp:
|
|
728
|
+
data: str
|
|
729
|
+
indices: str
|
|
730
|
+
output: str
|
|
731
|
+
batch_dims: int
|
|
732
|
+
data_shape: tuple[int, ...]
|
|
733
|
+
indices_shape: tuple[int, ...]
|
|
734
|
+
output_shape: tuple[int, ...]
|
|
735
|
+
dtype: ScalarType
|
|
736
|
+
indices_dtype: ScalarType
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
@dataclass(frozen=True)
|
|
740
|
+
class ScatterNDOp:
|
|
741
|
+
data: str
|
|
742
|
+
indices: str
|
|
743
|
+
updates: str
|
|
744
|
+
output: str
|
|
745
|
+
data_shape: tuple[int, ...]
|
|
746
|
+
indices_shape: tuple[int, ...]
|
|
747
|
+
updates_shape: tuple[int, ...]
|
|
748
|
+
output_shape: tuple[int, ...]
|
|
749
|
+
reduction: str
|
|
750
|
+
dtype: ScalarType
|
|
751
|
+
indices_dtype: ScalarType
|
|
752
|
+
|
|
753
|
+
|
|
598
754
|
@dataclass(frozen=True)
|
|
599
755
|
class TransposeOp:
|
|
600
756
|
input0: str
|
|
@@ -635,6 +791,21 @@ class EyeLikeOp:
|
|
|
635
791
|
input_dtype: ScalarType
|
|
636
792
|
|
|
637
793
|
|
|
794
|
+
@dataclass(frozen=True)
|
|
795
|
+
class TriluOp:
|
|
796
|
+
input0: str
|
|
797
|
+
output: str
|
|
798
|
+
input_shape: tuple[int, ...]
|
|
799
|
+
output_shape: tuple[int, ...]
|
|
800
|
+
upper: bool
|
|
801
|
+
k_value: int
|
|
802
|
+
k_input: str | None
|
|
803
|
+
k_input_shape: tuple[int, ...] | None
|
|
804
|
+
k_input_dtype: ScalarType | None
|
|
805
|
+
dtype: ScalarType
|
|
806
|
+
input_dtype: ScalarType
|
|
807
|
+
|
|
808
|
+
|
|
638
809
|
@dataclass(frozen=True)
|
|
639
810
|
class TileOp:
|
|
640
811
|
input0: str
|
|
@@ -800,6 +971,22 @@ class ArgReduceOp:
|
|
|
800
971
|
output_dtype: ScalarType
|
|
801
972
|
|
|
802
973
|
|
|
974
|
+
@dataclass(frozen=True)
|
|
975
|
+
class TopKOp:
|
|
976
|
+
input0: str
|
|
977
|
+
output_values: str
|
|
978
|
+
output_indices: str
|
|
979
|
+
input_shape: tuple[int, ...]
|
|
980
|
+
output_shape: tuple[int, ...]
|
|
981
|
+
axis: int
|
|
982
|
+
k: int
|
|
983
|
+
largest: bool
|
|
984
|
+
sorted: bool
|
|
985
|
+
input_dtype: ScalarType
|
|
986
|
+
output_values_dtype: ScalarType
|
|
987
|
+
output_indices_dtype: ScalarType
|
|
988
|
+
|
|
989
|
+
|
|
803
990
|
@dataclass(frozen=True)
|
|
804
991
|
class ConstantOfShapeOp:
|
|
805
992
|
input0: str
|
|
@@ -833,6 +1020,16 @@ class SizeOp:
|
|
|
833
1020
|
input_dtype: ScalarType
|
|
834
1021
|
|
|
835
1022
|
|
|
1023
|
+
@dataclass(frozen=True)
|
|
1024
|
+
class NonZeroOp:
|
|
1025
|
+
input0: str
|
|
1026
|
+
output: str
|
|
1027
|
+
input_shape: tuple[int, ...]
|
|
1028
|
+
output_shape: tuple[int, ...]
|
|
1029
|
+
dtype: ScalarType
|
|
1030
|
+
input_dtype: ScalarType
|
|
1031
|
+
|
|
1032
|
+
|
|
836
1033
|
@dataclass(frozen=True)
|
|
837
1034
|
class ExpandOp:
|
|
838
1035
|
input0: str
|
|
@@ -871,6 +1068,22 @@ class RangeOp:
|
|
|
871
1068
|
input_dtype: ScalarType
|
|
872
1069
|
|
|
873
1070
|
|
|
1071
|
+
@dataclass(frozen=True)
|
|
1072
|
+
class OneHotOp:
|
|
1073
|
+
indices: str
|
|
1074
|
+
depth: str
|
|
1075
|
+
values: str
|
|
1076
|
+
output: str
|
|
1077
|
+
axis: int
|
|
1078
|
+
indices_shape: tuple[int, ...]
|
|
1079
|
+
values_shape: tuple[int, ...]
|
|
1080
|
+
output_shape: tuple[int, ...]
|
|
1081
|
+
depth_dim: int
|
|
1082
|
+
dtype: ScalarType
|
|
1083
|
+
indices_dtype: ScalarType
|
|
1084
|
+
depth_dtype: ScalarType
|
|
1085
|
+
|
|
1086
|
+
|
|
874
1087
|
@dataclass(frozen=True)
|
|
875
1088
|
class SplitOp:
|
|
876
1089
|
input0: str
|
|
@@ -937,11 +1150,15 @@ class LoweredModel:
|
|
|
937
1150
|
| UnaryOp
|
|
938
1151
|
| ClipOp
|
|
939
1152
|
| CastOp
|
|
1153
|
+
| QuantizeLinearOp
|
|
940
1154
|
| MatMulOp
|
|
1155
|
+
| EinsumOp
|
|
941
1156
|
| GemmOp
|
|
942
1157
|
| AttentionOp
|
|
943
1158
|
| ConvOp
|
|
1159
|
+
| ConvTransposeOp
|
|
944
1160
|
| AveragePoolOp
|
|
1161
|
+
| LpPoolOp
|
|
945
1162
|
| BatchNormOp
|
|
946
1163
|
| LpNormalizationOp
|
|
947
1164
|
| InstanceNormalizationOp
|
|
@@ -953,16 +1170,20 @@ class LoweredModel:
|
|
|
953
1170
|
| LstmOp
|
|
954
1171
|
| SoftmaxOp
|
|
955
1172
|
| LogSoftmaxOp
|
|
1173
|
+
| HardmaxOp
|
|
956
1174
|
| NegativeLogLikelihoodLossOp
|
|
957
1175
|
| SoftmaxCrossEntropyLossOp
|
|
958
1176
|
| MaxPoolOp
|
|
959
1177
|
| ConcatOp
|
|
960
1178
|
| GatherElementsOp
|
|
961
1179
|
| GatherOp
|
|
1180
|
+
| GatherNDOp
|
|
1181
|
+
| ScatterNDOp
|
|
962
1182
|
| TransposeOp
|
|
963
1183
|
| ReshapeOp
|
|
964
1184
|
| IdentityOp
|
|
965
1185
|
| EyeLikeOp
|
|
1186
|
+
| TriluOp
|
|
966
1187
|
| TileOp
|
|
967
1188
|
| PadOp
|
|
968
1189
|
| DepthToSpaceOp
|
|
@@ -972,12 +1193,15 @@ class LoweredModel:
|
|
|
972
1193
|
| GridSampleOp
|
|
973
1194
|
| ReduceOp
|
|
974
1195
|
| ArgReduceOp
|
|
1196
|
+
| TopKOp
|
|
975
1197
|
| ConstantOfShapeOp
|
|
976
1198
|
| ShapeOp
|
|
977
1199
|
| SizeOp
|
|
1200
|
+
| NonZeroOp
|
|
978
1201
|
| ExpandOp
|
|
979
1202
|
| CumSumOp
|
|
980
1203
|
| RangeOp
|
|
1204
|
+
| OneHotOp
|
|
981
1205
|
| SplitOp,
|
|
982
1206
|
...,
|
|
983
1207
|
]
|
|
@@ -986,7 +1210,15 @@ class LoweredModel:
|
|
|
986
1210
|
|
|
987
1211
|
|
|
988
1212
|
class CEmitter:
|
|
989
|
-
def __init__(
|
|
1213
|
+
def __init__(
|
|
1214
|
+
self,
|
|
1215
|
+
template_dir: Path,
|
|
1216
|
+
*,
|
|
1217
|
+
restrict_arrays: bool = True,
|
|
1218
|
+
truncate_weights_after: int | None = None,
|
|
1219
|
+
large_temp_threshold_bytes: int = 1024,
|
|
1220
|
+
large_weight_threshold: int = 1024,
|
|
1221
|
+
) -> None:
|
|
990
1222
|
self._env = Environment(
|
|
991
1223
|
loader=FileSystemLoader(str(template_dir)),
|
|
992
1224
|
autoescape=select_autoescape(enabled_extensions=()),
|
|
@@ -994,6 +1226,15 @@ class CEmitter:
|
|
|
994
1226
|
lstrip_blocks=True,
|
|
995
1227
|
)
|
|
996
1228
|
self._restrict_arrays = restrict_arrays
|
|
1229
|
+
if truncate_weights_after is not None and truncate_weights_after < 1:
|
|
1230
|
+
raise CodegenError("truncate_weights_after must be >= 1")
|
|
1231
|
+
self._truncate_weights_after = truncate_weights_after
|
|
1232
|
+
if large_temp_threshold_bytes < 0:
|
|
1233
|
+
raise CodegenError("large_temp_threshold_bytes must be >= 0")
|
|
1234
|
+
self._large_temp_threshold_bytes = large_temp_threshold_bytes
|
|
1235
|
+
if large_weight_threshold < 0:
|
|
1236
|
+
raise CodegenError("large_weight_threshold must be >= 0")
|
|
1237
|
+
self._large_weight_threshold = large_weight_threshold
|
|
997
1238
|
|
|
998
1239
|
@staticmethod
|
|
999
1240
|
def _sanitize_identifier(name: str) -> str:
|
|
@@ -1006,10 +1247,8 @@ class CEmitter:
|
|
|
1006
1247
|
|
|
1007
1248
|
def _op_function_name(self, model: LoweredModel, index: int) -> str:
|
|
1008
1249
|
node_info = model.node_infos[index]
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
parts.append(node_info.name)
|
|
1012
|
-
base_name = "_".join(parts)
|
|
1250
|
+
suffix = node_info.name or node_info.op_type
|
|
1251
|
+
base_name = f"node{index}_{suffix}".lower()
|
|
1013
1252
|
return self._sanitize_identifier(base_name)
|
|
1014
1253
|
|
|
1015
1254
|
@staticmethod
|
|
@@ -1094,7 +1333,9 @@ class CEmitter:
|
|
|
1094
1333
|
| UnaryOp
|
|
1095
1334
|
| ClipOp
|
|
1096
1335
|
| CastOp
|
|
1336
|
+
| QuantizeLinearOp
|
|
1097
1337
|
| MatMulOp
|
|
1338
|
+
| EinsumOp
|
|
1098
1339
|
| GemmOp
|
|
1099
1340
|
| AttentionOp
|
|
1100
1341
|
| ConvOp
|
|
@@ -1110,16 +1351,20 @@ class CEmitter:
|
|
|
1110
1351
|
| LstmOp
|
|
1111
1352
|
| SoftmaxOp
|
|
1112
1353
|
| LogSoftmaxOp
|
|
1354
|
+
| HardmaxOp
|
|
1113
1355
|
| NegativeLogLikelihoodLossOp
|
|
1114
1356
|
| SoftmaxCrossEntropyLossOp
|
|
1115
1357
|
| MaxPoolOp
|
|
1116
1358
|
| ConcatOp
|
|
1117
1359
|
| GatherElementsOp
|
|
1118
1360
|
| GatherOp
|
|
1361
|
+
| GatherNDOp
|
|
1362
|
+
| ScatterNDOp
|
|
1119
1363
|
| TransposeOp
|
|
1120
1364
|
| ReshapeOp
|
|
1121
1365
|
| IdentityOp
|
|
1122
1366
|
| EyeLikeOp
|
|
1367
|
+
| TriluOp
|
|
1123
1368
|
| TileOp
|
|
1124
1369
|
| PadOp
|
|
1125
1370
|
| DepthToSpaceOp
|
|
@@ -1129,12 +1374,15 @@ class CEmitter:
|
|
|
1129
1374
|
| GridSampleOp
|
|
1130
1375
|
| ReduceOp
|
|
1131
1376
|
| ArgReduceOp
|
|
1377
|
+
| TopKOp
|
|
1132
1378
|
| ConstantOfShapeOp
|
|
1133
1379
|
| ShapeOp
|
|
1134
1380
|
| SizeOp
|
|
1381
|
+
| NonZeroOp
|
|
1135
1382
|
| ExpandOp
|
|
1136
1383
|
| CumSumOp
|
|
1137
1384
|
| RangeOp
|
|
1385
|
+
| OneHotOp
|
|
1138
1386
|
| SplitOp,
|
|
1139
1387
|
) -> tuple[str, ...]:
|
|
1140
1388
|
if isinstance(op, BinaryOp):
|
|
@@ -1155,8 +1403,16 @@ class CEmitter:
|
|
|
1155
1403
|
return tuple(names)
|
|
1156
1404
|
if isinstance(op, CastOp):
|
|
1157
1405
|
return (op.input0, op.output)
|
|
1406
|
+
if isinstance(op, QuantizeLinearOp):
|
|
1407
|
+
names = [op.input0, op.scale]
|
|
1408
|
+
if op.zero_point is not None:
|
|
1409
|
+
names.append(op.zero_point)
|
|
1410
|
+
names.append(op.output)
|
|
1411
|
+
return tuple(names)
|
|
1158
1412
|
if isinstance(op, MatMulOp):
|
|
1159
1413
|
return (op.input0, op.input1, op.output)
|
|
1414
|
+
if isinstance(op, EinsumOp):
|
|
1415
|
+
return (*op.inputs, op.output)
|
|
1160
1416
|
if isinstance(op, GemmOp):
|
|
1161
1417
|
names = [op.input_a, op.input_b]
|
|
1162
1418
|
if op.input_c is not None:
|
|
@@ -1187,8 +1443,16 @@ class CEmitter:
|
|
|
1187
1443
|
names.append(op.bias)
|
|
1188
1444
|
names.append(op.output)
|
|
1189
1445
|
return tuple(names)
|
|
1446
|
+
if isinstance(op, ConvTransposeOp):
|
|
1447
|
+
names = [op.input0, op.weights]
|
|
1448
|
+
if op.bias is not None:
|
|
1449
|
+
names.append(op.bias)
|
|
1450
|
+
names.append(op.output)
|
|
1451
|
+
return tuple(names)
|
|
1190
1452
|
if isinstance(op, AveragePoolOp):
|
|
1191
1453
|
return (op.input0, op.output)
|
|
1454
|
+
if isinstance(op, LpPoolOp):
|
|
1455
|
+
return (op.input0, op.output)
|
|
1192
1456
|
if isinstance(op, BatchNormOp):
|
|
1193
1457
|
return (op.input0, op.scale, op.bias, op.mean, op.variance, op.output)
|
|
1194
1458
|
if isinstance(op, LpNormalizationOp):
|
|
@@ -1230,7 +1494,7 @@ class CEmitter:
|
|
|
1230
1494
|
if op.output_y_c is not None:
|
|
1231
1495
|
names.append(op.output_y_c)
|
|
1232
1496
|
return tuple(names)
|
|
1233
|
-
if isinstance(op, (SoftmaxOp, LogSoftmaxOp)):
|
|
1497
|
+
if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
|
|
1234
1498
|
return (op.input0, op.output)
|
|
1235
1499
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
1236
1500
|
names = [op.input0, op.target]
|
|
@@ -1255,6 +1519,10 @@ class CEmitter:
|
|
|
1255
1519
|
return (op.data, op.indices, op.output)
|
|
1256
1520
|
if isinstance(op, GatherOp):
|
|
1257
1521
|
return (op.data, op.indices, op.output)
|
|
1522
|
+
if isinstance(op, GatherNDOp):
|
|
1523
|
+
return (op.data, op.indices, op.output)
|
|
1524
|
+
if isinstance(op, ScatterNDOp):
|
|
1525
|
+
return (op.data, op.indices, op.updates, op.output)
|
|
1258
1526
|
if isinstance(op, ConcatOp):
|
|
1259
1527
|
return (*op.inputs, op.output)
|
|
1260
1528
|
if isinstance(op, ConstantOfShapeOp):
|
|
@@ -1263,6 +1531,8 @@ class CEmitter:
|
|
|
1263
1531
|
return (op.input0, op.output)
|
|
1264
1532
|
if isinstance(op, SizeOp):
|
|
1265
1533
|
return (op.input0, op.output)
|
|
1534
|
+
if isinstance(op, NonZeroOp):
|
|
1535
|
+
return (op.input0, op.output)
|
|
1266
1536
|
if isinstance(op, ExpandOp):
|
|
1267
1537
|
return (op.input0, op.output)
|
|
1268
1538
|
if isinstance(op, CumSumOp):
|
|
@@ -1273,6 +1543,8 @@ class CEmitter:
|
|
|
1273
1543
|
return tuple(names)
|
|
1274
1544
|
if isinstance(op, RangeOp):
|
|
1275
1545
|
return (op.start, op.limit, op.delta, op.output)
|
|
1546
|
+
if isinstance(op, OneHotOp):
|
|
1547
|
+
return (op.indices, op.depth, op.values, op.output)
|
|
1276
1548
|
if isinstance(op, SplitOp):
|
|
1277
1549
|
return (op.input0, *op.outputs)
|
|
1278
1550
|
if isinstance(op, ReshapeOp):
|
|
@@ -1281,6 +1553,12 @@ class CEmitter:
|
|
|
1281
1553
|
return (op.input0, op.output)
|
|
1282
1554
|
if isinstance(op, EyeLikeOp):
|
|
1283
1555
|
return (op.input0, op.output)
|
|
1556
|
+
if isinstance(op, TriluOp):
|
|
1557
|
+
names = [op.input0]
|
|
1558
|
+
if op.k_input is not None:
|
|
1559
|
+
names.append(op.k_input)
|
|
1560
|
+
names.append(op.output)
|
|
1561
|
+
return tuple(names)
|
|
1284
1562
|
if isinstance(op, TileOp):
|
|
1285
1563
|
return (op.input0, op.output)
|
|
1286
1564
|
if isinstance(op, PadOp):
|
|
@@ -1320,6 +1598,8 @@ class CEmitter:
|
|
|
1320
1598
|
return tuple(names)
|
|
1321
1599
|
if isinstance(op, GridSampleOp):
|
|
1322
1600
|
return (op.input0, op.grid, op.output)
|
|
1601
|
+
if isinstance(op, TopKOp):
|
|
1602
|
+
return (op.input0, op.output_values, op.output_indices)
|
|
1323
1603
|
if isinstance(op, ReduceOp):
|
|
1324
1604
|
names = [op.input0]
|
|
1325
1605
|
if op.axes_input is not None:
|
|
@@ -1331,12 +1611,14 @@ class CEmitter:
|
|
|
1331
1611
|
def _build_name_map(self, model: LoweredModel) -> dict[str, str]:
|
|
1332
1612
|
used: set[str] = set()
|
|
1333
1613
|
name_map: dict[str, str] = {}
|
|
1614
|
+
constant_names = {const.name for const in model.constants}
|
|
1334
1615
|
names = [model.name]
|
|
1335
1616
|
names.extend(model.input_names)
|
|
1336
1617
|
names.extend(model.output_names)
|
|
1337
|
-
names.extend(const.name for const in model.constants)
|
|
1338
1618
|
for op in model.ops:
|
|
1339
|
-
names.extend(
|
|
1619
|
+
names.extend(
|
|
1620
|
+
name for name in self._op_names(op) if name not in constant_names
|
|
1621
|
+
)
|
|
1340
1622
|
for name in names:
|
|
1341
1623
|
if name in name_map:
|
|
1342
1624
|
continue
|
|
@@ -1344,6 +1626,14 @@ class CEmitter:
|
|
|
1344
1626
|
unique = self._ensure_unique_identifier(sanitized, used)
|
|
1345
1627
|
name_map[name] = unique
|
|
1346
1628
|
used.add(unique)
|
|
1629
|
+
for index, const in enumerate(model.constants, start=1):
|
|
1630
|
+
if const.name in name_map:
|
|
1631
|
+
continue
|
|
1632
|
+
base_name = self._sanitize_identifier(const.name.lower())
|
|
1633
|
+
weight_name = f"weight{index}_{base_name}"
|
|
1634
|
+
unique = self._ensure_unique_identifier(weight_name, used)
|
|
1635
|
+
name_map[const.name] = unique
|
|
1636
|
+
used.add(unique)
|
|
1347
1637
|
return name_map
|
|
1348
1638
|
|
|
1349
1639
|
@staticmethod
|
|
@@ -1362,11 +1652,15 @@ class CEmitter:
|
|
|
1362
1652
|
| UnaryOp
|
|
1363
1653
|
| ClipOp
|
|
1364
1654
|
| CastOp
|
|
1655
|
+
| QuantizeLinearOp
|
|
1365
1656
|
| MatMulOp
|
|
1657
|
+
| EinsumOp
|
|
1366
1658
|
| GemmOp
|
|
1367
1659
|
| AttentionOp
|
|
1368
1660
|
| ConvOp
|
|
1661
|
+
| ConvTransposeOp
|
|
1369
1662
|
| AveragePoolOp
|
|
1663
|
+
| LpPoolOp
|
|
1370
1664
|
| BatchNormOp
|
|
1371
1665
|
| LpNormalizationOp
|
|
1372
1666
|
| InstanceNormalizationOp
|
|
@@ -1378,16 +1672,20 @@ class CEmitter:
|
|
|
1378
1672
|
| LstmOp
|
|
1379
1673
|
| SoftmaxOp
|
|
1380
1674
|
| LogSoftmaxOp
|
|
1675
|
+
| HardmaxOp
|
|
1381
1676
|
| NegativeLogLikelihoodLossOp
|
|
1382
1677
|
| SoftmaxCrossEntropyLossOp
|
|
1383
1678
|
| MaxPoolOp
|
|
1384
1679
|
| ConcatOp
|
|
1385
1680
|
| GatherElementsOp
|
|
1386
1681
|
| GatherOp
|
|
1682
|
+
| GatherNDOp
|
|
1683
|
+
| ScatterNDOp
|
|
1387
1684
|
| TransposeOp
|
|
1388
1685
|
| ReshapeOp
|
|
1389
1686
|
| IdentityOp
|
|
1390
1687
|
| EyeLikeOp
|
|
1688
|
+
| TriluOp
|
|
1391
1689
|
| TileOp
|
|
1392
1690
|
| PadOp
|
|
1393
1691
|
| DepthToSpaceOp
|
|
@@ -1397,12 +1695,15 @@ class CEmitter:
|
|
|
1397
1695
|
| GridSampleOp
|
|
1398
1696
|
| ReduceOp
|
|
1399
1697
|
| ArgReduceOp
|
|
1698
|
+
| TopKOp
|
|
1400
1699
|
| ConstantOfShapeOp
|
|
1401
1700
|
| ShapeOp
|
|
1402
1701
|
| SizeOp
|
|
1702
|
+
| NonZeroOp
|
|
1403
1703
|
| ExpandOp
|
|
1404
1704
|
| CumSumOp
|
|
1405
1705
|
| RangeOp
|
|
1706
|
+
| OneHotOp
|
|
1406
1707
|
| SplitOp,
|
|
1407
1708
|
name_map: dict[str, str],
|
|
1408
1709
|
) -> (
|
|
@@ -1412,11 +1713,15 @@ class CEmitter:
|
|
|
1412
1713
|
| UnaryOp
|
|
1413
1714
|
| ClipOp
|
|
1414
1715
|
| CastOp
|
|
1716
|
+
| QuantizeLinearOp
|
|
1415
1717
|
| MatMulOp
|
|
1718
|
+
| EinsumOp
|
|
1416
1719
|
| GemmOp
|
|
1417
1720
|
| AttentionOp
|
|
1418
1721
|
| ConvOp
|
|
1722
|
+
| ConvTransposeOp
|
|
1419
1723
|
| AveragePoolOp
|
|
1724
|
+
| LpPoolOp
|
|
1420
1725
|
| BatchNormOp
|
|
1421
1726
|
| LpNormalizationOp
|
|
1422
1727
|
| InstanceNormalizationOp
|
|
@@ -1428,16 +1733,20 @@ class CEmitter:
|
|
|
1428
1733
|
| LstmOp
|
|
1429
1734
|
| SoftmaxOp
|
|
1430
1735
|
| LogSoftmaxOp
|
|
1736
|
+
| HardmaxOp
|
|
1431
1737
|
| NegativeLogLikelihoodLossOp
|
|
1432
1738
|
| SoftmaxCrossEntropyLossOp
|
|
1433
1739
|
| MaxPoolOp
|
|
1434
1740
|
| ConcatOp
|
|
1435
1741
|
| GatherElementsOp
|
|
1436
1742
|
| GatherOp
|
|
1743
|
+
| GatherNDOp
|
|
1744
|
+
| ScatterNDOp
|
|
1437
1745
|
| TransposeOp
|
|
1438
1746
|
| ReshapeOp
|
|
1439
1747
|
| IdentityOp
|
|
1440
1748
|
| EyeLikeOp
|
|
1749
|
+
| TriluOp
|
|
1441
1750
|
| TileOp
|
|
1442
1751
|
| PadOp
|
|
1443
1752
|
| DepthToSpaceOp
|
|
@@ -1447,12 +1756,15 @@ class CEmitter:
|
|
|
1447
1756
|
| GridSampleOp
|
|
1448
1757
|
| ReduceOp
|
|
1449
1758
|
| ArgReduceOp
|
|
1759
|
+
| TopKOp
|
|
1450
1760
|
| ConstantOfShapeOp
|
|
1451
1761
|
| ShapeOp
|
|
1452
1762
|
| SizeOp
|
|
1763
|
+
| NonZeroOp
|
|
1453
1764
|
| ExpandOp
|
|
1454
1765
|
| CumSumOp
|
|
1455
1766
|
| RangeOp
|
|
1767
|
+
| OneHotOp
|
|
1456
1768
|
| SplitOp
|
|
1457
1769
|
):
|
|
1458
1770
|
if isinstance(op, BinaryOp):
|
|
@@ -1462,6 +1774,8 @@ class CEmitter:
|
|
|
1462
1774
|
output=name_map.get(op.output, op.output),
|
|
1463
1775
|
function=op.function,
|
|
1464
1776
|
operator_kind=op.operator_kind,
|
|
1777
|
+
input0_shape=op.input0_shape,
|
|
1778
|
+
input1_shape=op.input1_shape,
|
|
1465
1779
|
shape=op.shape,
|
|
1466
1780
|
dtype=op.dtype,
|
|
1467
1781
|
input_dtype=op.input_dtype,
|
|
@@ -1518,6 +1832,18 @@ class CEmitter:
|
|
|
1518
1832
|
input_dtype=op.input_dtype,
|
|
1519
1833
|
dtype=op.dtype,
|
|
1520
1834
|
)
|
|
1835
|
+
if isinstance(op, QuantizeLinearOp):
|
|
1836
|
+
return QuantizeLinearOp(
|
|
1837
|
+
input0=name_map.get(op.input0, op.input0),
|
|
1838
|
+
scale=name_map.get(op.scale, op.scale),
|
|
1839
|
+
zero_point=self._map_optional_name(name_map, op.zero_point),
|
|
1840
|
+
output=name_map.get(op.output, op.output),
|
|
1841
|
+
input_shape=op.input_shape,
|
|
1842
|
+
axis=op.axis,
|
|
1843
|
+
dtype=op.dtype,
|
|
1844
|
+
input_dtype=op.input_dtype,
|
|
1845
|
+
scale_dtype=op.scale_dtype,
|
|
1846
|
+
)
|
|
1521
1847
|
if isinstance(op, MatMulOp):
|
|
1522
1848
|
return MatMulOp(
|
|
1523
1849
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1536,6 +1862,16 @@ class CEmitter:
|
|
|
1536
1862
|
right_vector=op.right_vector,
|
|
1537
1863
|
dtype=op.dtype,
|
|
1538
1864
|
)
|
|
1865
|
+
if isinstance(op, EinsumOp):
|
|
1866
|
+
return EinsumOp(
|
|
1867
|
+
inputs=tuple(name_map.get(name, name) for name in op.inputs),
|
|
1868
|
+
output=name_map.get(op.output, op.output),
|
|
1869
|
+
kind=op.kind,
|
|
1870
|
+
input_shapes=op.input_shapes,
|
|
1871
|
+
output_shape=op.output_shape,
|
|
1872
|
+
dtype=op.dtype,
|
|
1873
|
+
input_dtype=op.input_dtype,
|
|
1874
|
+
)
|
|
1539
1875
|
if isinstance(op, GemmOp):
|
|
1540
1876
|
return GemmOp(
|
|
1541
1877
|
input_a=name_map.get(op.input_a, op.input_a),
|
|
@@ -1629,6 +1965,26 @@ class CEmitter:
|
|
|
1629
1965
|
group=op.group,
|
|
1630
1966
|
dtype=op.dtype,
|
|
1631
1967
|
)
|
|
1968
|
+
if isinstance(op, ConvTransposeOp):
|
|
1969
|
+
return ConvTransposeOp(
|
|
1970
|
+
input0=name_map.get(op.input0, op.input0),
|
|
1971
|
+
weights=name_map.get(op.weights, op.weights),
|
|
1972
|
+
bias=self._map_optional_name(name_map, op.bias),
|
|
1973
|
+
output=name_map.get(op.output, op.output),
|
|
1974
|
+
batch=op.batch,
|
|
1975
|
+
in_channels=op.in_channels,
|
|
1976
|
+
out_channels=op.out_channels,
|
|
1977
|
+
spatial_rank=op.spatial_rank,
|
|
1978
|
+
in_spatial=op.in_spatial,
|
|
1979
|
+
out_spatial=op.out_spatial,
|
|
1980
|
+
kernel_shape=op.kernel_shape,
|
|
1981
|
+
strides=op.strides,
|
|
1982
|
+
pads=op.pads,
|
|
1983
|
+
dilations=op.dilations,
|
|
1984
|
+
output_padding=op.output_padding,
|
|
1985
|
+
group=op.group,
|
|
1986
|
+
dtype=op.dtype,
|
|
1987
|
+
)
|
|
1632
1988
|
if isinstance(op, AveragePoolOp):
|
|
1633
1989
|
return AveragePoolOp(
|
|
1634
1990
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1650,6 +2006,27 @@ class CEmitter:
|
|
|
1650
2006
|
count_include_pad=op.count_include_pad,
|
|
1651
2007
|
dtype=op.dtype,
|
|
1652
2008
|
)
|
|
2009
|
+
if isinstance(op, LpPoolOp):
|
|
2010
|
+
return LpPoolOp(
|
|
2011
|
+
input0=name_map.get(op.input0, op.input0),
|
|
2012
|
+
output=name_map.get(op.output, op.output),
|
|
2013
|
+
batch=op.batch,
|
|
2014
|
+
channels=op.channels,
|
|
2015
|
+
in_h=op.in_h,
|
|
2016
|
+
in_w=op.in_w,
|
|
2017
|
+
out_h=op.out_h,
|
|
2018
|
+
out_w=op.out_w,
|
|
2019
|
+
kernel_h=op.kernel_h,
|
|
2020
|
+
kernel_w=op.kernel_w,
|
|
2021
|
+
stride_h=op.stride_h,
|
|
2022
|
+
stride_w=op.stride_w,
|
|
2023
|
+
pad_top=op.pad_top,
|
|
2024
|
+
pad_left=op.pad_left,
|
|
2025
|
+
pad_bottom=op.pad_bottom,
|
|
2026
|
+
pad_right=op.pad_right,
|
|
2027
|
+
p=op.p,
|
|
2028
|
+
dtype=op.dtype,
|
|
2029
|
+
)
|
|
1653
2030
|
if isinstance(op, BatchNormOp):
|
|
1654
2031
|
return BatchNormOp(
|
|
1655
2032
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1813,6 +2190,17 @@ class CEmitter:
|
|
|
1813
2190
|
shape=op.shape,
|
|
1814
2191
|
dtype=op.dtype,
|
|
1815
2192
|
)
|
|
2193
|
+
if isinstance(op, HardmaxOp):
|
|
2194
|
+
return HardmaxOp(
|
|
2195
|
+
input0=name_map.get(op.input0, op.input0),
|
|
2196
|
+
output=name_map.get(op.output, op.output),
|
|
2197
|
+
outer=op.outer,
|
|
2198
|
+
axis_size=op.axis_size,
|
|
2199
|
+
inner=op.inner,
|
|
2200
|
+
axis=op.axis,
|
|
2201
|
+
shape=op.shape,
|
|
2202
|
+
dtype=op.dtype,
|
|
2203
|
+
)
|
|
1816
2204
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
1817
2205
|
return NegativeLogLikelihoodLossOp(
|
|
1818
2206
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1909,6 +2297,32 @@ class CEmitter:
|
|
|
1909
2297
|
dtype=op.dtype,
|
|
1910
2298
|
indices_dtype=op.indices_dtype,
|
|
1911
2299
|
)
|
|
2300
|
+
if isinstance(op, GatherNDOp):
|
|
2301
|
+
return GatherNDOp(
|
|
2302
|
+
data=name_map.get(op.data, op.data),
|
|
2303
|
+
indices=name_map.get(op.indices, op.indices),
|
|
2304
|
+
output=name_map.get(op.output, op.output),
|
|
2305
|
+
batch_dims=op.batch_dims,
|
|
2306
|
+
data_shape=op.data_shape,
|
|
2307
|
+
indices_shape=op.indices_shape,
|
|
2308
|
+
output_shape=op.output_shape,
|
|
2309
|
+
dtype=op.dtype,
|
|
2310
|
+
indices_dtype=op.indices_dtype,
|
|
2311
|
+
)
|
|
2312
|
+
if isinstance(op, ScatterNDOp):
|
|
2313
|
+
return ScatterNDOp(
|
|
2314
|
+
data=name_map.get(op.data, op.data),
|
|
2315
|
+
indices=name_map.get(op.indices, op.indices),
|
|
2316
|
+
updates=name_map.get(op.updates, op.updates),
|
|
2317
|
+
output=name_map.get(op.output, op.output),
|
|
2318
|
+
data_shape=op.data_shape,
|
|
2319
|
+
indices_shape=op.indices_shape,
|
|
2320
|
+
updates_shape=op.updates_shape,
|
|
2321
|
+
output_shape=op.output_shape,
|
|
2322
|
+
reduction=op.reduction,
|
|
2323
|
+
dtype=op.dtype,
|
|
2324
|
+
indices_dtype=op.indices_dtype,
|
|
2325
|
+
)
|
|
1912
2326
|
if isinstance(op, TransposeOp):
|
|
1913
2327
|
return TransposeOp(
|
|
1914
2328
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -1945,6 +2359,20 @@ class CEmitter:
|
|
|
1945
2359
|
dtype=op.dtype,
|
|
1946
2360
|
input_dtype=op.input_dtype,
|
|
1947
2361
|
)
|
|
2362
|
+
if isinstance(op, TriluOp):
|
|
2363
|
+
return TriluOp(
|
|
2364
|
+
input0=name_map.get(op.input0, op.input0),
|
|
2365
|
+
output=name_map.get(op.output, op.output),
|
|
2366
|
+
input_shape=op.input_shape,
|
|
2367
|
+
output_shape=op.output_shape,
|
|
2368
|
+
upper=op.upper,
|
|
2369
|
+
k_value=op.k_value,
|
|
2370
|
+
k_input=self._map_optional_name(name_map, op.k_input),
|
|
2371
|
+
k_input_shape=op.k_input_shape,
|
|
2372
|
+
k_input_dtype=op.k_input_dtype,
|
|
2373
|
+
dtype=op.dtype,
|
|
2374
|
+
input_dtype=op.input_dtype,
|
|
2375
|
+
)
|
|
1948
2376
|
if isinstance(op, TileOp):
|
|
1949
2377
|
return TileOp(
|
|
1950
2378
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -2101,6 +2529,21 @@ class CEmitter:
|
|
|
2101
2529
|
input_dtype=op.input_dtype,
|
|
2102
2530
|
output_dtype=op.output_dtype,
|
|
2103
2531
|
)
|
|
2532
|
+
if isinstance(op, TopKOp):
|
|
2533
|
+
return TopKOp(
|
|
2534
|
+
input0=name_map.get(op.input0, op.input0),
|
|
2535
|
+
output_values=name_map.get(op.output_values, op.output_values),
|
|
2536
|
+
output_indices=name_map.get(op.output_indices, op.output_indices),
|
|
2537
|
+
input_shape=op.input_shape,
|
|
2538
|
+
output_shape=op.output_shape,
|
|
2539
|
+
axis=op.axis,
|
|
2540
|
+
k=op.k,
|
|
2541
|
+
largest=op.largest,
|
|
2542
|
+
sorted=op.sorted,
|
|
2543
|
+
input_dtype=op.input_dtype,
|
|
2544
|
+
output_values_dtype=op.output_values_dtype,
|
|
2545
|
+
output_indices_dtype=op.output_indices_dtype,
|
|
2546
|
+
)
|
|
2104
2547
|
if isinstance(op, ConstantOfShapeOp):
|
|
2105
2548
|
return ConstantOfShapeOp(
|
|
2106
2549
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -2131,6 +2574,15 @@ class CEmitter:
|
|
|
2131
2574
|
dtype=op.dtype,
|
|
2132
2575
|
input_dtype=op.input_dtype,
|
|
2133
2576
|
)
|
|
2577
|
+
if isinstance(op, NonZeroOp):
|
|
2578
|
+
return NonZeroOp(
|
|
2579
|
+
input0=name_map.get(op.input0, op.input0),
|
|
2580
|
+
output=name_map.get(op.output, op.output),
|
|
2581
|
+
input_shape=op.input_shape,
|
|
2582
|
+
output_shape=op.output_shape,
|
|
2583
|
+
dtype=op.dtype,
|
|
2584
|
+
input_dtype=op.input_dtype,
|
|
2585
|
+
)
|
|
2134
2586
|
if isinstance(op, ExpandOp):
|
|
2135
2587
|
return ExpandOp(
|
|
2136
2588
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -2166,6 +2618,21 @@ class CEmitter:
|
|
|
2166
2618
|
dtype=op.dtype,
|
|
2167
2619
|
input_dtype=op.input_dtype,
|
|
2168
2620
|
)
|
|
2621
|
+
if isinstance(op, OneHotOp):
|
|
2622
|
+
return OneHotOp(
|
|
2623
|
+
indices=name_map.get(op.indices, op.indices),
|
|
2624
|
+
depth=name_map.get(op.depth, op.depth),
|
|
2625
|
+
values=name_map.get(op.values, op.values),
|
|
2626
|
+
output=name_map.get(op.output, op.output),
|
|
2627
|
+
axis=op.axis,
|
|
2628
|
+
indices_shape=op.indices_shape,
|
|
2629
|
+
values_shape=op.values_shape,
|
|
2630
|
+
output_shape=op.output_shape,
|
|
2631
|
+
depth_dim=op.depth_dim,
|
|
2632
|
+
dtype=op.dtype,
|
|
2633
|
+
indices_dtype=op.indices_dtype,
|
|
2634
|
+
depth_dtype=op.depth_dtype,
|
|
2635
|
+
)
|
|
2169
2636
|
if isinstance(op, SplitOp):
|
|
2170
2637
|
return SplitOp(
|
|
2171
2638
|
input0=name_map.get(op.input0, op.input0),
|
|
@@ -2246,11 +2713,19 @@ class CEmitter:
|
|
|
2246
2713
|
"unary": self._env.get_template("unary_op.c.j2"),
|
|
2247
2714
|
"clip": self._env.get_template("clip_op.c.j2"),
|
|
2248
2715
|
"cast": self._env.get_template("cast_op.c.j2"),
|
|
2716
|
+
"quantize_linear": self._env.get_template(
|
|
2717
|
+
"quantize_linear_op.c.j2"
|
|
2718
|
+
),
|
|
2249
2719
|
"matmul": self._env.get_template("matmul_op.c.j2"),
|
|
2720
|
+
"einsum": self._env.get_template("einsum_op.c.j2"),
|
|
2250
2721
|
"gemm": self._env.get_template("gemm_op.c.j2"),
|
|
2251
2722
|
"attention": self._env.get_template("attention_op.c.j2"),
|
|
2252
2723
|
"conv": self._env.get_template("conv_op.c.j2"),
|
|
2724
|
+
"conv_transpose": self._env.get_template(
|
|
2725
|
+
"conv_transpose_op.c.j2"
|
|
2726
|
+
),
|
|
2253
2727
|
"avg_pool": self._env.get_template("average_pool_op.c.j2"),
|
|
2728
|
+
"lp_pool": self._env.get_template("lp_pool_op.c.j2"),
|
|
2254
2729
|
"batch_norm": self._env.get_template("batch_norm_op.c.j2"),
|
|
2255
2730
|
"lp_norm": self._env.get_template("lp_normalization_op.c.j2"),
|
|
2256
2731
|
"instance_norm": self._env.get_template(
|
|
@@ -2270,6 +2745,7 @@ class CEmitter:
|
|
|
2270
2745
|
"lstm": self._env.get_template("lstm_op.c.j2"),
|
|
2271
2746
|
"softmax": self._env.get_template("softmax_op.c.j2"),
|
|
2272
2747
|
"logsoftmax": self._env.get_template("logsoftmax_op.c.j2"),
|
|
2748
|
+
"hardmax": self._env.get_template("hardmax_op.c.j2"),
|
|
2273
2749
|
"nllloss": self._env.get_template(
|
|
2274
2750
|
"negative_log_likelihood_loss_op.c.j2"
|
|
2275
2751
|
),
|
|
@@ -2280,10 +2756,13 @@ class CEmitter:
|
|
|
2280
2756
|
"concat": self._env.get_template("concat_op.c.j2"),
|
|
2281
2757
|
"gather_elements": self._env.get_template("gather_elements_op.c.j2"),
|
|
2282
2758
|
"gather": self._env.get_template("gather_op.c.j2"),
|
|
2759
|
+
"gather_nd": self._env.get_template("gather_nd_op.c.j2"),
|
|
2760
|
+
"scatter_nd": self._env.get_template("scatter_nd_op.c.j2"),
|
|
2283
2761
|
"transpose": self._env.get_template("transpose_op.c.j2"),
|
|
2284
2762
|
"reshape": self._env.get_template("reshape_op.c.j2"),
|
|
2285
2763
|
"identity": self._env.get_template("identity_op.c.j2"),
|
|
2286
2764
|
"eye_like": self._env.get_template("eye_like_op.c.j2"),
|
|
2765
|
+
"trilu": self._env.get_template("trilu_op.c.j2"),
|
|
2287
2766
|
"tile": self._env.get_template("tile_op.c.j2"),
|
|
2288
2767
|
"pad": self._env.get_template("pad_op.c.j2"),
|
|
2289
2768
|
"depth_to_space": self._env.get_template("depth_to_space_op.c.j2"),
|
|
@@ -2299,14 +2778,17 @@ class CEmitter:
|
|
|
2299
2778
|
"reduce_op_dynamic.c.j2"
|
|
2300
2779
|
),
|
|
2301
2780
|
"arg_reduce": self._env.get_template("arg_reduce_op.c.j2"),
|
|
2781
|
+
"topk": self._env.get_template("topk_op.c.j2"),
|
|
2302
2782
|
"constant_of_shape": self._env.get_template(
|
|
2303
2783
|
"constant_of_shape_op.c.j2"
|
|
2304
2784
|
),
|
|
2305
2785
|
"shape": self._env.get_template("shape_op.c.j2"),
|
|
2306
2786
|
"size": self._env.get_template("size_op.c.j2"),
|
|
2787
|
+
"nonzero": self._env.get_template("nonzero_op.c.j2"),
|
|
2307
2788
|
"expand": self._env.get_template("expand_op.c.j2"),
|
|
2308
2789
|
"cumsum": self._env.get_template("cumsum_op.c.j2"),
|
|
2309
2790
|
"range": self._env.get_template("range_op.c.j2"),
|
|
2791
|
+
"one_hot": self._env.get_template("one_hot_op.c.j2"),
|
|
2310
2792
|
"split": self._env.get_template("split_op.c.j2"),
|
|
2311
2793
|
}
|
|
2312
2794
|
if emit_testbench:
|
|
@@ -2328,6 +2810,9 @@ class CEmitter:
|
|
|
2328
2810
|
testbench_inputs = self._sanitize_testbench_inputs(
|
|
2329
2811
|
testbench_inputs, name_map
|
|
2330
2812
|
)
|
|
2813
|
+
inline_constants, large_constants = self._partition_constants(
|
|
2814
|
+
model.constants
|
|
2815
|
+
)
|
|
2331
2816
|
(
|
|
2332
2817
|
dim_order,
|
|
2333
2818
|
input_dim_names,
|
|
@@ -2353,11 +2838,15 @@ class CEmitter:
|
|
|
2353
2838
|
unary_template = templates["unary"]
|
|
2354
2839
|
clip_template = templates["clip"]
|
|
2355
2840
|
cast_template = templates["cast"]
|
|
2841
|
+
quantize_linear_template = templates["quantize_linear"]
|
|
2356
2842
|
matmul_template = templates["matmul"]
|
|
2843
|
+
einsum_template = templates["einsum"]
|
|
2357
2844
|
gemm_template = templates["gemm"]
|
|
2358
2845
|
attention_template = templates["attention"]
|
|
2359
2846
|
conv_template = templates["conv"]
|
|
2847
|
+
conv_transpose_template = templates["conv_transpose"]
|
|
2360
2848
|
avg_pool_template = templates["avg_pool"]
|
|
2849
|
+
lp_pool_template = templates["lp_pool"]
|
|
2361
2850
|
batch_norm_template = templates["batch_norm"]
|
|
2362
2851
|
lp_norm_template = templates["lp_norm"]
|
|
2363
2852
|
instance_norm_template = templates["instance_norm"]
|
|
@@ -2369,16 +2858,20 @@ class CEmitter:
|
|
|
2369
2858
|
lstm_template = templates["lstm"]
|
|
2370
2859
|
softmax_template = templates["softmax"]
|
|
2371
2860
|
logsoftmax_template = templates["logsoftmax"]
|
|
2861
|
+
hardmax_template = templates["hardmax"]
|
|
2372
2862
|
nllloss_template = templates["nllloss"]
|
|
2373
2863
|
softmax_cross_entropy_loss_template = templates["softmax_cross_entropy_loss"]
|
|
2374
2864
|
maxpool_template = templates["maxpool"]
|
|
2375
2865
|
concat_template = templates["concat"]
|
|
2376
2866
|
gather_elements_template = templates["gather_elements"]
|
|
2377
2867
|
gather_template = templates["gather"]
|
|
2868
|
+
gather_nd_template = templates["gather_nd"]
|
|
2869
|
+
scatter_nd_template = templates["scatter_nd"]
|
|
2378
2870
|
transpose_template = templates["transpose"]
|
|
2379
2871
|
reshape_template = templates["reshape"]
|
|
2380
2872
|
identity_template = templates["identity"]
|
|
2381
2873
|
eye_like_template = templates["eye_like"]
|
|
2874
|
+
trilu_template = templates["trilu"]
|
|
2382
2875
|
tile_template = templates["tile"]
|
|
2383
2876
|
pad_template = templates["pad"]
|
|
2384
2877
|
depth_to_space_template = templates["depth_to_space"]
|
|
@@ -2390,12 +2883,15 @@ class CEmitter:
|
|
|
2390
2883
|
reduce_template = templates["reduce"]
|
|
2391
2884
|
reduce_dynamic_template = templates["reduce_dynamic"]
|
|
2392
2885
|
arg_reduce_template = templates["arg_reduce"]
|
|
2886
|
+
topk_template = templates["topk"]
|
|
2393
2887
|
constant_of_shape_template = templates["constant_of_shape"]
|
|
2394
2888
|
shape_template = templates["shape"]
|
|
2395
2889
|
size_template = templates["size"]
|
|
2890
|
+
nonzero_template = templates["nonzero"]
|
|
2396
2891
|
expand_template = templates["expand"]
|
|
2397
2892
|
cumsum_template = templates["cumsum"]
|
|
2398
2893
|
range_template = templates["range"]
|
|
2894
|
+
one_hot_template = templates["one_hot"]
|
|
2399
2895
|
split_template = templates["split"]
|
|
2400
2896
|
testbench_template = templates.get("testbench")
|
|
2401
2897
|
reserved_names = {
|
|
@@ -2427,11 +2923,15 @@ class CEmitter:
|
|
|
2427
2923
|
unary_template=unary_template,
|
|
2428
2924
|
clip_template=clip_template,
|
|
2429
2925
|
cast_template=cast_template,
|
|
2926
|
+
quantize_linear_template=quantize_linear_template,
|
|
2430
2927
|
matmul_template=matmul_template,
|
|
2928
|
+
einsum_template=einsum_template,
|
|
2431
2929
|
gemm_template=gemm_template,
|
|
2432
2930
|
attention_template=attention_template,
|
|
2433
2931
|
conv_template=conv_template,
|
|
2932
|
+
conv_transpose_template=conv_transpose_template,
|
|
2434
2933
|
avg_pool_template=avg_pool_template,
|
|
2934
|
+
lp_pool_template=lp_pool_template,
|
|
2435
2935
|
batch_norm_template=batch_norm_template,
|
|
2436
2936
|
lp_norm_template=lp_norm_template,
|
|
2437
2937
|
instance_norm_template=instance_norm_template,
|
|
@@ -2443,16 +2943,20 @@ class CEmitter:
|
|
|
2443
2943
|
lstm_template=lstm_template,
|
|
2444
2944
|
softmax_template=softmax_template,
|
|
2445
2945
|
logsoftmax_template=logsoftmax_template,
|
|
2946
|
+
hardmax_template=hardmax_template,
|
|
2446
2947
|
nllloss_template=nllloss_template,
|
|
2447
2948
|
softmax_cross_entropy_loss_template=softmax_cross_entropy_loss_template,
|
|
2448
2949
|
maxpool_template=maxpool_template,
|
|
2449
2950
|
concat_template=concat_template,
|
|
2450
2951
|
gather_elements_template=gather_elements_template,
|
|
2451
2952
|
gather_template=gather_template,
|
|
2953
|
+
gather_nd_template=gather_nd_template,
|
|
2954
|
+
scatter_nd_template=scatter_nd_template,
|
|
2452
2955
|
transpose_template=transpose_template,
|
|
2453
2956
|
reshape_template=reshape_template,
|
|
2454
2957
|
identity_template=identity_template,
|
|
2455
2958
|
eye_like_template=eye_like_template,
|
|
2959
|
+
trilu_template=trilu_template,
|
|
2456
2960
|
tile_template=tile_template,
|
|
2457
2961
|
pad_template=pad_template,
|
|
2458
2962
|
depth_to_space_template=depth_to_space_template,
|
|
@@ -2464,12 +2968,15 @@ class CEmitter:
|
|
|
2464
2968
|
reduce_template=reduce_template,
|
|
2465
2969
|
reduce_dynamic_template=reduce_dynamic_template,
|
|
2466
2970
|
arg_reduce_template=arg_reduce_template,
|
|
2971
|
+
topk_template=topk_template,
|
|
2467
2972
|
constant_of_shape_template=constant_of_shape_template,
|
|
2468
2973
|
shape_template=shape_template,
|
|
2469
2974
|
size_template=size_template,
|
|
2975
|
+
nonzero_template=nonzero_template,
|
|
2470
2976
|
expand_template=expand_template,
|
|
2471
2977
|
cumsum_template=cumsum_template,
|
|
2472
2978
|
range_template=range_template,
|
|
2979
|
+
one_hot_template=one_hot_template,
|
|
2473
2980
|
split_template=split_template,
|
|
2474
2981
|
scalar_registry=scalar_registry,
|
|
2475
2982
|
dim_args=dim_args,
|
|
@@ -2495,25 +3002,45 @@ class CEmitter:
|
|
|
2495
3002
|
scalar_preamble = [
|
|
2496
3003
|
line for line in scalar_include_lines if not line.startswith("#include ")
|
|
2497
3004
|
]
|
|
3005
|
+
testbench_math_include = set()
|
|
3006
|
+
if emit_testbench and self._testbench_requires_math(
|
|
3007
|
+
model, testbench_inputs
|
|
3008
|
+
):
|
|
3009
|
+
testbench_math_include.add("#include <math.h>")
|
|
2498
3010
|
includes = self._collect_includes(
|
|
2499
3011
|
model,
|
|
2500
3012
|
resolved_ops,
|
|
2501
3013
|
emit_testbench=emit_testbench,
|
|
2502
|
-
extra_includes=scalar_includes,
|
|
3014
|
+
extra_includes=scalar_includes | testbench_math_include,
|
|
3015
|
+
needs_weight_loader=bool(large_constants),
|
|
2503
3016
|
)
|
|
2504
|
-
sections = [
|
|
3017
|
+
sections = [
|
|
3018
|
+
self._emit_header_comment(model.header),
|
|
3019
|
+
"",
|
|
3020
|
+
*includes,
|
|
3021
|
+
"",
|
|
3022
|
+
self._emit_index_type_define(),
|
|
3023
|
+
]
|
|
2505
3024
|
if scalar_preamble:
|
|
2506
3025
|
sections.extend(("", *scalar_preamble))
|
|
2507
3026
|
sections.append("")
|
|
2508
|
-
constants_section = self._emit_constant_definitions(
|
|
3027
|
+
constants_section = self._emit_constant_definitions(inline_constants)
|
|
2509
3028
|
if constants_section:
|
|
2510
3029
|
sections.extend((constants_section.rstrip(), ""))
|
|
3030
|
+
large_constants_section = self._emit_constant_storage_definitions(
|
|
3031
|
+
large_constants
|
|
3032
|
+
)
|
|
3033
|
+
if large_constants_section:
|
|
3034
|
+
sections.extend((large_constants_section.rstrip(), ""))
|
|
2511
3035
|
if scalar_functions:
|
|
2512
3036
|
sections.extend(("\n".join(scalar_functions), ""))
|
|
3037
|
+
weight_loader = self._emit_weight_loader(model, large_constants)
|
|
2513
3038
|
sections.extend(
|
|
2514
3039
|
(
|
|
2515
3040
|
operator_fns.rstrip(),
|
|
2516
3041
|
"",
|
|
3042
|
+
weight_loader.rstrip(),
|
|
3043
|
+
"",
|
|
2517
3044
|
wrapper_fn,
|
|
2518
3045
|
)
|
|
2519
3046
|
)
|
|
@@ -2527,6 +3054,7 @@ class CEmitter:
|
|
|
2527
3054
|
testbench_inputs=testbench_inputs,
|
|
2528
3055
|
dim_order=dim_order,
|
|
2529
3056
|
dim_values=dim_values,
|
|
3057
|
+
weight_data_filename=self._weight_data_filename(model),
|
|
2530
3058
|
),
|
|
2531
3059
|
)
|
|
2532
3060
|
)
|
|
@@ -2549,6 +3077,9 @@ class CEmitter:
|
|
|
2549
3077
|
testbench_inputs = self._sanitize_testbench_inputs(
|
|
2550
3078
|
testbench_inputs, name_map
|
|
2551
3079
|
)
|
|
3080
|
+
inline_constants, large_constants = self._partition_constants(
|
|
3081
|
+
model.constants
|
|
3082
|
+
)
|
|
2552
3083
|
(
|
|
2553
3084
|
dim_order,
|
|
2554
3085
|
input_dim_names,
|
|
@@ -2574,11 +3105,15 @@ class CEmitter:
|
|
|
2574
3105
|
unary_template = templates["unary"]
|
|
2575
3106
|
clip_template = templates["clip"]
|
|
2576
3107
|
cast_template = templates["cast"]
|
|
3108
|
+
quantize_linear_template = templates["quantize_linear"]
|
|
2577
3109
|
matmul_template = templates["matmul"]
|
|
3110
|
+
einsum_template = templates["einsum"]
|
|
2578
3111
|
gemm_template = templates["gemm"]
|
|
2579
3112
|
attention_template = templates["attention"]
|
|
2580
3113
|
conv_template = templates["conv"]
|
|
3114
|
+
conv_transpose_template = templates["conv_transpose"]
|
|
2581
3115
|
avg_pool_template = templates["avg_pool"]
|
|
3116
|
+
lp_pool_template = templates["lp_pool"]
|
|
2582
3117
|
batch_norm_template = templates["batch_norm"]
|
|
2583
3118
|
lp_norm_template = templates["lp_norm"]
|
|
2584
3119
|
instance_norm_template = templates["instance_norm"]
|
|
@@ -2590,16 +3125,20 @@ class CEmitter:
|
|
|
2590
3125
|
lstm_template = templates["lstm"]
|
|
2591
3126
|
softmax_template = templates["softmax"]
|
|
2592
3127
|
logsoftmax_template = templates["logsoftmax"]
|
|
3128
|
+
hardmax_template = templates["hardmax"]
|
|
2593
3129
|
nllloss_template = templates["nllloss"]
|
|
2594
3130
|
softmax_cross_entropy_loss_template = templates["softmax_cross_entropy_loss"]
|
|
2595
3131
|
maxpool_template = templates["maxpool"]
|
|
2596
3132
|
concat_template = templates["concat"]
|
|
2597
3133
|
gather_elements_template = templates["gather_elements"]
|
|
2598
3134
|
gather_template = templates["gather"]
|
|
3135
|
+
gather_nd_template = templates["gather_nd"]
|
|
3136
|
+
scatter_nd_template = templates["scatter_nd"]
|
|
2599
3137
|
transpose_template = templates["transpose"]
|
|
2600
3138
|
reshape_template = templates["reshape"]
|
|
2601
3139
|
identity_template = templates["identity"]
|
|
2602
3140
|
eye_like_template = templates["eye_like"]
|
|
3141
|
+
trilu_template = templates["trilu"]
|
|
2603
3142
|
tile_template = templates["tile"]
|
|
2604
3143
|
pad_template = templates["pad"]
|
|
2605
3144
|
depth_to_space_template = templates["depth_to_space"]
|
|
@@ -2611,12 +3150,15 @@ class CEmitter:
|
|
|
2611
3150
|
reduce_template = templates["reduce"]
|
|
2612
3151
|
reduce_dynamic_template = templates["reduce_dynamic"]
|
|
2613
3152
|
arg_reduce_template = templates["arg_reduce"]
|
|
3153
|
+
topk_template = templates["topk"]
|
|
2614
3154
|
constant_of_shape_template = templates["constant_of_shape"]
|
|
2615
3155
|
shape_template = templates["shape"]
|
|
2616
3156
|
size_template = templates["size"]
|
|
3157
|
+
nonzero_template = templates["nonzero"]
|
|
2617
3158
|
expand_template = templates["expand"]
|
|
2618
3159
|
cumsum_template = templates["cumsum"]
|
|
2619
3160
|
range_template = templates["range"]
|
|
3161
|
+
one_hot_template = templates["one_hot"]
|
|
2620
3162
|
split_template = templates["split"]
|
|
2621
3163
|
testbench_template = templates.get("testbench")
|
|
2622
3164
|
reserved_names = {
|
|
@@ -2648,11 +3190,15 @@ class CEmitter:
|
|
|
2648
3190
|
unary_template=unary_template,
|
|
2649
3191
|
clip_template=clip_template,
|
|
2650
3192
|
cast_template=cast_template,
|
|
3193
|
+
quantize_linear_template=quantize_linear_template,
|
|
2651
3194
|
matmul_template=matmul_template,
|
|
3195
|
+
einsum_template=einsum_template,
|
|
2652
3196
|
gemm_template=gemm_template,
|
|
2653
3197
|
attention_template=attention_template,
|
|
2654
3198
|
conv_template=conv_template,
|
|
3199
|
+
conv_transpose_template=conv_transpose_template,
|
|
2655
3200
|
avg_pool_template=avg_pool_template,
|
|
3201
|
+
lp_pool_template=lp_pool_template,
|
|
2656
3202
|
batch_norm_template=batch_norm_template,
|
|
2657
3203
|
lp_norm_template=lp_norm_template,
|
|
2658
3204
|
instance_norm_template=instance_norm_template,
|
|
@@ -2664,16 +3210,20 @@ class CEmitter:
|
|
|
2664
3210
|
lstm_template=lstm_template,
|
|
2665
3211
|
softmax_template=softmax_template,
|
|
2666
3212
|
logsoftmax_template=logsoftmax_template,
|
|
3213
|
+
hardmax_template=hardmax_template,
|
|
2667
3214
|
nllloss_template=nllloss_template,
|
|
2668
3215
|
softmax_cross_entropy_loss_template=softmax_cross_entropy_loss_template,
|
|
2669
3216
|
maxpool_template=maxpool_template,
|
|
2670
3217
|
concat_template=concat_template,
|
|
2671
3218
|
gather_elements_template=gather_elements_template,
|
|
2672
3219
|
gather_template=gather_template,
|
|
3220
|
+
gather_nd_template=gather_nd_template,
|
|
3221
|
+
scatter_nd_template=scatter_nd_template,
|
|
2673
3222
|
transpose_template=transpose_template,
|
|
2674
3223
|
reshape_template=reshape_template,
|
|
2675
3224
|
identity_template=identity_template,
|
|
2676
3225
|
eye_like_template=eye_like_template,
|
|
3226
|
+
trilu_template=trilu_template,
|
|
2677
3227
|
tile_template=tile_template,
|
|
2678
3228
|
pad_template=pad_template,
|
|
2679
3229
|
depth_to_space_template=depth_to_space_template,
|
|
@@ -2685,12 +3235,15 @@ class CEmitter:
|
|
|
2685
3235
|
reduce_template=reduce_template,
|
|
2686
3236
|
reduce_dynamic_template=reduce_dynamic_template,
|
|
2687
3237
|
arg_reduce_template=arg_reduce_template,
|
|
3238
|
+
topk_template=topk_template,
|
|
2688
3239
|
constant_of_shape_template=constant_of_shape_template,
|
|
2689
3240
|
shape_template=shape_template,
|
|
2690
3241
|
size_template=size_template,
|
|
3242
|
+
nonzero_template=nonzero_template,
|
|
2691
3243
|
expand_template=expand_template,
|
|
2692
3244
|
cumsum_template=cumsum_template,
|
|
2693
3245
|
range_template=range_template,
|
|
3246
|
+
one_hot_template=one_hot_template,
|
|
2694
3247
|
split_template=split_template,
|
|
2695
3248
|
scalar_registry=scalar_registry,
|
|
2696
3249
|
dim_args=dim_args,
|
|
@@ -2716,25 +3269,45 @@ class CEmitter:
|
|
|
2716
3269
|
scalar_preamble = [
|
|
2717
3270
|
line for line in scalar_include_lines if not line.startswith("#include ")
|
|
2718
3271
|
]
|
|
3272
|
+
testbench_math_include = set()
|
|
3273
|
+
if emit_testbench and self._testbench_requires_math(
|
|
3274
|
+
model, testbench_inputs
|
|
3275
|
+
):
|
|
3276
|
+
testbench_math_include.add("#include <math.h>")
|
|
2719
3277
|
includes = self._collect_includes(
|
|
2720
3278
|
model,
|
|
2721
3279
|
resolved_ops,
|
|
2722
3280
|
emit_testbench=emit_testbench,
|
|
2723
|
-
extra_includes=scalar_includes,
|
|
3281
|
+
extra_includes=scalar_includes | testbench_math_include,
|
|
3282
|
+
needs_weight_loader=bool(large_constants),
|
|
2724
3283
|
)
|
|
2725
|
-
sections = [
|
|
3284
|
+
sections = [
|
|
3285
|
+
self._emit_header_comment(model.header),
|
|
3286
|
+
"",
|
|
3287
|
+
*includes,
|
|
3288
|
+
"",
|
|
3289
|
+
self._emit_index_type_define(),
|
|
3290
|
+
]
|
|
2726
3291
|
if scalar_preamble:
|
|
2727
3292
|
sections.extend(("", *scalar_preamble))
|
|
2728
3293
|
sections.append("")
|
|
2729
|
-
constants_section = self._emit_constant_declarations(
|
|
3294
|
+
constants_section = self._emit_constant_declarations(inline_constants)
|
|
2730
3295
|
if constants_section:
|
|
2731
3296
|
sections.extend((constants_section.rstrip(), ""))
|
|
3297
|
+
large_constants_section = self._emit_constant_storage_definitions(
|
|
3298
|
+
large_constants
|
|
3299
|
+
)
|
|
3300
|
+
if large_constants_section:
|
|
3301
|
+
sections.extend((large_constants_section.rstrip(), ""))
|
|
2732
3302
|
if scalar_functions:
|
|
2733
3303
|
sections.extend(("\n".join(scalar_functions), ""))
|
|
3304
|
+
weight_loader = self._emit_weight_loader(model, large_constants)
|
|
2734
3305
|
sections.extend(
|
|
2735
3306
|
(
|
|
2736
3307
|
operator_fns.rstrip(),
|
|
2737
3308
|
"",
|
|
3309
|
+
weight_loader.rstrip(),
|
|
3310
|
+
"",
|
|
2738
3311
|
wrapper_fn,
|
|
2739
3312
|
)
|
|
2740
3313
|
)
|
|
@@ -2748,6 +3321,7 @@ class CEmitter:
|
|
|
2748
3321
|
testbench_inputs=testbench_inputs,
|
|
2749
3322
|
dim_order=dim_order,
|
|
2750
3323
|
dim_values=dim_values,
|
|
3324
|
+
weight_data_filename=self._weight_data_filename(model),
|
|
2751
3325
|
),
|
|
2752
3326
|
)
|
|
2753
3327
|
)
|
|
@@ -2755,14 +3329,14 @@ class CEmitter:
|
|
|
2755
3329
|
main_rendered = "\n".join(sections)
|
|
2756
3330
|
if not main_rendered.endswith("\n"):
|
|
2757
3331
|
main_rendered += "\n"
|
|
2758
|
-
data_includes = self._collect_constant_includes(
|
|
3332
|
+
data_includes = self._collect_constant_includes(inline_constants)
|
|
2759
3333
|
data_sections = [self._emit_header_comment(model.header), ""]
|
|
2760
3334
|
if data_includes:
|
|
2761
3335
|
data_sections.extend((*data_includes, ""))
|
|
2762
3336
|
else:
|
|
2763
3337
|
data_sections.append("")
|
|
2764
3338
|
data_constants = self._emit_constant_definitions(
|
|
2765
|
-
|
|
3339
|
+
inline_constants, storage_prefix="const"
|
|
2766
3340
|
)
|
|
2767
3341
|
if data_constants:
|
|
2768
3342
|
data_sections.append(data_constants.rstrip())
|
|
@@ -2856,6 +3430,23 @@ class CEmitter:
|
|
|
2856
3430
|
comment_lines.append(" */")
|
|
2857
3431
|
return "\n".join(comment_lines)
|
|
2858
3432
|
|
|
3433
|
+
@staticmethod
|
|
3434
|
+
def _emit_constant_comment(constant: ConstTensor, index: int) -> str:
|
|
3435
|
+
shape = constant.shape
|
|
3436
|
+
lines = [
|
|
3437
|
+
f"Weight {index}:",
|
|
3438
|
+
f"Name: {constant.name}",
|
|
3439
|
+
f"Shape: {shape if shape else '[]'}",
|
|
3440
|
+
f"Elements: {CEmitter._element_count(shape)}",
|
|
3441
|
+
f"Dtype: {constant.dtype.onnx_name}",
|
|
3442
|
+
]
|
|
3443
|
+
comment_lines = ["/*"]
|
|
3444
|
+
comment_lines.extend(
|
|
3445
|
+
f" * {line}" if line else " *" for line in lines
|
|
3446
|
+
)
|
|
3447
|
+
comment_lines.append(" */")
|
|
3448
|
+
return "\n".join(comment_lines)
|
|
3449
|
+
|
|
2859
3450
|
@staticmethod
|
|
2860
3451
|
def _collect_constant_includes(constants: tuple[ConstTensor, ...]) -> list[str]:
|
|
2861
3452
|
if not constants:
|
|
@@ -2920,6 +3511,7 @@ class CEmitter:
|
|
|
2920
3511
|
ScalarFunction.FMOD,
|
|
2921
3512
|
ScalarFunction.REMAINDER,
|
|
2922
3513
|
ScalarFunction.LEAKY_RELU,
|
|
3514
|
+
ScalarFunction.MISH,
|
|
2923
3515
|
ScalarFunction.MUL,
|
|
2924
3516
|
ScalarFunction.NEG,
|
|
2925
3517
|
ScalarFunction.LOGICAL_NOT,
|
|
@@ -3005,11 +3597,15 @@ class CEmitter:
|
|
|
3005
3597
|
| UnaryOp
|
|
3006
3598
|
| ClipOp
|
|
3007
3599
|
| CastOp
|
|
3600
|
+
| QuantizeLinearOp
|
|
3008
3601
|
| MatMulOp
|
|
3602
|
+
| EinsumOp
|
|
3009
3603
|
| GemmOp
|
|
3010
3604
|
| AttentionOp
|
|
3011
3605
|
| ConvOp
|
|
3606
|
+
| ConvTransposeOp
|
|
3012
3607
|
| AveragePoolOp
|
|
3608
|
+
| LpPoolOp
|
|
3013
3609
|
| BatchNormOp
|
|
3014
3610
|
| LpNormalizationOp
|
|
3015
3611
|
| InstanceNormalizationOp
|
|
@@ -3021,16 +3617,20 @@ class CEmitter:
|
|
|
3021
3617
|
| LstmOp
|
|
3022
3618
|
| SoftmaxOp
|
|
3023
3619
|
| LogSoftmaxOp
|
|
3620
|
+
| HardmaxOp
|
|
3024
3621
|
| NegativeLogLikelihoodLossOp
|
|
3025
3622
|
| SoftmaxCrossEntropyLossOp
|
|
3026
3623
|
| MaxPoolOp
|
|
3027
3624
|
| ConcatOp
|
|
3028
3625
|
| GatherElementsOp
|
|
3029
3626
|
| GatherOp
|
|
3627
|
+
| GatherNDOp
|
|
3628
|
+
| ScatterNDOp
|
|
3030
3629
|
| TransposeOp
|
|
3031
3630
|
| ReshapeOp
|
|
3032
3631
|
| IdentityOp
|
|
3033
3632
|
| EyeLikeOp
|
|
3633
|
+
| TriluOp
|
|
3034
3634
|
| TileOp
|
|
3035
3635
|
| DepthToSpaceOp
|
|
3036
3636
|
| SpaceToDepthOp
|
|
@@ -3039,21 +3639,27 @@ class CEmitter:
|
|
|
3039
3639
|
| GridSampleOp
|
|
3040
3640
|
| ReduceOp
|
|
3041
3641
|
| ArgReduceOp
|
|
3642
|
+
| TopKOp
|
|
3042
3643
|
| ConstantOfShapeOp
|
|
3043
3644
|
| ShapeOp
|
|
3044
3645
|
| SizeOp
|
|
3646
|
+
| NonZeroOp
|
|
3045
3647
|
| ExpandOp
|
|
3046
3648
|
| CumSumOp
|
|
3047
3649
|
| RangeOp
|
|
3650
|
+
| OneHotOp
|
|
3048
3651
|
| SplitOp
|
|
3049
3652
|
],
|
|
3050
3653
|
*,
|
|
3051
3654
|
emit_testbench: bool,
|
|
3052
3655
|
extra_includes: set[str] | None = None,
|
|
3656
|
+
needs_weight_loader: bool = False,
|
|
3053
3657
|
) -> list[str]:
|
|
3054
|
-
includes: set[str] = {"#include <
|
|
3658
|
+
includes: set[str] = {"#include <stdint.h>"}
|
|
3055
3659
|
if emit_testbench:
|
|
3056
|
-
includes.
|
|
3660
|
+
includes.add("#include <stdio.h>")
|
|
3661
|
+
if needs_weight_loader:
|
|
3662
|
+
includes.add("#include <stdio.h>")
|
|
3057
3663
|
if extra_includes:
|
|
3058
3664
|
includes.update(extra_includes)
|
|
3059
3665
|
if any(
|
|
@@ -3074,7 +3680,9 @@ class CEmitter:
|
|
|
3074
3680
|
*constant_of_shape_inputs,
|
|
3075
3681
|
}
|
|
3076
3682
|
model_dtypes.update(
|
|
3077
|
-
op.dtype
|
|
3683
|
+
op.dtype
|
|
3684
|
+
for op in resolved_ops
|
|
3685
|
+
if not isinstance(op, (ArgReduceOp, TopKOp))
|
|
3078
3686
|
)
|
|
3079
3687
|
arg_reduce_dtypes = {
|
|
3080
3688
|
dtype
|
|
@@ -3083,6 +3691,17 @@ class CEmitter:
|
|
|
3083
3691
|
for dtype in (op.input_dtype, op.output_dtype)
|
|
3084
3692
|
}
|
|
3085
3693
|
model_dtypes.update(arg_reduce_dtypes)
|
|
3694
|
+
topk_dtypes = {
|
|
3695
|
+
dtype
|
|
3696
|
+
for op in resolved_ops
|
|
3697
|
+
if isinstance(op, TopKOp)
|
|
3698
|
+
for dtype in (
|
|
3699
|
+
op.input_dtype,
|
|
3700
|
+
op.output_values_dtype,
|
|
3701
|
+
op.output_indices_dtype,
|
|
3702
|
+
)
|
|
3703
|
+
}
|
|
3704
|
+
model_dtypes.update(topk_dtypes)
|
|
3086
3705
|
slice_input_dtypes = {
|
|
3087
3706
|
dtype
|
|
3088
3707
|
for op in resolved_ops
|
|
@@ -3095,12 +3714,18 @@ class CEmitter:
|
|
|
3095
3714
|
)
|
|
3096
3715
|
if dtype is not None
|
|
3097
3716
|
}
|
|
3717
|
+
trilu_k_dtypes = {
|
|
3718
|
+
op.k_input_dtype
|
|
3719
|
+
for op in resolved_ops
|
|
3720
|
+
if isinstance(op, TriluOp) and op.k_input_dtype is not None
|
|
3721
|
+
}
|
|
3098
3722
|
maxpool_indices_dtypes = {
|
|
3099
3723
|
op.indices_dtype
|
|
3100
3724
|
for op in resolved_ops
|
|
3101
3725
|
if isinstance(op, MaxPoolOp) and op.indices_dtype is not None
|
|
3102
3726
|
}
|
|
3103
3727
|
model_dtypes.update(maxpool_indices_dtypes)
|
|
3728
|
+
model_dtypes.update(trilu_k_dtypes)
|
|
3104
3729
|
nll_target_dtypes = {
|
|
3105
3730
|
op.target_dtype
|
|
3106
3731
|
for op in resolved_ops
|
|
@@ -3124,12 +3749,20 @@ class CEmitter:
|
|
|
3124
3749
|
for op in resolved_ops
|
|
3125
3750
|
):
|
|
3126
3751
|
includes.add("#include <stdbool.h>")
|
|
3752
|
+
if any(
|
|
3753
|
+
isinstance(op, SoftmaxCrossEntropyLossOp)
|
|
3754
|
+
and op.ignore_index is not None
|
|
3755
|
+
for op in resolved_ops
|
|
3756
|
+
):
|
|
3757
|
+
includes.add("#include <stdbool.h>")
|
|
3127
3758
|
if any(
|
|
3128
3759
|
isinstance(op, UnaryOp)
|
|
3129
3760
|
and unary_op_symbol(op.function, dtype=op.dtype) in {"llabs", "abs"}
|
|
3130
3761
|
for op in resolved_ops
|
|
3131
3762
|
):
|
|
3132
3763
|
includes.add("#include <stdlib.h>")
|
|
3764
|
+
if any(isinstance(op, PadOp) for op in resolved_ops):
|
|
3765
|
+
includes.add("#include <stddef.h>")
|
|
3133
3766
|
if CEmitter._needs_math(resolved_ops):
|
|
3134
3767
|
includes.add("#include <math.h>")
|
|
3135
3768
|
if CEmitter._needs_limits(resolved_ops):
|
|
@@ -3140,9 +3773,9 @@ class CEmitter:
|
|
|
3140
3773
|
):
|
|
3141
3774
|
includes.add("#include <string.h>")
|
|
3142
3775
|
ordered_includes = (
|
|
3143
|
-
"#include <stddef.h>",
|
|
3144
|
-
"#include <stdio.h>",
|
|
3145
3776
|
"#include <stdint.h>",
|
|
3777
|
+
"#include <stdio.h>",
|
|
3778
|
+
"#include <stddef.h>",
|
|
3146
3779
|
"#include <stdbool.h>",
|
|
3147
3780
|
"#include <stdlib.h>",
|
|
3148
3781
|
"#include <math.h>",
|
|
@@ -3152,6 +3785,16 @@ class CEmitter:
|
|
|
3152
3785
|
)
|
|
3153
3786
|
return [include for include in ordered_includes if include in includes]
|
|
3154
3787
|
|
|
3788
|
+
@staticmethod
|
|
3789
|
+
def _emit_index_type_define() -> str:
|
|
3790
|
+
return "\n".join(
|
|
3791
|
+
(
|
|
3792
|
+
"#ifndef idx_t",
|
|
3793
|
+
"#define idx_t int32_t",
|
|
3794
|
+
"#endif",
|
|
3795
|
+
)
|
|
3796
|
+
)
|
|
3797
|
+
|
|
3155
3798
|
@staticmethod
|
|
3156
3799
|
def _needs_stdint(
|
|
3157
3800
|
model_dtypes: set[ScalarType],
|
|
@@ -3186,11 +3829,15 @@ class CEmitter:
|
|
|
3186
3829
|
| UnaryOp
|
|
3187
3830
|
| ClipOp
|
|
3188
3831
|
| CastOp
|
|
3832
|
+
| QuantizeLinearOp
|
|
3189
3833
|
| MatMulOp
|
|
3834
|
+
| EinsumOp
|
|
3190
3835
|
| GemmOp
|
|
3191
3836
|
| AttentionOp
|
|
3192
3837
|
| ConvOp
|
|
3838
|
+
| ConvTransposeOp
|
|
3193
3839
|
| AveragePoolOp
|
|
3840
|
+
| LpPoolOp
|
|
3194
3841
|
| BatchNormOp
|
|
3195
3842
|
| LpNormalizationOp
|
|
3196
3843
|
| InstanceNormalizationOp
|
|
@@ -3202,16 +3849,20 @@ class CEmitter:
|
|
|
3202
3849
|
| LstmOp
|
|
3203
3850
|
| SoftmaxOp
|
|
3204
3851
|
| LogSoftmaxOp
|
|
3852
|
+
| HardmaxOp
|
|
3205
3853
|
| NegativeLogLikelihoodLossOp
|
|
3206
3854
|
| SoftmaxCrossEntropyLossOp
|
|
3207
3855
|
| MaxPoolOp
|
|
3208
3856
|
| ConcatOp
|
|
3209
3857
|
| GatherElementsOp
|
|
3210
3858
|
| GatherOp
|
|
3859
|
+
| GatherNDOp
|
|
3860
|
+
| ScatterNDOp
|
|
3211
3861
|
| TransposeOp
|
|
3212
3862
|
| ReshapeOp
|
|
3213
3863
|
| IdentityOp
|
|
3214
3864
|
| EyeLikeOp
|
|
3865
|
+
| TriluOp
|
|
3215
3866
|
| TileOp
|
|
3216
3867
|
| DepthToSpaceOp
|
|
3217
3868
|
| SpaceToDepthOp
|
|
@@ -3220,12 +3871,15 @@ class CEmitter:
|
|
|
3220
3871
|
| GridSampleOp
|
|
3221
3872
|
| ReduceOp
|
|
3222
3873
|
| ArgReduceOp
|
|
3874
|
+
| TopKOp
|
|
3223
3875
|
| ConstantOfShapeOp
|
|
3224
3876
|
| ShapeOp
|
|
3225
3877
|
| SizeOp
|
|
3878
|
+
| NonZeroOp
|
|
3226
3879
|
| ExpandOp
|
|
3227
3880
|
| CumSumOp
|
|
3228
3881
|
| RangeOp
|
|
3882
|
+
| OneHotOp
|
|
3229
3883
|
| SplitOp
|
|
3230
3884
|
],
|
|
3231
3885
|
) -> bool:
|
|
@@ -3322,6 +3976,11 @@ class CEmitter:
|
|
|
3322
3976
|
for op in resolved_ops
|
|
3323
3977
|
):
|
|
3324
3978
|
return True
|
|
3979
|
+
if any(
|
|
3980
|
+
isinstance(op, (LpPoolOp, QuantizeLinearOp))
|
|
3981
|
+
for op in resolved_ops
|
|
3982
|
+
):
|
|
3983
|
+
return True
|
|
3325
3984
|
return False
|
|
3326
3985
|
|
|
3327
3986
|
@staticmethod
|
|
@@ -3331,11 +3990,15 @@ class CEmitter:
|
|
|
3331
3990
|
| UnaryOp
|
|
3332
3991
|
| ClipOp
|
|
3333
3992
|
| CastOp
|
|
3993
|
+
| QuantizeLinearOp
|
|
3334
3994
|
| MatMulOp
|
|
3995
|
+
| EinsumOp
|
|
3335
3996
|
| GemmOp
|
|
3336
3997
|
| AttentionOp
|
|
3337
3998
|
| ConvOp
|
|
3999
|
+
| ConvTransposeOp
|
|
3338
4000
|
| AveragePoolOp
|
|
4001
|
+
| LpPoolOp
|
|
3339
4002
|
| BatchNormOp
|
|
3340
4003
|
| LpNormalizationOp
|
|
3341
4004
|
| InstanceNormalizationOp
|
|
@@ -3347,16 +4010,19 @@ class CEmitter:
|
|
|
3347
4010
|
| LstmOp
|
|
3348
4011
|
| SoftmaxOp
|
|
3349
4012
|
| LogSoftmaxOp
|
|
4013
|
+
| HardmaxOp
|
|
3350
4014
|
| NegativeLogLikelihoodLossOp
|
|
3351
4015
|
| SoftmaxCrossEntropyLossOp
|
|
3352
4016
|
| MaxPoolOp
|
|
3353
4017
|
| ConcatOp
|
|
3354
4018
|
| GatherElementsOp
|
|
3355
4019
|
| GatherOp
|
|
4020
|
+
| GatherNDOp
|
|
3356
4021
|
| TransposeOp
|
|
3357
4022
|
| ReshapeOp
|
|
3358
4023
|
| IdentityOp
|
|
3359
4024
|
| EyeLikeOp
|
|
4025
|
+
| TriluOp
|
|
3360
4026
|
| TileOp
|
|
3361
4027
|
| DepthToSpaceOp
|
|
3362
4028
|
| SpaceToDepthOp
|
|
@@ -3365,12 +4031,15 @@ class CEmitter:
|
|
|
3365
4031
|
| GridSampleOp
|
|
3366
4032
|
| ReduceOp
|
|
3367
4033
|
| ArgReduceOp
|
|
4034
|
+
| TopKOp
|
|
3368
4035
|
| ConstantOfShapeOp
|
|
3369
4036
|
| ShapeOp
|
|
3370
4037
|
| SizeOp
|
|
4038
|
+
| NonZeroOp
|
|
3371
4039
|
| ExpandOp
|
|
3372
4040
|
| CumSumOp
|
|
3373
4041
|
| RangeOp
|
|
4042
|
+
| OneHotOp
|
|
3374
4043
|
| SplitOp
|
|
3375
4044
|
],
|
|
3376
4045
|
) -> bool:
|
|
@@ -3400,6 +4069,11 @@ class CEmitter:
|
|
|
3400
4069
|
for op in resolved_ops
|
|
3401
4070
|
):
|
|
3402
4071
|
return True
|
|
4072
|
+
if any(
|
|
4073
|
+
isinstance(op, QuantizeLinearOp) and op.dtype.is_integer
|
|
4074
|
+
for op in resolved_ops
|
|
4075
|
+
):
|
|
4076
|
+
return True
|
|
3403
4077
|
return False
|
|
3404
4078
|
|
|
3405
4079
|
def _emit_model_wrapper(
|
|
@@ -3411,11 +4085,15 @@ class CEmitter:
|
|
|
3411
4085
|
| UnaryOp
|
|
3412
4086
|
| ClipOp
|
|
3413
4087
|
| CastOp
|
|
4088
|
+
| QuantizeLinearOp
|
|
3414
4089
|
| MatMulOp
|
|
4090
|
+
| EinsumOp
|
|
3415
4091
|
| GemmOp
|
|
3416
4092
|
| AttentionOp
|
|
3417
4093
|
| ConvOp
|
|
4094
|
+
| ConvTransposeOp
|
|
3418
4095
|
| AveragePoolOp
|
|
4096
|
+
| LpPoolOp
|
|
3419
4097
|
| BatchNormOp
|
|
3420
4098
|
| LpNormalizationOp
|
|
3421
4099
|
| InstanceNormalizationOp
|
|
@@ -3427,16 +4105,19 @@ class CEmitter:
|
|
|
3427
4105
|
| LstmOp
|
|
3428
4106
|
| SoftmaxOp
|
|
3429
4107
|
| LogSoftmaxOp
|
|
4108
|
+
| HardmaxOp
|
|
3430
4109
|
| NegativeLogLikelihoodLossOp
|
|
3431
4110
|
| SoftmaxCrossEntropyLossOp
|
|
3432
4111
|
| MaxPoolOp
|
|
3433
4112
|
| ConcatOp
|
|
3434
4113
|
| GatherElementsOp
|
|
3435
4114
|
| GatherOp
|
|
4115
|
+
| GatherNDOp
|
|
3436
4116
|
| TransposeOp
|
|
3437
4117
|
| ReshapeOp
|
|
3438
4118
|
| IdentityOp
|
|
3439
4119
|
| EyeLikeOp
|
|
4120
|
+
| TriluOp
|
|
3440
4121
|
| TileOp
|
|
3441
4122
|
| DepthToSpaceOp
|
|
3442
4123
|
| SpaceToDepthOp
|
|
@@ -3445,12 +4126,15 @@ class CEmitter:
|
|
|
3445
4126
|
| GridSampleOp
|
|
3446
4127
|
| ReduceOp
|
|
3447
4128
|
| ArgReduceOp
|
|
4129
|
+
| TopKOp
|
|
3448
4130
|
| ConstantOfShapeOp
|
|
3449
4131
|
| ShapeOp
|
|
3450
4132
|
| SizeOp
|
|
4133
|
+
| NonZeroOp
|
|
3451
4134
|
| ExpandOp
|
|
3452
4135
|
| CumSumOp
|
|
3453
4136
|
| RangeOp
|
|
4137
|
+
| OneHotOp
|
|
3454
4138
|
| SplitOp
|
|
3455
4139
|
],
|
|
3456
4140
|
temp_buffers: tuple[TempBuffer, ...],
|
|
@@ -3480,8 +4164,14 @@ class CEmitter:
|
|
|
3480
4164
|
lines = [f"void {model.name}({signature}) {{"]
|
|
3481
4165
|
for temp in temp_buffers:
|
|
3482
4166
|
c_type = temp.dtype.c_type
|
|
4167
|
+
storage = (
|
|
4168
|
+
"static "
|
|
4169
|
+
if self._temp_buffer_size_bytes(temp)
|
|
4170
|
+
> self._large_temp_threshold_bytes
|
|
4171
|
+
else ""
|
|
4172
|
+
)
|
|
3483
4173
|
lines.append(
|
|
3484
|
-
f" {c_type} {temp.name}{self._array_suffix(temp.shape)};"
|
|
4174
|
+
f" {storage}{c_type} {temp.name}{self._array_suffix(temp.shape)};"
|
|
3485
4175
|
)
|
|
3486
4176
|
for index, op in enumerate(resolved_ops):
|
|
3487
4177
|
op_name = self._op_function_name(model, index)
|
|
@@ -3490,6 +4180,13 @@ class CEmitter:
|
|
|
3490
4180
|
lines.append("}")
|
|
3491
4181
|
return "\n".join(lines)
|
|
3492
4182
|
|
|
4183
|
+
@staticmethod
|
|
4184
|
+
def _temp_buffer_size_bytes(temp: TempBuffer) -> int:
|
|
4185
|
+
element_count = 1
|
|
4186
|
+
for dim in temp.shape:
|
|
4187
|
+
element_count *= dim
|
|
4188
|
+
return element_count * temp.dtype.np_dtype.itemsize
|
|
4189
|
+
|
|
3493
4190
|
@staticmethod
|
|
3494
4191
|
def _build_op_call(
|
|
3495
4192
|
op: BinaryOp
|
|
@@ -3497,11 +4194,15 @@ class CEmitter:
|
|
|
3497
4194
|
| UnaryOp
|
|
3498
4195
|
| ClipOp
|
|
3499
4196
|
| CastOp
|
|
4197
|
+
| QuantizeLinearOp
|
|
3500
4198
|
| MatMulOp
|
|
4199
|
+
| EinsumOp
|
|
3501
4200
|
| GemmOp
|
|
3502
4201
|
| AttentionOp
|
|
3503
4202
|
| ConvOp
|
|
4203
|
+
| ConvTransposeOp
|
|
3504
4204
|
| AveragePoolOp
|
|
4205
|
+
| LpPoolOp
|
|
3505
4206
|
| BatchNormOp
|
|
3506
4207
|
| LpNormalizationOp
|
|
3507
4208
|
| InstanceNormalizationOp
|
|
@@ -3513,16 +4214,20 @@ class CEmitter:
|
|
|
3513
4214
|
| LstmOp
|
|
3514
4215
|
| SoftmaxOp
|
|
3515
4216
|
| LogSoftmaxOp
|
|
4217
|
+
| HardmaxOp
|
|
3516
4218
|
| NegativeLogLikelihoodLossOp
|
|
3517
4219
|
| SoftmaxCrossEntropyLossOp
|
|
3518
4220
|
| MaxPoolOp
|
|
3519
4221
|
| ConcatOp
|
|
3520
4222
|
| GatherElementsOp
|
|
3521
4223
|
| GatherOp
|
|
4224
|
+
| GatherNDOp
|
|
4225
|
+
| ScatterNDOp
|
|
3522
4226
|
| TransposeOp
|
|
3523
4227
|
| ReshapeOp
|
|
3524
4228
|
| IdentityOp
|
|
3525
4229
|
| EyeLikeOp
|
|
4230
|
+
| TriluOp
|
|
3526
4231
|
| TileOp
|
|
3527
4232
|
| PadOp
|
|
3528
4233
|
| DepthToSpaceOp
|
|
@@ -3532,12 +4237,15 @@ class CEmitter:
|
|
|
3532
4237
|
| GridSampleOp
|
|
3533
4238
|
| ReduceOp
|
|
3534
4239
|
| ArgReduceOp
|
|
4240
|
+
| TopKOp
|
|
3535
4241
|
| ConstantOfShapeOp
|
|
3536
4242
|
| ShapeOp
|
|
3537
4243
|
| SizeOp
|
|
4244
|
+
| NonZeroOp
|
|
3538
4245
|
| ExpandOp
|
|
3539
4246
|
| CumSumOp
|
|
3540
4247
|
| RangeOp
|
|
4248
|
+
| OneHotOp
|
|
3541
4249
|
| SplitOp,
|
|
3542
4250
|
dim_order: Sequence[str],
|
|
3543
4251
|
) -> str:
|
|
@@ -3556,6 +4264,9 @@ class CEmitter:
|
|
|
3556
4264
|
if isinstance(op, MatMulOp):
|
|
3557
4265
|
args.extend([op.input0, op.input1, op.output])
|
|
3558
4266
|
return ", ".join(args)
|
|
4267
|
+
if isinstance(op, EinsumOp):
|
|
4268
|
+
args.extend([*op.inputs, op.output])
|
|
4269
|
+
return ", ".join(args)
|
|
3559
4270
|
if isinstance(op, GemmOp):
|
|
3560
4271
|
if op.input_c is None:
|
|
3561
4272
|
args.extend([op.input_a, op.input_b, op.output])
|
|
@@ -3574,6 +4285,13 @@ class CEmitter:
|
|
|
3574
4285
|
call_parts.append(op.output)
|
|
3575
4286
|
args.extend(call_parts)
|
|
3576
4287
|
return ", ".join(args)
|
|
4288
|
+
if isinstance(op, QuantizeLinearOp):
|
|
4289
|
+
call_parts = [op.input0, op.scale]
|
|
4290
|
+
if op.zero_point is not None:
|
|
4291
|
+
call_parts.append(op.zero_point)
|
|
4292
|
+
call_parts.append(op.output)
|
|
4293
|
+
args.extend(call_parts)
|
|
4294
|
+
return ", ".join(args)
|
|
3577
4295
|
if isinstance(op, AttentionOp):
|
|
3578
4296
|
call_parts = [op.input_q, op.input_k, op.input_v]
|
|
3579
4297
|
if op.input_attn_mask is not None:
|
|
@@ -3599,9 +4317,18 @@ class CEmitter:
|
|
|
3599
4317
|
return ", ".join(args)
|
|
3600
4318
|
args.extend([op.input0, op.weights, op.bias, op.output])
|
|
3601
4319
|
return ", ".join(args)
|
|
4320
|
+
if isinstance(op, ConvTransposeOp):
|
|
4321
|
+
if op.bias is None:
|
|
4322
|
+
args.extend([op.input0, op.weights, op.output])
|
|
4323
|
+
return ", ".join(args)
|
|
4324
|
+
args.extend([op.input0, op.weights, op.bias, op.output])
|
|
4325
|
+
return ", ".join(args)
|
|
3602
4326
|
if isinstance(op, AveragePoolOp):
|
|
3603
4327
|
args.extend([op.input0, op.output])
|
|
3604
4328
|
return ", ".join(args)
|
|
4329
|
+
if isinstance(op, LpPoolOp):
|
|
4330
|
+
args.extend([op.input0, op.output])
|
|
4331
|
+
return ", ".join(args)
|
|
3605
4332
|
if isinstance(op, BatchNormOp):
|
|
3606
4333
|
args.extend(
|
|
3607
4334
|
[op.input0, op.scale, op.bias, op.mean, op.variance, op.output]
|
|
@@ -3653,7 +4380,7 @@ class CEmitter:
|
|
|
3653
4380
|
call_parts.append(op.output_y_c)
|
|
3654
4381
|
args.extend(call_parts)
|
|
3655
4382
|
return ", ".join(args)
|
|
3656
|
-
if isinstance(op, (SoftmaxOp, LogSoftmaxOp)):
|
|
4383
|
+
if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
|
|
3657
4384
|
args.extend([op.input0, op.output])
|
|
3658
4385
|
return ", ".join(args)
|
|
3659
4386
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
@@ -3684,6 +4411,12 @@ class CEmitter:
|
|
|
3684
4411
|
if isinstance(op, GatherOp):
|
|
3685
4412
|
args.extend([op.data, op.indices, op.output])
|
|
3686
4413
|
return ", ".join(args)
|
|
4414
|
+
if isinstance(op, GatherNDOp):
|
|
4415
|
+
args.extend([op.data, op.indices, op.output])
|
|
4416
|
+
return ", ".join(args)
|
|
4417
|
+
if isinstance(op, ScatterNDOp):
|
|
4418
|
+
args.extend([op.data, op.indices, op.updates, op.output])
|
|
4419
|
+
return ", ".join(args)
|
|
3687
4420
|
if isinstance(op, ConcatOp):
|
|
3688
4421
|
args.extend([*op.inputs, op.output])
|
|
3689
4422
|
return ", ".join(args)
|
|
@@ -3696,9 +4429,18 @@ class CEmitter:
|
|
|
3696
4429
|
if isinstance(op, SizeOp):
|
|
3697
4430
|
args.extend([op.input0, op.output])
|
|
3698
4431
|
return ", ".join(args)
|
|
4432
|
+
if isinstance(op, NonZeroOp):
|
|
4433
|
+
args.extend([op.input0, op.output])
|
|
4434
|
+
return ", ".join(args)
|
|
3699
4435
|
if isinstance(op, ExpandOp):
|
|
3700
4436
|
args.extend([op.input0, op.output])
|
|
3701
4437
|
return ", ".join(args)
|
|
4438
|
+
if isinstance(op, TriluOp):
|
|
4439
|
+
call_parts = [op.input0, op.output]
|
|
4440
|
+
if op.k_input is not None:
|
|
4441
|
+
call_parts.append(op.k_input)
|
|
4442
|
+
args.extend(call_parts)
|
|
4443
|
+
return ", ".join(args)
|
|
3702
4444
|
if isinstance(op, CumSumOp):
|
|
3703
4445
|
args.append(op.input0)
|
|
3704
4446
|
if op.axis_input is not None:
|
|
@@ -3708,6 +4450,9 @@ class CEmitter:
|
|
|
3708
4450
|
if isinstance(op, RangeOp):
|
|
3709
4451
|
args.extend([op.start, op.limit, op.delta, op.output])
|
|
3710
4452
|
return ", ".join(args)
|
|
4453
|
+
if isinstance(op, OneHotOp):
|
|
4454
|
+
args.extend([op.indices, op.depth, op.values, op.output])
|
|
4455
|
+
return ", ".join(args)
|
|
3711
4456
|
if isinstance(op, SplitOp):
|
|
3712
4457
|
args.extend([op.input0, *op.outputs])
|
|
3713
4458
|
return ", ".join(args)
|
|
@@ -3749,6 +4494,12 @@ class CEmitter:
|
|
|
3749
4494
|
call_parts.append(op.output)
|
|
3750
4495
|
args.extend(call_parts)
|
|
3751
4496
|
return ", ".join(args)
|
|
4497
|
+
if isinstance(op, TriluOp):
|
|
4498
|
+
call_parts = [op.input0, op.output]
|
|
4499
|
+
if op.k_input is not None:
|
|
4500
|
+
call_parts.append(op.k_input)
|
|
4501
|
+
args.extend(call_parts)
|
|
4502
|
+
return ", ".join(args)
|
|
3752
4503
|
if isinstance(op, GridSampleOp):
|
|
3753
4504
|
args.extend([op.input0, op.grid, op.output])
|
|
3754
4505
|
return ", ".join(args)
|
|
@@ -3761,6 +4512,9 @@ class CEmitter:
|
|
|
3761
4512
|
if isinstance(op, ArgReduceOp):
|
|
3762
4513
|
args.extend([op.input0, op.output])
|
|
3763
4514
|
return ", ".join(args)
|
|
4515
|
+
if isinstance(op, TopKOp):
|
|
4516
|
+
args.extend([op.input0, op.output_values, op.output_indices])
|
|
4517
|
+
return ", ".join(args)
|
|
3764
4518
|
args.extend([op.input0, op.output])
|
|
3765
4519
|
return ", ".join(args)
|
|
3766
4520
|
|
|
@@ -3792,11 +4546,11 @@ class CEmitter:
|
|
|
3792
4546
|
return {}
|
|
3793
4547
|
if len(intermediates) == 1:
|
|
3794
4548
|
name, shape, dtype = intermediates[0]
|
|
3795
|
-
temp_name = allocate_temp_name("
|
|
4549
|
+
temp_name = allocate_temp_name(f"tmp0_{name}")
|
|
3796
4550
|
return {name: TempBuffer(name=temp_name, shape=shape, dtype=dtype)}
|
|
3797
4551
|
return {
|
|
3798
4552
|
name: TempBuffer(
|
|
3799
|
-
name=allocate_temp_name(f"tmp{index}"),
|
|
4553
|
+
name=allocate_temp_name(f"tmp{index}_{name}"),
|
|
3800
4554
|
shape=shape,
|
|
3801
4555
|
dtype=dtype,
|
|
3802
4556
|
)
|
|
@@ -3811,11 +4565,15 @@ class CEmitter:
|
|
|
3811
4565
|
| UnaryOp
|
|
3812
4566
|
| ClipOp
|
|
3813
4567
|
| CastOp
|
|
4568
|
+
| QuantizeLinearOp
|
|
3814
4569
|
| MatMulOp
|
|
4570
|
+
| EinsumOp
|
|
3815
4571
|
| GemmOp
|
|
3816
4572
|
| AttentionOp
|
|
3817
4573
|
| ConvOp
|
|
4574
|
+
| ConvTransposeOp
|
|
3818
4575
|
| AveragePoolOp
|
|
4576
|
+
| LpPoolOp
|
|
3819
4577
|
| BatchNormOp
|
|
3820
4578
|
| LpNormalizationOp
|
|
3821
4579
|
| InstanceNormalizationOp
|
|
@@ -3827,16 +4585,20 @@ class CEmitter:
|
|
|
3827
4585
|
| LstmOp
|
|
3828
4586
|
| SoftmaxOp
|
|
3829
4587
|
| LogSoftmaxOp
|
|
4588
|
+
| HardmaxOp
|
|
3830
4589
|
| NegativeLogLikelihoodLossOp
|
|
3831
4590
|
| SoftmaxCrossEntropyLossOp
|
|
3832
4591
|
| MaxPoolOp
|
|
3833
4592
|
| ConcatOp
|
|
3834
4593
|
| GatherElementsOp
|
|
3835
4594
|
| GatherOp
|
|
4595
|
+
| GatherNDOp
|
|
4596
|
+
| ScatterNDOp
|
|
3836
4597
|
| TransposeOp
|
|
3837
4598
|
| ReshapeOp
|
|
3838
4599
|
| IdentityOp
|
|
3839
4600
|
| EyeLikeOp
|
|
4601
|
+
| TriluOp
|
|
3840
4602
|
| TileOp
|
|
3841
4603
|
| DepthToSpaceOp
|
|
3842
4604
|
| SpaceToDepthOp
|
|
@@ -3845,12 +4607,15 @@ class CEmitter:
|
|
|
3845
4607
|
| GridSampleOp
|
|
3846
4608
|
| ReduceOp
|
|
3847
4609
|
| ArgReduceOp
|
|
4610
|
+
| TopKOp
|
|
3848
4611
|
| ConstantOfShapeOp
|
|
3849
4612
|
| ShapeOp
|
|
3850
4613
|
| SizeOp
|
|
4614
|
+
| NonZeroOp
|
|
3851
4615
|
| ExpandOp
|
|
3852
4616
|
| CumSumOp
|
|
3853
4617
|
| RangeOp
|
|
4618
|
+
| OneHotOp
|
|
3854
4619
|
| SplitOp,
|
|
3855
4620
|
temp_map: dict[str, str],
|
|
3856
4621
|
) -> (
|
|
@@ -3860,11 +4625,15 @@ class CEmitter:
|
|
|
3860
4625
|
| UnaryOp
|
|
3861
4626
|
| ClipOp
|
|
3862
4627
|
| CastOp
|
|
4628
|
+
| QuantizeLinearOp
|
|
3863
4629
|
| MatMulOp
|
|
4630
|
+
| EinsumOp
|
|
3864
4631
|
| GemmOp
|
|
3865
4632
|
| AttentionOp
|
|
3866
4633
|
| ConvOp
|
|
4634
|
+
| ConvTransposeOp
|
|
3867
4635
|
| AveragePoolOp
|
|
4636
|
+
| LpPoolOp
|
|
3868
4637
|
| BatchNormOp
|
|
3869
4638
|
| LpNormalizationOp
|
|
3870
4639
|
| InstanceNormalizationOp
|
|
@@ -3876,16 +4645,20 @@ class CEmitter:
|
|
|
3876
4645
|
| LstmOp
|
|
3877
4646
|
| SoftmaxOp
|
|
3878
4647
|
| LogSoftmaxOp
|
|
4648
|
+
| HardmaxOp
|
|
3879
4649
|
| NegativeLogLikelihoodLossOp
|
|
3880
4650
|
| SoftmaxCrossEntropyLossOp
|
|
3881
4651
|
| MaxPoolOp
|
|
3882
4652
|
| ConcatOp
|
|
3883
4653
|
| GatherElementsOp
|
|
3884
4654
|
| GatherOp
|
|
4655
|
+
| GatherNDOp
|
|
4656
|
+
| ScatterNDOp
|
|
3885
4657
|
| TransposeOp
|
|
3886
4658
|
| ReshapeOp
|
|
3887
4659
|
| IdentityOp
|
|
3888
4660
|
| EyeLikeOp
|
|
4661
|
+
| TriluOp
|
|
3889
4662
|
| TileOp
|
|
3890
4663
|
| DepthToSpaceOp
|
|
3891
4664
|
| SpaceToDepthOp
|
|
@@ -3894,12 +4667,15 @@ class CEmitter:
|
|
|
3894
4667
|
| GridSampleOp
|
|
3895
4668
|
| ReduceOp
|
|
3896
4669
|
| ArgReduceOp
|
|
4670
|
+
| TopKOp
|
|
3897
4671
|
| ConstantOfShapeOp
|
|
3898
4672
|
| ShapeOp
|
|
3899
4673
|
| SizeOp
|
|
4674
|
+
| NonZeroOp
|
|
3900
4675
|
| ExpandOp
|
|
3901
4676
|
| CumSumOp
|
|
3902
4677
|
| RangeOp
|
|
4678
|
+
| OneHotOp
|
|
3903
4679
|
| SplitOp
|
|
3904
4680
|
):
|
|
3905
4681
|
if isinstance(op, BinaryOp):
|
|
@@ -3909,6 +4685,8 @@ class CEmitter:
|
|
|
3909
4685
|
output=temp_map.get(op.output, op.output),
|
|
3910
4686
|
function=op.function,
|
|
3911
4687
|
operator_kind=op.operator_kind,
|
|
4688
|
+
input0_shape=op.input0_shape,
|
|
4689
|
+
input1_shape=op.input1_shape,
|
|
3912
4690
|
shape=op.shape,
|
|
3913
4691
|
dtype=op.dtype,
|
|
3914
4692
|
input_dtype=op.input_dtype,
|
|
@@ -3979,6 +4757,16 @@ class CEmitter:
|
|
|
3979
4757
|
right_vector=op.right_vector,
|
|
3980
4758
|
dtype=op.dtype,
|
|
3981
4759
|
)
|
|
4760
|
+
if isinstance(op, EinsumOp):
|
|
4761
|
+
return EinsumOp(
|
|
4762
|
+
inputs=tuple(temp_map.get(name, name) for name in op.inputs),
|
|
4763
|
+
output=temp_map.get(op.output, op.output),
|
|
4764
|
+
kind=op.kind,
|
|
4765
|
+
input_shapes=op.input_shapes,
|
|
4766
|
+
output_shape=op.output_shape,
|
|
4767
|
+
dtype=op.dtype,
|
|
4768
|
+
input_dtype=op.input_dtype,
|
|
4769
|
+
)
|
|
3982
4770
|
if isinstance(op, CastOp):
|
|
3983
4771
|
return CastOp(
|
|
3984
4772
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -3987,6 +4775,22 @@ class CEmitter:
|
|
|
3987
4775
|
input_dtype=op.input_dtype,
|
|
3988
4776
|
dtype=op.dtype,
|
|
3989
4777
|
)
|
|
4778
|
+
if isinstance(op, QuantizeLinearOp):
|
|
4779
|
+
return QuantizeLinearOp(
|
|
4780
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
4781
|
+
scale=temp_map.get(op.scale, op.scale),
|
|
4782
|
+
zero_point=(
|
|
4783
|
+
temp_map.get(op.zero_point, op.zero_point)
|
|
4784
|
+
if op.zero_point is not None
|
|
4785
|
+
else None
|
|
4786
|
+
),
|
|
4787
|
+
output=temp_map.get(op.output, op.output),
|
|
4788
|
+
input_shape=op.input_shape,
|
|
4789
|
+
axis=op.axis,
|
|
4790
|
+
dtype=op.dtype,
|
|
4791
|
+
input_dtype=op.input_dtype,
|
|
4792
|
+
scale_dtype=op.scale_dtype,
|
|
4793
|
+
)
|
|
3990
4794
|
if isinstance(op, GemmOp):
|
|
3991
4795
|
return GemmOp(
|
|
3992
4796
|
input_a=temp_map.get(op.input_a, op.input_a),
|
|
@@ -4160,6 +4964,26 @@ class CEmitter:
|
|
|
4160
4964
|
group=op.group,
|
|
4161
4965
|
dtype=op.dtype,
|
|
4162
4966
|
)
|
|
4967
|
+
if isinstance(op, ConvTransposeOp):
|
|
4968
|
+
return ConvTransposeOp(
|
|
4969
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
4970
|
+
weights=temp_map.get(op.weights, op.weights),
|
|
4971
|
+
bias=temp_map.get(op.bias, op.bias) if op.bias else None,
|
|
4972
|
+
output=temp_map.get(op.output, op.output),
|
|
4973
|
+
batch=op.batch,
|
|
4974
|
+
in_channels=op.in_channels,
|
|
4975
|
+
out_channels=op.out_channels,
|
|
4976
|
+
spatial_rank=op.spatial_rank,
|
|
4977
|
+
in_spatial=op.in_spatial,
|
|
4978
|
+
out_spatial=op.out_spatial,
|
|
4979
|
+
kernel_shape=op.kernel_shape,
|
|
4980
|
+
strides=op.strides,
|
|
4981
|
+
pads=op.pads,
|
|
4982
|
+
dilations=op.dilations,
|
|
4983
|
+
output_padding=op.output_padding,
|
|
4984
|
+
group=op.group,
|
|
4985
|
+
dtype=op.dtype,
|
|
4986
|
+
)
|
|
4163
4987
|
if isinstance(op, AveragePoolOp):
|
|
4164
4988
|
return AveragePoolOp(
|
|
4165
4989
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4181,6 +5005,27 @@ class CEmitter:
|
|
|
4181
5005
|
count_include_pad=op.count_include_pad,
|
|
4182
5006
|
dtype=op.dtype,
|
|
4183
5007
|
)
|
|
5008
|
+
if isinstance(op, LpPoolOp):
|
|
5009
|
+
return LpPoolOp(
|
|
5010
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
5011
|
+
output=temp_map.get(op.output, op.output),
|
|
5012
|
+
batch=op.batch,
|
|
5013
|
+
channels=op.channels,
|
|
5014
|
+
in_h=op.in_h,
|
|
5015
|
+
in_w=op.in_w,
|
|
5016
|
+
out_h=op.out_h,
|
|
5017
|
+
out_w=op.out_w,
|
|
5018
|
+
kernel_h=op.kernel_h,
|
|
5019
|
+
kernel_w=op.kernel_w,
|
|
5020
|
+
stride_h=op.stride_h,
|
|
5021
|
+
stride_w=op.stride_w,
|
|
5022
|
+
pad_top=op.pad_top,
|
|
5023
|
+
pad_left=op.pad_left,
|
|
5024
|
+
pad_bottom=op.pad_bottom,
|
|
5025
|
+
pad_right=op.pad_right,
|
|
5026
|
+
p=op.p,
|
|
5027
|
+
dtype=op.dtype,
|
|
5028
|
+
)
|
|
4184
5029
|
if isinstance(op, BatchNormOp):
|
|
4185
5030
|
return BatchNormOp(
|
|
4186
5031
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4318,6 +5163,17 @@ class CEmitter:
|
|
|
4318
5163
|
shape=op.shape,
|
|
4319
5164
|
dtype=op.dtype,
|
|
4320
5165
|
)
|
|
5166
|
+
if isinstance(op, HardmaxOp):
|
|
5167
|
+
return HardmaxOp(
|
|
5168
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
5169
|
+
output=temp_map.get(op.output, op.output),
|
|
5170
|
+
outer=op.outer,
|
|
5171
|
+
axis_size=op.axis_size,
|
|
5172
|
+
inner=op.inner,
|
|
5173
|
+
axis=op.axis,
|
|
5174
|
+
shape=op.shape,
|
|
5175
|
+
dtype=op.dtype,
|
|
5176
|
+
)
|
|
4321
5177
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
4322
5178
|
return NegativeLogLikelihoodLossOp(
|
|
4323
5179
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4419,6 +5275,32 @@ class CEmitter:
|
|
|
4419
5275
|
dtype=op.dtype,
|
|
4420
5276
|
indices_dtype=op.indices_dtype,
|
|
4421
5277
|
)
|
|
5278
|
+
if isinstance(op, GatherNDOp):
|
|
5279
|
+
return GatherNDOp(
|
|
5280
|
+
data=temp_map.get(op.data, op.data),
|
|
5281
|
+
indices=temp_map.get(op.indices, op.indices),
|
|
5282
|
+
output=temp_map.get(op.output, op.output),
|
|
5283
|
+
batch_dims=op.batch_dims,
|
|
5284
|
+
data_shape=op.data_shape,
|
|
5285
|
+
indices_shape=op.indices_shape,
|
|
5286
|
+
output_shape=op.output_shape,
|
|
5287
|
+
dtype=op.dtype,
|
|
5288
|
+
indices_dtype=op.indices_dtype,
|
|
5289
|
+
)
|
|
5290
|
+
if isinstance(op, ScatterNDOp):
|
|
5291
|
+
return ScatterNDOp(
|
|
5292
|
+
data=temp_map.get(op.data, op.data),
|
|
5293
|
+
indices=temp_map.get(op.indices, op.indices),
|
|
5294
|
+
updates=temp_map.get(op.updates, op.updates),
|
|
5295
|
+
output=temp_map.get(op.output, op.output),
|
|
5296
|
+
data_shape=op.data_shape,
|
|
5297
|
+
indices_shape=op.indices_shape,
|
|
5298
|
+
updates_shape=op.updates_shape,
|
|
5299
|
+
output_shape=op.output_shape,
|
|
5300
|
+
reduction=op.reduction,
|
|
5301
|
+
dtype=op.dtype,
|
|
5302
|
+
indices_dtype=op.indices_dtype,
|
|
5303
|
+
)
|
|
4422
5304
|
if isinstance(op, ConcatOp):
|
|
4423
5305
|
return ConcatOp(
|
|
4424
5306
|
inputs=tuple(temp_map.get(name, name) for name in op.inputs),
|
|
@@ -4458,6 +5340,15 @@ class CEmitter:
|
|
|
4458
5340
|
dtype=op.dtype,
|
|
4459
5341
|
input_dtype=op.input_dtype,
|
|
4460
5342
|
)
|
|
5343
|
+
if isinstance(op, NonZeroOp):
|
|
5344
|
+
return NonZeroOp(
|
|
5345
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
5346
|
+
output=temp_map.get(op.output, op.output),
|
|
5347
|
+
input_shape=op.input_shape,
|
|
5348
|
+
output_shape=op.output_shape,
|
|
5349
|
+
dtype=op.dtype,
|
|
5350
|
+
input_dtype=op.input_dtype,
|
|
5351
|
+
)
|
|
4461
5352
|
if isinstance(op, ExpandOp):
|
|
4462
5353
|
return ExpandOp(
|
|
4463
5354
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4493,6 +5384,21 @@ class CEmitter:
|
|
|
4493
5384
|
dtype=op.dtype,
|
|
4494
5385
|
input_dtype=op.input_dtype,
|
|
4495
5386
|
)
|
|
5387
|
+
if isinstance(op, OneHotOp):
|
|
5388
|
+
return OneHotOp(
|
|
5389
|
+
indices=temp_map.get(op.indices, op.indices),
|
|
5390
|
+
depth=temp_map.get(op.depth, op.depth),
|
|
5391
|
+
values=temp_map.get(op.values, op.values),
|
|
5392
|
+
output=temp_map.get(op.output, op.output),
|
|
5393
|
+
axis=op.axis,
|
|
5394
|
+
indices_shape=op.indices_shape,
|
|
5395
|
+
values_shape=op.values_shape,
|
|
5396
|
+
output_shape=op.output_shape,
|
|
5397
|
+
depth_dim=op.depth_dim,
|
|
5398
|
+
dtype=op.dtype,
|
|
5399
|
+
indices_dtype=op.indices_dtype,
|
|
5400
|
+
depth_dtype=op.depth_dtype,
|
|
5401
|
+
)
|
|
4496
5402
|
if isinstance(op, SplitOp):
|
|
4497
5403
|
return SplitOp(
|
|
4498
5404
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4542,6 +5448,24 @@ class CEmitter:
|
|
|
4542
5448
|
dtype=op.dtype,
|
|
4543
5449
|
input_dtype=op.input_dtype,
|
|
4544
5450
|
)
|
|
5451
|
+
if isinstance(op, TriluOp):
|
|
5452
|
+
return TriluOp(
|
|
5453
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
5454
|
+
output=temp_map.get(op.output, op.output),
|
|
5455
|
+
input_shape=op.input_shape,
|
|
5456
|
+
output_shape=op.output_shape,
|
|
5457
|
+
upper=op.upper,
|
|
5458
|
+
k_value=op.k_value,
|
|
5459
|
+
k_input=(
|
|
5460
|
+
temp_map.get(op.k_input, op.k_input)
|
|
5461
|
+
if op.k_input is not None
|
|
5462
|
+
else None
|
|
5463
|
+
),
|
|
5464
|
+
k_input_shape=op.k_input_shape,
|
|
5465
|
+
k_input_dtype=op.k_input_dtype,
|
|
5466
|
+
dtype=op.dtype,
|
|
5467
|
+
input_dtype=op.input_dtype,
|
|
5468
|
+
)
|
|
4545
5469
|
if isinstance(op, TileOp):
|
|
4546
5470
|
return TileOp(
|
|
4547
5471
|
input0=temp_map.get(op.input0, op.input0),
|
|
@@ -4726,6 +5650,21 @@ class CEmitter:
|
|
|
4726
5650
|
input_dtype=op.input_dtype,
|
|
4727
5651
|
output_dtype=op.output_dtype,
|
|
4728
5652
|
)
|
|
5653
|
+
if isinstance(op, TopKOp):
|
|
5654
|
+
return TopKOp(
|
|
5655
|
+
input0=temp_map.get(op.input0, op.input0),
|
|
5656
|
+
output_values=temp_map.get(op.output_values, op.output_values),
|
|
5657
|
+
output_indices=temp_map.get(op.output_indices, op.output_indices),
|
|
5658
|
+
input_shape=op.input_shape,
|
|
5659
|
+
output_shape=op.output_shape,
|
|
5660
|
+
axis=op.axis,
|
|
5661
|
+
k=op.k,
|
|
5662
|
+
largest=op.largest,
|
|
5663
|
+
sorted=op.sorted,
|
|
5664
|
+
input_dtype=op.input_dtype,
|
|
5665
|
+
output_values_dtype=op.output_values_dtype,
|
|
5666
|
+
output_indices_dtype=op.output_indices_dtype,
|
|
5667
|
+
)
|
|
4729
5668
|
return UnaryOp(
|
|
4730
5669
|
input0=temp_map.get(op.input0, op.input0),
|
|
4731
5670
|
output=temp_map.get(op.output, op.output),
|
|
@@ -4743,11 +5682,15 @@ class CEmitter:
|
|
|
4743
5682
|
| UnaryOp
|
|
4744
5683
|
| ClipOp
|
|
4745
5684
|
| CastOp
|
|
5685
|
+
| QuantizeLinearOp
|
|
4746
5686
|
| MatMulOp
|
|
5687
|
+
| EinsumOp
|
|
4747
5688
|
| GemmOp
|
|
4748
5689
|
| AttentionOp
|
|
4749
5690
|
| ConvOp
|
|
5691
|
+
| ConvTransposeOp
|
|
4750
5692
|
| AveragePoolOp
|
|
5693
|
+
| LpPoolOp
|
|
4751
5694
|
| BatchNormOp
|
|
4752
5695
|
| LpNormalizationOp
|
|
4753
5696
|
| InstanceNormalizationOp
|
|
@@ -4759,16 +5702,20 @@ class CEmitter:
|
|
|
4759
5702
|
| LstmOp
|
|
4760
5703
|
| SoftmaxOp
|
|
4761
5704
|
| LogSoftmaxOp
|
|
5705
|
+
| HardmaxOp
|
|
4762
5706
|
| NegativeLogLikelihoodLossOp
|
|
4763
5707
|
| SoftmaxCrossEntropyLossOp
|
|
4764
5708
|
| MaxPoolOp
|
|
4765
5709
|
| ConcatOp
|
|
4766
5710
|
| GatherElementsOp
|
|
4767
5711
|
| GatherOp
|
|
5712
|
+
| GatherNDOp
|
|
5713
|
+
| ScatterNDOp
|
|
4768
5714
|
| TransposeOp
|
|
4769
5715
|
| ReshapeOp
|
|
4770
5716
|
| IdentityOp
|
|
4771
5717
|
| EyeLikeOp
|
|
5718
|
+
| TriluOp
|
|
4772
5719
|
| TileOp
|
|
4773
5720
|
| DepthToSpaceOp
|
|
4774
5721
|
| SpaceToDepthOp
|
|
@@ -4777,12 +5724,15 @@ class CEmitter:
|
|
|
4777
5724
|
| GridSampleOp
|
|
4778
5725
|
| ReduceOp
|
|
4779
5726
|
| ArgReduceOp
|
|
5727
|
+
| TopKOp
|
|
4780
5728
|
| ConstantOfShapeOp
|
|
4781
5729
|
| ShapeOp
|
|
4782
5730
|
| SizeOp
|
|
5731
|
+
| NonZeroOp
|
|
4783
5732
|
| ExpandOp
|
|
4784
5733
|
| CumSumOp
|
|
4785
5734
|
| RangeOp
|
|
5735
|
+
| OneHotOp
|
|
4786
5736
|
| SplitOp,
|
|
4787
5737
|
index: int,
|
|
4788
5738
|
*,
|
|
@@ -4798,11 +5748,15 @@ class CEmitter:
|
|
|
4798
5748
|
unary_template,
|
|
4799
5749
|
clip_template,
|
|
4800
5750
|
cast_template,
|
|
5751
|
+
quantize_linear_template,
|
|
4801
5752
|
matmul_template,
|
|
5753
|
+
einsum_template,
|
|
4802
5754
|
gemm_template,
|
|
4803
5755
|
attention_template,
|
|
4804
5756
|
conv_template,
|
|
5757
|
+
conv_transpose_template,
|
|
4805
5758
|
avg_pool_template,
|
|
5759
|
+
lp_pool_template,
|
|
4806
5760
|
batch_norm_template,
|
|
4807
5761
|
lp_norm_template,
|
|
4808
5762
|
instance_norm_template,
|
|
@@ -4814,16 +5768,20 @@ class CEmitter:
|
|
|
4814
5768
|
lstm_template,
|
|
4815
5769
|
softmax_template,
|
|
4816
5770
|
logsoftmax_template,
|
|
5771
|
+
hardmax_template,
|
|
4817
5772
|
nllloss_template,
|
|
4818
5773
|
softmax_cross_entropy_loss_template,
|
|
4819
5774
|
maxpool_template,
|
|
4820
5775
|
concat_template,
|
|
4821
5776
|
gather_elements_template,
|
|
4822
5777
|
gather_template,
|
|
5778
|
+
gather_nd_template,
|
|
5779
|
+
scatter_nd_template,
|
|
4823
5780
|
transpose_template,
|
|
4824
5781
|
reshape_template,
|
|
4825
5782
|
identity_template,
|
|
4826
5783
|
eye_like_template,
|
|
5784
|
+
trilu_template,
|
|
4827
5785
|
tile_template,
|
|
4828
5786
|
pad_template,
|
|
4829
5787
|
depth_to_space_template,
|
|
@@ -4835,12 +5793,15 @@ class CEmitter:
|
|
|
4835
5793
|
reduce_template,
|
|
4836
5794
|
reduce_dynamic_template,
|
|
4837
5795
|
arg_reduce_template,
|
|
5796
|
+
topk_template,
|
|
4838
5797
|
constant_of_shape_template,
|
|
4839
5798
|
shape_template,
|
|
4840
5799
|
size_template,
|
|
5800
|
+
nonzero_template,
|
|
4841
5801
|
expand_template,
|
|
4842
5802
|
cumsum_template,
|
|
4843
5803
|
range_template,
|
|
5804
|
+
one_hot_template,
|
|
4844
5805
|
split_template,
|
|
4845
5806
|
scalar_registry: ScalarFunctionRegistry | None = None,
|
|
4846
5807
|
dim_args: str = "",
|
|
@@ -4885,21 +5846,27 @@ class CEmitter:
|
|
|
4885
5846
|
output_dim_names = _dim_names_for(op.output)
|
|
4886
5847
|
shape = CEmitter._shape_dim_exprs(op.shape, output_dim_names)
|
|
4887
5848
|
loop_vars = CEmitter._loop_vars(op.shape)
|
|
4888
|
-
|
|
5849
|
+
output_suffix = self._param_array_suffix(op.shape, output_dim_names)
|
|
5850
|
+
input0_suffix = self._param_array_suffix(
|
|
5851
|
+
op.input0_shape, _dim_names_for(op.input0)
|
|
5852
|
+
)
|
|
5853
|
+
input1_suffix = self._param_array_suffix(
|
|
5854
|
+
op.input1_shape, _dim_names_for(op.input1)
|
|
5855
|
+
)
|
|
4889
5856
|
input_c_type = op.input_dtype.c_type
|
|
4890
5857
|
output_c_type = op.dtype.c_type
|
|
4891
5858
|
param_decls = self._build_param_decls(
|
|
4892
5859
|
[
|
|
4893
|
-
(params["input0"], input_c_type,
|
|
4894
|
-
(params["input1"], input_c_type,
|
|
4895
|
-
(params["output"], output_c_type,
|
|
5860
|
+
(params["input0"], input_c_type, input0_suffix, True),
|
|
5861
|
+
(params["input1"], input_c_type, input1_suffix, True),
|
|
5862
|
+
(params["output"], output_c_type, output_suffix, False),
|
|
4896
5863
|
]
|
|
4897
5864
|
)
|
|
4898
5865
|
common = {
|
|
4899
5866
|
"model_name": model.name,
|
|
4900
5867
|
"op_name": op_name,
|
|
4901
5868
|
"element_count": CEmitter._element_count_expr(shape),
|
|
4902
|
-
"array_suffix":
|
|
5869
|
+
"array_suffix": output_suffix,
|
|
4903
5870
|
"shape": shape,
|
|
4904
5871
|
"loop_vars": loop_vars,
|
|
4905
5872
|
"input_c_type": input_c_type,
|
|
@@ -4908,11 +5875,17 @@ class CEmitter:
|
|
|
4908
5875
|
"dim_args": dim_args,
|
|
4909
5876
|
"params": param_decls,
|
|
4910
5877
|
}
|
|
4911
|
-
left_expr =
|
|
4912
|
-
|
|
5878
|
+
left_expr = CEmitter._broadcast_index_expr(
|
|
5879
|
+
params["input0"],
|
|
5880
|
+
op.input0_shape,
|
|
5881
|
+
op.shape,
|
|
5882
|
+
loop_vars,
|
|
4913
5883
|
)
|
|
4914
|
-
right_expr =
|
|
4915
|
-
|
|
5884
|
+
right_expr = CEmitter._broadcast_index_expr(
|
|
5885
|
+
params["input1"],
|
|
5886
|
+
op.input1_shape,
|
|
5887
|
+
op.shape,
|
|
5888
|
+
loop_vars,
|
|
4916
5889
|
)
|
|
4917
5890
|
operator_expr = None
|
|
4918
5891
|
operator = op_spec.operator
|
|
@@ -5177,39 +6150,170 @@ class CEmitter:
|
|
|
5177
6150
|
k=op.k,
|
|
5178
6151
|
).rstrip()
|
|
5179
6152
|
return with_node_comment(rendered)
|
|
5180
|
-
if isinstance(op,
|
|
6153
|
+
if isinstance(op, EinsumOp):
|
|
5181
6154
|
params = self._shared_param_map(
|
|
5182
6155
|
[
|
|
5183
|
-
(
|
|
5184
|
-
|
|
5185
|
-
|
|
6156
|
+
*(
|
|
6157
|
+
(f"input{idx}", name)
|
|
6158
|
+
for idx, name in enumerate(op.inputs)
|
|
6159
|
+
),
|
|
5186
6160
|
("output", op.output),
|
|
5187
6161
|
]
|
|
5188
6162
|
)
|
|
5189
|
-
|
|
5190
|
-
|
|
5191
|
-
|
|
5192
|
-
input_b_suffix = self._param_array_suffix(input_b_shape)
|
|
5193
|
-
output_suffix = self._param_array_suffix((op.m, op.n))
|
|
5194
|
-
c_suffix = (
|
|
5195
|
-
self._param_array_suffix(op.c_shape)
|
|
5196
|
-
if op.c_shape is not None
|
|
5197
|
-
else ""
|
|
6163
|
+
output_dim_names = _dim_names_for(op.output)
|
|
6164
|
+
output_shape = CEmitter._shape_dim_exprs(
|
|
6165
|
+
op.output_shape, output_dim_names
|
|
5198
6166
|
)
|
|
6167
|
+
output_loop_vars = CEmitter._loop_vars(op.output_shape)
|
|
6168
|
+
if output_loop_vars:
|
|
6169
|
+
output_expr = f"{params['output']}" + "".join(
|
|
6170
|
+
f"[{var}]" for var in output_loop_vars
|
|
6171
|
+
)
|
|
6172
|
+
else:
|
|
6173
|
+
output_expr = f"{params['output']}[0]"
|
|
6174
|
+
input_shapes = op.input_shapes
|
|
6175
|
+
input_dim_names = [
|
|
6176
|
+
_dim_names_for(name) for name in op.inputs
|
|
6177
|
+
]
|
|
6178
|
+
input_suffixes = [
|
|
6179
|
+
self._param_array_suffix(shape, dim_names)
|
|
6180
|
+
for shape, dim_names in zip(input_shapes, input_dim_names)
|
|
6181
|
+
]
|
|
5199
6182
|
param_decls = self._build_param_decls(
|
|
5200
6183
|
[
|
|
5201
|
-
(
|
|
5202
|
-
|
|
6184
|
+
*(
|
|
6185
|
+
(
|
|
6186
|
+
params[f"input{idx}"],
|
|
6187
|
+
op.input_dtype.c_type,
|
|
6188
|
+
input_suffixes[idx],
|
|
6189
|
+
True,
|
|
6190
|
+
)
|
|
6191
|
+
for idx in range(len(op.inputs))
|
|
6192
|
+
),
|
|
5203
6193
|
(
|
|
5204
|
-
params["
|
|
5205
|
-
c_type,
|
|
5206
|
-
|
|
5207
|
-
|
|
5208
|
-
)
|
|
5209
|
-
|
|
5210
|
-
|
|
5211
|
-
|
|
5212
|
-
|
|
6194
|
+
params["output"],
|
|
6195
|
+
op.dtype.c_type,
|
|
6196
|
+
self._param_array_suffix(op.output_shape, output_dim_names),
|
|
6197
|
+
False,
|
|
6198
|
+
),
|
|
6199
|
+
]
|
|
6200
|
+
)
|
|
6201
|
+
input_loop_vars: tuple[str, ...] = ()
|
|
6202
|
+
input_loop_bounds: tuple[str | int, ...] = ()
|
|
6203
|
+
reduce_loop_var = "k"
|
|
6204
|
+
reduce_loop_bound: str | int | None = None
|
|
6205
|
+
input_expr = None
|
|
6206
|
+
input0_expr = None
|
|
6207
|
+
input1_expr = None
|
|
6208
|
+
if op.kind == EinsumKind.REDUCE_ALL:
|
|
6209
|
+
input_loop_vars = CEmitter._loop_vars(input_shapes[0])
|
|
6210
|
+
input_loop_bounds = tuple(
|
|
6211
|
+
CEmitter._shape_dim_exprs(
|
|
6212
|
+
input_shapes[0], input_dim_names[0]
|
|
6213
|
+
)
|
|
6214
|
+
)
|
|
6215
|
+
if input_loop_vars:
|
|
6216
|
+
input_expr = f"{params['input0']}" + "".join(
|
|
6217
|
+
f"[{var}]" for var in input_loop_vars
|
|
6218
|
+
)
|
|
6219
|
+
else:
|
|
6220
|
+
input_expr = f"{params['input0']}[0]"
|
|
6221
|
+
elif op.kind == EinsumKind.SUM_J:
|
|
6222
|
+
input_shape_exprs = CEmitter._shape_dim_exprs(
|
|
6223
|
+
input_shapes[0], input_dim_names[0]
|
|
6224
|
+
)
|
|
6225
|
+
reduce_loop_bound = input_shape_exprs[1]
|
|
6226
|
+
input_expr = (
|
|
6227
|
+
f"{params['input0']}"
|
|
6228
|
+
f"[{output_loop_vars[0]}][{reduce_loop_var}]"
|
|
6229
|
+
)
|
|
6230
|
+
elif op.kind == EinsumKind.TRANSPOSE:
|
|
6231
|
+
input_expr = (
|
|
6232
|
+
f"{params['input0']}"
|
|
6233
|
+
f"[{output_loop_vars[1]}][{output_loop_vars[0]}]"
|
|
6234
|
+
)
|
|
6235
|
+
elif op.kind == EinsumKind.DOT:
|
|
6236
|
+
input_shape_exprs = CEmitter._shape_dim_exprs(
|
|
6237
|
+
input_shapes[0], input_dim_names[0]
|
|
6238
|
+
)
|
|
6239
|
+
reduce_loop_bound = input_shape_exprs[0]
|
|
6240
|
+
input0_expr = f"{params['input0']}[{reduce_loop_var}]"
|
|
6241
|
+
input1_expr = f"{params['input1']}[{reduce_loop_var}]"
|
|
6242
|
+
elif op.kind == EinsumKind.BATCH_MATMUL:
|
|
6243
|
+
input_shape_exprs = CEmitter._shape_dim_exprs(
|
|
6244
|
+
input_shapes[0], input_dim_names[0]
|
|
6245
|
+
)
|
|
6246
|
+
reduce_loop_bound = input_shape_exprs[2]
|
|
6247
|
+
input0_expr = (
|
|
6248
|
+
f"{params['input0']}"
|
|
6249
|
+
f"[{output_loop_vars[0]}]"
|
|
6250
|
+
f"[{output_loop_vars[1]}][{reduce_loop_var}]"
|
|
6251
|
+
)
|
|
6252
|
+
input1_expr = (
|
|
6253
|
+
f"{params['input1']}"
|
|
6254
|
+
f"[{output_loop_vars[0]}]"
|
|
6255
|
+
f"[{reduce_loop_var}][{output_loop_vars[2]}]"
|
|
6256
|
+
)
|
|
6257
|
+
elif op.kind == EinsumKind.BATCH_DIAGONAL:
|
|
6258
|
+
diag_var = output_loop_vars[-1]
|
|
6259
|
+
prefix_vars = output_loop_vars[:-1]
|
|
6260
|
+
input_expr = f"{params['input0']}" + "".join(
|
|
6261
|
+
f"[{var}]" for var in prefix_vars
|
|
6262
|
+
)
|
|
6263
|
+
input_expr += f"[{diag_var}][{diag_var}]"
|
|
6264
|
+
rendered = einsum_template.render(
|
|
6265
|
+
model_name=model.name,
|
|
6266
|
+
op_name=op_name,
|
|
6267
|
+
params=param_decls,
|
|
6268
|
+
dim_args=dim_args,
|
|
6269
|
+
kind=op.kind.value,
|
|
6270
|
+
output_loop_vars=output_loop_vars,
|
|
6271
|
+
output_loop_bounds=output_shape,
|
|
6272
|
+
output_expr=output_expr,
|
|
6273
|
+
acc_type=op.dtype.c_type,
|
|
6274
|
+
zero_literal=zero_literal,
|
|
6275
|
+
input_loop_vars=input_loop_vars,
|
|
6276
|
+
input_loop_bounds=input_loop_bounds,
|
|
6277
|
+
reduce_loop_var=reduce_loop_var,
|
|
6278
|
+
reduce_loop_bound=reduce_loop_bound,
|
|
6279
|
+
input_expr=input_expr,
|
|
6280
|
+
input0_expr=input0_expr,
|
|
6281
|
+
input1_expr=input1_expr,
|
|
6282
|
+
).rstrip()
|
|
6283
|
+
return with_node_comment(rendered)
|
|
6284
|
+
if isinstance(op, GemmOp):
|
|
6285
|
+
params = self._shared_param_map(
|
|
6286
|
+
[
|
|
6287
|
+
("input_a", op.input_a),
|
|
6288
|
+
("input_b", op.input_b),
|
|
6289
|
+
("input_c", op.input_c),
|
|
6290
|
+
("output", op.output),
|
|
6291
|
+
]
|
|
6292
|
+
)
|
|
6293
|
+
input_a_shape = (op.k, op.m) if op.trans_a else (op.m, op.k)
|
|
6294
|
+
input_b_shape = (op.n, op.k) if op.trans_b else (op.k, op.n)
|
|
6295
|
+
input_a_suffix = self._param_array_suffix(input_a_shape)
|
|
6296
|
+
input_b_suffix = self._param_array_suffix(input_b_shape)
|
|
6297
|
+
output_suffix = self._param_array_suffix((op.m, op.n))
|
|
6298
|
+
c_suffix = (
|
|
6299
|
+
self._param_array_suffix(op.c_shape)
|
|
6300
|
+
if op.c_shape is not None
|
|
6301
|
+
else ""
|
|
6302
|
+
)
|
|
6303
|
+
param_decls = self._build_param_decls(
|
|
6304
|
+
[
|
|
6305
|
+
(params["input_a"], c_type, input_a_suffix, True),
|
|
6306
|
+
(params["input_b"], c_type, input_b_suffix, True),
|
|
6307
|
+
(
|
|
6308
|
+
params["input_c"],
|
|
6309
|
+
c_type,
|
|
6310
|
+
c_suffix,
|
|
6311
|
+
True,
|
|
6312
|
+
)
|
|
6313
|
+
if params["input_c"]
|
|
6314
|
+
else (None, "", "", True),
|
|
6315
|
+
(params["output"], c_type, output_suffix, False),
|
|
6316
|
+
]
|
|
5213
6317
|
)
|
|
5214
6318
|
alpha_literal = CEmitter._format_literal(op.dtype, op.alpha)
|
|
5215
6319
|
beta_literal = CEmitter._format_literal(op.dtype, op.beta)
|
|
@@ -5556,6 +6660,81 @@ class CEmitter:
|
|
|
5556
6660
|
in_indices=in_indices,
|
|
5557
6661
|
).rstrip()
|
|
5558
6662
|
return with_node_comment(rendered)
|
|
6663
|
+
if isinstance(op, ConvTransposeOp):
|
|
6664
|
+
params = self._shared_param_map(
|
|
6665
|
+
[
|
|
6666
|
+
("input0", op.input0),
|
|
6667
|
+
("weights", op.weights),
|
|
6668
|
+
("bias", op.bias),
|
|
6669
|
+
("output", op.output),
|
|
6670
|
+
]
|
|
6671
|
+
)
|
|
6672
|
+
input_shape = (op.batch, op.in_channels, *op.in_spatial)
|
|
6673
|
+
weight_shape = (
|
|
6674
|
+
op.in_channels,
|
|
6675
|
+
op.out_channels // op.group,
|
|
6676
|
+
*op.kernel_shape,
|
|
6677
|
+
)
|
|
6678
|
+
output_shape = (op.batch, op.out_channels, *op.out_spatial)
|
|
6679
|
+
in_indices = tuple(f"id{dim}" for dim in range(op.spatial_rank))
|
|
6680
|
+
kernel_indices = tuple(
|
|
6681
|
+
f"kd{dim}" for dim in range(op.spatial_rank)
|
|
6682
|
+
)
|
|
6683
|
+
out_indices = tuple(f"od{dim}" for dim in range(op.spatial_rank))
|
|
6684
|
+
pad_begin = op.pads[: op.spatial_rank]
|
|
6685
|
+
group_in_channels = op.in_channels // op.group
|
|
6686
|
+
group_out_channels = op.out_channels // op.group
|
|
6687
|
+
input_suffix = self._param_array_suffix(input_shape)
|
|
6688
|
+
weight_suffix = self._param_array_suffix(weight_shape)
|
|
6689
|
+
bias_suffix = self._param_array_suffix((op.out_channels,))
|
|
6690
|
+
output_suffix = self._param_array_suffix(output_shape)
|
|
6691
|
+
param_decls = self._build_param_decls(
|
|
6692
|
+
[
|
|
6693
|
+
(params["input0"], c_type, input_suffix, True),
|
|
6694
|
+
(params["weights"], c_type, weight_suffix, True),
|
|
6695
|
+
(
|
|
6696
|
+
params["bias"],
|
|
6697
|
+
c_type,
|
|
6698
|
+
bias_suffix,
|
|
6699
|
+
True,
|
|
6700
|
+
)
|
|
6701
|
+
if params["bias"]
|
|
6702
|
+
else (None, "", "", True),
|
|
6703
|
+
(params["output"], c_type, output_suffix, False),
|
|
6704
|
+
]
|
|
6705
|
+
)
|
|
6706
|
+
rendered = conv_transpose_template.render(
|
|
6707
|
+
model_name=model.name,
|
|
6708
|
+
op_name=op_name,
|
|
6709
|
+
input0=params["input0"],
|
|
6710
|
+
weights=params["weights"],
|
|
6711
|
+
bias=params["bias"],
|
|
6712
|
+
output=params["output"],
|
|
6713
|
+
params=param_decls,
|
|
6714
|
+
c_type=c_type,
|
|
6715
|
+
zero_literal=zero_literal,
|
|
6716
|
+
input_suffix=input_suffix,
|
|
6717
|
+
weight_suffix=weight_suffix,
|
|
6718
|
+
bias_suffix=bias_suffix,
|
|
6719
|
+
output_suffix=output_suffix,
|
|
6720
|
+
batch=op.batch,
|
|
6721
|
+
in_channels=op.in_channels,
|
|
6722
|
+
out_channels=op.out_channels,
|
|
6723
|
+
spatial_rank=op.spatial_rank,
|
|
6724
|
+
in_spatial=op.in_spatial,
|
|
6725
|
+
out_spatial=op.out_spatial,
|
|
6726
|
+
kernel_shape=op.kernel_shape,
|
|
6727
|
+
strides=op.strides,
|
|
6728
|
+
pads_begin=pad_begin,
|
|
6729
|
+
dilations=op.dilations,
|
|
6730
|
+
group=op.group,
|
|
6731
|
+
group_in_channels=group_in_channels,
|
|
6732
|
+
group_out_channels=group_out_channels,
|
|
6733
|
+
in_indices=in_indices,
|
|
6734
|
+
kernel_indices=kernel_indices,
|
|
6735
|
+
out_indices=out_indices,
|
|
6736
|
+
).rstrip()
|
|
6737
|
+
return with_node_comment(rendered)
|
|
5559
6738
|
if isinstance(op, AveragePoolOp):
|
|
5560
6739
|
params = self._shared_param_map(
|
|
5561
6740
|
[("input0", op.input0), ("output", op.output)]
|
|
@@ -5597,6 +6776,49 @@ class CEmitter:
|
|
|
5597
6776
|
count_include_pad=int(op.count_include_pad),
|
|
5598
6777
|
).rstrip()
|
|
5599
6778
|
return with_node_comment(rendered)
|
|
6779
|
+
if isinstance(op, LpPoolOp):
|
|
6780
|
+
params = self._shared_param_map(
|
|
6781
|
+
[("input0", op.input0), ("output", op.output)]
|
|
6782
|
+
)
|
|
6783
|
+
input_shape = (op.batch, op.channels, op.in_h, op.in_w)
|
|
6784
|
+
output_shape = (op.batch, op.channels, op.out_h, op.out_w)
|
|
6785
|
+
input_suffix = self._param_array_suffix(input_shape)
|
|
6786
|
+
output_suffix = self._param_array_suffix(output_shape)
|
|
6787
|
+
param_decls = self._build_param_decls(
|
|
6788
|
+
[
|
|
6789
|
+
(params["input0"], c_type, input_suffix, True),
|
|
6790
|
+
(params["output"], c_type, output_suffix, False),
|
|
6791
|
+
]
|
|
6792
|
+
)
|
|
6793
|
+
rendered = lp_pool_template.render(
|
|
6794
|
+
model_name=model.name,
|
|
6795
|
+
op_name=op_name,
|
|
6796
|
+
input0=params["input0"],
|
|
6797
|
+
output=params["output"],
|
|
6798
|
+
params=param_decls,
|
|
6799
|
+
c_type=c_type,
|
|
6800
|
+
input_suffix=input_suffix,
|
|
6801
|
+
output_suffix=output_suffix,
|
|
6802
|
+
batch=op.batch,
|
|
6803
|
+
channels=op.channels,
|
|
6804
|
+
in_h=op.in_h,
|
|
6805
|
+
in_w=op.in_w,
|
|
6806
|
+
out_h=op.out_h,
|
|
6807
|
+
out_w=op.out_w,
|
|
6808
|
+
kernel_h=op.kernel_h,
|
|
6809
|
+
kernel_w=op.kernel_w,
|
|
6810
|
+
stride_h=op.stride_h,
|
|
6811
|
+
stride_w=op.stride_w,
|
|
6812
|
+
pad_top=op.pad_top,
|
|
6813
|
+
pad_left=op.pad_left,
|
|
6814
|
+
pad_bottom=op.pad_bottom,
|
|
6815
|
+
pad_right=op.pad_right,
|
|
6816
|
+
p=op.p,
|
|
6817
|
+
zero_literal=zero_literal,
|
|
6818
|
+
abs_fn=CEmitter._math_fn(op.dtype, "fabsf", "fabs"),
|
|
6819
|
+
pow_fn=CEmitter._math_fn(op.dtype, "powf", "pow"),
|
|
6820
|
+
).rstrip()
|
|
6821
|
+
return with_node_comment(rendered)
|
|
5600
6822
|
if isinstance(op, BatchNormOp):
|
|
5601
6823
|
params = self._shared_param_map(
|
|
5602
6824
|
[
|
|
@@ -5769,6 +6991,19 @@ class CEmitter:
|
|
|
5769
6991
|
).rstrip()
|
|
5770
6992
|
return with_node_comment(rendered)
|
|
5771
6993
|
if isinstance(op, LayerNormalizationOp):
|
|
6994
|
+
acc_dtype = (
|
|
6995
|
+
ScalarType.F32
|
|
6996
|
+
if op.dtype in {ScalarType.F16, ScalarType.F32}
|
|
6997
|
+
else op.dtype
|
|
6998
|
+
)
|
|
6999
|
+
acc_type = acc_dtype.c_type
|
|
7000
|
+
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
7001
|
+
acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
|
|
7002
|
+
acc_epsilon_literal = CEmitter._format_floating(
|
|
7003
|
+
op.epsilon, acc_dtype
|
|
7004
|
+
)
|
|
7005
|
+
acc_sqrt_fn = CEmitter._math_fn(acc_dtype, "sqrtf", "sqrt")
|
|
7006
|
+
use_kahan = op.dtype in {ScalarType.F16, ScalarType.F32}
|
|
5772
7007
|
params = self._shared_param_map(
|
|
5773
7008
|
[
|
|
5774
7009
|
("input0", op.input0),
|
|
@@ -5878,8 +7113,12 @@ class CEmitter:
|
|
|
5878
7113
|
bias_index_vars=bias_index_vars,
|
|
5879
7114
|
mean_index_vars=mean_index_vars,
|
|
5880
7115
|
inner=op.inner,
|
|
5881
|
-
|
|
5882
|
-
|
|
7116
|
+
acc_type=acc_type,
|
|
7117
|
+
acc_zero_literal=acc_zero_literal,
|
|
7118
|
+
acc_one_literal=acc_one_literal,
|
|
7119
|
+
acc_epsilon_literal=acc_epsilon_literal,
|
|
7120
|
+
acc_sqrt_fn=acc_sqrt_fn,
|
|
7121
|
+
use_kahan=use_kahan,
|
|
5883
7122
|
).rstrip()
|
|
5884
7123
|
return with_node_comment(rendered)
|
|
5885
7124
|
if isinstance(op, MeanVarianceNormalizationOp):
|
|
@@ -6244,7 +7483,41 @@ class CEmitter:
|
|
|
6244
7483
|
log_fn=CEmitter._math_fn(op.dtype, "logf", "log"),
|
|
6245
7484
|
).rstrip()
|
|
6246
7485
|
return with_node_comment(rendered)
|
|
7486
|
+
if isinstance(op, HardmaxOp):
|
|
7487
|
+
params = self._shared_param_map(
|
|
7488
|
+
[("input0", op.input0), ("output", op.output)]
|
|
7489
|
+
)
|
|
7490
|
+
array_suffix = self._param_array_suffix(op.shape)
|
|
7491
|
+
param_decls = self._build_param_decls(
|
|
7492
|
+
[
|
|
7493
|
+
(params["input0"], c_type, array_suffix, True),
|
|
7494
|
+
(params["output"], c_type, array_suffix, False),
|
|
7495
|
+
]
|
|
7496
|
+
)
|
|
7497
|
+
rendered = hardmax_template.render(
|
|
7498
|
+
model_name=model.name,
|
|
7499
|
+
op_name=op_name,
|
|
7500
|
+
input0=params["input0"],
|
|
7501
|
+
output=params["output"],
|
|
7502
|
+
params=param_decls,
|
|
7503
|
+
c_type=c_type,
|
|
7504
|
+
array_suffix=array_suffix,
|
|
7505
|
+
outer=op.outer,
|
|
7506
|
+
axis_size=op.axis_size,
|
|
7507
|
+
inner=op.inner,
|
|
7508
|
+
zero_literal=zero_literal,
|
|
7509
|
+
one_literal=CEmitter._format_literal(op.dtype, 1),
|
|
7510
|
+
).rstrip()
|
|
7511
|
+
return with_node_comment(rendered)
|
|
6247
7512
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
7513
|
+
acc_dtype = (
|
|
7514
|
+
ScalarType.F64
|
|
7515
|
+
if op.dtype in {ScalarType.F16, ScalarType.F32}
|
|
7516
|
+
else op.dtype
|
|
7517
|
+
)
|
|
7518
|
+
acc_type = acc_dtype.c_type
|
|
7519
|
+
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
7520
|
+
acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
|
|
6248
7521
|
params = self._shared_param_map(
|
|
6249
7522
|
[
|
|
6250
7523
|
("input0", op.input0),
|
|
@@ -6292,9 +7565,22 @@ class CEmitter:
|
|
|
6292
7565
|
ignore_index=op.ignore_index,
|
|
6293
7566
|
zero_literal=zero_literal,
|
|
6294
7567
|
one_literal=CEmitter._format_literal(op.dtype, 1),
|
|
7568
|
+
acc_type=acc_type,
|
|
7569
|
+
acc_zero_literal=acc_zero_literal,
|
|
7570
|
+
acc_one_literal=acc_one_literal,
|
|
6295
7571
|
).rstrip()
|
|
6296
7572
|
return with_node_comment(rendered)
|
|
6297
7573
|
if isinstance(op, SoftmaxCrossEntropyLossOp):
|
|
7574
|
+
acc_dtype = (
|
|
7575
|
+
ScalarType.F64
|
|
7576
|
+
if op.dtype in {ScalarType.F16, ScalarType.F32}
|
|
7577
|
+
else op.dtype
|
|
7578
|
+
)
|
|
7579
|
+
acc_type = acc_dtype.c_type
|
|
7580
|
+
acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
|
|
7581
|
+
acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
|
|
7582
|
+
acc_exp_fn = CEmitter._math_fn(acc_dtype, "expf", "exp")
|
|
7583
|
+
acc_log_fn = CEmitter._math_fn(acc_dtype, "logf", "log")
|
|
6298
7584
|
params = self._shared_param_map(
|
|
6299
7585
|
[
|
|
6300
7586
|
("input0", op.input0),
|
|
@@ -6361,8 +7647,11 @@ class CEmitter:
|
|
|
6361
7647
|
ignore_index=ignore_index,
|
|
6362
7648
|
zero_literal=zero_literal,
|
|
6363
7649
|
one_literal=CEmitter._format_literal(op.dtype, 1),
|
|
6364
|
-
|
|
6365
|
-
|
|
7650
|
+
acc_type=acc_type,
|
|
7651
|
+
acc_zero_literal=acc_zero_literal,
|
|
7652
|
+
acc_one_literal=acc_one_literal,
|
|
7653
|
+
acc_exp_fn=acc_exp_fn,
|
|
7654
|
+
acc_log_fn=acc_log_fn,
|
|
6366
7655
|
).rstrip()
|
|
6367
7656
|
return with_node_comment(rendered)
|
|
6368
7657
|
if isinstance(op, MaxPoolOp):
|
|
@@ -6569,6 +7858,180 @@ class CEmitter:
|
|
|
6569
7858
|
axis_dim=op.data_shape[op.axis],
|
|
6570
7859
|
).rstrip()
|
|
6571
7860
|
return with_node_comment(rendered)
|
|
7861
|
+
if isinstance(op, GatherNDOp):
|
|
7862
|
+
params = self._shared_param_map(
|
|
7863
|
+
[
|
|
7864
|
+
("data", op.data),
|
|
7865
|
+
("indices", op.indices),
|
|
7866
|
+
("output", op.output),
|
|
7867
|
+
]
|
|
7868
|
+
)
|
|
7869
|
+
indices_dim_names = _dim_names_for(op.indices)
|
|
7870
|
+
data_dim_names = _dim_names_for(op.data)
|
|
7871
|
+
data_shape = CEmitter._shape_dim_exprs(op.data_shape, data_dim_names)
|
|
7872
|
+
indices_shape = CEmitter._shape_dim_exprs(
|
|
7873
|
+
op.indices_shape, indices_dim_names
|
|
7874
|
+
)
|
|
7875
|
+
indices_prefix_shape = indices_shape[:-1]
|
|
7876
|
+
indices_prefix_loop_vars = (
|
|
7877
|
+
CEmitter._loop_vars(op.indices_shape[:-1])
|
|
7878
|
+
if op.indices_shape[:-1]
|
|
7879
|
+
else ()
|
|
7880
|
+
)
|
|
7881
|
+
index_depth = op.indices_shape[-1]
|
|
7882
|
+
tail_shape = data_shape[op.batch_dims + index_depth :]
|
|
7883
|
+
tail_loop_vars = (
|
|
7884
|
+
tuple(f"t{index}" for index in range(len(tail_shape)))
|
|
7885
|
+
if tail_shape
|
|
7886
|
+
else ()
|
|
7887
|
+
)
|
|
7888
|
+
output_loop_vars = (*indices_prefix_loop_vars, *tail_loop_vars)
|
|
7889
|
+
if output_loop_vars:
|
|
7890
|
+
output_index_expr = params["output"] + "".join(
|
|
7891
|
+
f"[{var}]" for var in output_loop_vars
|
|
7892
|
+
)
|
|
7893
|
+
else:
|
|
7894
|
+
output_index_expr = f"{params['output']}[0]"
|
|
7895
|
+
data_index_vars = (
|
|
7896
|
+
*indices_prefix_loop_vars[: op.batch_dims],
|
|
7897
|
+
*tuple(f"index{idx}" for idx in range(index_depth)),
|
|
7898
|
+
*tail_loop_vars,
|
|
7899
|
+
)
|
|
7900
|
+
data_index_expr = params["data"] + "".join(
|
|
7901
|
+
f"[{var}]" for var in data_index_vars
|
|
7902
|
+
)
|
|
7903
|
+
data_suffix = self._param_array_suffix(op.data_shape)
|
|
7904
|
+
indices_suffix = self._param_array_suffix(op.indices_shape)
|
|
7905
|
+
output_suffix = self._param_array_suffix(op.output_shape)
|
|
7906
|
+
param_decls = self._build_param_decls(
|
|
7907
|
+
[
|
|
7908
|
+
(params["data"], c_type, data_suffix, True),
|
|
7909
|
+
(
|
|
7910
|
+
params["indices"],
|
|
7911
|
+
op.indices_dtype.c_type,
|
|
7912
|
+
indices_suffix,
|
|
7913
|
+
True,
|
|
7914
|
+
),
|
|
7915
|
+
(params["output"], c_type, output_suffix, False),
|
|
7916
|
+
]
|
|
7917
|
+
)
|
|
7918
|
+
rendered = gather_nd_template.render(
|
|
7919
|
+
model_name=model.name,
|
|
7920
|
+
op_name=op_name,
|
|
7921
|
+
data=params["data"],
|
|
7922
|
+
indices=params["indices"],
|
|
7923
|
+
output=params["output"],
|
|
7924
|
+
params=param_decls,
|
|
7925
|
+
c_type=c_type,
|
|
7926
|
+
data_suffix=data_suffix,
|
|
7927
|
+
indices_suffix=indices_suffix,
|
|
7928
|
+
output_suffix=output_suffix,
|
|
7929
|
+
indices_prefix_shape=indices_prefix_shape,
|
|
7930
|
+
indices_prefix_loop_vars=indices_prefix_loop_vars,
|
|
7931
|
+
index_depth=index_depth,
|
|
7932
|
+
tail_shape=tail_shape,
|
|
7933
|
+
tail_loop_vars=tail_loop_vars,
|
|
7934
|
+
output_index_expr=output_index_expr,
|
|
7935
|
+
data_index_expr=data_index_expr,
|
|
7936
|
+
batch_dims=op.batch_dims,
|
|
7937
|
+
data_shape=data_shape,
|
|
7938
|
+
).rstrip()
|
|
7939
|
+
return with_node_comment(rendered)
|
|
7940
|
+
if isinstance(op, ScatterNDOp):
|
|
7941
|
+
params = self._shared_param_map(
|
|
7942
|
+
[
|
|
7943
|
+
("data", op.data),
|
|
7944
|
+
("indices", op.indices),
|
|
7945
|
+
("updates", op.updates),
|
|
7946
|
+
("output", op.output),
|
|
7947
|
+
]
|
|
7948
|
+
)
|
|
7949
|
+
output_dim_names = _dim_names_for(op.output)
|
|
7950
|
+
indices_dim_names = _dim_names_for(op.indices)
|
|
7951
|
+
updates_dim_names = _dim_names_for(op.updates)
|
|
7952
|
+
data_dim_names = _dim_names_for(op.data)
|
|
7953
|
+
output_shape = CEmitter._shape_dim_exprs(
|
|
7954
|
+
op.output_shape, output_dim_names
|
|
7955
|
+
)
|
|
7956
|
+
data_shape = CEmitter._shape_dim_exprs(op.data_shape, data_dim_names)
|
|
7957
|
+
indices_shape = CEmitter._shape_dim_exprs(
|
|
7958
|
+
op.indices_shape, indices_dim_names
|
|
7959
|
+
)
|
|
7960
|
+
output_loop_vars = CEmitter._loop_vars(op.output_shape)
|
|
7961
|
+
indices_prefix_shape = indices_shape[:-1]
|
|
7962
|
+
indices_prefix_loop_vars = (
|
|
7963
|
+
CEmitter._loop_vars(op.indices_shape[:-1])
|
|
7964
|
+
if op.indices_shape[:-1]
|
|
7965
|
+
else ()
|
|
7966
|
+
)
|
|
7967
|
+
index_depth = op.indices_shape[-1]
|
|
7968
|
+
tail_shape = output_shape[index_depth:]
|
|
7969
|
+
tail_loop_vars = (
|
|
7970
|
+
tuple(
|
|
7971
|
+
f"t{index}"
|
|
7972
|
+
for index in range(len(op.output_shape[index_depth:]))
|
|
7973
|
+
)
|
|
7974
|
+
if op.output_shape[index_depth:]
|
|
7975
|
+
else ()
|
|
7976
|
+
)
|
|
7977
|
+
index_vars = tuple(f"index{idx}" for idx in range(index_depth))
|
|
7978
|
+
output_index_expr = f"{params['output']}" + "".join(
|
|
7979
|
+
f"[{var}]" for var in (*index_vars, *tail_loop_vars)
|
|
7980
|
+
)
|
|
7981
|
+
updates_index_vars = (*indices_prefix_loop_vars, *tail_loop_vars)
|
|
7982
|
+
if not op.updates_shape:
|
|
7983
|
+
updates_index_expr = f"{params['updates']}[0]"
|
|
7984
|
+
else:
|
|
7985
|
+
updates_index_expr = f"{params['updates']}" + "".join(
|
|
7986
|
+
f"[{var}]" for var in updates_index_vars
|
|
7987
|
+
)
|
|
7988
|
+
data_suffix = self._param_array_suffix(
|
|
7989
|
+
op.data_shape, data_dim_names
|
|
7990
|
+
)
|
|
7991
|
+
indices_suffix = self._param_array_suffix(
|
|
7992
|
+
op.indices_shape, indices_dim_names
|
|
7993
|
+
)
|
|
7994
|
+
updates_suffix = self._param_array_suffix(
|
|
7995
|
+
op.updates_shape, updates_dim_names
|
|
7996
|
+
)
|
|
7997
|
+
output_suffix = self._param_array_suffix(
|
|
7998
|
+
op.output_shape, output_dim_names
|
|
7999
|
+
)
|
|
8000
|
+
param_decls = self._build_param_decls(
|
|
8001
|
+
[
|
|
8002
|
+
(params["data"], c_type, data_suffix, True),
|
|
8003
|
+
(
|
|
8004
|
+
params["indices"],
|
|
8005
|
+
op.indices_dtype.c_type,
|
|
8006
|
+
indices_suffix,
|
|
8007
|
+
True,
|
|
8008
|
+
),
|
|
8009
|
+
(params["updates"], c_type, updates_suffix, True),
|
|
8010
|
+
(params["output"], c_type, output_suffix, False),
|
|
8011
|
+
]
|
|
8012
|
+
)
|
|
8013
|
+
rendered = scatter_nd_template.render(
|
|
8014
|
+
model_name=model.name,
|
|
8015
|
+
op_name=op_name,
|
|
8016
|
+
data=params["data"],
|
|
8017
|
+
indices=params["indices"],
|
|
8018
|
+
updates=params["updates"],
|
|
8019
|
+
output=params["output"],
|
|
8020
|
+
params=param_decls,
|
|
8021
|
+
c_type=c_type,
|
|
8022
|
+
output_shape=output_shape,
|
|
8023
|
+
output_loop_vars=output_loop_vars,
|
|
8024
|
+
indices_prefix_shape=indices_prefix_shape,
|
|
8025
|
+
indices_prefix_loop_vars=indices_prefix_loop_vars,
|
|
8026
|
+
index_depth=index_depth,
|
|
8027
|
+
data_shape=data_shape,
|
|
8028
|
+
tail_shape=tail_shape,
|
|
8029
|
+
tail_loop_vars=tail_loop_vars,
|
|
8030
|
+
output_index_expr=output_index_expr,
|
|
8031
|
+
updates_index_expr=updates_index_expr,
|
|
8032
|
+
reduction=op.reduction,
|
|
8033
|
+
).rstrip()
|
|
8034
|
+
return with_node_comment(rendered)
|
|
6572
8035
|
if isinstance(op, TransposeOp):
|
|
6573
8036
|
params = self._shared_param_map(
|
|
6574
8037
|
[("input0", op.input0), ("output", op.output)]
|
|
@@ -6608,6 +8071,7 @@ class CEmitter:
|
|
|
6608
8071
|
[("input0", op.input0), ("output", op.output)]
|
|
6609
8072
|
)
|
|
6610
8073
|
input_suffix = self._param_array_suffix(op.input_shape)
|
|
8074
|
+
output_shape = CEmitter._codegen_shape(op.output_shape)
|
|
6611
8075
|
output_suffix = self._param_array_suffix(op.output_shape)
|
|
6612
8076
|
param_decls = self._build_param_decls(
|
|
6613
8077
|
[
|
|
@@ -6615,6 +8079,7 @@ class CEmitter:
|
|
|
6615
8079
|
(params["output"], c_type, output_suffix, False),
|
|
6616
8080
|
]
|
|
6617
8081
|
)
|
|
8082
|
+
loop_vars = CEmitter._loop_vars(op.output_shape)
|
|
6618
8083
|
rendered = reshape_template.render(
|
|
6619
8084
|
model_name=model.name,
|
|
6620
8085
|
op_name=op_name,
|
|
@@ -6625,6 +8090,8 @@ class CEmitter:
|
|
|
6625
8090
|
input_suffix=input_suffix,
|
|
6626
8091
|
output_suffix=output_suffix,
|
|
6627
8092
|
element_count=CEmitter._element_count(op.output_shape),
|
|
8093
|
+
output_shape=output_shape,
|
|
8094
|
+
loop_vars=loop_vars,
|
|
6628
8095
|
).rstrip()
|
|
6629
8096
|
return with_node_comment(rendered)
|
|
6630
8097
|
if isinstance(op, IdentityOp):
|
|
@@ -6691,6 +8158,61 @@ class CEmitter:
|
|
|
6691
8158
|
one_literal=f"(({c_type})1)",
|
|
6692
8159
|
).rstrip()
|
|
6693
8160
|
return with_node_comment(rendered)
|
|
8161
|
+
if isinstance(op, TriluOp):
|
|
8162
|
+
param_specs = [("input0", op.input0), ("output", op.output)]
|
|
8163
|
+
if op.k_input is not None:
|
|
8164
|
+
param_specs.append(("k_input", op.k_input))
|
|
8165
|
+
params = self._shared_param_map(param_specs)
|
|
8166
|
+
output_dim_names = _dim_names_for(op.output)
|
|
8167
|
+
shape = CEmitter._shape_dim_exprs(op.output_shape, output_dim_names)
|
|
8168
|
+
output_suffix = self._param_array_suffix(op.output_shape, output_dim_names)
|
|
8169
|
+
input_suffix = self._param_array_suffix(
|
|
8170
|
+
op.input_shape, _dim_names_for(op.input0)
|
|
8171
|
+
)
|
|
8172
|
+
k_suffix = ""
|
|
8173
|
+
if op.k_input is not None and op.k_input_shape is not None:
|
|
8174
|
+
k_suffix = self._param_array_suffix(
|
|
8175
|
+
op.k_input_shape, _dim_names_for(op.k_input)
|
|
8176
|
+
)
|
|
8177
|
+
batch_dims = op.output_shape[:-2]
|
|
8178
|
+
batch_size = CEmitter._element_count(batch_dims or (1,))
|
|
8179
|
+
param_decls = [
|
|
8180
|
+
(params["input0"], c_type, input_suffix, True),
|
|
8181
|
+
(params["output"], c_type, output_suffix, False),
|
|
8182
|
+
]
|
|
8183
|
+
if op.k_input is not None and op.k_input_dtype is not None:
|
|
8184
|
+
param_decls.append(
|
|
8185
|
+
(
|
|
8186
|
+
params["k_input"],
|
|
8187
|
+
op.k_input_dtype.c_type,
|
|
8188
|
+
k_suffix,
|
|
8189
|
+
True,
|
|
8190
|
+
)
|
|
8191
|
+
)
|
|
8192
|
+
rendered = trilu_template.render(
|
|
8193
|
+
model_name=model.name,
|
|
8194
|
+
op_name=op_name,
|
|
8195
|
+
input0=params["input0"],
|
|
8196
|
+
output=params["output"],
|
|
8197
|
+
k_input=params.get("k_input"),
|
|
8198
|
+
params=self._build_param_decls(param_decls),
|
|
8199
|
+
c_type=c_type,
|
|
8200
|
+
k_c_type=(
|
|
8201
|
+
op.k_input_dtype.c_type
|
|
8202
|
+
if op.k_input_dtype is not None
|
|
8203
|
+
else ScalarType.I64.c_type
|
|
8204
|
+
),
|
|
8205
|
+
input_suffix=input_suffix,
|
|
8206
|
+
output_suffix=output_suffix,
|
|
8207
|
+
shape=shape,
|
|
8208
|
+
batch_size=batch_size,
|
|
8209
|
+
rows=op.output_shape[-2],
|
|
8210
|
+
cols=op.output_shape[-1],
|
|
8211
|
+
k_value=op.k_value,
|
|
8212
|
+
upper=op.upper,
|
|
8213
|
+
zero_literal=zero_literal,
|
|
8214
|
+
).rstrip()
|
|
8215
|
+
return with_node_comment(rendered)
|
|
6694
8216
|
if isinstance(op, TileOp):
|
|
6695
8217
|
params = self._shared_param_map(
|
|
6696
8218
|
[("input0", op.input0), ("output", op.output)]
|
|
@@ -7224,17 +8746,19 @@ class CEmitter:
|
|
|
7224
8746
|
update_expr = None
|
|
7225
8747
|
init_literal = None
|
|
7226
8748
|
final_expr = "acc"
|
|
8749
|
+
use_kahan = False
|
|
8750
|
+
kahan_value_expr = None
|
|
7227
8751
|
fabs_fn = CEmitter._math_fn(op.dtype, "fabsf", "fabs")
|
|
7228
8752
|
exp_fn = CEmitter._math_fn(op.dtype, "expf", "exp")
|
|
7229
8753
|
log_fn = CEmitter._math_fn(op.dtype, "logf", "log")
|
|
7230
8754
|
sqrt_fn = CEmitter._math_fn(op.dtype, "sqrtf", "sqrt")
|
|
7231
|
-
count_literal = CEmitter._format_literal(
|
|
7232
|
-
op.dtype, op.reduce_count
|
|
7233
|
-
)
|
|
7234
8755
|
if op.reduce_kind == "sum":
|
|
7235
8756
|
init_literal = zero_literal
|
|
7236
8757
|
update_expr = f"acc += {value_expr};"
|
|
7237
8758
|
elif op.reduce_kind == "mean":
|
|
8759
|
+
count_literal = CEmitter._format_literal(
|
|
8760
|
+
op.dtype, op.reduce_count
|
|
8761
|
+
)
|
|
7238
8762
|
init_literal = zero_literal
|
|
7239
8763
|
update_expr = f"acc += {value_expr};"
|
|
7240
8764
|
final_expr = f"acc / {count_literal}"
|
|
@@ -7269,6 +8793,24 @@ class CEmitter:
|
|
|
7269
8793
|
raise CodegenError(
|
|
7270
8794
|
f"Unsupported reduce kind {op.reduce_kind}"
|
|
7271
8795
|
)
|
|
8796
|
+
if op.dtype in {ScalarType.F16, ScalarType.F32} and op.reduce_kind in {
|
|
8797
|
+
"sum",
|
|
8798
|
+
"mean",
|
|
8799
|
+
"logsum",
|
|
8800
|
+
"logsumexp",
|
|
8801
|
+
"l1",
|
|
8802
|
+
"l2",
|
|
8803
|
+
"sumsquare",
|
|
8804
|
+
}:
|
|
8805
|
+
use_kahan = True
|
|
8806
|
+
if op.reduce_kind == "logsumexp":
|
|
8807
|
+
kahan_value_expr = f"{exp_fn}({value_expr})"
|
|
8808
|
+
elif op.reduce_kind == "l1":
|
|
8809
|
+
kahan_value_expr = f"{fabs_fn}({value_expr})"
|
|
8810
|
+
elif op.reduce_kind in {"l2", "sumsquare"}:
|
|
8811
|
+
kahan_value_expr = f"{value_expr} * {value_expr}"
|
|
8812
|
+
else:
|
|
8813
|
+
kahan_value_expr = value_expr
|
|
7272
8814
|
input_suffix = self._param_array_suffix(op.input_shape)
|
|
7273
8815
|
output_suffix = self._param_array_suffix(op.output_shape)
|
|
7274
8816
|
param_decls = self._build_param_decls(
|
|
@@ -7292,8 +8834,11 @@ class CEmitter:
|
|
|
7292
8834
|
reduce_dims=reduce_dims,
|
|
7293
8835
|
output_index_expr=output_index_expr,
|
|
7294
8836
|
init_literal=init_literal,
|
|
8837
|
+
zero_literal=zero_literal,
|
|
7295
8838
|
update_expr=update_expr,
|
|
7296
8839
|
final_expr=final_expr,
|
|
8840
|
+
use_kahan=use_kahan,
|
|
8841
|
+
kahan_value_expr=kahan_value_expr,
|
|
7297
8842
|
).rstrip()
|
|
7298
8843
|
return with_node_comment(rendered)
|
|
7299
8844
|
if isinstance(op, ArgReduceOp):
|
|
@@ -7367,6 +8912,83 @@ class CEmitter:
|
|
|
7367
8912
|
dim_args=dim_args,
|
|
7368
8913
|
).rstrip()
|
|
7369
8914
|
return with_node_comment(rendered)
|
|
8915
|
+
if isinstance(op, TopKOp):
|
|
8916
|
+
params = self._shared_param_map(
|
|
8917
|
+
[
|
|
8918
|
+
("input0", op.input0),
|
|
8919
|
+
("output_values", op.output_values),
|
|
8920
|
+
("output_indices", op.output_indices),
|
|
8921
|
+
]
|
|
8922
|
+
)
|
|
8923
|
+
output_shape = CEmitter._codegen_shape(op.output_shape)
|
|
8924
|
+
outer_shape = tuple(
|
|
8925
|
+
dim for axis, dim in enumerate(output_shape) if axis != op.axis
|
|
8926
|
+
)
|
|
8927
|
+
outer_loop_vars = CEmitter._loop_vars(outer_shape)
|
|
8928
|
+
reduce_var = "r0"
|
|
8929
|
+
k_var = "k0"
|
|
8930
|
+
input_indices: list[str] = []
|
|
8931
|
+
output_indices: list[str] = []
|
|
8932
|
+
outer_index = 0
|
|
8933
|
+
for axis in range(len(op.input_shape)):
|
|
8934
|
+
if axis == op.axis:
|
|
8935
|
+
input_indices.append(reduce_var)
|
|
8936
|
+
output_indices.append(k_var)
|
|
8937
|
+
else:
|
|
8938
|
+
input_indices.append(outer_loop_vars[outer_index])
|
|
8939
|
+
output_indices.append(outer_loop_vars[outer_index])
|
|
8940
|
+
outer_index += 1
|
|
8941
|
+
input_index_expr = "".join(f"[{var}]" for var in input_indices)
|
|
8942
|
+
output_index_expr = "".join(f"[{var}]" for var in output_indices)
|
|
8943
|
+
compare_expr = (
|
|
8944
|
+
"(a > b) || ((a == b) && (ai < bi))"
|
|
8945
|
+
if op.largest
|
|
8946
|
+
else "(a < b) || ((a == b) && (ai < bi))"
|
|
8947
|
+
)
|
|
8948
|
+
input_suffix = self._param_array_suffix(op.input_shape)
|
|
8949
|
+
output_suffix = self._param_array_suffix(op.output_shape)
|
|
8950
|
+
param_decls = self._build_param_decls(
|
|
8951
|
+
[
|
|
8952
|
+
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
8953
|
+
(
|
|
8954
|
+
params["output_values"],
|
|
8955
|
+
op.output_values_dtype.c_type,
|
|
8956
|
+
output_suffix,
|
|
8957
|
+
False,
|
|
8958
|
+
),
|
|
8959
|
+
(
|
|
8960
|
+
params["output_indices"],
|
|
8961
|
+
op.output_indices_dtype.c_type,
|
|
8962
|
+
output_suffix,
|
|
8963
|
+
False,
|
|
8964
|
+
),
|
|
8965
|
+
]
|
|
8966
|
+
)
|
|
8967
|
+
rendered = topk_template.render(
|
|
8968
|
+
model_name=model.name,
|
|
8969
|
+
op_name=op_name,
|
|
8970
|
+
input0=params["input0"],
|
|
8971
|
+
output_values=params["output_values"],
|
|
8972
|
+
output_indices=params["output_indices"],
|
|
8973
|
+
params=param_decls,
|
|
8974
|
+
input_c_type=op.input_dtype.c_type,
|
|
8975
|
+
output_values_c_type=op.output_values_dtype.c_type,
|
|
8976
|
+
output_indices_c_type=op.output_indices_dtype.c_type,
|
|
8977
|
+
input_suffix=input_suffix,
|
|
8978
|
+
output_suffix=output_suffix,
|
|
8979
|
+
output_shape=output_shape,
|
|
8980
|
+
outer_shape=outer_shape,
|
|
8981
|
+
outer_loop_vars=outer_loop_vars,
|
|
8982
|
+
reduce_var=reduce_var,
|
|
8983
|
+
k_var=k_var,
|
|
8984
|
+
axis_dim=op.input_shape[op.axis],
|
|
8985
|
+
k=op.k,
|
|
8986
|
+
input_index_expr=input_index_expr,
|
|
8987
|
+
output_index_expr=output_index_expr,
|
|
8988
|
+
compare_expr=compare_expr,
|
|
8989
|
+
dim_args=dim_args,
|
|
8990
|
+
).rstrip()
|
|
8991
|
+
return with_node_comment(rendered)
|
|
7370
8992
|
if isinstance(op, ReduceOp):
|
|
7371
8993
|
name_params = self._shared_param_map(
|
|
7372
8994
|
[
|
|
@@ -7545,35 +9167,76 @@ class CEmitter:
|
|
|
7545
9167
|
c_type=c_type,
|
|
7546
9168
|
input_suffix=input_suffix,
|
|
7547
9169
|
output_suffix=output_suffix,
|
|
7548
|
-
values=[
|
|
7549
|
-
CEmitter._format_literal(op.dtype, value)
|
|
7550
|
-
for value in op.values
|
|
7551
|
-
],
|
|
9170
|
+
values=[
|
|
9171
|
+
CEmitter._format_literal(op.dtype, value)
|
|
9172
|
+
for value in op.values
|
|
9173
|
+
],
|
|
9174
|
+
).rstrip()
|
|
9175
|
+
return with_node_comment(rendered)
|
|
9176
|
+
if isinstance(op, SizeOp):
|
|
9177
|
+
params = self._shared_param_map(
|
|
9178
|
+
[("input0", op.input0), ("output", op.output)]
|
|
9179
|
+
)
|
|
9180
|
+
input_suffix = self._param_array_suffix(op.input_shape)
|
|
9181
|
+
output_suffix = self._param_array_suffix(op.output_shape)
|
|
9182
|
+
param_decls = self._build_param_decls(
|
|
9183
|
+
[
|
|
9184
|
+
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
9185
|
+
(params["output"], c_type, output_suffix, False),
|
|
9186
|
+
]
|
|
9187
|
+
)
|
|
9188
|
+
rendered = size_template.render(
|
|
9189
|
+
model_name=model.name,
|
|
9190
|
+
op_name=op_name,
|
|
9191
|
+
input0=params["input0"],
|
|
9192
|
+
output=params["output"],
|
|
9193
|
+
params=param_decls,
|
|
9194
|
+
input_c_type=op.input_dtype.c_type,
|
|
9195
|
+
c_type=c_type,
|
|
9196
|
+
input_suffix=input_suffix,
|
|
9197
|
+
output_suffix=output_suffix,
|
|
9198
|
+
value=CEmitter._format_literal(op.dtype, op.value),
|
|
7552
9199
|
).rstrip()
|
|
7553
9200
|
return with_node_comment(rendered)
|
|
7554
|
-
if isinstance(op,
|
|
9201
|
+
if isinstance(op, NonZeroOp):
|
|
7555
9202
|
params = self._shared_param_map(
|
|
7556
9203
|
[("input0", op.input0), ("output", op.output)]
|
|
7557
9204
|
)
|
|
7558
|
-
|
|
7559
|
-
|
|
9205
|
+
input_dim_names = _dim_names_for(op.input0)
|
|
9206
|
+
output_dim_names = _dim_names_for(op.output)
|
|
9207
|
+
input_shape = CEmitter._shape_dim_exprs(
|
|
9208
|
+
op.input_shape, input_dim_names
|
|
9209
|
+
)
|
|
9210
|
+
loop_vars = CEmitter._loop_vars(op.input_shape)
|
|
9211
|
+
input_suffix = self._param_array_suffix(
|
|
9212
|
+
op.input_shape, input_dim_names
|
|
9213
|
+
)
|
|
9214
|
+
output_suffix = self._param_array_suffix(
|
|
9215
|
+
op.output_shape, output_dim_names
|
|
9216
|
+
)
|
|
7560
9217
|
param_decls = self._build_param_decls(
|
|
7561
9218
|
[
|
|
7562
9219
|
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
7563
9220
|
(params["output"], c_type, output_suffix, False),
|
|
7564
9221
|
]
|
|
7565
9222
|
)
|
|
7566
|
-
|
|
9223
|
+
input_expr = f"{params['input0']}" + "".join(
|
|
9224
|
+
f"[{var}]" for var in loop_vars
|
|
9225
|
+
)
|
|
9226
|
+
rendered = nonzero_template.render(
|
|
7567
9227
|
model_name=model.name,
|
|
7568
9228
|
op_name=op_name,
|
|
7569
9229
|
input0=params["input0"],
|
|
7570
9230
|
output=params["output"],
|
|
7571
9231
|
params=param_decls,
|
|
7572
9232
|
input_c_type=op.input_dtype.c_type,
|
|
7573
|
-
|
|
9233
|
+
output_c_type=c_type,
|
|
7574
9234
|
input_suffix=input_suffix,
|
|
7575
9235
|
output_suffix=output_suffix,
|
|
7576
|
-
|
|
9236
|
+
input_shape=input_shape,
|
|
9237
|
+
loop_vars=loop_vars,
|
|
9238
|
+
input_expr=input_expr,
|
|
9239
|
+
zero_literal=op.input_dtype.zero_literal,
|
|
7577
9240
|
).rstrip()
|
|
7578
9241
|
return with_node_comment(rendered)
|
|
7579
9242
|
if isinstance(op, ExpandOp):
|
|
@@ -7692,6 +9355,74 @@ class CEmitter:
|
|
|
7692
9355
|
length=op.length,
|
|
7693
9356
|
).rstrip()
|
|
7694
9357
|
return with_node_comment(rendered)
|
|
9358
|
+
if isinstance(op, OneHotOp):
|
|
9359
|
+
params = self._shared_param_map(
|
|
9360
|
+
[
|
|
9361
|
+
("indices", op.indices),
|
|
9362
|
+
("depth", op.depth),
|
|
9363
|
+
("values", op.values),
|
|
9364
|
+
("output", op.output),
|
|
9365
|
+
]
|
|
9366
|
+
)
|
|
9367
|
+
output_dim_names = _dim_names_for(op.output)
|
|
9368
|
+
indices_dim_names = _dim_names_for(op.indices)
|
|
9369
|
+
values_dim_names = _dim_names_for(op.values)
|
|
9370
|
+
output_shape = CEmitter._codegen_shape(op.output_shape)
|
|
9371
|
+
loop_vars = CEmitter._loop_vars(output_shape)
|
|
9372
|
+
indices_indices = tuple(
|
|
9373
|
+
var for idx, var in enumerate(loop_vars) if idx != op.axis
|
|
9374
|
+
)
|
|
9375
|
+
if not indices_indices:
|
|
9376
|
+
indices_indices = ("0",)
|
|
9377
|
+
output_suffix = self._param_array_suffix(
|
|
9378
|
+
op.output_shape, output_dim_names
|
|
9379
|
+
)
|
|
9380
|
+
indices_suffix = self._param_array_suffix(
|
|
9381
|
+
op.indices_shape, indices_dim_names
|
|
9382
|
+
)
|
|
9383
|
+
values_suffix = self._param_array_suffix(
|
|
9384
|
+
op.values_shape, values_dim_names
|
|
9385
|
+
)
|
|
9386
|
+
depth_suffix = self._param_array_suffix(())
|
|
9387
|
+
param_decls = self._build_param_decls(
|
|
9388
|
+
[
|
|
9389
|
+
(
|
|
9390
|
+
params["indices"],
|
|
9391
|
+
op.indices_dtype.c_type,
|
|
9392
|
+
indices_suffix,
|
|
9393
|
+
True,
|
|
9394
|
+
),
|
|
9395
|
+
(
|
|
9396
|
+
params["depth"],
|
|
9397
|
+
op.depth_dtype.c_type,
|
|
9398
|
+
depth_suffix,
|
|
9399
|
+
True,
|
|
9400
|
+
),
|
|
9401
|
+
(params["values"], c_type, values_suffix, True),
|
|
9402
|
+
(params["output"], c_type, output_suffix, False),
|
|
9403
|
+
]
|
|
9404
|
+
)
|
|
9405
|
+
rendered = one_hot_template.render(
|
|
9406
|
+
model_name=model.name,
|
|
9407
|
+
op_name=op_name,
|
|
9408
|
+
indices=params["indices"],
|
|
9409
|
+
depth=params["depth"],
|
|
9410
|
+
values=params["values"],
|
|
9411
|
+
output=params["output"],
|
|
9412
|
+
params=param_decls,
|
|
9413
|
+
indices_suffix=indices_suffix,
|
|
9414
|
+
depth_suffix=depth_suffix,
|
|
9415
|
+
values_suffix=values_suffix,
|
|
9416
|
+
output_suffix=output_suffix,
|
|
9417
|
+
output_shape=output_shape,
|
|
9418
|
+
loop_vars=loop_vars,
|
|
9419
|
+
indices_indices=indices_indices,
|
|
9420
|
+
axis_index=loop_vars[op.axis],
|
|
9421
|
+
depth_dim=op.depth_dim,
|
|
9422
|
+
indices_c_type=op.indices_dtype.c_type,
|
|
9423
|
+
c_type=c_type,
|
|
9424
|
+
).rstrip()
|
|
9425
|
+
return with_node_comment(rendered)
|
|
7695
9426
|
if isinstance(op, SplitOp):
|
|
7696
9427
|
output_params = [
|
|
7697
9428
|
(f"output_{index}", name)
|
|
@@ -7772,6 +9503,86 @@ class CEmitter:
|
|
|
7772
9503
|
dim_args=dim_args,
|
|
7773
9504
|
).rstrip()
|
|
7774
9505
|
return with_node_comment(rendered)
|
|
9506
|
+
if isinstance(op, QuantizeLinearOp):
|
|
9507
|
+
params = self._shared_param_map(
|
|
9508
|
+
[
|
|
9509
|
+
("input0", op.input0),
|
|
9510
|
+
("scale", op.scale),
|
|
9511
|
+
("zero_point", op.zero_point),
|
|
9512
|
+
("output", op.output),
|
|
9513
|
+
]
|
|
9514
|
+
)
|
|
9515
|
+
output_dim_names = _dim_names_for(op.output)
|
|
9516
|
+
shape = CEmitter._shape_dim_exprs(op.input_shape, output_dim_names)
|
|
9517
|
+
loop_vars = CEmitter._loop_vars(op.input_shape)
|
|
9518
|
+
input_suffix = self._param_array_suffix(
|
|
9519
|
+
op.input_shape, _dim_names_for(op.input0)
|
|
9520
|
+
)
|
|
9521
|
+
scale_shape = (
|
|
9522
|
+
()
|
|
9523
|
+
if op.axis is None
|
|
9524
|
+
else (op.input_shape[op.axis],)
|
|
9525
|
+
)
|
|
9526
|
+
scale_suffix = self._param_array_suffix(
|
|
9527
|
+
scale_shape, _dim_names_for(op.scale)
|
|
9528
|
+
)
|
|
9529
|
+
zero_point_suffix = self._param_array_suffix(
|
|
9530
|
+
scale_shape, _dim_names_for(op.zero_point or "")
|
|
9531
|
+
)
|
|
9532
|
+
param_decls = self._build_param_decls(
|
|
9533
|
+
[
|
|
9534
|
+
(params["input0"], op.input_dtype.c_type, input_suffix, True),
|
|
9535
|
+
(params["scale"], op.scale_dtype.c_type, scale_suffix, True),
|
|
9536
|
+
(
|
|
9537
|
+
params["zero_point"],
|
|
9538
|
+
op.dtype.c_type,
|
|
9539
|
+
zero_point_suffix,
|
|
9540
|
+
True,
|
|
9541
|
+
)
|
|
9542
|
+
if params["zero_point"]
|
|
9543
|
+
else (None, "", "", True),
|
|
9544
|
+
(params["output"], op.dtype.c_type, input_suffix, False),
|
|
9545
|
+
]
|
|
9546
|
+
)
|
|
9547
|
+
compute_type = "double" if op.input_dtype == ScalarType.F64 else "float"
|
|
9548
|
+
round_fn = CEmitter._math_fn(
|
|
9549
|
+
op.input_dtype, "nearbyintf", "nearbyint"
|
|
9550
|
+
)
|
|
9551
|
+
scale_index = "0" if op.axis is None else loop_vars[op.axis]
|
|
9552
|
+
input_expr = f"{params['input0']}" + "".join(
|
|
9553
|
+
f"[{var}]" for var in loop_vars
|
|
9554
|
+
)
|
|
9555
|
+
output_expr = f"{params['output']}" + "".join(
|
|
9556
|
+
f"[{var}]" for var in loop_vars
|
|
9557
|
+
)
|
|
9558
|
+
scale_expr = f"{params['scale']}[{scale_index}]"
|
|
9559
|
+
if params["zero_point"]:
|
|
9560
|
+
zero_expr = f"{params['zero_point']}[{scale_index}]"
|
|
9561
|
+
else:
|
|
9562
|
+
zero_expr = "0"
|
|
9563
|
+
rendered = quantize_linear_template.render(
|
|
9564
|
+
model_name=model.name,
|
|
9565
|
+
op_name=op_name,
|
|
9566
|
+
input0=params["input0"],
|
|
9567
|
+
scale=params["scale"],
|
|
9568
|
+
zero_point=params["zero_point"],
|
|
9569
|
+
output=params["output"],
|
|
9570
|
+
params=param_decls,
|
|
9571
|
+
compute_type=compute_type,
|
|
9572
|
+
input_c_type=op.input_dtype.c_type,
|
|
9573
|
+
output_c_type=op.dtype.c_type,
|
|
9574
|
+
shape=shape,
|
|
9575
|
+
loop_vars=loop_vars,
|
|
9576
|
+
input_expr=input_expr,
|
|
9577
|
+
scale_expr=scale_expr,
|
|
9578
|
+
zero_expr=zero_expr,
|
|
9579
|
+
output_expr=output_expr,
|
|
9580
|
+
round_fn=round_fn,
|
|
9581
|
+
min_literal=op.dtype.min_literal,
|
|
9582
|
+
max_literal=op.dtype.max_literal,
|
|
9583
|
+
dim_args=dim_args,
|
|
9584
|
+
).rstrip()
|
|
9585
|
+
return with_node_comment(rendered)
|
|
7775
9586
|
if isinstance(op, ClipOp):
|
|
7776
9587
|
params = self._shared_param_map(
|
|
7777
9588
|
[
|
|
@@ -7934,11 +9745,15 @@ class CEmitter:
|
|
|
7934
9745
|
| UnaryOp
|
|
7935
9746
|
| ClipOp
|
|
7936
9747
|
| CastOp
|
|
9748
|
+
| QuantizeLinearOp
|
|
7937
9749
|
| MatMulOp
|
|
9750
|
+
| EinsumOp
|
|
7938
9751
|
| GemmOp
|
|
7939
9752
|
| AttentionOp
|
|
7940
9753
|
| ConvOp
|
|
9754
|
+
| ConvTransposeOp
|
|
7941
9755
|
| AveragePoolOp
|
|
9756
|
+
| LpPoolOp
|
|
7942
9757
|
| BatchNormOp
|
|
7943
9758
|
| LpNormalizationOp
|
|
7944
9759
|
| InstanceNormalizationOp
|
|
@@ -7950,16 +9765,20 @@ class CEmitter:
|
|
|
7950
9765
|
| LstmOp
|
|
7951
9766
|
| SoftmaxOp
|
|
7952
9767
|
| LogSoftmaxOp
|
|
9768
|
+
| HardmaxOp
|
|
7953
9769
|
| NegativeLogLikelihoodLossOp
|
|
7954
9770
|
| SoftmaxCrossEntropyLossOp
|
|
7955
9771
|
| MaxPoolOp
|
|
7956
9772
|
| ConcatOp
|
|
7957
9773
|
| GatherElementsOp
|
|
7958
9774
|
| GatherOp
|
|
9775
|
+
| GatherNDOp
|
|
9776
|
+
| ScatterNDOp
|
|
7959
9777
|
| TransposeOp
|
|
7960
9778
|
| ReshapeOp
|
|
7961
9779
|
| IdentityOp
|
|
7962
9780
|
| EyeLikeOp
|
|
9781
|
+
| TriluOp
|
|
7963
9782
|
| TileOp
|
|
7964
9783
|
| PadOp
|
|
7965
9784
|
| DepthToSpaceOp
|
|
@@ -7968,16 +9787,20 @@ class CEmitter:
|
|
|
7968
9787
|
| GridSampleOp
|
|
7969
9788
|
| ReduceOp
|
|
7970
9789
|
| ArgReduceOp
|
|
9790
|
+
| TopKOp
|
|
7971
9791
|
| ConstantOfShapeOp
|
|
7972
9792
|
| ShapeOp
|
|
7973
9793
|
| SizeOp
|
|
7974
9794
|
| ExpandOp
|
|
7975
9795
|
| CumSumOp
|
|
7976
9796
|
| RangeOp
|
|
9797
|
+
| OneHotOp
|
|
7977
9798
|
| SplitOp,
|
|
7978
9799
|
) -> str:
|
|
7979
9800
|
if isinstance(op, SplitOp):
|
|
7980
9801
|
return op.outputs[0]
|
|
9802
|
+
if isinstance(op, TopKOp):
|
|
9803
|
+
return op.output_values
|
|
7981
9804
|
return op.output
|
|
7982
9805
|
|
|
7983
9806
|
@staticmethod
|
|
@@ -7988,11 +9811,15 @@ class CEmitter:
|
|
|
7988
9811
|
| UnaryOp
|
|
7989
9812
|
| ClipOp
|
|
7990
9813
|
| CastOp
|
|
9814
|
+
| QuantizeLinearOp
|
|
7991
9815
|
| MatMulOp
|
|
9816
|
+
| EinsumOp
|
|
7992
9817
|
| GemmOp
|
|
7993
9818
|
| AttentionOp
|
|
7994
9819
|
| ConvOp
|
|
9820
|
+
| ConvTransposeOp
|
|
7995
9821
|
| AveragePoolOp
|
|
9822
|
+
| LpPoolOp
|
|
7996
9823
|
| BatchNormOp
|
|
7997
9824
|
| LpNormalizationOp
|
|
7998
9825
|
| InstanceNormalizationOp
|
|
@@ -8004,16 +9831,20 @@ class CEmitter:
|
|
|
8004
9831
|
| LstmOp
|
|
8005
9832
|
| SoftmaxOp
|
|
8006
9833
|
| LogSoftmaxOp
|
|
9834
|
+
| HardmaxOp
|
|
8007
9835
|
| NegativeLogLikelihoodLossOp
|
|
8008
9836
|
| SoftmaxCrossEntropyLossOp
|
|
8009
9837
|
| MaxPoolOp
|
|
8010
9838
|
| ConcatOp
|
|
8011
9839
|
| GatherElementsOp
|
|
8012
9840
|
| GatherOp
|
|
9841
|
+
| GatherNDOp
|
|
9842
|
+
| ScatterNDOp
|
|
8013
9843
|
| TransposeOp
|
|
8014
9844
|
| ReshapeOp
|
|
8015
9845
|
| IdentityOp
|
|
8016
9846
|
| EyeLikeOp
|
|
9847
|
+
| TriluOp
|
|
8017
9848
|
| TileOp
|
|
8018
9849
|
| PadOp
|
|
8019
9850
|
| DepthToSpaceOp
|
|
@@ -8022,18 +9853,28 @@ class CEmitter:
|
|
|
8022
9853
|
| GridSampleOp
|
|
8023
9854
|
| ReduceOp
|
|
8024
9855
|
| ArgReduceOp
|
|
9856
|
+
| TopKOp
|
|
8025
9857
|
| ConstantOfShapeOp
|
|
8026
9858
|
| ShapeOp
|
|
8027
9859
|
| SizeOp
|
|
8028
9860
|
| ExpandOp
|
|
8029
9861
|
| CumSumOp
|
|
8030
9862
|
| RangeOp
|
|
9863
|
+
| OneHotOp
|
|
8031
9864
|
| SplitOp,
|
|
8032
9865
|
) -> tuple[tuple[str, tuple[int, ...]], ...]:
|
|
8033
9866
|
if isinstance(op, BinaryOp):
|
|
8034
|
-
return (
|
|
9867
|
+
return (
|
|
9868
|
+
(op.input0, op.input0_shape),
|
|
9869
|
+
(op.input1, op.input1_shape),
|
|
9870
|
+
)
|
|
8035
9871
|
if isinstance(op, MultiInputBinaryOp):
|
|
8036
9872
|
return tuple((name, op.shape) for name in op.inputs)
|
|
9873
|
+
if isinstance(op, EinsumOp):
|
|
9874
|
+
return tuple(
|
|
9875
|
+
(name, shape)
|
|
9876
|
+
for name, shape in zip(op.inputs, op.input_shapes)
|
|
9877
|
+
)
|
|
8037
9878
|
if isinstance(op, UnaryOp):
|
|
8038
9879
|
return ((op.input0, op.shape),)
|
|
8039
9880
|
if isinstance(op, LpNormalizationOp):
|
|
@@ -8068,10 +9909,27 @@ class CEmitter:
|
|
|
8068
9909
|
return tuple(inputs)
|
|
8069
9910
|
if isinstance(op, CastOp):
|
|
8070
9911
|
return ((op.input0, op.shape),)
|
|
9912
|
+
if isinstance(op, NonZeroOp):
|
|
9913
|
+
return ((op.input0, op.input_shape),)
|
|
9914
|
+
if isinstance(op, QuantizeLinearOp):
|
|
9915
|
+
scale_shape = (
|
|
9916
|
+
()
|
|
9917
|
+
if op.axis is None
|
|
9918
|
+
else (op.input_shape[op.axis],)
|
|
9919
|
+
)
|
|
9920
|
+
inputs = [(op.input0, op.input_shape), (op.scale, scale_shape)]
|
|
9921
|
+
if op.zero_point is not None:
|
|
9922
|
+
inputs.append((op.zero_point, scale_shape))
|
|
9923
|
+
return tuple(inputs)
|
|
8071
9924
|
if isinstance(op, IdentityOp):
|
|
8072
9925
|
return ((op.input0, op.shape),)
|
|
8073
9926
|
if isinstance(op, EyeLikeOp):
|
|
8074
9927
|
return ((op.input0, op.output_shape),)
|
|
9928
|
+
if isinstance(op, TriluOp):
|
|
9929
|
+
inputs = [(op.input0, op.input_shape)]
|
|
9930
|
+
if op.k_input is not None and op.k_input_shape is not None:
|
|
9931
|
+
inputs.append((op.k_input, op.k_input_shape))
|
|
9932
|
+
return tuple(inputs)
|
|
8075
9933
|
if isinstance(op, GridSampleOp):
|
|
8076
9934
|
return ((op.input0, op.input_shape), (op.grid, op.grid_shape))
|
|
8077
9935
|
if isinstance(op, PadOp):
|
|
@@ -8083,8 +9941,22 @@ class CEmitter:
|
|
|
8083
9941
|
if op.value_input is not None and op.value_shape is not None:
|
|
8084
9942
|
inputs.append((op.value_input, op.value_shape))
|
|
8085
9943
|
return tuple(inputs)
|
|
9944
|
+
if isinstance(op, ScatterNDOp):
|
|
9945
|
+
return ((op.data, op.data_shape),)
|
|
8086
9946
|
if isinstance(op, CumSumOp):
|
|
8087
9947
|
return ((op.input0, op.input_shape),)
|
|
9948
|
+
if isinstance(op, RangeOp):
|
|
9949
|
+
return ((op.start, ()), (op.limit, ()), (op.delta, ()))
|
|
9950
|
+
if isinstance(op, OneHotOp):
|
|
9951
|
+
return (
|
|
9952
|
+
(op.indices, op.indices_shape),
|
|
9953
|
+
(op.depth, ()),
|
|
9954
|
+
(op.values, op.values_shape),
|
|
9955
|
+
)
|
|
9956
|
+
if isinstance(op, SplitOp):
|
|
9957
|
+
return ((op.input0, op.input_shape),)
|
|
9958
|
+
if isinstance(op, TopKOp):
|
|
9959
|
+
return ((op.input0, op.input_shape),)
|
|
8088
9960
|
return ()
|
|
8089
9961
|
|
|
8090
9962
|
def _propagate_tensor_dim_names(
|
|
@@ -8096,11 +9968,15 @@ class CEmitter:
|
|
|
8096
9968
|
| UnaryOp
|
|
8097
9969
|
| ClipOp
|
|
8098
9970
|
| CastOp
|
|
9971
|
+
| QuantizeLinearOp
|
|
8099
9972
|
| MatMulOp
|
|
9973
|
+
| EinsumOp
|
|
8100
9974
|
| GemmOp
|
|
8101
9975
|
| AttentionOp
|
|
8102
9976
|
| ConvOp
|
|
9977
|
+
| ConvTransposeOp
|
|
8103
9978
|
| AveragePoolOp
|
|
9979
|
+
| LpPoolOp
|
|
8104
9980
|
| BatchNormOp
|
|
8105
9981
|
| LpNormalizationOp
|
|
8106
9982
|
| InstanceNormalizationOp
|
|
@@ -8112,16 +9988,19 @@ class CEmitter:
|
|
|
8112
9988
|
| LstmOp
|
|
8113
9989
|
| SoftmaxOp
|
|
8114
9990
|
| LogSoftmaxOp
|
|
9991
|
+
| HardmaxOp
|
|
8115
9992
|
| NegativeLogLikelihoodLossOp
|
|
8116
9993
|
| SoftmaxCrossEntropyLossOp
|
|
8117
9994
|
| MaxPoolOp
|
|
8118
9995
|
| ConcatOp
|
|
8119
9996
|
| GatherElementsOp
|
|
8120
9997
|
| GatherOp
|
|
9998
|
+
| GatherNDOp
|
|
8121
9999
|
| TransposeOp
|
|
8122
10000
|
| ReshapeOp
|
|
8123
10001
|
| IdentityOp
|
|
8124
10002
|
| EyeLikeOp
|
|
10003
|
+
| TriluOp
|
|
8125
10004
|
| TileOp
|
|
8126
10005
|
| PadOp
|
|
8127
10006
|
| DepthToSpaceOp
|
|
@@ -8130,11 +10009,14 @@ class CEmitter:
|
|
|
8130
10009
|
| GridSampleOp
|
|
8131
10010
|
| ReduceOp
|
|
8132
10011
|
| ArgReduceOp
|
|
10012
|
+
| TopKOp
|
|
8133
10013
|
| ConstantOfShapeOp
|
|
8134
10014
|
| ShapeOp
|
|
8135
10015
|
| SizeOp
|
|
10016
|
+
| NonZeroOp
|
|
8136
10017
|
| ExpandOp
|
|
8137
10018
|
| RangeOp
|
|
10019
|
+
| OneHotOp
|
|
8138
10020
|
| SplitOp
|
|
8139
10021
|
],
|
|
8140
10022
|
tensor_dim_names: dict[str, dict[int, str]],
|
|
@@ -8157,11 +10039,15 @@ class CEmitter:
|
|
|
8157
10039
|
| UnaryOp
|
|
8158
10040
|
| ClipOp
|
|
8159
10041
|
| CastOp
|
|
10042
|
+
| QuantizeLinearOp
|
|
8160
10043
|
| MatMulOp
|
|
10044
|
+
| EinsumOp
|
|
8161
10045
|
| GemmOp
|
|
8162
10046
|
| AttentionOp
|
|
8163
10047
|
| ConvOp
|
|
10048
|
+
| ConvTransposeOp
|
|
8164
10049
|
| AveragePoolOp
|
|
10050
|
+
| LpPoolOp
|
|
8165
10051
|
| BatchNormOp
|
|
8166
10052
|
| LpNormalizationOp
|
|
8167
10053
|
| InstanceNormalizationOp
|
|
@@ -8173,16 +10059,20 @@ class CEmitter:
|
|
|
8173
10059
|
| LstmOp
|
|
8174
10060
|
| SoftmaxOp
|
|
8175
10061
|
| LogSoftmaxOp
|
|
10062
|
+
| HardmaxOp
|
|
8176
10063
|
| NegativeLogLikelihoodLossOp
|
|
8177
10064
|
| SoftmaxCrossEntropyLossOp
|
|
8178
10065
|
| MaxPoolOp
|
|
8179
10066
|
| ConcatOp
|
|
8180
10067
|
| GatherElementsOp
|
|
8181
10068
|
| GatherOp
|
|
10069
|
+
| GatherNDOp
|
|
10070
|
+
| ScatterNDOp
|
|
8182
10071
|
| TransposeOp
|
|
8183
10072
|
| ReshapeOp
|
|
8184
10073
|
| IdentityOp
|
|
8185
10074
|
| EyeLikeOp
|
|
10075
|
+
| TriluOp
|
|
8186
10076
|
| TileOp
|
|
8187
10077
|
| PadOp
|
|
8188
10078
|
| DepthToSpaceOp
|
|
@@ -8191,11 +10081,14 @@ class CEmitter:
|
|
|
8191
10081
|
| GridSampleOp
|
|
8192
10082
|
| ReduceOp
|
|
8193
10083
|
| ArgReduceOp
|
|
10084
|
+
| TopKOp
|
|
8194
10085
|
| ConstantOfShapeOp
|
|
8195
10086
|
| ShapeOp
|
|
8196
10087
|
| SizeOp
|
|
10088
|
+
| NonZeroOp
|
|
8197
10089
|
| ExpandOp
|
|
8198
10090
|
| RangeOp
|
|
10091
|
+
| OneHotOp
|
|
8199
10092
|
| SplitOp,
|
|
8200
10093
|
) -> tuple[tuple[str, tuple[int, ...], str], ...]:
|
|
8201
10094
|
if isinstance(op, AttentionOp):
|
|
@@ -8292,6 +10185,19 @@ class CEmitter:
|
|
|
8292
10185
|
)
|
|
8293
10186
|
if isinstance(op, ArgReduceOp):
|
|
8294
10187
|
return ((op.output, CEmitter._op_output_shape(op), op.output_dtype),)
|
|
10188
|
+
if isinstance(op, TopKOp):
|
|
10189
|
+
return (
|
|
10190
|
+
(
|
|
10191
|
+
op.output_values,
|
|
10192
|
+
CEmitter._op_output_shape(op),
|
|
10193
|
+
op.output_values_dtype,
|
|
10194
|
+
),
|
|
10195
|
+
(
|
|
10196
|
+
op.output_indices,
|
|
10197
|
+
CEmitter._op_output_shape(op),
|
|
10198
|
+
op.output_indices_dtype,
|
|
10199
|
+
),
|
|
10200
|
+
)
|
|
8295
10201
|
return ((op.output, CEmitter._op_output_shape(op), op.dtype),)
|
|
8296
10202
|
|
|
8297
10203
|
@staticmethod
|
|
@@ -8303,6 +10209,7 @@ class CEmitter:
|
|
|
8303
10209
|
| ClipOp
|
|
8304
10210
|
| CastOp
|
|
8305
10211
|
| MatMulOp
|
|
10212
|
+
| EinsumOp
|
|
8306
10213
|
| GemmOp
|
|
8307
10214
|
| AttentionOp
|
|
8308
10215
|
| ConvOp
|
|
@@ -8318,25 +10225,34 @@ class CEmitter:
|
|
|
8318
10225
|
| LstmOp
|
|
8319
10226
|
| SoftmaxOp
|
|
8320
10227
|
| LogSoftmaxOp
|
|
10228
|
+
| HardmaxOp
|
|
8321
10229
|
| NegativeLogLikelihoodLossOp
|
|
8322
10230
|
| SoftmaxCrossEntropyLossOp
|
|
8323
10231
|
| MaxPoolOp
|
|
8324
10232
|
| ConcatOp
|
|
8325
10233
|
| GatherElementsOp
|
|
8326
10234
|
| GatherOp
|
|
10235
|
+
| GatherNDOp
|
|
8327
10236
|
| TransposeOp
|
|
8328
10237
|
| ReshapeOp
|
|
10238
|
+
| IdentityOp
|
|
10239
|
+
| EyeLikeOp
|
|
10240
|
+
| TriluOp
|
|
10241
|
+
| TileOp
|
|
8329
10242
|
| SliceOp
|
|
8330
10243
|
| ResizeOp
|
|
8331
10244
|
| GridSampleOp
|
|
8332
10245
|
| ReduceOp
|
|
8333
10246
|
| ArgReduceOp
|
|
10247
|
+
| TopKOp
|
|
8334
10248
|
| ConstantOfShapeOp
|
|
8335
10249
|
| ShapeOp
|
|
8336
10250
|
| SizeOp
|
|
10251
|
+
| NonZeroOp
|
|
8337
10252
|
| ExpandOp
|
|
8338
10253
|
| CumSumOp
|
|
8339
10254
|
| RangeOp
|
|
10255
|
+
| OneHotOp
|
|
8340
10256
|
| SplitOp
|
|
8341
10257
|
| PadOp,
|
|
8342
10258
|
) -> tuple[int, ...]:
|
|
@@ -8350,16 +10266,24 @@ class CEmitter:
|
|
|
8350
10266
|
return op.shape
|
|
8351
10267
|
if isinstance(op, ClipOp):
|
|
8352
10268
|
return op.output_shape
|
|
10269
|
+
if isinstance(op, QuantizeLinearOp):
|
|
10270
|
+
return op.input_shape
|
|
8353
10271
|
if isinstance(op, CastOp):
|
|
8354
10272
|
return op.shape
|
|
8355
10273
|
if isinstance(op, MatMulOp):
|
|
8356
|
-
return
|
|
10274
|
+
return op.output_shape
|
|
10275
|
+
if isinstance(op, EinsumOp):
|
|
10276
|
+
return op.output_shape
|
|
8357
10277
|
if isinstance(op, GemmOp):
|
|
8358
10278
|
return (op.m, op.n)
|
|
8359
10279
|
if isinstance(op, ConvOp):
|
|
8360
10280
|
return (op.batch, op.out_channels, *op.out_spatial)
|
|
10281
|
+
if isinstance(op, ConvTransposeOp):
|
|
10282
|
+
return (op.batch, op.out_channels, *op.out_spatial)
|
|
8361
10283
|
if isinstance(op, AveragePoolOp):
|
|
8362
10284
|
return (op.batch, op.channels, op.out_h, op.out_w)
|
|
10285
|
+
if isinstance(op, LpPoolOp):
|
|
10286
|
+
return (op.batch, op.channels, op.out_h, op.out_w)
|
|
8363
10287
|
if isinstance(op, BatchNormOp):
|
|
8364
10288
|
return op.shape
|
|
8365
10289
|
if isinstance(
|
|
@@ -8380,6 +10304,8 @@ class CEmitter:
|
|
|
8380
10304
|
return op.shape
|
|
8381
10305
|
if isinstance(op, LogSoftmaxOp):
|
|
8382
10306
|
return op.shape
|
|
10307
|
+
if isinstance(op, HardmaxOp):
|
|
10308
|
+
return op.shape
|
|
8383
10309
|
if isinstance(op, NegativeLogLikelihoodLossOp):
|
|
8384
10310
|
return op.output_shape
|
|
8385
10311
|
if isinstance(op, SoftmaxCrossEntropyLossOp):
|
|
@@ -8392,6 +10318,10 @@ class CEmitter:
|
|
|
8392
10318
|
return op.output_shape
|
|
8393
10319
|
if isinstance(op, GatherOp):
|
|
8394
10320
|
return op.output_shape
|
|
10321
|
+
if isinstance(op, GatherNDOp):
|
|
10322
|
+
return op.output_shape
|
|
10323
|
+
if isinstance(op, ScatterNDOp):
|
|
10324
|
+
return op.output_shape
|
|
8395
10325
|
if isinstance(op, TransposeOp):
|
|
8396
10326
|
return op.output_shape
|
|
8397
10327
|
if isinstance(op, ReshapeOp):
|
|
@@ -8400,6 +10330,8 @@ class CEmitter:
|
|
|
8400
10330
|
return op.shape
|
|
8401
10331
|
if isinstance(op, EyeLikeOp):
|
|
8402
10332
|
return op.output_shape
|
|
10333
|
+
if isinstance(op, TriluOp):
|
|
10334
|
+
return op.output_shape
|
|
8403
10335
|
if isinstance(op, TileOp):
|
|
8404
10336
|
return op.output_shape
|
|
8405
10337
|
if isinstance(op, PadOp):
|
|
@@ -8418,18 +10350,24 @@ class CEmitter:
|
|
|
8418
10350
|
return op.output_shape
|
|
8419
10351
|
if isinstance(op, ArgReduceOp):
|
|
8420
10352
|
return op.output_shape
|
|
10353
|
+
if isinstance(op, TopKOp):
|
|
10354
|
+
return op.output_shape
|
|
8421
10355
|
if isinstance(op, ConstantOfShapeOp):
|
|
8422
10356
|
return op.shape
|
|
8423
10357
|
if isinstance(op, ShapeOp):
|
|
8424
10358
|
return op.output_shape
|
|
8425
10359
|
if isinstance(op, SizeOp):
|
|
8426
10360
|
return op.output_shape
|
|
10361
|
+
if isinstance(op, NonZeroOp):
|
|
10362
|
+
return op.output_shape
|
|
8427
10363
|
if isinstance(op, ExpandOp):
|
|
8428
10364
|
return op.output_shape
|
|
8429
10365
|
if isinstance(op, CumSumOp):
|
|
8430
10366
|
return op.input_shape
|
|
8431
10367
|
if isinstance(op, RangeOp):
|
|
8432
10368
|
return op.output_shape
|
|
10369
|
+
if isinstance(op, OneHotOp):
|
|
10370
|
+
return op.output_shape
|
|
8433
10371
|
if op.output_rank == 3:
|
|
8434
10372
|
return (op.batch, op.q_seq, op.q_heads * op.v_head_size)
|
|
8435
10373
|
return (op.batch, op.q_heads, op.q_seq, op.v_head_size)
|
|
@@ -8441,11 +10379,16 @@ class CEmitter:
|
|
|
8441
10379
|
| WhereOp
|
|
8442
10380
|
| UnaryOp
|
|
8443
10381
|
| ClipOp
|
|
10382
|
+
| CastOp
|
|
10383
|
+
| QuantizeLinearOp
|
|
8444
10384
|
| MatMulOp
|
|
10385
|
+
| EinsumOp
|
|
8445
10386
|
| GemmOp
|
|
8446
10387
|
| AttentionOp
|
|
8447
10388
|
| ConvOp
|
|
10389
|
+
| ConvTransposeOp
|
|
8448
10390
|
| AveragePoolOp
|
|
10391
|
+
| LpPoolOp
|
|
8449
10392
|
| BatchNormOp
|
|
8450
10393
|
| LpNormalizationOp
|
|
8451
10394
|
| InstanceNormalizationOp
|
|
@@ -8455,14 +10398,20 @@ class CEmitter:
|
|
|
8455
10398
|
| RMSNormalizationOp
|
|
8456
10399
|
| SoftmaxOp
|
|
8457
10400
|
| LogSoftmaxOp
|
|
10401
|
+
| HardmaxOp
|
|
8458
10402
|
| NegativeLogLikelihoodLossOp
|
|
8459
10403
|
| SoftmaxCrossEntropyLossOp
|
|
8460
10404
|
| MaxPoolOp
|
|
8461
10405
|
| ConcatOp
|
|
8462
10406
|
| GatherElementsOp
|
|
8463
10407
|
| GatherOp
|
|
10408
|
+
| GatherNDOp
|
|
8464
10409
|
| TransposeOp
|
|
8465
10410
|
| ReshapeOp
|
|
10411
|
+
| IdentityOp
|
|
10412
|
+
| EyeLikeOp
|
|
10413
|
+
| TriluOp
|
|
10414
|
+
| TileOp
|
|
8466
10415
|
| ResizeOp
|
|
8467
10416
|
| GridSampleOp
|
|
8468
10417
|
| ReduceOp
|
|
@@ -8470,21 +10419,25 @@ class CEmitter:
|
|
|
8470
10419
|
| ConstantOfShapeOp
|
|
8471
10420
|
| ShapeOp
|
|
8472
10421
|
| SizeOp
|
|
10422
|
+
| NonZeroOp
|
|
8473
10423
|
| ExpandOp
|
|
8474
10424
|
| CumSumOp
|
|
8475
10425
|
| RangeOp
|
|
10426
|
+
| OneHotOp
|
|
8476
10427
|
| SplitOp
|
|
8477
10428
|
| PadOp,
|
|
8478
10429
|
) -> ScalarType:
|
|
8479
10430
|
if isinstance(op, ArgReduceOp):
|
|
8480
10431
|
return op.output_dtype
|
|
10432
|
+
if isinstance(op, TopKOp):
|
|
10433
|
+
return op.output_values_dtype
|
|
8481
10434
|
return op.dtype
|
|
8482
10435
|
|
|
8483
10436
|
@staticmethod
|
|
8484
10437
|
def _codegen_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
8485
10438
|
if not shape:
|
|
8486
10439
|
return (1,)
|
|
8487
|
-
return shape
|
|
10440
|
+
return tuple(max(1, dim) if isinstance(dim, int) else dim for dim in shape)
|
|
8488
10441
|
|
|
8489
10442
|
@staticmethod
|
|
8490
10443
|
def _array_suffix(shape: tuple[int, ...]) -> str:
|
|
@@ -8623,6 +10576,8 @@ class CEmitter:
|
|
|
8623
10576
|
dim_names: Mapping[int, str] | None,
|
|
8624
10577
|
) -> tuple[str | int, ...]:
|
|
8625
10578
|
dim_names = dim_names or {}
|
|
10579
|
+
if not shape:
|
|
10580
|
+
shape = (1,)
|
|
8626
10581
|
return tuple(
|
|
8627
10582
|
dim_names.get(index, dim) for index, dim in enumerate(shape)
|
|
8628
10583
|
)
|
|
@@ -8677,7 +10632,8 @@ class CEmitter:
|
|
|
8677
10632
|
|
|
8678
10633
|
@staticmethod
|
|
8679
10634
|
def _element_count(shape: tuple[int, ...]) -> int:
|
|
8680
|
-
|
|
10635
|
+
if not shape:
|
|
10636
|
+
return 1
|
|
8681
10637
|
count = 1
|
|
8682
10638
|
for dim in shape:
|
|
8683
10639
|
if dim < 0:
|
|
@@ -8745,6 +10701,7 @@ class CEmitter:
|
|
|
8745
10701
|
testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None = None,
|
|
8746
10702
|
dim_order: Sequence[str],
|
|
8747
10703
|
dim_values: Mapping[str, int],
|
|
10704
|
+
weight_data_filename: str,
|
|
8748
10705
|
) -> str:
|
|
8749
10706
|
input_counts = tuple(
|
|
8750
10707
|
self._element_count(shape) for shape in model.input_shapes
|
|
@@ -8755,7 +10712,8 @@ class CEmitter:
|
|
|
8755
10712
|
model.input_names, model.input_shapes, input_counts, model.input_dtypes
|
|
8756
10713
|
):
|
|
8757
10714
|
codegen_shape = self._codegen_shape(shape)
|
|
8758
|
-
|
|
10715
|
+
loop_shape = (1,) if not shape else shape
|
|
10716
|
+
loop_vars = self._loop_vars(loop_shape)
|
|
8759
10717
|
if dtype in {ScalarType.F16, ScalarType.F32}:
|
|
8760
10718
|
random_expr = "rng_next_float()"
|
|
8761
10719
|
elif dtype == ScalarType.F64:
|
|
@@ -8769,20 +10727,26 @@ class CEmitter:
|
|
|
8769
10727
|
constant_lines = None
|
|
8770
10728
|
if constant_values is not None:
|
|
8771
10729
|
constant_name = f"{name}_testbench_data"
|
|
8772
|
-
|
|
8773
|
-
|
|
8774
|
-
|
|
8775
|
-
|
|
10730
|
+
if constant_values:
|
|
10731
|
+
constant_lines = [
|
|
10732
|
+
self._format_value(value, dtype)
|
|
10733
|
+
for value in constant_values
|
|
10734
|
+
]
|
|
10735
|
+
else:
|
|
10736
|
+
constant_lines = [self._format_value(0, dtype)]
|
|
8776
10737
|
inputs.append(
|
|
8777
10738
|
{
|
|
8778
10739
|
"name": name,
|
|
8779
|
-
"shape":
|
|
10740
|
+
"shape": loop_shape,
|
|
8780
10741
|
"shape_literal": ",".join(str(dim) for dim in shape),
|
|
8781
10742
|
"count": count,
|
|
8782
10743
|
"array_suffix": self._array_suffix(codegen_shape),
|
|
10744
|
+
"array_index_expr": "".join(
|
|
10745
|
+
f"[{var}]" for var in loop_vars
|
|
10746
|
+
),
|
|
8783
10747
|
"loop_vars": loop_vars,
|
|
8784
|
-
"rank": len(
|
|
8785
|
-
"index_expr": self._index_expr(
|
|
10748
|
+
"rank": len(loop_shape),
|
|
10749
|
+
"index_expr": self._index_expr(loop_shape, loop_vars),
|
|
8786
10750
|
"dtype": dtype,
|
|
8787
10751
|
"c_type": dtype.c_type,
|
|
8788
10752
|
"random_expr": random_expr,
|
|
@@ -8797,17 +10761,21 @@ class CEmitter:
|
|
|
8797
10761
|
model.output_names, model.output_shapes, model.output_dtypes
|
|
8798
10762
|
):
|
|
8799
10763
|
codegen_shape = self._codegen_shape(shape)
|
|
8800
|
-
|
|
10764
|
+
loop_shape = (1,) if not shape else shape
|
|
10765
|
+
output_loop_vars = self._loop_vars(loop_shape)
|
|
8801
10766
|
outputs.append(
|
|
8802
10767
|
{
|
|
8803
10768
|
"name": name,
|
|
8804
|
-
"shape":
|
|
10769
|
+
"shape": loop_shape,
|
|
8805
10770
|
"shape_literal": ",".join(str(dim) for dim in shape),
|
|
8806
|
-
"count": self._element_count(
|
|
10771
|
+
"count": self._element_count(shape),
|
|
8807
10772
|
"array_suffix": self._array_suffix(codegen_shape),
|
|
10773
|
+
"array_index_expr": "".join(
|
|
10774
|
+
f"[{var}]" for var in output_loop_vars
|
|
10775
|
+
),
|
|
8808
10776
|
"loop_vars": output_loop_vars,
|
|
8809
|
-
"rank": len(
|
|
8810
|
-
"index_expr": self._index_expr(
|
|
10777
|
+
"rank": len(loop_shape),
|
|
10778
|
+
"index_expr": self._index_expr(loop_shape, output_loop_vars),
|
|
8811
10779
|
"dtype": dtype,
|
|
8812
10780
|
"c_type": dtype.c_type,
|
|
8813
10781
|
"print_format": self._print_format(dtype),
|
|
@@ -8822,9 +10790,87 @@ class CEmitter:
|
|
|
8822
10790
|
],
|
|
8823
10791
|
inputs=inputs,
|
|
8824
10792
|
outputs=outputs,
|
|
10793
|
+
weight_data_filename=weight_data_filename,
|
|
8825
10794
|
).rstrip()
|
|
8826
10795
|
return _format_c_indentation(rendered)
|
|
8827
10796
|
|
|
10797
|
+
@staticmethod
|
|
10798
|
+
def _testbench_requires_math(
|
|
10799
|
+
model: LoweredModel,
|
|
10800
|
+
testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None,
|
|
10801
|
+
) -> bool:
|
|
10802
|
+
if not testbench_inputs:
|
|
10803
|
+
return False
|
|
10804
|
+
dtype_map = dict(zip(model.input_names, model.input_dtypes))
|
|
10805
|
+
float_dtypes = {ScalarType.F16, ScalarType.F32, ScalarType.F64}
|
|
10806
|
+
for name, values in testbench_inputs.items():
|
|
10807
|
+
if dtype_map.get(name) not in float_dtypes:
|
|
10808
|
+
continue
|
|
10809
|
+
for value in values:
|
|
10810
|
+
if not math.isfinite(float(value)):
|
|
10811
|
+
return True
|
|
10812
|
+
return False
|
|
10813
|
+
|
|
10814
|
+
def _partition_constants(
|
|
10815
|
+
self, constants: tuple[ConstTensor, ...]
|
|
10816
|
+
) -> tuple[tuple[ConstTensor, ...], tuple[ConstTensor, ...]]:
|
|
10817
|
+
if self._large_weight_threshold <= 0:
|
|
10818
|
+
return (), constants
|
|
10819
|
+
inline: list[ConstTensor] = []
|
|
10820
|
+
large: list[ConstTensor] = []
|
|
10821
|
+
for const in constants:
|
|
10822
|
+
if self._element_count(const.shape) > self._large_weight_threshold:
|
|
10823
|
+
large.append(const)
|
|
10824
|
+
else:
|
|
10825
|
+
inline.append(const)
|
|
10826
|
+
return tuple(inline), tuple(large)
|
|
10827
|
+
|
|
10828
|
+
@staticmethod
|
|
10829
|
+
def _weight_data_filename(model: LoweredModel) -> str:
|
|
10830
|
+
return f"{model.name}.bin"
|
|
10831
|
+
|
|
10832
|
+
def _emit_weight_loader(
|
|
10833
|
+
self, model: LoweredModel, large_constants: tuple[ConstTensor, ...]
|
|
10834
|
+
) -> str:
|
|
10835
|
+
lines = [f"_Bool {model.name}_load(const char *path) {{"]
|
|
10836
|
+
if not large_constants:
|
|
10837
|
+
lines.append(" (void)path;")
|
|
10838
|
+
lines.append(" return 1;")
|
|
10839
|
+
lines.append("}")
|
|
10840
|
+
return _format_c_indentation("\n".join(lines))
|
|
10841
|
+
lines.append(" FILE *file = fopen(path, \"rb\");")
|
|
10842
|
+
lines.append(" if (!file) {")
|
|
10843
|
+
lines.append(" return 0;")
|
|
10844
|
+
lines.append(" }")
|
|
10845
|
+
lines.append(
|
|
10846
|
+
f" _Bool ok = {model.name}_load_file(file);"
|
|
10847
|
+
)
|
|
10848
|
+
lines.append(" fclose(file);")
|
|
10849
|
+
lines.append(" return ok;")
|
|
10850
|
+
lines.append("}")
|
|
10851
|
+
lines.append("")
|
|
10852
|
+
lines.append(f"static _Bool {model.name}_load_file(FILE *file) {{")
|
|
10853
|
+
for const in large_constants:
|
|
10854
|
+
shape = self._codegen_shape(const.shape)
|
|
10855
|
+
loop_vars = self._loop_vars(shape)
|
|
10856
|
+
for depth, var in enumerate(loop_vars):
|
|
10857
|
+
lines.append(
|
|
10858
|
+
f" for (idx_t {var} = 0; {var} < {shape[depth]}; ++{var}) {{"
|
|
10859
|
+
)
|
|
10860
|
+
index_expr = "".join(f"[{var}]" for var in loop_vars)
|
|
10861
|
+
zero_index = "[0]" * len(shape)
|
|
10862
|
+
lines.append(
|
|
10863
|
+
f" if (fread(&{const.name}{index_expr}, "
|
|
10864
|
+
f"sizeof({const.name}{zero_index}), 1, file) != 1) {{"
|
|
10865
|
+
)
|
|
10866
|
+
lines.append(" return 0;")
|
|
10867
|
+
lines.append(" }")
|
|
10868
|
+
for _ in loop_vars[::-1]:
|
|
10869
|
+
lines.append(" }")
|
|
10870
|
+
lines.append(" return 1;")
|
|
10871
|
+
lines.append("}")
|
|
10872
|
+
return _format_c_indentation("\n".join(lines))
|
|
10873
|
+
|
|
8828
10874
|
def _emit_constant_definitions(
|
|
8829
10875
|
self,
|
|
8830
10876
|
constants: tuple[ConstTensor, ...],
|
|
@@ -8834,26 +10880,31 @@ class CEmitter:
|
|
|
8834
10880
|
if not constants:
|
|
8835
10881
|
return ""
|
|
8836
10882
|
lines: list[str] = []
|
|
8837
|
-
for const in constants:
|
|
10883
|
+
for index, const in enumerate(constants, start=1):
|
|
10884
|
+
lines.append(self._emit_constant_comment(const, index))
|
|
8838
10885
|
c_type = const.dtype.c_type
|
|
8839
|
-
|
|
10886
|
+
shape = self._codegen_shape(const.shape)
|
|
10887
|
+
array_suffix = self._array_suffix(shape)
|
|
8840
10888
|
values = [
|
|
8841
|
-
self.
|
|
10889
|
+
self._format_weight_value(value, const.dtype)
|
|
10890
|
+
for value in const.data
|
|
8842
10891
|
]
|
|
8843
10892
|
lines.append(
|
|
8844
10893
|
f"{storage_prefix} {c_type} {const.name}{array_suffix} = {{"
|
|
8845
10894
|
)
|
|
8846
10895
|
if values:
|
|
8847
|
-
|
|
8848
|
-
|
|
8849
|
-
values
|
|
8850
|
-
|
|
8851
|
-
|
|
8852
|
-
|
|
8853
|
-
|
|
8854
|
-
|
|
8855
|
-
|
|
8856
|
-
lines.
|
|
10896
|
+
if (
|
|
10897
|
+
self._truncate_weights_after is not None
|
|
10898
|
+
and len(values) > self._truncate_weights_after
|
|
10899
|
+
):
|
|
10900
|
+
truncated_lines, _, _, _ = (
|
|
10901
|
+
self._emit_initializer_lines_truncated(
|
|
10902
|
+
values, shape, self._truncate_weights_after
|
|
10903
|
+
)
|
|
10904
|
+
)
|
|
10905
|
+
lines.extend(truncated_lines)
|
|
10906
|
+
else:
|
|
10907
|
+
lines.extend(self._emit_initializer_lines(values, shape))
|
|
8857
10908
|
lines.append("};")
|
|
8858
10909
|
lines.append("")
|
|
8859
10910
|
if lines and not lines[-1]:
|
|
@@ -8866,12 +10917,44 @@ class CEmitter:
|
|
|
8866
10917
|
if not constants:
|
|
8867
10918
|
return ""
|
|
8868
10919
|
lines = []
|
|
8869
|
-
for const in constants:
|
|
10920
|
+
for index, const in enumerate(constants, start=1):
|
|
10921
|
+
lines.append(self._emit_constant_comment(const, index))
|
|
8870
10922
|
c_type = const.dtype.c_type
|
|
8871
10923
|
array_suffix = self._array_suffix(const.shape)
|
|
8872
10924
|
lines.append(f"extern const {c_type} {const.name}{array_suffix};")
|
|
8873
10925
|
return "\n".join(lines)
|
|
8874
10926
|
|
|
10927
|
+
def _emit_constant_storage_definitions(
|
|
10928
|
+
self,
|
|
10929
|
+
constants: tuple[ConstTensor, ...],
|
|
10930
|
+
*,
|
|
10931
|
+
storage_prefix: str = "static",
|
|
10932
|
+
) -> str:
|
|
10933
|
+
if not constants:
|
|
10934
|
+
return ""
|
|
10935
|
+
lines: list[str] = []
|
|
10936
|
+
for index, const in enumerate(constants, start=1):
|
|
10937
|
+
lines.append(self._emit_constant_comment(const, index))
|
|
10938
|
+
c_type = const.dtype.c_type
|
|
10939
|
+
array_suffix = self._array_suffix(const.shape)
|
|
10940
|
+
lines.append(f"{storage_prefix} {c_type} {const.name}{array_suffix};")
|
|
10941
|
+
lines.append("")
|
|
10942
|
+
if lines and not lines[-1]:
|
|
10943
|
+
lines.pop()
|
|
10944
|
+
return "\n".join(lines)
|
|
10945
|
+
|
|
10946
|
+
def collect_weight_data(
|
|
10947
|
+
self, constants: tuple[ConstTensor, ...]
|
|
10948
|
+
) -> bytes | None:
|
|
10949
|
+
_, large_constants = self._partition_constants(constants)
|
|
10950
|
+
if not large_constants:
|
|
10951
|
+
return None
|
|
10952
|
+
chunks: list[bytes] = []
|
|
10953
|
+
for const in large_constants:
|
|
10954
|
+
array = np.asarray(const.data, dtype=const.dtype.np_dtype)
|
|
10955
|
+
chunks.append(array.tobytes(order="C"))
|
|
10956
|
+
return b"".join(chunks)
|
|
10957
|
+
|
|
8875
10958
|
@staticmethod
|
|
8876
10959
|
def _index_expr(shape: tuple[int, ...], loop_vars: tuple[str, ...]) -> str:
|
|
8877
10960
|
shape = CEmitter._codegen_shape(shape)
|
|
@@ -8886,6 +10969,10 @@ class CEmitter:
|
|
|
8886
10969
|
|
|
8887
10970
|
@staticmethod
|
|
8888
10971
|
def _format_float(value: float) -> str:
|
|
10972
|
+
if math.isnan(value):
|
|
10973
|
+
return "NAN"
|
|
10974
|
+
if math.isinf(value):
|
|
10975
|
+
return "-INFINITY" if value < 0 else "INFINITY"
|
|
8889
10976
|
formatted = f"{value:.9g}"
|
|
8890
10977
|
if "e" not in formatted and "E" not in formatted and "." not in formatted:
|
|
8891
10978
|
formatted = f"{formatted}.0"
|
|
@@ -8897,11 +10984,57 @@ class CEmitter:
|
|
|
8897
10984
|
|
|
8898
10985
|
@staticmethod
|
|
8899
10986
|
def _format_double(value: float) -> str:
|
|
10987
|
+
if math.isnan(value):
|
|
10988
|
+
return "NAN"
|
|
10989
|
+
if math.isinf(value):
|
|
10990
|
+
return "-INFINITY" if value < 0 else "INFINITY"
|
|
8900
10991
|
formatted = f"{value:.17g}"
|
|
8901
10992
|
if "e" not in formatted and "E" not in formatted and "." not in formatted:
|
|
8902
10993
|
formatted = f"{formatted}.0"
|
|
8903
10994
|
return formatted
|
|
8904
10995
|
|
|
10996
|
+
@staticmethod
|
|
10997
|
+
def _format_float32_hex(value: float) -> str:
|
|
10998
|
+
bits = struct.unpack("<I", struct.pack("<f", float(value)))[0]
|
|
10999
|
+
sign = "-" if (bits >> 31) else ""
|
|
11000
|
+
exponent = (bits >> 23) & 0xFF
|
|
11001
|
+
mantissa = bits & 0x7FFFFF
|
|
11002
|
+
if exponent == 0 and mantissa == 0:
|
|
11003
|
+
return f"{sign}0x0.0p+0"
|
|
11004
|
+
if exponent == 0xFF:
|
|
11005
|
+
if mantissa == 0:
|
|
11006
|
+
return f"{sign}INFINITY"
|
|
11007
|
+
return "NAN"
|
|
11008
|
+
if exponent == 0:
|
|
11009
|
+
shift = mantissa.bit_length() - 1
|
|
11010
|
+
exponent_val = shift - 149
|
|
11011
|
+
fraction = (mantissa - (1 << shift)) << (24 - shift)
|
|
11012
|
+
else:
|
|
11013
|
+
exponent_val = exponent - 127
|
|
11014
|
+
fraction = mantissa << 1
|
|
11015
|
+
return f"{sign}0x1.{fraction:06x}p{exponent_val:+d}"
|
|
11016
|
+
|
|
11017
|
+
@staticmethod
|
|
11018
|
+
def _format_float64_hex(value: float) -> str:
|
|
11019
|
+
bits = struct.unpack("<Q", struct.pack("<d", float(value)))[0]
|
|
11020
|
+
sign = "-" if (bits >> 63) else ""
|
|
11021
|
+
exponent = (bits >> 52) & 0x7FF
|
|
11022
|
+
mantissa = bits & 0xFFFFFFFFFFFFF
|
|
11023
|
+
if exponent == 0 and mantissa == 0:
|
|
11024
|
+
return f"{sign}0x0.0p+0"
|
|
11025
|
+
if exponent == 0x7FF:
|
|
11026
|
+
if mantissa == 0:
|
|
11027
|
+
return f"{sign}INFINITY"
|
|
11028
|
+
return "NAN"
|
|
11029
|
+
if exponent == 0:
|
|
11030
|
+
shift = mantissa.bit_length() - 1
|
|
11031
|
+
exponent_val = shift - 1074
|
|
11032
|
+
fraction = (mantissa - (1 << shift)) << (52 - shift)
|
|
11033
|
+
else:
|
|
11034
|
+
exponent_val = exponent - 1023
|
|
11035
|
+
fraction = mantissa
|
|
11036
|
+
return f"{sign}0x1.{fraction:013x}p{exponent_val:+d}"
|
|
11037
|
+
|
|
8905
11038
|
@staticmethod
|
|
8906
11039
|
def _format_floating(value: float, dtype: ScalarType) -> str:
|
|
8907
11040
|
if dtype == ScalarType.F64:
|
|
@@ -8992,14 +11125,139 @@ class CEmitter:
|
|
|
8992
11125
|
return self._format_int(int(value), 8, "INT8_MIN")
|
|
8993
11126
|
raise CodegenError(f"Unsupported dtype {dtype.onnx_name}")
|
|
8994
11127
|
|
|
11128
|
+
def _format_weight_value(
|
|
11129
|
+
self, value: float | int | bool, dtype: ScalarType
|
|
11130
|
+
) -> str:
|
|
11131
|
+
if dtype == ScalarType.F16:
|
|
11132
|
+
formatted = self._format_float32_hex(float(value))
|
|
11133
|
+
if formatted == "NAN" or formatted.endswith("INFINITY"):
|
|
11134
|
+
return f"(_Float16){formatted}"
|
|
11135
|
+
return f"(_Float16){formatted}f"
|
|
11136
|
+
if dtype == ScalarType.F32:
|
|
11137
|
+
formatted = self._format_float32_hex(float(value))
|
|
11138
|
+
if formatted == "NAN" or formatted.endswith("INFINITY"):
|
|
11139
|
+
return formatted
|
|
11140
|
+
return f"{formatted}f"
|
|
11141
|
+
if dtype == ScalarType.F64:
|
|
11142
|
+
return self._format_float64_hex(float(value))
|
|
11143
|
+
if dtype == ScalarType.BOOL:
|
|
11144
|
+
return "true" if bool(value) else "false"
|
|
11145
|
+
if dtype == ScalarType.U64:
|
|
11146
|
+
return self._format_uint(int(value), 64, "UINT64_MAX")
|
|
11147
|
+
if dtype == ScalarType.U32:
|
|
11148
|
+
return self._format_uint(int(value), 32, "UINT32_MAX")
|
|
11149
|
+
if dtype == ScalarType.U16:
|
|
11150
|
+
return self._format_uint(int(value), 16, "UINT16_MAX")
|
|
11151
|
+
if dtype == ScalarType.U8:
|
|
11152
|
+
return self._format_uint(int(value), 8, "UINT8_MAX")
|
|
11153
|
+
if dtype == ScalarType.I64:
|
|
11154
|
+
return self._format_int64(int(value))
|
|
11155
|
+
if dtype == ScalarType.I32:
|
|
11156
|
+
return self._format_int(int(value), 32, "INT32_MIN")
|
|
11157
|
+
if dtype == ScalarType.I16:
|
|
11158
|
+
return self._format_int(int(value), 16, "INT16_MIN")
|
|
11159
|
+
if dtype == ScalarType.I8:
|
|
11160
|
+
return self._format_int(int(value), 8, "INT8_MIN")
|
|
11161
|
+
raise CodegenError(f"Unsupported dtype {dtype.onnx_name}")
|
|
11162
|
+
|
|
11163
|
+
@staticmethod
|
|
11164
|
+
def _emit_initializer_lines(
|
|
11165
|
+
values: Sequence[str],
|
|
11166
|
+
shape: tuple[int, ...],
|
|
11167
|
+
indent: str = " ",
|
|
11168
|
+
per_line: int = 8,
|
|
11169
|
+
) -> list[str]:
|
|
11170
|
+
if len(shape) == 1:
|
|
11171
|
+
lines: list[str] = []
|
|
11172
|
+
for index in range(0, len(values), per_line):
|
|
11173
|
+
chunk = ", ".join(values[index : index + per_line])
|
|
11174
|
+
lines.append(f"{indent}{chunk},")
|
|
11175
|
+
if lines:
|
|
11176
|
+
lines[-1] = lines[-1].rstrip(",")
|
|
11177
|
+
return lines
|
|
11178
|
+
sub_shape = shape[1:]
|
|
11179
|
+
sub_size = prod(sub_shape)
|
|
11180
|
+
lines = []
|
|
11181
|
+
for index in range(shape[0]):
|
|
11182
|
+
start = index * sub_size
|
|
11183
|
+
end = start + sub_size
|
|
11184
|
+
lines.append(f"{indent}{{")
|
|
11185
|
+
lines.extend(
|
|
11186
|
+
CEmitter._emit_initializer_lines(
|
|
11187
|
+
values[start:end],
|
|
11188
|
+
sub_shape,
|
|
11189
|
+
indent + " ",
|
|
11190
|
+
per_line,
|
|
11191
|
+
)
|
|
11192
|
+
)
|
|
11193
|
+
lines.append(f"{indent}}},")
|
|
11194
|
+
if lines:
|
|
11195
|
+
lines[-1] = lines[-1].rstrip(",")
|
|
11196
|
+
return lines
|
|
11197
|
+
|
|
11198
|
+
@staticmethod
|
|
11199
|
+
def _emit_initializer_lines_truncated(
|
|
11200
|
+
values: Sequence[str],
|
|
11201
|
+
shape: tuple[int, ...],
|
|
11202
|
+
truncate_after: int,
|
|
11203
|
+
indent: str = " ",
|
|
11204
|
+
per_line: int = 8,
|
|
11205
|
+
start_index: int = 0,
|
|
11206
|
+
emitted: int = 0,
|
|
11207
|
+
) -> tuple[list[str], int, int, bool]:
|
|
11208
|
+
if len(shape) == 1:
|
|
11209
|
+
items: list[str] = []
|
|
11210
|
+
truncated = False
|
|
11211
|
+
index = start_index
|
|
11212
|
+
for _ in range(shape[0]):
|
|
11213
|
+
if emitted >= truncate_after:
|
|
11214
|
+
items.append("...")
|
|
11215
|
+
truncated = True
|
|
11216
|
+
break
|
|
11217
|
+
items.append(values[index])
|
|
11218
|
+
index += 1
|
|
11219
|
+
emitted += 1
|
|
11220
|
+
lines: list[str] = []
|
|
11221
|
+
for item_index in range(0, len(items), per_line):
|
|
11222
|
+
chunk = ", ".join(items[item_index : item_index + per_line])
|
|
11223
|
+
lines.append(f"{indent}{chunk},")
|
|
11224
|
+
if lines:
|
|
11225
|
+
lines[-1] = lines[-1].rstrip(",")
|
|
11226
|
+
return lines, index, emitted, truncated
|
|
11227
|
+
sub_shape = shape[1:]
|
|
11228
|
+
lines: list[str] = []
|
|
11229
|
+
index = start_index
|
|
11230
|
+
truncated = False
|
|
11231
|
+
for _ in range(shape[0]):
|
|
11232
|
+
lines.append(f"{indent}{{")
|
|
11233
|
+
sub_lines, index, emitted, sub_truncated = (
|
|
11234
|
+
CEmitter._emit_initializer_lines_truncated(
|
|
11235
|
+
values,
|
|
11236
|
+
sub_shape,
|
|
11237
|
+
truncate_after,
|
|
11238
|
+
indent + " ",
|
|
11239
|
+
per_line,
|
|
11240
|
+
index,
|
|
11241
|
+
emitted,
|
|
11242
|
+
)
|
|
11243
|
+
)
|
|
11244
|
+
lines.extend(sub_lines)
|
|
11245
|
+
lines.append(f"{indent}}},")
|
|
11246
|
+
if sub_truncated:
|
|
11247
|
+
truncated = True
|
|
11248
|
+
break
|
|
11249
|
+
if lines:
|
|
11250
|
+
lines[-1] = lines[-1].rstrip(",")
|
|
11251
|
+
return lines, index, emitted, truncated
|
|
11252
|
+
|
|
8995
11253
|
@staticmethod
|
|
8996
11254
|
def _print_format(dtype: ScalarType) -> str:
|
|
8997
11255
|
if dtype == ScalarType.F16:
|
|
8998
|
-
return "
|
|
11256
|
+
return "\\\"%a\\\""
|
|
8999
11257
|
if dtype == ScalarType.F32:
|
|
9000
|
-
return "
|
|
11258
|
+
return "\\\"%a\\\""
|
|
9001
11259
|
if dtype == ScalarType.F64:
|
|
9002
|
-
return "
|
|
11260
|
+
return "\\\"%a\\\""
|
|
9003
11261
|
if dtype == ScalarType.BOOL:
|
|
9004
11262
|
return "%d"
|
|
9005
11263
|
if dtype == ScalarType.U64:
|