emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from enum import Enum
4
5
  import itertools
6
+ import math
7
+ from math import prod
5
8
  from pathlib import Path
6
9
  import re
10
+ import struct
7
11
  from typing import Mapping, Sequence
8
12
 
9
13
  from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
14
+ import numpy as np
10
15
 
11
16
  from ..errors import CodegenError
12
17
  from ..ops import (
@@ -24,6 +29,38 @@ from shared.scalar_types import ScalarFunctionError, ScalarType
24
29
 
25
30
 
26
31
  def _format_c_indentation(source: str, *, indent: str = " ") -> str:
32
+ def strip_string_literals(line: str) -> str:
33
+ sanitized: list[str] = []
34
+ in_string = False
35
+ in_char = False
36
+ escape = False
37
+ for char in line:
38
+ if escape:
39
+ escape = False
40
+ if not (in_string or in_char):
41
+ sanitized.append(char)
42
+ continue
43
+ if in_string:
44
+ if char == "\\":
45
+ escape = True
46
+ elif char == '"':
47
+ in_string = False
48
+ continue
49
+ if in_char:
50
+ if char == "\\":
51
+ escape = True
52
+ elif char == "'":
53
+ in_char = False
54
+ continue
55
+ if char == '"':
56
+ in_string = True
57
+ continue
58
+ if char == "'":
59
+ in_char = True
60
+ continue
61
+ sanitized.append(char)
62
+ return "".join(sanitized)
63
+
27
64
  formatted_lines: list[str] = []
28
65
  indent_level = 0
29
66
  for line in source.splitlines():
@@ -34,8 +71,9 @@ def _format_c_indentation(source: str, *, indent: str = " ") -> str:
34
71
  if stripped.startswith("}"):
35
72
  indent_level = max(indent_level - 1, 0)
36
73
  formatted_lines.append(f"{indent * indent_level}{stripped}")
37
- open_count = stripped.count("{")
38
- close_count = stripped.count("}")
74
+ sanitized = strip_string_literals(stripped)
75
+ open_count = sanitized.count("{")
76
+ close_count = sanitized.count("}")
39
77
  if stripped.startswith("}"):
40
78
  close_count = max(close_count - 1, 0)
41
79
  indent_level += open_count - close_count
@@ -119,6 +157,8 @@ class BinaryOp:
119
157
  output: str
120
158
  function: ScalarFunction
121
159
  operator_kind: OperatorKind
160
+ input0_shape: tuple[int, ...]
161
+ input1_shape: tuple[int, ...]
122
162
  shape: tuple[int, ...]
123
163
  dtype: ScalarType
124
164
  input_dtype: ScalarType
@@ -211,6 +251,26 @@ class MatMulOp:
211
251
  dtype: ScalarType
212
252
 
213
253
 
254
+ class EinsumKind(str, Enum):
255
+ REDUCE_ALL = "reduce_all"
256
+ SUM_J = "sum_j"
257
+ TRANSPOSE = "transpose"
258
+ DOT = "dot"
259
+ BATCH_MATMUL = "batch_matmul"
260
+ BATCH_DIAGONAL = "batch_diagonal"
261
+
262
+
263
+ @dataclass(frozen=True)
264
+ class EinsumOp:
265
+ inputs: tuple[str, ...]
266
+ output: str
267
+ kind: EinsumKind
268
+ input_shapes: tuple[tuple[int, ...], ...]
269
+ output_shape: tuple[int, ...]
270
+ dtype: ScalarType
271
+ input_dtype: ScalarType
272
+
273
+
214
274
  @dataclass(frozen=True)
215
275
  class GemmOp:
216
276
  input_a: str
@@ -305,6 +365,27 @@ class ConvOp:
305
365
  return self.out_spatial[1]
306
366
 
307
367
 
368
+ @dataclass(frozen=True)
369
+ class ConvTransposeOp:
370
+ input0: str
371
+ weights: str
372
+ bias: str | None
373
+ output: str
374
+ batch: int
375
+ in_channels: int
376
+ out_channels: int
377
+ spatial_rank: int
378
+ in_spatial: tuple[int, ...]
379
+ out_spatial: tuple[int, ...]
380
+ kernel_shape: tuple[int, ...]
381
+ strides: tuple[int, ...]
382
+ pads: tuple[int, ...]
383
+ dilations: tuple[int, ...]
384
+ output_padding: tuple[int, ...]
385
+ group: int
386
+ dtype: ScalarType
387
+
388
+
308
389
  @dataclass(frozen=True)
309
390
  class AveragePoolOp:
310
391
  input0: str
@@ -327,6 +408,41 @@ class AveragePoolOp:
327
408
  dtype: ScalarType
328
409
 
329
410
 
411
+ @dataclass(frozen=True)
412
+ class LpPoolOp:
413
+ input0: str
414
+ output: str
415
+ batch: int
416
+ channels: int
417
+ in_h: int
418
+ in_w: int
419
+ out_h: int
420
+ out_w: int
421
+ kernel_h: int
422
+ kernel_w: int
423
+ stride_h: int
424
+ stride_w: int
425
+ pad_top: int
426
+ pad_left: int
427
+ pad_bottom: int
428
+ pad_right: int
429
+ p: int
430
+ dtype: ScalarType
431
+
432
+
433
+ @dataclass(frozen=True)
434
+ class QuantizeLinearOp:
435
+ input0: str
436
+ scale: str
437
+ zero_point: str | None
438
+ output: str
439
+ input_shape: tuple[int, ...]
440
+ axis: int | None
441
+ dtype: ScalarType
442
+ input_dtype: ScalarType
443
+ scale_dtype: ScalarType
444
+
445
+
330
446
  @dataclass(frozen=True)
331
447
  class SoftmaxOp:
332
448
  input0: str
@@ -351,6 +467,18 @@ class LogSoftmaxOp:
351
467
  dtype: ScalarType
352
468
 
353
469
 
470
+ @dataclass(frozen=True)
471
+ class HardmaxOp:
472
+ input0: str
473
+ output: str
474
+ outer: int
475
+ axis_size: int
476
+ inner: int
477
+ axis: int
478
+ shape: tuple[int, ...]
479
+ dtype: ScalarType
480
+
481
+
354
482
  @dataclass(frozen=True)
355
483
  class NegativeLogLikelihoodLossOp:
356
484
  input0: str
@@ -595,6 +723,34 @@ class GatherOp:
595
723
  indices_dtype: ScalarType
596
724
 
597
725
 
726
+ @dataclass(frozen=True)
727
+ class GatherNDOp:
728
+ data: str
729
+ indices: str
730
+ output: str
731
+ batch_dims: int
732
+ data_shape: tuple[int, ...]
733
+ indices_shape: tuple[int, ...]
734
+ output_shape: tuple[int, ...]
735
+ dtype: ScalarType
736
+ indices_dtype: ScalarType
737
+
738
+
739
+ @dataclass(frozen=True)
740
+ class ScatterNDOp:
741
+ data: str
742
+ indices: str
743
+ updates: str
744
+ output: str
745
+ data_shape: tuple[int, ...]
746
+ indices_shape: tuple[int, ...]
747
+ updates_shape: tuple[int, ...]
748
+ output_shape: tuple[int, ...]
749
+ reduction: str
750
+ dtype: ScalarType
751
+ indices_dtype: ScalarType
752
+
753
+
598
754
  @dataclass(frozen=True)
599
755
  class TransposeOp:
600
756
  input0: str
@@ -635,6 +791,21 @@ class EyeLikeOp:
635
791
  input_dtype: ScalarType
636
792
 
637
793
 
794
+ @dataclass(frozen=True)
795
+ class TriluOp:
796
+ input0: str
797
+ output: str
798
+ input_shape: tuple[int, ...]
799
+ output_shape: tuple[int, ...]
800
+ upper: bool
801
+ k_value: int
802
+ k_input: str | None
803
+ k_input_shape: tuple[int, ...] | None
804
+ k_input_dtype: ScalarType | None
805
+ dtype: ScalarType
806
+ input_dtype: ScalarType
807
+
808
+
638
809
  @dataclass(frozen=True)
639
810
  class TileOp:
640
811
  input0: str
@@ -800,6 +971,22 @@ class ArgReduceOp:
800
971
  output_dtype: ScalarType
801
972
 
802
973
 
974
+ @dataclass(frozen=True)
975
+ class TopKOp:
976
+ input0: str
977
+ output_values: str
978
+ output_indices: str
979
+ input_shape: tuple[int, ...]
980
+ output_shape: tuple[int, ...]
981
+ axis: int
982
+ k: int
983
+ largest: bool
984
+ sorted: bool
985
+ input_dtype: ScalarType
986
+ output_values_dtype: ScalarType
987
+ output_indices_dtype: ScalarType
988
+
989
+
803
990
  @dataclass(frozen=True)
804
991
  class ConstantOfShapeOp:
805
992
  input0: str
@@ -833,6 +1020,16 @@ class SizeOp:
833
1020
  input_dtype: ScalarType
834
1021
 
835
1022
 
1023
+ @dataclass(frozen=True)
1024
+ class NonZeroOp:
1025
+ input0: str
1026
+ output: str
1027
+ input_shape: tuple[int, ...]
1028
+ output_shape: tuple[int, ...]
1029
+ dtype: ScalarType
1030
+ input_dtype: ScalarType
1031
+
1032
+
836
1033
  @dataclass(frozen=True)
837
1034
  class ExpandOp:
838
1035
  input0: str
@@ -871,6 +1068,22 @@ class RangeOp:
871
1068
  input_dtype: ScalarType
872
1069
 
873
1070
 
1071
+ @dataclass(frozen=True)
1072
+ class OneHotOp:
1073
+ indices: str
1074
+ depth: str
1075
+ values: str
1076
+ output: str
1077
+ axis: int
1078
+ indices_shape: tuple[int, ...]
1079
+ values_shape: tuple[int, ...]
1080
+ output_shape: tuple[int, ...]
1081
+ depth_dim: int
1082
+ dtype: ScalarType
1083
+ indices_dtype: ScalarType
1084
+ depth_dtype: ScalarType
1085
+
1086
+
874
1087
  @dataclass(frozen=True)
875
1088
  class SplitOp:
876
1089
  input0: str
@@ -937,11 +1150,15 @@ class LoweredModel:
937
1150
  | UnaryOp
938
1151
  | ClipOp
939
1152
  | CastOp
1153
+ | QuantizeLinearOp
940
1154
  | MatMulOp
1155
+ | EinsumOp
941
1156
  | GemmOp
942
1157
  | AttentionOp
943
1158
  | ConvOp
1159
+ | ConvTransposeOp
944
1160
  | AveragePoolOp
1161
+ | LpPoolOp
945
1162
  | BatchNormOp
946
1163
  | LpNormalizationOp
947
1164
  | InstanceNormalizationOp
@@ -953,16 +1170,20 @@ class LoweredModel:
953
1170
  | LstmOp
954
1171
  | SoftmaxOp
955
1172
  | LogSoftmaxOp
1173
+ | HardmaxOp
956
1174
  | NegativeLogLikelihoodLossOp
957
1175
  | SoftmaxCrossEntropyLossOp
958
1176
  | MaxPoolOp
959
1177
  | ConcatOp
960
1178
  | GatherElementsOp
961
1179
  | GatherOp
1180
+ | GatherNDOp
1181
+ | ScatterNDOp
962
1182
  | TransposeOp
963
1183
  | ReshapeOp
964
1184
  | IdentityOp
965
1185
  | EyeLikeOp
1186
+ | TriluOp
966
1187
  | TileOp
967
1188
  | PadOp
968
1189
  | DepthToSpaceOp
@@ -972,12 +1193,15 @@ class LoweredModel:
972
1193
  | GridSampleOp
973
1194
  | ReduceOp
974
1195
  | ArgReduceOp
1196
+ | TopKOp
975
1197
  | ConstantOfShapeOp
976
1198
  | ShapeOp
977
1199
  | SizeOp
1200
+ | NonZeroOp
978
1201
  | ExpandOp
979
1202
  | CumSumOp
980
1203
  | RangeOp
1204
+ | OneHotOp
981
1205
  | SplitOp,
982
1206
  ...,
983
1207
  ]
@@ -986,7 +1210,15 @@ class LoweredModel:
986
1210
 
987
1211
 
988
1212
  class CEmitter:
989
- def __init__(self, template_dir: Path, *, restrict_arrays: bool = True) -> None:
1213
+ def __init__(
1214
+ self,
1215
+ template_dir: Path,
1216
+ *,
1217
+ restrict_arrays: bool = True,
1218
+ truncate_weights_after: int | None = None,
1219
+ large_temp_threshold_bytes: int = 1024,
1220
+ large_weight_threshold: int = 1024,
1221
+ ) -> None:
990
1222
  self._env = Environment(
991
1223
  loader=FileSystemLoader(str(template_dir)),
992
1224
  autoescape=select_autoescape(enabled_extensions=()),
@@ -994,6 +1226,15 @@ class CEmitter:
994
1226
  lstrip_blocks=True,
995
1227
  )
996
1228
  self._restrict_arrays = restrict_arrays
1229
+ if truncate_weights_after is not None and truncate_weights_after < 1:
1230
+ raise CodegenError("truncate_weights_after must be >= 1")
1231
+ self._truncate_weights_after = truncate_weights_after
1232
+ if large_temp_threshold_bytes < 0:
1233
+ raise CodegenError("large_temp_threshold_bytes must be >= 0")
1234
+ self._large_temp_threshold_bytes = large_temp_threshold_bytes
1235
+ if large_weight_threshold < 0:
1236
+ raise CodegenError("large_weight_threshold must be >= 0")
1237
+ self._large_weight_threshold = large_weight_threshold
997
1238
 
998
1239
  @staticmethod
999
1240
  def _sanitize_identifier(name: str) -> str:
@@ -1006,10 +1247,8 @@ class CEmitter:
1006
1247
 
1007
1248
  def _op_function_name(self, model: LoweredModel, index: int) -> str:
1008
1249
  node_info = model.node_infos[index]
1009
- parts = [f"node{index}", node_info.op_type]
1010
- if node_info.name:
1011
- parts.append(node_info.name)
1012
- base_name = "_".join(parts)
1250
+ suffix = node_info.name or node_info.op_type
1251
+ base_name = f"node{index}_{suffix}".lower()
1013
1252
  return self._sanitize_identifier(base_name)
1014
1253
 
1015
1254
  @staticmethod
@@ -1094,7 +1333,9 @@ class CEmitter:
1094
1333
  | UnaryOp
1095
1334
  | ClipOp
1096
1335
  | CastOp
1336
+ | QuantizeLinearOp
1097
1337
  | MatMulOp
1338
+ | EinsumOp
1098
1339
  | GemmOp
1099
1340
  | AttentionOp
1100
1341
  | ConvOp
@@ -1110,16 +1351,20 @@ class CEmitter:
1110
1351
  | LstmOp
1111
1352
  | SoftmaxOp
1112
1353
  | LogSoftmaxOp
1354
+ | HardmaxOp
1113
1355
  | NegativeLogLikelihoodLossOp
1114
1356
  | SoftmaxCrossEntropyLossOp
1115
1357
  | MaxPoolOp
1116
1358
  | ConcatOp
1117
1359
  | GatherElementsOp
1118
1360
  | GatherOp
1361
+ | GatherNDOp
1362
+ | ScatterNDOp
1119
1363
  | TransposeOp
1120
1364
  | ReshapeOp
1121
1365
  | IdentityOp
1122
1366
  | EyeLikeOp
1367
+ | TriluOp
1123
1368
  | TileOp
1124
1369
  | PadOp
1125
1370
  | DepthToSpaceOp
@@ -1129,12 +1374,15 @@ class CEmitter:
1129
1374
  | GridSampleOp
1130
1375
  | ReduceOp
1131
1376
  | ArgReduceOp
1377
+ | TopKOp
1132
1378
  | ConstantOfShapeOp
1133
1379
  | ShapeOp
1134
1380
  | SizeOp
1381
+ | NonZeroOp
1135
1382
  | ExpandOp
1136
1383
  | CumSumOp
1137
1384
  | RangeOp
1385
+ | OneHotOp
1138
1386
  | SplitOp,
1139
1387
  ) -> tuple[str, ...]:
1140
1388
  if isinstance(op, BinaryOp):
@@ -1155,8 +1403,16 @@ class CEmitter:
1155
1403
  return tuple(names)
1156
1404
  if isinstance(op, CastOp):
1157
1405
  return (op.input0, op.output)
1406
+ if isinstance(op, QuantizeLinearOp):
1407
+ names = [op.input0, op.scale]
1408
+ if op.zero_point is not None:
1409
+ names.append(op.zero_point)
1410
+ names.append(op.output)
1411
+ return tuple(names)
1158
1412
  if isinstance(op, MatMulOp):
1159
1413
  return (op.input0, op.input1, op.output)
1414
+ if isinstance(op, EinsumOp):
1415
+ return (*op.inputs, op.output)
1160
1416
  if isinstance(op, GemmOp):
1161
1417
  names = [op.input_a, op.input_b]
1162
1418
  if op.input_c is not None:
@@ -1187,8 +1443,16 @@ class CEmitter:
1187
1443
  names.append(op.bias)
1188
1444
  names.append(op.output)
1189
1445
  return tuple(names)
1446
+ if isinstance(op, ConvTransposeOp):
1447
+ names = [op.input0, op.weights]
1448
+ if op.bias is not None:
1449
+ names.append(op.bias)
1450
+ names.append(op.output)
1451
+ return tuple(names)
1190
1452
  if isinstance(op, AveragePoolOp):
1191
1453
  return (op.input0, op.output)
1454
+ if isinstance(op, LpPoolOp):
1455
+ return (op.input0, op.output)
1192
1456
  if isinstance(op, BatchNormOp):
1193
1457
  return (op.input0, op.scale, op.bias, op.mean, op.variance, op.output)
1194
1458
  if isinstance(op, LpNormalizationOp):
@@ -1230,7 +1494,7 @@ class CEmitter:
1230
1494
  if op.output_y_c is not None:
1231
1495
  names.append(op.output_y_c)
1232
1496
  return tuple(names)
1233
- if isinstance(op, (SoftmaxOp, LogSoftmaxOp)):
1497
+ if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
1234
1498
  return (op.input0, op.output)
1235
1499
  if isinstance(op, NegativeLogLikelihoodLossOp):
1236
1500
  names = [op.input0, op.target]
@@ -1255,6 +1519,10 @@ class CEmitter:
1255
1519
  return (op.data, op.indices, op.output)
1256
1520
  if isinstance(op, GatherOp):
1257
1521
  return (op.data, op.indices, op.output)
1522
+ if isinstance(op, GatherNDOp):
1523
+ return (op.data, op.indices, op.output)
1524
+ if isinstance(op, ScatterNDOp):
1525
+ return (op.data, op.indices, op.updates, op.output)
1258
1526
  if isinstance(op, ConcatOp):
1259
1527
  return (*op.inputs, op.output)
1260
1528
  if isinstance(op, ConstantOfShapeOp):
@@ -1263,6 +1531,8 @@ class CEmitter:
1263
1531
  return (op.input0, op.output)
1264
1532
  if isinstance(op, SizeOp):
1265
1533
  return (op.input0, op.output)
1534
+ if isinstance(op, NonZeroOp):
1535
+ return (op.input0, op.output)
1266
1536
  if isinstance(op, ExpandOp):
1267
1537
  return (op.input0, op.output)
1268
1538
  if isinstance(op, CumSumOp):
@@ -1273,6 +1543,8 @@ class CEmitter:
1273
1543
  return tuple(names)
1274
1544
  if isinstance(op, RangeOp):
1275
1545
  return (op.start, op.limit, op.delta, op.output)
1546
+ if isinstance(op, OneHotOp):
1547
+ return (op.indices, op.depth, op.values, op.output)
1276
1548
  if isinstance(op, SplitOp):
1277
1549
  return (op.input0, *op.outputs)
1278
1550
  if isinstance(op, ReshapeOp):
@@ -1281,6 +1553,12 @@ class CEmitter:
1281
1553
  return (op.input0, op.output)
1282
1554
  if isinstance(op, EyeLikeOp):
1283
1555
  return (op.input0, op.output)
1556
+ if isinstance(op, TriluOp):
1557
+ names = [op.input0]
1558
+ if op.k_input is not None:
1559
+ names.append(op.k_input)
1560
+ names.append(op.output)
1561
+ return tuple(names)
1284
1562
  if isinstance(op, TileOp):
1285
1563
  return (op.input0, op.output)
1286
1564
  if isinstance(op, PadOp):
@@ -1320,6 +1598,8 @@ class CEmitter:
1320
1598
  return tuple(names)
1321
1599
  if isinstance(op, GridSampleOp):
1322
1600
  return (op.input0, op.grid, op.output)
1601
+ if isinstance(op, TopKOp):
1602
+ return (op.input0, op.output_values, op.output_indices)
1323
1603
  if isinstance(op, ReduceOp):
1324
1604
  names = [op.input0]
1325
1605
  if op.axes_input is not None:
@@ -1331,12 +1611,14 @@ class CEmitter:
1331
1611
  def _build_name_map(self, model: LoweredModel) -> dict[str, str]:
1332
1612
  used: set[str] = set()
1333
1613
  name_map: dict[str, str] = {}
1614
+ constant_names = {const.name for const in model.constants}
1334
1615
  names = [model.name]
1335
1616
  names.extend(model.input_names)
1336
1617
  names.extend(model.output_names)
1337
- names.extend(const.name for const in model.constants)
1338
1618
  for op in model.ops:
1339
- names.extend(self._op_names(op))
1619
+ names.extend(
1620
+ name for name in self._op_names(op) if name not in constant_names
1621
+ )
1340
1622
  for name in names:
1341
1623
  if name in name_map:
1342
1624
  continue
@@ -1344,6 +1626,14 @@ class CEmitter:
1344
1626
  unique = self._ensure_unique_identifier(sanitized, used)
1345
1627
  name_map[name] = unique
1346
1628
  used.add(unique)
1629
+ for index, const in enumerate(model.constants, start=1):
1630
+ if const.name in name_map:
1631
+ continue
1632
+ base_name = self._sanitize_identifier(const.name.lower())
1633
+ weight_name = f"weight{index}_{base_name}"
1634
+ unique = self._ensure_unique_identifier(weight_name, used)
1635
+ name_map[const.name] = unique
1636
+ used.add(unique)
1347
1637
  return name_map
1348
1638
 
1349
1639
  @staticmethod
@@ -1362,11 +1652,15 @@ class CEmitter:
1362
1652
  | UnaryOp
1363
1653
  | ClipOp
1364
1654
  | CastOp
1655
+ | QuantizeLinearOp
1365
1656
  | MatMulOp
1657
+ | EinsumOp
1366
1658
  | GemmOp
1367
1659
  | AttentionOp
1368
1660
  | ConvOp
1661
+ | ConvTransposeOp
1369
1662
  | AveragePoolOp
1663
+ | LpPoolOp
1370
1664
  | BatchNormOp
1371
1665
  | LpNormalizationOp
1372
1666
  | InstanceNormalizationOp
@@ -1378,16 +1672,20 @@ class CEmitter:
1378
1672
  | LstmOp
1379
1673
  | SoftmaxOp
1380
1674
  | LogSoftmaxOp
1675
+ | HardmaxOp
1381
1676
  | NegativeLogLikelihoodLossOp
1382
1677
  | SoftmaxCrossEntropyLossOp
1383
1678
  | MaxPoolOp
1384
1679
  | ConcatOp
1385
1680
  | GatherElementsOp
1386
1681
  | GatherOp
1682
+ | GatherNDOp
1683
+ | ScatterNDOp
1387
1684
  | TransposeOp
1388
1685
  | ReshapeOp
1389
1686
  | IdentityOp
1390
1687
  | EyeLikeOp
1688
+ | TriluOp
1391
1689
  | TileOp
1392
1690
  | PadOp
1393
1691
  | DepthToSpaceOp
@@ -1397,12 +1695,15 @@ class CEmitter:
1397
1695
  | GridSampleOp
1398
1696
  | ReduceOp
1399
1697
  | ArgReduceOp
1698
+ | TopKOp
1400
1699
  | ConstantOfShapeOp
1401
1700
  | ShapeOp
1402
1701
  | SizeOp
1702
+ | NonZeroOp
1403
1703
  | ExpandOp
1404
1704
  | CumSumOp
1405
1705
  | RangeOp
1706
+ | OneHotOp
1406
1707
  | SplitOp,
1407
1708
  name_map: dict[str, str],
1408
1709
  ) -> (
@@ -1412,11 +1713,15 @@ class CEmitter:
1412
1713
  | UnaryOp
1413
1714
  | ClipOp
1414
1715
  | CastOp
1716
+ | QuantizeLinearOp
1415
1717
  | MatMulOp
1718
+ | EinsumOp
1416
1719
  | GemmOp
1417
1720
  | AttentionOp
1418
1721
  | ConvOp
1722
+ | ConvTransposeOp
1419
1723
  | AveragePoolOp
1724
+ | LpPoolOp
1420
1725
  | BatchNormOp
1421
1726
  | LpNormalizationOp
1422
1727
  | InstanceNormalizationOp
@@ -1428,16 +1733,20 @@ class CEmitter:
1428
1733
  | LstmOp
1429
1734
  | SoftmaxOp
1430
1735
  | LogSoftmaxOp
1736
+ | HardmaxOp
1431
1737
  | NegativeLogLikelihoodLossOp
1432
1738
  | SoftmaxCrossEntropyLossOp
1433
1739
  | MaxPoolOp
1434
1740
  | ConcatOp
1435
1741
  | GatherElementsOp
1436
1742
  | GatherOp
1743
+ | GatherNDOp
1744
+ | ScatterNDOp
1437
1745
  | TransposeOp
1438
1746
  | ReshapeOp
1439
1747
  | IdentityOp
1440
1748
  | EyeLikeOp
1749
+ | TriluOp
1441
1750
  | TileOp
1442
1751
  | PadOp
1443
1752
  | DepthToSpaceOp
@@ -1447,12 +1756,15 @@ class CEmitter:
1447
1756
  | GridSampleOp
1448
1757
  | ReduceOp
1449
1758
  | ArgReduceOp
1759
+ | TopKOp
1450
1760
  | ConstantOfShapeOp
1451
1761
  | ShapeOp
1452
1762
  | SizeOp
1763
+ | NonZeroOp
1453
1764
  | ExpandOp
1454
1765
  | CumSumOp
1455
1766
  | RangeOp
1767
+ | OneHotOp
1456
1768
  | SplitOp
1457
1769
  ):
1458
1770
  if isinstance(op, BinaryOp):
@@ -1462,6 +1774,8 @@ class CEmitter:
1462
1774
  output=name_map.get(op.output, op.output),
1463
1775
  function=op.function,
1464
1776
  operator_kind=op.operator_kind,
1777
+ input0_shape=op.input0_shape,
1778
+ input1_shape=op.input1_shape,
1465
1779
  shape=op.shape,
1466
1780
  dtype=op.dtype,
1467
1781
  input_dtype=op.input_dtype,
@@ -1518,6 +1832,18 @@ class CEmitter:
1518
1832
  input_dtype=op.input_dtype,
1519
1833
  dtype=op.dtype,
1520
1834
  )
1835
+ if isinstance(op, QuantizeLinearOp):
1836
+ return QuantizeLinearOp(
1837
+ input0=name_map.get(op.input0, op.input0),
1838
+ scale=name_map.get(op.scale, op.scale),
1839
+ zero_point=self._map_optional_name(name_map, op.zero_point),
1840
+ output=name_map.get(op.output, op.output),
1841
+ input_shape=op.input_shape,
1842
+ axis=op.axis,
1843
+ dtype=op.dtype,
1844
+ input_dtype=op.input_dtype,
1845
+ scale_dtype=op.scale_dtype,
1846
+ )
1521
1847
  if isinstance(op, MatMulOp):
1522
1848
  return MatMulOp(
1523
1849
  input0=name_map.get(op.input0, op.input0),
@@ -1536,6 +1862,16 @@ class CEmitter:
1536
1862
  right_vector=op.right_vector,
1537
1863
  dtype=op.dtype,
1538
1864
  )
1865
+ if isinstance(op, EinsumOp):
1866
+ return EinsumOp(
1867
+ inputs=tuple(name_map.get(name, name) for name in op.inputs),
1868
+ output=name_map.get(op.output, op.output),
1869
+ kind=op.kind,
1870
+ input_shapes=op.input_shapes,
1871
+ output_shape=op.output_shape,
1872
+ dtype=op.dtype,
1873
+ input_dtype=op.input_dtype,
1874
+ )
1539
1875
  if isinstance(op, GemmOp):
1540
1876
  return GemmOp(
1541
1877
  input_a=name_map.get(op.input_a, op.input_a),
@@ -1629,6 +1965,26 @@ class CEmitter:
1629
1965
  group=op.group,
1630
1966
  dtype=op.dtype,
1631
1967
  )
1968
+ if isinstance(op, ConvTransposeOp):
1969
+ return ConvTransposeOp(
1970
+ input0=name_map.get(op.input0, op.input0),
1971
+ weights=name_map.get(op.weights, op.weights),
1972
+ bias=self._map_optional_name(name_map, op.bias),
1973
+ output=name_map.get(op.output, op.output),
1974
+ batch=op.batch,
1975
+ in_channels=op.in_channels,
1976
+ out_channels=op.out_channels,
1977
+ spatial_rank=op.spatial_rank,
1978
+ in_spatial=op.in_spatial,
1979
+ out_spatial=op.out_spatial,
1980
+ kernel_shape=op.kernel_shape,
1981
+ strides=op.strides,
1982
+ pads=op.pads,
1983
+ dilations=op.dilations,
1984
+ output_padding=op.output_padding,
1985
+ group=op.group,
1986
+ dtype=op.dtype,
1987
+ )
1632
1988
  if isinstance(op, AveragePoolOp):
1633
1989
  return AveragePoolOp(
1634
1990
  input0=name_map.get(op.input0, op.input0),
@@ -1650,6 +2006,27 @@ class CEmitter:
1650
2006
  count_include_pad=op.count_include_pad,
1651
2007
  dtype=op.dtype,
1652
2008
  )
2009
+ if isinstance(op, LpPoolOp):
2010
+ return LpPoolOp(
2011
+ input0=name_map.get(op.input0, op.input0),
2012
+ output=name_map.get(op.output, op.output),
2013
+ batch=op.batch,
2014
+ channels=op.channels,
2015
+ in_h=op.in_h,
2016
+ in_w=op.in_w,
2017
+ out_h=op.out_h,
2018
+ out_w=op.out_w,
2019
+ kernel_h=op.kernel_h,
2020
+ kernel_w=op.kernel_w,
2021
+ stride_h=op.stride_h,
2022
+ stride_w=op.stride_w,
2023
+ pad_top=op.pad_top,
2024
+ pad_left=op.pad_left,
2025
+ pad_bottom=op.pad_bottom,
2026
+ pad_right=op.pad_right,
2027
+ p=op.p,
2028
+ dtype=op.dtype,
2029
+ )
1653
2030
  if isinstance(op, BatchNormOp):
1654
2031
  return BatchNormOp(
1655
2032
  input0=name_map.get(op.input0, op.input0),
@@ -1813,6 +2190,17 @@ class CEmitter:
1813
2190
  shape=op.shape,
1814
2191
  dtype=op.dtype,
1815
2192
  )
2193
+ if isinstance(op, HardmaxOp):
2194
+ return HardmaxOp(
2195
+ input0=name_map.get(op.input0, op.input0),
2196
+ output=name_map.get(op.output, op.output),
2197
+ outer=op.outer,
2198
+ axis_size=op.axis_size,
2199
+ inner=op.inner,
2200
+ axis=op.axis,
2201
+ shape=op.shape,
2202
+ dtype=op.dtype,
2203
+ )
1816
2204
  if isinstance(op, NegativeLogLikelihoodLossOp):
1817
2205
  return NegativeLogLikelihoodLossOp(
1818
2206
  input0=name_map.get(op.input0, op.input0),
@@ -1909,6 +2297,32 @@ class CEmitter:
1909
2297
  dtype=op.dtype,
1910
2298
  indices_dtype=op.indices_dtype,
1911
2299
  )
2300
+ if isinstance(op, GatherNDOp):
2301
+ return GatherNDOp(
2302
+ data=name_map.get(op.data, op.data),
2303
+ indices=name_map.get(op.indices, op.indices),
2304
+ output=name_map.get(op.output, op.output),
2305
+ batch_dims=op.batch_dims,
2306
+ data_shape=op.data_shape,
2307
+ indices_shape=op.indices_shape,
2308
+ output_shape=op.output_shape,
2309
+ dtype=op.dtype,
2310
+ indices_dtype=op.indices_dtype,
2311
+ )
2312
+ if isinstance(op, ScatterNDOp):
2313
+ return ScatterNDOp(
2314
+ data=name_map.get(op.data, op.data),
2315
+ indices=name_map.get(op.indices, op.indices),
2316
+ updates=name_map.get(op.updates, op.updates),
2317
+ output=name_map.get(op.output, op.output),
2318
+ data_shape=op.data_shape,
2319
+ indices_shape=op.indices_shape,
2320
+ updates_shape=op.updates_shape,
2321
+ output_shape=op.output_shape,
2322
+ reduction=op.reduction,
2323
+ dtype=op.dtype,
2324
+ indices_dtype=op.indices_dtype,
2325
+ )
1912
2326
  if isinstance(op, TransposeOp):
1913
2327
  return TransposeOp(
1914
2328
  input0=name_map.get(op.input0, op.input0),
@@ -1945,6 +2359,20 @@ class CEmitter:
1945
2359
  dtype=op.dtype,
1946
2360
  input_dtype=op.input_dtype,
1947
2361
  )
2362
+ if isinstance(op, TriluOp):
2363
+ return TriluOp(
2364
+ input0=name_map.get(op.input0, op.input0),
2365
+ output=name_map.get(op.output, op.output),
2366
+ input_shape=op.input_shape,
2367
+ output_shape=op.output_shape,
2368
+ upper=op.upper,
2369
+ k_value=op.k_value,
2370
+ k_input=self._map_optional_name(name_map, op.k_input),
2371
+ k_input_shape=op.k_input_shape,
2372
+ k_input_dtype=op.k_input_dtype,
2373
+ dtype=op.dtype,
2374
+ input_dtype=op.input_dtype,
2375
+ )
1948
2376
  if isinstance(op, TileOp):
1949
2377
  return TileOp(
1950
2378
  input0=name_map.get(op.input0, op.input0),
@@ -2101,6 +2529,21 @@ class CEmitter:
2101
2529
  input_dtype=op.input_dtype,
2102
2530
  output_dtype=op.output_dtype,
2103
2531
  )
2532
+ if isinstance(op, TopKOp):
2533
+ return TopKOp(
2534
+ input0=name_map.get(op.input0, op.input0),
2535
+ output_values=name_map.get(op.output_values, op.output_values),
2536
+ output_indices=name_map.get(op.output_indices, op.output_indices),
2537
+ input_shape=op.input_shape,
2538
+ output_shape=op.output_shape,
2539
+ axis=op.axis,
2540
+ k=op.k,
2541
+ largest=op.largest,
2542
+ sorted=op.sorted,
2543
+ input_dtype=op.input_dtype,
2544
+ output_values_dtype=op.output_values_dtype,
2545
+ output_indices_dtype=op.output_indices_dtype,
2546
+ )
2104
2547
  if isinstance(op, ConstantOfShapeOp):
2105
2548
  return ConstantOfShapeOp(
2106
2549
  input0=name_map.get(op.input0, op.input0),
@@ -2131,6 +2574,15 @@ class CEmitter:
2131
2574
  dtype=op.dtype,
2132
2575
  input_dtype=op.input_dtype,
2133
2576
  )
2577
+ if isinstance(op, NonZeroOp):
2578
+ return NonZeroOp(
2579
+ input0=name_map.get(op.input0, op.input0),
2580
+ output=name_map.get(op.output, op.output),
2581
+ input_shape=op.input_shape,
2582
+ output_shape=op.output_shape,
2583
+ dtype=op.dtype,
2584
+ input_dtype=op.input_dtype,
2585
+ )
2134
2586
  if isinstance(op, ExpandOp):
2135
2587
  return ExpandOp(
2136
2588
  input0=name_map.get(op.input0, op.input0),
@@ -2166,6 +2618,21 @@ class CEmitter:
2166
2618
  dtype=op.dtype,
2167
2619
  input_dtype=op.input_dtype,
2168
2620
  )
2621
+ if isinstance(op, OneHotOp):
2622
+ return OneHotOp(
2623
+ indices=name_map.get(op.indices, op.indices),
2624
+ depth=name_map.get(op.depth, op.depth),
2625
+ values=name_map.get(op.values, op.values),
2626
+ output=name_map.get(op.output, op.output),
2627
+ axis=op.axis,
2628
+ indices_shape=op.indices_shape,
2629
+ values_shape=op.values_shape,
2630
+ output_shape=op.output_shape,
2631
+ depth_dim=op.depth_dim,
2632
+ dtype=op.dtype,
2633
+ indices_dtype=op.indices_dtype,
2634
+ depth_dtype=op.depth_dtype,
2635
+ )
2169
2636
  if isinstance(op, SplitOp):
2170
2637
  return SplitOp(
2171
2638
  input0=name_map.get(op.input0, op.input0),
@@ -2246,11 +2713,19 @@ class CEmitter:
2246
2713
  "unary": self._env.get_template("unary_op.c.j2"),
2247
2714
  "clip": self._env.get_template("clip_op.c.j2"),
2248
2715
  "cast": self._env.get_template("cast_op.c.j2"),
2716
+ "quantize_linear": self._env.get_template(
2717
+ "quantize_linear_op.c.j2"
2718
+ ),
2249
2719
  "matmul": self._env.get_template("matmul_op.c.j2"),
2720
+ "einsum": self._env.get_template("einsum_op.c.j2"),
2250
2721
  "gemm": self._env.get_template("gemm_op.c.j2"),
2251
2722
  "attention": self._env.get_template("attention_op.c.j2"),
2252
2723
  "conv": self._env.get_template("conv_op.c.j2"),
2724
+ "conv_transpose": self._env.get_template(
2725
+ "conv_transpose_op.c.j2"
2726
+ ),
2253
2727
  "avg_pool": self._env.get_template("average_pool_op.c.j2"),
2728
+ "lp_pool": self._env.get_template("lp_pool_op.c.j2"),
2254
2729
  "batch_norm": self._env.get_template("batch_norm_op.c.j2"),
2255
2730
  "lp_norm": self._env.get_template("lp_normalization_op.c.j2"),
2256
2731
  "instance_norm": self._env.get_template(
@@ -2270,6 +2745,7 @@ class CEmitter:
2270
2745
  "lstm": self._env.get_template("lstm_op.c.j2"),
2271
2746
  "softmax": self._env.get_template("softmax_op.c.j2"),
2272
2747
  "logsoftmax": self._env.get_template("logsoftmax_op.c.j2"),
2748
+ "hardmax": self._env.get_template("hardmax_op.c.j2"),
2273
2749
  "nllloss": self._env.get_template(
2274
2750
  "negative_log_likelihood_loss_op.c.j2"
2275
2751
  ),
@@ -2280,10 +2756,13 @@ class CEmitter:
2280
2756
  "concat": self._env.get_template("concat_op.c.j2"),
2281
2757
  "gather_elements": self._env.get_template("gather_elements_op.c.j2"),
2282
2758
  "gather": self._env.get_template("gather_op.c.j2"),
2759
+ "gather_nd": self._env.get_template("gather_nd_op.c.j2"),
2760
+ "scatter_nd": self._env.get_template("scatter_nd_op.c.j2"),
2283
2761
  "transpose": self._env.get_template("transpose_op.c.j2"),
2284
2762
  "reshape": self._env.get_template("reshape_op.c.j2"),
2285
2763
  "identity": self._env.get_template("identity_op.c.j2"),
2286
2764
  "eye_like": self._env.get_template("eye_like_op.c.j2"),
2765
+ "trilu": self._env.get_template("trilu_op.c.j2"),
2287
2766
  "tile": self._env.get_template("tile_op.c.j2"),
2288
2767
  "pad": self._env.get_template("pad_op.c.j2"),
2289
2768
  "depth_to_space": self._env.get_template("depth_to_space_op.c.j2"),
@@ -2299,14 +2778,17 @@ class CEmitter:
2299
2778
  "reduce_op_dynamic.c.j2"
2300
2779
  ),
2301
2780
  "arg_reduce": self._env.get_template("arg_reduce_op.c.j2"),
2781
+ "topk": self._env.get_template("topk_op.c.j2"),
2302
2782
  "constant_of_shape": self._env.get_template(
2303
2783
  "constant_of_shape_op.c.j2"
2304
2784
  ),
2305
2785
  "shape": self._env.get_template("shape_op.c.j2"),
2306
2786
  "size": self._env.get_template("size_op.c.j2"),
2787
+ "nonzero": self._env.get_template("nonzero_op.c.j2"),
2307
2788
  "expand": self._env.get_template("expand_op.c.j2"),
2308
2789
  "cumsum": self._env.get_template("cumsum_op.c.j2"),
2309
2790
  "range": self._env.get_template("range_op.c.j2"),
2791
+ "one_hot": self._env.get_template("one_hot_op.c.j2"),
2310
2792
  "split": self._env.get_template("split_op.c.j2"),
2311
2793
  }
2312
2794
  if emit_testbench:
@@ -2328,6 +2810,9 @@ class CEmitter:
2328
2810
  testbench_inputs = self._sanitize_testbench_inputs(
2329
2811
  testbench_inputs, name_map
2330
2812
  )
2813
+ inline_constants, large_constants = self._partition_constants(
2814
+ model.constants
2815
+ )
2331
2816
  (
2332
2817
  dim_order,
2333
2818
  input_dim_names,
@@ -2353,11 +2838,15 @@ class CEmitter:
2353
2838
  unary_template = templates["unary"]
2354
2839
  clip_template = templates["clip"]
2355
2840
  cast_template = templates["cast"]
2841
+ quantize_linear_template = templates["quantize_linear"]
2356
2842
  matmul_template = templates["matmul"]
2843
+ einsum_template = templates["einsum"]
2357
2844
  gemm_template = templates["gemm"]
2358
2845
  attention_template = templates["attention"]
2359
2846
  conv_template = templates["conv"]
2847
+ conv_transpose_template = templates["conv_transpose"]
2360
2848
  avg_pool_template = templates["avg_pool"]
2849
+ lp_pool_template = templates["lp_pool"]
2361
2850
  batch_norm_template = templates["batch_norm"]
2362
2851
  lp_norm_template = templates["lp_norm"]
2363
2852
  instance_norm_template = templates["instance_norm"]
@@ -2369,16 +2858,20 @@ class CEmitter:
2369
2858
  lstm_template = templates["lstm"]
2370
2859
  softmax_template = templates["softmax"]
2371
2860
  logsoftmax_template = templates["logsoftmax"]
2861
+ hardmax_template = templates["hardmax"]
2372
2862
  nllloss_template = templates["nllloss"]
2373
2863
  softmax_cross_entropy_loss_template = templates["softmax_cross_entropy_loss"]
2374
2864
  maxpool_template = templates["maxpool"]
2375
2865
  concat_template = templates["concat"]
2376
2866
  gather_elements_template = templates["gather_elements"]
2377
2867
  gather_template = templates["gather"]
2868
+ gather_nd_template = templates["gather_nd"]
2869
+ scatter_nd_template = templates["scatter_nd"]
2378
2870
  transpose_template = templates["transpose"]
2379
2871
  reshape_template = templates["reshape"]
2380
2872
  identity_template = templates["identity"]
2381
2873
  eye_like_template = templates["eye_like"]
2874
+ trilu_template = templates["trilu"]
2382
2875
  tile_template = templates["tile"]
2383
2876
  pad_template = templates["pad"]
2384
2877
  depth_to_space_template = templates["depth_to_space"]
@@ -2390,12 +2883,15 @@ class CEmitter:
2390
2883
  reduce_template = templates["reduce"]
2391
2884
  reduce_dynamic_template = templates["reduce_dynamic"]
2392
2885
  arg_reduce_template = templates["arg_reduce"]
2886
+ topk_template = templates["topk"]
2393
2887
  constant_of_shape_template = templates["constant_of_shape"]
2394
2888
  shape_template = templates["shape"]
2395
2889
  size_template = templates["size"]
2890
+ nonzero_template = templates["nonzero"]
2396
2891
  expand_template = templates["expand"]
2397
2892
  cumsum_template = templates["cumsum"]
2398
2893
  range_template = templates["range"]
2894
+ one_hot_template = templates["one_hot"]
2399
2895
  split_template = templates["split"]
2400
2896
  testbench_template = templates.get("testbench")
2401
2897
  reserved_names = {
@@ -2427,11 +2923,15 @@ class CEmitter:
2427
2923
  unary_template=unary_template,
2428
2924
  clip_template=clip_template,
2429
2925
  cast_template=cast_template,
2926
+ quantize_linear_template=quantize_linear_template,
2430
2927
  matmul_template=matmul_template,
2928
+ einsum_template=einsum_template,
2431
2929
  gemm_template=gemm_template,
2432
2930
  attention_template=attention_template,
2433
2931
  conv_template=conv_template,
2932
+ conv_transpose_template=conv_transpose_template,
2434
2933
  avg_pool_template=avg_pool_template,
2934
+ lp_pool_template=lp_pool_template,
2435
2935
  batch_norm_template=batch_norm_template,
2436
2936
  lp_norm_template=lp_norm_template,
2437
2937
  instance_norm_template=instance_norm_template,
@@ -2443,16 +2943,20 @@ class CEmitter:
2443
2943
  lstm_template=lstm_template,
2444
2944
  softmax_template=softmax_template,
2445
2945
  logsoftmax_template=logsoftmax_template,
2946
+ hardmax_template=hardmax_template,
2446
2947
  nllloss_template=nllloss_template,
2447
2948
  softmax_cross_entropy_loss_template=softmax_cross_entropy_loss_template,
2448
2949
  maxpool_template=maxpool_template,
2449
2950
  concat_template=concat_template,
2450
2951
  gather_elements_template=gather_elements_template,
2451
2952
  gather_template=gather_template,
2953
+ gather_nd_template=gather_nd_template,
2954
+ scatter_nd_template=scatter_nd_template,
2452
2955
  transpose_template=transpose_template,
2453
2956
  reshape_template=reshape_template,
2454
2957
  identity_template=identity_template,
2455
2958
  eye_like_template=eye_like_template,
2959
+ trilu_template=trilu_template,
2456
2960
  tile_template=tile_template,
2457
2961
  pad_template=pad_template,
2458
2962
  depth_to_space_template=depth_to_space_template,
@@ -2464,12 +2968,15 @@ class CEmitter:
2464
2968
  reduce_template=reduce_template,
2465
2969
  reduce_dynamic_template=reduce_dynamic_template,
2466
2970
  arg_reduce_template=arg_reduce_template,
2971
+ topk_template=topk_template,
2467
2972
  constant_of_shape_template=constant_of_shape_template,
2468
2973
  shape_template=shape_template,
2469
2974
  size_template=size_template,
2975
+ nonzero_template=nonzero_template,
2470
2976
  expand_template=expand_template,
2471
2977
  cumsum_template=cumsum_template,
2472
2978
  range_template=range_template,
2979
+ one_hot_template=one_hot_template,
2473
2980
  split_template=split_template,
2474
2981
  scalar_registry=scalar_registry,
2475
2982
  dim_args=dim_args,
@@ -2495,25 +3002,45 @@ class CEmitter:
2495
3002
  scalar_preamble = [
2496
3003
  line for line in scalar_include_lines if not line.startswith("#include ")
2497
3004
  ]
3005
+ testbench_math_include = set()
3006
+ if emit_testbench and self._testbench_requires_math(
3007
+ model, testbench_inputs
3008
+ ):
3009
+ testbench_math_include.add("#include <math.h>")
2498
3010
  includes = self._collect_includes(
2499
3011
  model,
2500
3012
  resolved_ops,
2501
3013
  emit_testbench=emit_testbench,
2502
- extra_includes=scalar_includes,
3014
+ extra_includes=scalar_includes | testbench_math_include,
3015
+ needs_weight_loader=bool(large_constants),
2503
3016
  )
2504
- sections = [self._emit_header_comment(model.header), "", *includes]
3017
+ sections = [
3018
+ self._emit_header_comment(model.header),
3019
+ "",
3020
+ *includes,
3021
+ "",
3022
+ self._emit_index_type_define(),
3023
+ ]
2505
3024
  if scalar_preamble:
2506
3025
  sections.extend(("", *scalar_preamble))
2507
3026
  sections.append("")
2508
- constants_section = self._emit_constant_definitions(model.constants)
3027
+ constants_section = self._emit_constant_definitions(inline_constants)
2509
3028
  if constants_section:
2510
3029
  sections.extend((constants_section.rstrip(), ""))
3030
+ large_constants_section = self._emit_constant_storage_definitions(
3031
+ large_constants
3032
+ )
3033
+ if large_constants_section:
3034
+ sections.extend((large_constants_section.rstrip(), ""))
2511
3035
  if scalar_functions:
2512
3036
  sections.extend(("\n".join(scalar_functions), ""))
3037
+ weight_loader = self._emit_weight_loader(model, large_constants)
2513
3038
  sections.extend(
2514
3039
  (
2515
3040
  operator_fns.rstrip(),
2516
3041
  "",
3042
+ weight_loader.rstrip(),
3043
+ "",
2517
3044
  wrapper_fn,
2518
3045
  )
2519
3046
  )
@@ -2527,6 +3054,7 @@ class CEmitter:
2527
3054
  testbench_inputs=testbench_inputs,
2528
3055
  dim_order=dim_order,
2529
3056
  dim_values=dim_values,
3057
+ weight_data_filename=self._weight_data_filename(model),
2530
3058
  ),
2531
3059
  )
2532
3060
  )
@@ -2549,6 +3077,9 @@ class CEmitter:
2549
3077
  testbench_inputs = self._sanitize_testbench_inputs(
2550
3078
  testbench_inputs, name_map
2551
3079
  )
3080
+ inline_constants, large_constants = self._partition_constants(
3081
+ model.constants
3082
+ )
2552
3083
  (
2553
3084
  dim_order,
2554
3085
  input_dim_names,
@@ -2574,11 +3105,15 @@ class CEmitter:
2574
3105
  unary_template = templates["unary"]
2575
3106
  clip_template = templates["clip"]
2576
3107
  cast_template = templates["cast"]
3108
+ quantize_linear_template = templates["quantize_linear"]
2577
3109
  matmul_template = templates["matmul"]
3110
+ einsum_template = templates["einsum"]
2578
3111
  gemm_template = templates["gemm"]
2579
3112
  attention_template = templates["attention"]
2580
3113
  conv_template = templates["conv"]
3114
+ conv_transpose_template = templates["conv_transpose"]
2581
3115
  avg_pool_template = templates["avg_pool"]
3116
+ lp_pool_template = templates["lp_pool"]
2582
3117
  batch_norm_template = templates["batch_norm"]
2583
3118
  lp_norm_template = templates["lp_norm"]
2584
3119
  instance_norm_template = templates["instance_norm"]
@@ -2590,16 +3125,20 @@ class CEmitter:
2590
3125
  lstm_template = templates["lstm"]
2591
3126
  softmax_template = templates["softmax"]
2592
3127
  logsoftmax_template = templates["logsoftmax"]
3128
+ hardmax_template = templates["hardmax"]
2593
3129
  nllloss_template = templates["nllloss"]
2594
3130
  softmax_cross_entropy_loss_template = templates["softmax_cross_entropy_loss"]
2595
3131
  maxpool_template = templates["maxpool"]
2596
3132
  concat_template = templates["concat"]
2597
3133
  gather_elements_template = templates["gather_elements"]
2598
3134
  gather_template = templates["gather"]
3135
+ gather_nd_template = templates["gather_nd"]
3136
+ scatter_nd_template = templates["scatter_nd"]
2599
3137
  transpose_template = templates["transpose"]
2600
3138
  reshape_template = templates["reshape"]
2601
3139
  identity_template = templates["identity"]
2602
3140
  eye_like_template = templates["eye_like"]
3141
+ trilu_template = templates["trilu"]
2603
3142
  tile_template = templates["tile"]
2604
3143
  pad_template = templates["pad"]
2605
3144
  depth_to_space_template = templates["depth_to_space"]
@@ -2611,12 +3150,15 @@ class CEmitter:
2611
3150
  reduce_template = templates["reduce"]
2612
3151
  reduce_dynamic_template = templates["reduce_dynamic"]
2613
3152
  arg_reduce_template = templates["arg_reduce"]
3153
+ topk_template = templates["topk"]
2614
3154
  constant_of_shape_template = templates["constant_of_shape"]
2615
3155
  shape_template = templates["shape"]
2616
3156
  size_template = templates["size"]
3157
+ nonzero_template = templates["nonzero"]
2617
3158
  expand_template = templates["expand"]
2618
3159
  cumsum_template = templates["cumsum"]
2619
3160
  range_template = templates["range"]
3161
+ one_hot_template = templates["one_hot"]
2620
3162
  split_template = templates["split"]
2621
3163
  testbench_template = templates.get("testbench")
2622
3164
  reserved_names = {
@@ -2648,11 +3190,15 @@ class CEmitter:
2648
3190
  unary_template=unary_template,
2649
3191
  clip_template=clip_template,
2650
3192
  cast_template=cast_template,
3193
+ quantize_linear_template=quantize_linear_template,
2651
3194
  matmul_template=matmul_template,
3195
+ einsum_template=einsum_template,
2652
3196
  gemm_template=gemm_template,
2653
3197
  attention_template=attention_template,
2654
3198
  conv_template=conv_template,
3199
+ conv_transpose_template=conv_transpose_template,
2655
3200
  avg_pool_template=avg_pool_template,
3201
+ lp_pool_template=lp_pool_template,
2656
3202
  batch_norm_template=batch_norm_template,
2657
3203
  lp_norm_template=lp_norm_template,
2658
3204
  instance_norm_template=instance_norm_template,
@@ -2664,16 +3210,20 @@ class CEmitter:
2664
3210
  lstm_template=lstm_template,
2665
3211
  softmax_template=softmax_template,
2666
3212
  logsoftmax_template=logsoftmax_template,
3213
+ hardmax_template=hardmax_template,
2667
3214
  nllloss_template=nllloss_template,
2668
3215
  softmax_cross_entropy_loss_template=softmax_cross_entropy_loss_template,
2669
3216
  maxpool_template=maxpool_template,
2670
3217
  concat_template=concat_template,
2671
3218
  gather_elements_template=gather_elements_template,
2672
3219
  gather_template=gather_template,
3220
+ gather_nd_template=gather_nd_template,
3221
+ scatter_nd_template=scatter_nd_template,
2673
3222
  transpose_template=transpose_template,
2674
3223
  reshape_template=reshape_template,
2675
3224
  identity_template=identity_template,
2676
3225
  eye_like_template=eye_like_template,
3226
+ trilu_template=trilu_template,
2677
3227
  tile_template=tile_template,
2678
3228
  pad_template=pad_template,
2679
3229
  depth_to_space_template=depth_to_space_template,
@@ -2685,12 +3235,15 @@ class CEmitter:
2685
3235
  reduce_template=reduce_template,
2686
3236
  reduce_dynamic_template=reduce_dynamic_template,
2687
3237
  arg_reduce_template=arg_reduce_template,
3238
+ topk_template=topk_template,
2688
3239
  constant_of_shape_template=constant_of_shape_template,
2689
3240
  shape_template=shape_template,
2690
3241
  size_template=size_template,
3242
+ nonzero_template=nonzero_template,
2691
3243
  expand_template=expand_template,
2692
3244
  cumsum_template=cumsum_template,
2693
3245
  range_template=range_template,
3246
+ one_hot_template=one_hot_template,
2694
3247
  split_template=split_template,
2695
3248
  scalar_registry=scalar_registry,
2696
3249
  dim_args=dim_args,
@@ -2716,25 +3269,45 @@ class CEmitter:
2716
3269
  scalar_preamble = [
2717
3270
  line for line in scalar_include_lines if not line.startswith("#include ")
2718
3271
  ]
3272
+ testbench_math_include = set()
3273
+ if emit_testbench and self._testbench_requires_math(
3274
+ model, testbench_inputs
3275
+ ):
3276
+ testbench_math_include.add("#include <math.h>")
2719
3277
  includes = self._collect_includes(
2720
3278
  model,
2721
3279
  resolved_ops,
2722
3280
  emit_testbench=emit_testbench,
2723
- extra_includes=scalar_includes,
3281
+ extra_includes=scalar_includes | testbench_math_include,
3282
+ needs_weight_loader=bool(large_constants),
2724
3283
  )
2725
- sections = [self._emit_header_comment(model.header), "", *includes]
3284
+ sections = [
3285
+ self._emit_header_comment(model.header),
3286
+ "",
3287
+ *includes,
3288
+ "",
3289
+ self._emit_index_type_define(),
3290
+ ]
2726
3291
  if scalar_preamble:
2727
3292
  sections.extend(("", *scalar_preamble))
2728
3293
  sections.append("")
2729
- constants_section = self._emit_constant_declarations(model.constants)
3294
+ constants_section = self._emit_constant_declarations(inline_constants)
2730
3295
  if constants_section:
2731
3296
  sections.extend((constants_section.rstrip(), ""))
3297
+ large_constants_section = self._emit_constant_storage_definitions(
3298
+ large_constants
3299
+ )
3300
+ if large_constants_section:
3301
+ sections.extend((large_constants_section.rstrip(), ""))
2732
3302
  if scalar_functions:
2733
3303
  sections.extend(("\n".join(scalar_functions), ""))
3304
+ weight_loader = self._emit_weight_loader(model, large_constants)
2734
3305
  sections.extend(
2735
3306
  (
2736
3307
  operator_fns.rstrip(),
2737
3308
  "",
3309
+ weight_loader.rstrip(),
3310
+ "",
2738
3311
  wrapper_fn,
2739
3312
  )
2740
3313
  )
@@ -2748,6 +3321,7 @@ class CEmitter:
2748
3321
  testbench_inputs=testbench_inputs,
2749
3322
  dim_order=dim_order,
2750
3323
  dim_values=dim_values,
3324
+ weight_data_filename=self._weight_data_filename(model),
2751
3325
  ),
2752
3326
  )
2753
3327
  )
@@ -2755,14 +3329,14 @@ class CEmitter:
2755
3329
  main_rendered = "\n".join(sections)
2756
3330
  if not main_rendered.endswith("\n"):
2757
3331
  main_rendered += "\n"
2758
- data_includes = self._collect_constant_includes(model.constants)
3332
+ data_includes = self._collect_constant_includes(inline_constants)
2759
3333
  data_sections = [self._emit_header_comment(model.header), ""]
2760
3334
  if data_includes:
2761
3335
  data_sections.extend((*data_includes, ""))
2762
3336
  else:
2763
3337
  data_sections.append("")
2764
3338
  data_constants = self._emit_constant_definitions(
2765
- model.constants, storage_prefix="const"
3339
+ inline_constants, storage_prefix="const"
2766
3340
  )
2767
3341
  if data_constants:
2768
3342
  data_sections.append(data_constants.rstrip())
@@ -2856,6 +3430,23 @@ class CEmitter:
2856
3430
  comment_lines.append(" */")
2857
3431
  return "\n".join(comment_lines)
2858
3432
 
3433
+ @staticmethod
3434
+ def _emit_constant_comment(constant: ConstTensor, index: int) -> str:
3435
+ shape = constant.shape
3436
+ lines = [
3437
+ f"Weight {index}:",
3438
+ f"Name: {constant.name}",
3439
+ f"Shape: {shape if shape else '[]'}",
3440
+ f"Elements: {CEmitter._element_count(shape)}",
3441
+ f"Dtype: {constant.dtype.onnx_name}",
3442
+ ]
3443
+ comment_lines = ["/*"]
3444
+ comment_lines.extend(
3445
+ f" * {line}" if line else " *" for line in lines
3446
+ )
3447
+ comment_lines.append(" */")
3448
+ return "\n".join(comment_lines)
3449
+
2859
3450
  @staticmethod
2860
3451
  def _collect_constant_includes(constants: tuple[ConstTensor, ...]) -> list[str]:
2861
3452
  if not constants:
@@ -2920,6 +3511,7 @@ class CEmitter:
2920
3511
  ScalarFunction.FMOD,
2921
3512
  ScalarFunction.REMAINDER,
2922
3513
  ScalarFunction.LEAKY_RELU,
3514
+ ScalarFunction.MISH,
2923
3515
  ScalarFunction.MUL,
2924
3516
  ScalarFunction.NEG,
2925
3517
  ScalarFunction.LOGICAL_NOT,
@@ -3005,11 +3597,15 @@ class CEmitter:
3005
3597
  | UnaryOp
3006
3598
  | ClipOp
3007
3599
  | CastOp
3600
+ | QuantizeLinearOp
3008
3601
  | MatMulOp
3602
+ | EinsumOp
3009
3603
  | GemmOp
3010
3604
  | AttentionOp
3011
3605
  | ConvOp
3606
+ | ConvTransposeOp
3012
3607
  | AveragePoolOp
3608
+ | LpPoolOp
3013
3609
  | BatchNormOp
3014
3610
  | LpNormalizationOp
3015
3611
  | InstanceNormalizationOp
@@ -3021,16 +3617,20 @@ class CEmitter:
3021
3617
  | LstmOp
3022
3618
  | SoftmaxOp
3023
3619
  | LogSoftmaxOp
3620
+ | HardmaxOp
3024
3621
  | NegativeLogLikelihoodLossOp
3025
3622
  | SoftmaxCrossEntropyLossOp
3026
3623
  | MaxPoolOp
3027
3624
  | ConcatOp
3028
3625
  | GatherElementsOp
3029
3626
  | GatherOp
3627
+ | GatherNDOp
3628
+ | ScatterNDOp
3030
3629
  | TransposeOp
3031
3630
  | ReshapeOp
3032
3631
  | IdentityOp
3033
3632
  | EyeLikeOp
3633
+ | TriluOp
3034
3634
  | TileOp
3035
3635
  | DepthToSpaceOp
3036
3636
  | SpaceToDepthOp
@@ -3039,21 +3639,27 @@ class CEmitter:
3039
3639
  | GridSampleOp
3040
3640
  | ReduceOp
3041
3641
  | ArgReduceOp
3642
+ | TopKOp
3042
3643
  | ConstantOfShapeOp
3043
3644
  | ShapeOp
3044
3645
  | SizeOp
3646
+ | NonZeroOp
3045
3647
  | ExpandOp
3046
3648
  | CumSumOp
3047
3649
  | RangeOp
3650
+ | OneHotOp
3048
3651
  | SplitOp
3049
3652
  ],
3050
3653
  *,
3051
3654
  emit_testbench: bool,
3052
3655
  extra_includes: set[str] | None = None,
3656
+ needs_weight_loader: bool = False,
3053
3657
  ) -> list[str]:
3054
- includes: set[str] = {"#include <stddef.h>"}
3658
+ includes: set[str] = {"#include <stdint.h>"}
3055
3659
  if emit_testbench:
3056
- includes.update({"#include <stdio.h>", "#include <stdint.h>"})
3660
+ includes.add("#include <stdio.h>")
3661
+ if needs_weight_loader:
3662
+ includes.add("#include <stdio.h>")
3057
3663
  if extra_includes:
3058
3664
  includes.update(extra_includes)
3059
3665
  if any(
@@ -3074,7 +3680,9 @@ class CEmitter:
3074
3680
  *constant_of_shape_inputs,
3075
3681
  }
3076
3682
  model_dtypes.update(
3077
- op.dtype for op in resolved_ops if not isinstance(op, ArgReduceOp)
3683
+ op.dtype
3684
+ for op in resolved_ops
3685
+ if not isinstance(op, (ArgReduceOp, TopKOp))
3078
3686
  )
3079
3687
  arg_reduce_dtypes = {
3080
3688
  dtype
@@ -3083,6 +3691,17 @@ class CEmitter:
3083
3691
  for dtype in (op.input_dtype, op.output_dtype)
3084
3692
  }
3085
3693
  model_dtypes.update(arg_reduce_dtypes)
3694
+ topk_dtypes = {
3695
+ dtype
3696
+ for op in resolved_ops
3697
+ if isinstance(op, TopKOp)
3698
+ for dtype in (
3699
+ op.input_dtype,
3700
+ op.output_values_dtype,
3701
+ op.output_indices_dtype,
3702
+ )
3703
+ }
3704
+ model_dtypes.update(topk_dtypes)
3086
3705
  slice_input_dtypes = {
3087
3706
  dtype
3088
3707
  for op in resolved_ops
@@ -3095,12 +3714,18 @@ class CEmitter:
3095
3714
  )
3096
3715
  if dtype is not None
3097
3716
  }
3717
+ trilu_k_dtypes = {
3718
+ op.k_input_dtype
3719
+ for op in resolved_ops
3720
+ if isinstance(op, TriluOp) and op.k_input_dtype is not None
3721
+ }
3098
3722
  maxpool_indices_dtypes = {
3099
3723
  op.indices_dtype
3100
3724
  for op in resolved_ops
3101
3725
  if isinstance(op, MaxPoolOp) and op.indices_dtype is not None
3102
3726
  }
3103
3727
  model_dtypes.update(maxpool_indices_dtypes)
3728
+ model_dtypes.update(trilu_k_dtypes)
3104
3729
  nll_target_dtypes = {
3105
3730
  op.target_dtype
3106
3731
  for op in resolved_ops
@@ -3124,12 +3749,20 @@ class CEmitter:
3124
3749
  for op in resolved_ops
3125
3750
  ):
3126
3751
  includes.add("#include <stdbool.h>")
3752
+ if any(
3753
+ isinstance(op, SoftmaxCrossEntropyLossOp)
3754
+ and op.ignore_index is not None
3755
+ for op in resolved_ops
3756
+ ):
3757
+ includes.add("#include <stdbool.h>")
3127
3758
  if any(
3128
3759
  isinstance(op, UnaryOp)
3129
3760
  and unary_op_symbol(op.function, dtype=op.dtype) in {"llabs", "abs"}
3130
3761
  for op in resolved_ops
3131
3762
  ):
3132
3763
  includes.add("#include <stdlib.h>")
3764
+ if any(isinstance(op, PadOp) for op in resolved_ops):
3765
+ includes.add("#include <stddef.h>")
3133
3766
  if CEmitter._needs_math(resolved_ops):
3134
3767
  includes.add("#include <math.h>")
3135
3768
  if CEmitter._needs_limits(resolved_ops):
@@ -3140,9 +3773,9 @@ class CEmitter:
3140
3773
  ):
3141
3774
  includes.add("#include <string.h>")
3142
3775
  ordered_includes = (
3143
- "#include <stddef.h>",
3144
- "#include <stdio.h>",
3145
3776
  "#include <stdint.h>",
3777
+ "#include <stdio.h>",
3778
+ "#include <stddef.h>",
3146
3779
  "#include <stdbool.h>",
3147
3780
  "#include <stdlib.h>",
3148
3781
  "#include <math.h>",
@@ -3152,6 +3785,16 @@ class CEmitter:
3152
3785
  )
3153
3786
  return [include for include in ordered_includes if include in includes]
3154
3787
 
3788
+ @staticmethod
3789
+ def _emit_index_type_define() -> str:
3790
+ return "\n".join(
3791
+ (
3792
+ "#ifndef idx_t",
3793
+ "#define idx_t int32_t",
3794
+ "#endif",
3795
+ )
3796
+ )
3797
+
3155
3798
  @staticmethod
3156
3799
  def _needs_stdint(
3157
3800
  model_dtypes: set[ScalarType],
@@ -3186,11 +3829,15 @@ class CEmitter:
3186
3829
  | UnaryOp
3187
3830
  | ClipOp
3188
3831
  | CastOp
3832
+ | QuantizeLinearOp
3189
3833
  | MatMulOp
3834
+ | EinsumOp
3190
3835
  | GemmOp
3191
3836
  | AttentionOp
3192
3837
  | ConvOp
3838
+ | ConvTransposeOp
3193
3839
  | AveragePoolOp
3840
+ | LpPoolOp
3194
3841
  | BatchNormOp
3195
3842
  | LpNormalizationOp
3196
3843
  | InstanceNormalizationOp
@@ -3202,16 +3849,20 @@ class CEmitter:
3202
3849
  | LstmOp
3203
3850
  | SoftmaxOp
3204
3851
  | LogSoftmaxOp
3852
+ | HardmaxOp
3205
3853
  | NegativeLogLikelihoodLossOp
3206
3854
  | SoftmaxCrossEntropyLossOp
3207
3855
  | MaxPoolOp
3208
3856
  | ConcatOp
3209
3857
  | GatherElementsOp
3210
3858
  | GatherOp
3859
+ | GatherNDOp
3860
+ | ScatterNDOp
3211
3861
  | TransposeOp
3212
3862
  | ReshapeOp
3213
3863
  | IdentityOp
3214
3864
  | EyeLikeOp
3865
+ | TriluOp
3215
3866
  | TileOp
3216
3867
  | DepthToSpaceOp
3217
3868
  | SpaceToDepthOp
@@ -3220,12 +3871,15 @@ class CEmitter:
3220
3871
  | GridSampleOp
3221
3872
  | ReduceOp
3222
3873
  | ArgReduceOp
3874
+ | TopKOp
3223
3875
  | ConstantOfShapeOp
3224
3876
  | ShapeOp
3225
3877
  | SizeOp
3878
+ | NonZeroOp
3226
3879
  | ExpandOp
3227
3880
  | CumSumOp
3228
3881
  | RangeOp
3882
+ | OneHotOp
3229
3883
  | SplitOp
3230
3884
  ],
3231
3885
  ) -> bool:
@@ -3322,6 +3976,11 @@ class CEmitter:
3322
3976
  for op in resolved_ops
3323
3977
  ):
3324
3978
  return True
3979
+ if any(
3980
+ isinstance(op, (LpPoolOp, QuantizeLinearOp))
3981
+ for op in resolved_ops
3982
+ ):
3983
+ return True
3325
3984
  return False
3326
3985
 
3327
3986
  @staticmethod
@@ -3331,11 +3990,15 @@ class CEmitter:
3331
3990
  | UnaryOp
3332
3991
  | ClipOp
3333
3992
  | CastOp
3993
+ | QuantizeLinearOp
3334
3994
  | MatMulOp
3995
+ | EinsumOp
3335
3996
  | GemmOp
3336
3997
  | AttentionOp
3337
3998
  | ConvOp
3999
+ | ConvTransposeOp
3338
4000
  | AveragePoolOp
4001
+ | LpPoolOp
3339
4002
  | BatchNormOp
3340
4003
  | LpNormalizationOp
3341
4004
  | InstanceNormalizationOp
@@ -3347,16 +4010,19 @@ class CEmitter:
3347
4010
  | LstmOp
3348
4011
  | SoftmaxOp
3349
4012
  | LogSoftmaxOp
4013
+ | HardmaxOp
3350
4014
  | NegativeLogLikelihoodLossOp
3351
4015
  | SoftmaxCrossEntropyLossOp
3352
4016
  | MaxPoolOp
3353
4017
  | ConcatOp
3354
4018
  | GatherElementsOp
3355
4019
  | GatherOp
4020
+ | GatherNDOp
3356
4021
  | TransposeOp
3357
4022
  | ReshapeOp
3358
4023
  | IdentityOp
3359
4024
  | EyeLikeOp
4025
+ | TriluOp
3360
4026
  | TileOp
3361
4027
  | DepthToSpaceOp
3362
4028
  | SpaceToDepthOp
@@ -3365,12 +4031,15 @@ class CEmitter:
3365
4031
  | GridSampleOp
3366
4032
  | ReduceOp
3367
4033
  | ArgReduceOp
4034
+ | TopKOp
3368
4035
  | ConstantOfShapeOp
3369
4036
  | ShapeOp
3370
4037
  | SizeOp
4038
+ | NonZeroOp
3371
4039
  | ExpandOp
3372
4040
  | CumSumOp
3373
4041
  | RangeOp
4042
+ | OneHotOp
3374
4043
  | SplitOp
3375
4044
  ],
3376
4045
  ) -> bool:
@@ -3400,6 +4069,11 @@ class CEmitter:
3400
4069
  for op in resolved_ops
3401
4070
  ):
3402
4071
  return True
4072
+ if any(
4073
+ isinstance(op, QuantizeLinearOp) and op.dtype.is_integer
4074
+ for op in resolved_ops
4075
+ ):
4076
+ return True
3403
4077
  return False
3404
4078
 
3405
4079
  def _emit_model_wrapper(
@@ -3411,11 +4085,15 @@ class CEmitter:
3411
4085
  | UnaryOp
3412
4086
  | ClipOp
3413
4087
  | CastOp
4088
+ | QuantizeLinearOp
3414
4089
  | MatMulOp
4090
+ | EinsumOp
3415
4091
  | GemmOp
3416
4092
  | AttentionOp
3417
4093
  | ConvOp
4094
+ | ConvTransposeOp
3418
4095
  | AveragePoolOp
4096
+ | LpPoolOp
3419
4097
  | BatchNormOp
3420
4098
  | LpNormalizationOp
3421
4099
  | InstanceNormalizationOp
@@ -3427,16 +4105,19 @@ class CEmitter:
3427
4105
  | LstmOp
3428
4106
  | SoftmaxOp
3429
4107
  | LogSoftmaxOp
4108
+ | HardmaxOp
3430
4109
  | NegativeLogLikelihoodLossOp
3431
4110
  | SoftmaxCrossEntropyLossOp
3432
4111
  | MaxPoolOp
3433
4112
  | ConcatOp
3434
4113
  | GatherElementsOp
3435
4114
  | GatherOp
4115
+ | GatherNDOp
3436
4116
  | TransposeOp
3437
4117
  | ReshapeOp
3438
4118
  | IdentityOp
3439
4119
  | EyeLikeOp
4120
+ | TriluOp
3440
4121
  | TileOp
3441
4122
  | DepthToSpaceOp
3442
4123
  | SpaceToDepthOp
@@ -3445,12 +4126,15 @@ class CEmitter:
3445
4126
  | GridSampleOp
3446
4127
  | ReduceOp
3447
4128
  | ArgReduceOp
4129
+ | TopKOp
3448
4130
  | ConstantOfShapeOp
3449
4131
  | ShapeOp
3450
4132
  | SizeOp
4133
+ | NonZeroOp
3451
4134
  | ExpandOp
3452
4135
  | CumSumOp
3453
4136
  | RangeOp
4137
+ | OneHotOp
3454
4138
  | SplitOp
3455
4139
  ],
3456
4140
  temp_buffers: tuple[TempBuffer, ...],
@@ -3480,8 +4164,14 @@ class CEmitter:
3480
4164
  lines = [f"void {model.name}({signature}) {{"]
3481
4165
  for temp in temp_buffers:
3482
4166
  c_type = temp.dtype.c_type
4167
+ storage = (
4168
+ "static "
4169
+ if self._temp_buffer_size_bytes(temp)
4170
+ > self._large_temp_threshold_bytes
4171
+ else ""
4172
+ )
3483
4173
  lines.append(
3484
- f" {c_type} {temp.name}{self._array_suffix(temp.shape)};"
4174
+ f" {storage}{c_type} {temp.name}{self._array_suffix(temp.shape)};"
3485
4175
  )
3486
4176
  for index, op in enumerate(resolved_ops):
3487
4177
  op_name = self._op_function_name(model, index)
@@ -3490,6 +4180,13 @@ class CEmitter:
3490
4180
  lines.append("}")
3491
4181
  return "\n".join(lines)
3492
4182
 
4183
+ @staticmethod
4184
+ def _temp_buffer_size_bytes(temp: TempBuffer) -> int:
4185
+ element_count = 1
4186
+ for dim in temp.shape:
4187
+ element_count *= dim
4188
+ return element_count * temp.dtype.np_dtype.itemsize
4189
+
3493
4190
  @staticmethod
3494
4191
  def _build_op_call(
3495
4192
  op: BinaryOp
@@ -3497,11 +4194,15 @@ class CEmitter:
3497
4194
  | UnaryOp
3498
4195
  | ClipOp
3499
4196
  | CastOp
4197
+ | QuantizeLinearOp
3500
4198
  | MatMulOp
4199
+ | EinsumOp
3501
4200
  | GemmOp
3502
4201
  | AttentionOp
3503
4202
  | ConvOp
4203
+ | ConvTransposeOp
3504
4204
  | AveragePoolOp
4205
+ | LpPoolOp
3505
4206
  | BatchNormOp
3506
4207
  | LpNormalizationOp
3507
4208
  | InstanceNormalizationOp
@@ -3513,16 +4214,20 @@ class CEmitter:
3513
4214
  | LstmOp
3514
4215
  | SoftmaxOp
3515
4216
  | LogSoftmaxOp
4217
+ | HardmaxOp
3516
4218
  | NegativeLogLikelihoodLossOp
3517
4219
  | SoftmaxCrossEntropyLossOp
3518
4220
  | MaxPoolOp
3519
4221
  | ConcatOp
3520
4222
  | GatherElementsOp
3521
4223
  | GatherOp
4224
+ | GatherNDOp
4225
+ | ScatterNDOp
3522
4226
  | TransposeOp
3523
4227
  | ReshapeOp
3524
4228
  | IdentityOp
3525
4229
  | EyeLikeOp
4230
+ | TriluOp
3526
4231
  | TileOp
3527
4232
  | PadOp
3528
4233
  | DepthToSpaceOp
@@ -3532,12 +4237,15 @@ class CEmitter:
3532
4237
  | GridSampleOp
3533
4238
  | ReduceOp
3534
4239
  | ArgReduceOp
4240
+ | TopKOp
3535
4241
  | ConstantOfShapeOp
3536
4242
  | ShapeOp
3537
4243
  | SizeOp
4244
+ | NonZeroOp
3538
4245
  | ExpandOp
3539
4246
  | CumSumOp
3540
4247
  | RangeOp
4248
+ | OneHotOp
3541
4249
  | SplitOp,
3542
4250
  dim_order: Sequence[str],
3543
4251
  ) -> str:
@@ -3556,6 +4264,9 @@ class CEmitter:
3556
4264
  if isinstance(op, MatMulOp):
3557
4265
  args.extend([op.input0, op.input1, op.output])
3558
4266
  return ", ".join(args)
4267
+ if isinstance(op, EinsumOp):
4268
+ args.extend([*op.inputs, op.output])
4269
+ return ", ".join(args)
3559
4270
  if isinstance(op, GemmOp):
3560
4271
  if op.input_c is None:
3561
4272
  args.extend([op.input_a, op.input_b, op.output])
@@ -3574,6 +4285,13 @@ class CEmitter:
3574
4285
  call_parts.append(op.output)
3575
4286
  args.extend(call_parts)
3576
4287
  return ", ".join(args)
4288
+ if isinstance(op, QuantizeLinearOp):
4289
+ call_parts = [op.input0, op.scale]
4290
+ if op.zero_point is not None:
4291
+ call_parts.append(op.zero_point)
4292
+ call_parts.append(op.output)
4293
+ args.extend(call_parts)
4294
+ return ", ".join(args)
3577
4295
  if isinstance(op, AttentionOp):
3578
4296
  call_parts = [op.input_q, op.input_k, op.input_v]
3579
4297
  if op.input_attn_mask is not None:
@@ -3599,9 +4317,18 @@ class CEmitter:
3599
4317
  return ", ".join(args)
3600
4318
  args.extend([op.input0, op.weights, op.bias, op.output])
3601
4319
  return ", ".join(args)
4320
+ if isinstance(op, ConvTransposeOp):
4321
+ if op.bias is None:
4322
+ args.extend([op.input0, op.weights, op.output])
4323
+ return ", ".join(args)
4324
+ args.extend([op.input0, op.weights, op.bias, op.output])
4325
+ return ", ".join(args)
3602
4326
  if isinstance(op, AveragePoolOp):
3603
4327
  args.extend([op.input0, op.output])
3604
4328
  return ", ".join(args)
4329
+ if isinstance(op, LpPoolOp):
4330
+ args.extend([op.input0, op.output])
4331
+ return ", ".join(args)
3605
4332
  if isinstance(op, BatchNormOp):
3606
4333
  args.extend(
3607
4334
  [op.input0, op.scale, op.bias, op.mean, op.variance, op.output]
@@ -3653,7 +4380,7 @@ class CEmitter:
3653
4380
  call_parts.append(op.output_y_c)
3654
4381
  args.extend(call_parts)
3655
4382
  return ", ".join(args)
3656
- if isinstance(op, (SoftmaxOp, LogSoftmaxOp)):
4383
+ if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
3657
4384
  args.extend([op.input0, op.output])
3658
4385
  return ", ".join(args)
3659
4386
  if isinstance(op, NegativeLogLikelihoodLossOp):
@@ -3684,6 +4411,12 @@ class CEmitter:
3684
4411
  if isinstance(op, GatherOp):
3685
4412
  args.extend([op.data, op.indices, op.output])
3686
4413
  return ", ".join(args)
4414
+ if isinstance(op, GatherNDOp):
4415
+ args.extend([op.data, op.indices, op.output])
4416
+ return ", ".join(args)
4417
+ if isinstance(op, ScatterNDOp):
4418
+ args.extend([op.data, op.indices, op.updates, op.output])
4419
+ return ", ".join(args)
3687
4420
  if isinstance(op, ConcatOp):
3688
4421
  args.extend([*op.inputs, op.output])
3689
4422
  return ", ".join(args)
@@ -3696,9 +4429,18 @@ class CEmitter:
3696
4429
  if isinstance(op, SizeOp):
3697
4430
  args.extend([op.input0, op.output])
3698
4431
  return ", ".join(args)
4432
+ if isinstance(op, NonZeroOp):
4433
+ args.extend([op.input0, op.output])
4434
+ return ", ".join(args)
3699
4435
  if isinstance(op, ExpandOp):
3700
4436
  args.extend([op.input0, op.output])
3701
4437
  return ", ".join(args)
4438
+ if isinstance(op, TriluOp):
4439
+ call_parts = [op.input0, op.output]
4440
+ if op.k_input is not None:
4441
+ call_parts.append(op.k_input)
4442
+ args.extend(call_parts)
4443
+ return ", ".join(args)
3702
4444
  if isinstance(op, CumSumOp):
3703
4445
  args.append(op.input0)
3704
4446
  if op.axis_input is not None:
@@ -3708,6 +4450,9 @@ class CEmitter:
3708
4450
  if isinstance(op, RangeOp):
3709
4451
  args.extend([op.start, op.limit, op.delta, op.output])
3710
4452
  return ", ".join(args)
4453
+ if isinstance(op, OneHotOp):
4454
+ args.extend([op.indices, op.depth, op.values, op.output])
4455
+ return ", ".join(args)
3711
4456
  if isinstance(op, SplitOp):
3712
4457
  args.extend([op.input0, *op.outputs])
3713
4458
  return ", ".join(args)
@@ -3749,6 +4494,12 @@ class CEmitter:
3749
4494
  call_parts.append(op.output)
3750
4495
  args.extend(call_parts)
3751
4496
  return ", ".join(args)
4497
+ if isinstance(op, TriluOp):
4498
+ call_parts = [op.input0, op.output]
4499
+ if op.k_input is not None:
4500
+ call_parts.append(op.k_input)
4501
+ args.extend(call_parts)
4502
+ return ", ".join(args)
3752
4503
  if isinstance(op, GridSampleOp):
3753
4504
  args.extend([op.input0, op.grid, op.output])
3754
4505
  return ", ".join(args)
@@ -3761,6 +4512,9 @@ class CEmitter:
3761
4512
  if isinstance(op, ArgReduceOp):
3762
4513
  args.extend([op.input0, op.output])
3763
4514
  return ", ".join(args)
4515
+ if isinstance(op, TopKOp):
4516
+ args.extend([op.input0, op.output_values, op.output_indices])
4517
+ return ", ".join(args)
3764
4518
  args.extend([op.input0, op.output])
3765
4519
  return ", ".join(args)
3766
4520
 
@@ -3792,11 +4546,11 @@ class CEmitter:
3792
4546
  return {}
3793
4547
  if len(intermediates) == 1:
3794
4548
  name, shape, dtype = intermediates[0]
3795
- temp_name = allocate_temp_name("tmp")
4549
+ temp_name = allocate_temp_name(f"tmp0_{name}")
3796
4550
  return {name: TempBuffer(name=temp_name, shape=shape, dtype=dtype)}
3797
4551
  return {
3798
4552
  name: TempBuffer(
3799
- name=allocate_temp_name(f"tmp{index}"),
4553
+ name=allocate_temp_name(f"tmp{index}_{name}"),
3800
4554
  shape=shape,
3801
4555
  dtype=dtype,
3802
4556
  )
@@ -3811,11 +4565,15 @@ class CEmitter:
3811
4565
  | UnaryOp
3812
4566
  | ClipOp
3813
4567
  | CastOp
4568
+ | QuantizeLinearOp
3814
4569
  | MatMulOp
4570
+ | EinsumOp
3815
4571
  | GemmOp
3816
4572
  | AttentionOp
3817
4573
  | ConvOp
4574
+ | ConvTransposeOp
3818
4575
  | AveragePoolOp
4576
+ | LpPoolOp
3819
4577
  | BatchNormOp
3820
4578
  | LpNormalizationOp
3821
4579
  | InstanceNormalizationOp
@@ -3827,16 +4585,20 @@ class CEmitter:
3827
4585
  | LstmOp
3828
4586
  | SoftmaxOp
3829
4587
  | LogSoftmaxOp
4588
+ | HardmaxOp
3830
4589
  | NegativeLogLikelihoodLossOp
3831
4590
  | SoftmaxCrossEntropyLossOp
3832
4591
  | MaxPoolOp
3833
4592
  | ConcatOp
3834
4593
  | GatherElementsOp
3835
4594
  | GatherOp
4595
+ | GatherNDOp
4596
+ | ScatterNDOp
3836
4597
  | TransposeOp
3837
4598
  | ReshapeOp
3838
4599
  | IdentityOp
3839
4600
  | EyeLikeOp
4601
+ | TriluOp
3840
4602
  | TileOp
3841
4603
  | DepthToSpaceOp
3842
4604
  | SpaceToDepthOp
@@ -3845,12 +4607,15 @@ class CEmitter:
3845
4607
  | GridSampleOp
3846
4608
  | ReduceOp
3847
4609
  | ArgReduceOp
4610
+ | TopKOp
3848
4611
  | ConstantOfShapeOp
3849
4612
  | ShapeOp
3850
4613
  | SizeOp
4614
+ | NonZeroOp
3851
4615
  | ExpandOp
3852
4616
  | CumSumOp
3853
4617
  | RangeOp
4618
+ | OneHotOp
3854
4619
  | SplitOp,
3855
4620
  temp_map: dict[str, str],
3856
4621
  ) -> (
@@ -3860,11 +4625,15 @@ class CEmitter:
3860
4625
  | UnaryOp
3861
4626
  | ClipOp
3862
4627
  | CastOp
4628
+ | QuantizeLinearOp
3863
4629
  | MatMulOp
4630
+ | EinsumOp
3864
4631
  | GemmOp
3865
4632
  | AttentionOp
3866
4633
  | ConvOp
4634
+ | ConvTransposeOp
3867
4635
  | AveragePoolOp
4636
+ | LpPoolOp
3868
4637
  | BatchNormOp
3869
4638
  | LpNormalizationOp
3870
4639
  | InstanceNormalizationOp
@@ -3876,16 +4645,20 @@ class CEmitter:
3876
4645
  | LstmOp
3877
4646
  | SoftmaxOp
3878
4647
  | LogSoftmaxOp
4648
+ | HardmaxOp
3879
4649
  | NegativeLogLikelihoodLossOp
3880
4650
  | SoftmaxCrossEntropyLossOp
3881
4651
  | MaxPoolOp
3882
4652
  | ConcatOp
3883
4653
  | GatherElementsOp
3884
4654
  | GatherOp
4655
+ | GatherNDOp
4656
+ | ScatterNDOp
3885
4657
  | TransposeOp
3886
4658
  | ReshapeOp
3887
4659
  | IdentityOp
3888
4660
  | EyeLikeOp
4661
+ | TriluOp
3889
4662
  | TileOp
3890
4663
  | DepthToSpaceOp
3891
4664
  | SpaceToDepthOp
@@ -3894,12 +4667,15 @@ class CEmitter:
3894
4667
  | GridSampleOp
3895
4668
  | ReduceOp
3896
4669
  | ArgReduceOp
4670
+ | TopKOp
3897
4671
  | ConstantOfShapeOp
3898
4672
  | ShapeOp
3899
4673
  | SizeOp
4674
+ | NonZeroOp
3900
4675
  | ExpandOp
3901
4676
  | CumSumOp
3902
4677
  | RangeOp
4678
+ | OneHotOp
3903
4679
  | SplitOp
3904
4680
  ):
3905
4681
  if isinstance(op, BinaryOp):
@@ -3909,6 +4685,8 @@ class CEmitter:
3909
4685
  output=temp_map.get(op.output, op.output),
3910
4686
  function=op.function,
3911
4687
  operator_kind=op.operator_kind,
4688
+ input0_shape=op.input0_shape,
4689
+ input1_shape=op.input1_shape,
3912
4690
  shape=op.shape,
3913
4691
  dtype=op.dtype,
3914
4692
  input_dtype=op.input_dtype,
@@ -3979,6 +4757,16 @@ class CEmitter:
3979
4757
  right_vector=op.right_vector,
3980
4758
  dtype=op.dtype,
3981
4759
  )
4760
+ if isinstance(op, EinsumOp):
4761
+ return EinsumOp(
4762
+ inputs=tuple(temp_map.get(name, name) for name in op.inputs),
4763
+ output=temp_map.get(op.output, op.output),
4764
+ kind=op.kind,
4765
+ input_shapes=op.input_shapes,
4766
+ output_shape=op.output_shape,
4767
+ dtype=op.dtype,
4768
+ input_dtype=op.input_dtype,
4769
+ )
3982
4770
  if isinstance(op, CastOp):
3983
4771
  return CastOp(
3984
4772
  input0=temp_map.get(op.input0, op.input0),
@@ -3987,6 +4775,22 @@ class CEmitter:
3987
4775
  input_dtype=op.input_dtype,
3988
4776
  dtype=op.dtype,
3989
4777
  )
4778
+ if isinstance(op, QuantizeLinearOp):
4779
+ return QuantizeLinearOp(
4780
+ input0=temp_map.get(op.input0, op.input0),
4781
+ scale=temp_map.get(op.scale, op.scale),
4782
+ zero_point=(
4783
+ temp_map.get(op.zero_point, op.zero_point)
4784
+ if op.zero_point is not None
4785
+ else None
4786
+ ),
4787
+ output=temp_map.get(op.output, op.output),
4788
+ input_shape=op.input_shape,
4789
+ axis=op.axis,
4790
+ dtype=op.dtype,
4791
+ input_dtype=op.input_dtype,
4792
+ scale_dtype=op.scale_dtype,
4793
+ )
3990
4794
  if isinstance(op, GemmOp):
3991
4795
  return GemmOp(
3992
4796
  input_a=temp_map.get(op.input_a, op.input_a),
@@ -4160,6 +4964,26 @@ class CEmitter:
4160
4964
  group=op.group,
4161
4965
  dtype=op.dtype,
4162
4966
  )
4967
+ if isinstance(op, ConvTransposeOp):
4968
+ return ConvTransposeOp(
4969
+ input0=temp_map.get(op.input0, op.input0),
4970
+ weights=temp_map.get(op.weights, op.weights),
4971
+ bias=temp_map.get(op.bias, op.bias) if op.bias else None,
4972
+ output=temp_map.get(op.output, op.output),
4973
+ batch=op.batch,
4974
+ in_channels=op.in_channels,
4975
+ out_channels=op.out_channels,
4976
+ spatial_rank=op.spatial_rank,
4977
+ in_spatial=op.in_spatial,
4978
+ out_spatial=op.out_spatial,
4979
+ kernel_shape=op.kernel_shape,
4980
+ strides=op.strides,
4981
+ pads=op.pads,
4982
+ dilations=op.dilations,
4983
+ output_padding=op.output_padding,
4984
+ group=op.group,
4985
+ dtype=op.dtype,
4986
+ )
4163
4987
  if isinstance(op, AveragePoolOp):
4164
4988
  return AveragePoolOp(
4165
4989
  input0=temp_map.get(op.input0, op.input0),
@@ -4181,6 +5005,27 @@ class CEmitter:
4181
5005
  count_include_pad=op.count_include_pad,
4182
5006
  dtype=op.dtype,
4183
5007
  )
5008
+ if isinstance(op, LpPoolOp):
5009
+ return LpPoolOp(
5010
+ input0=temp_map.get(op.input0, op.input0),
5011
+ output=temp_map.get(op.output, op.output),
5012
+ batch=op.batch,
5013
+ channels=op.channels,
5014
+ in_h=op.in_h,
5015
+ in_w=op.in_w,
5016
+ out_h=op.out_h,
5017
+ out_w=op.out_w,
5018
+ kernel_h=op.kernel_h,
5019
+ kernel_w=op.kernel_w,
5020
+ stride_h=op.stride_h,
5021
+ stride_w=op.stride_w,
5022
+ pad_top=op.pad_top,
5023
+ pad_left=op.pad_left,
5024
+ pad_bottom=op.pad_bottom,
5025
+ pad_right=op.pad_right,
5026
+ p=op.p,
5027
+ dtype=op.dtype,
5028
+ )
4184
5029
  if isinstance(op, BatchNormOp):
4185
5030
  return BatchNormOp(
4186
5031
  input0=temp_map.get(op.input0, op.input0),
@@ -4318,6 +5163,17 @@ class CEmitter:
4318
5163
  shape=op.shape,
4319
5164
  dtype=op.dtype,
4320
5165
  )
5166
+ if isinstance(op, HardmaxOp):
5167
+ return HardmaxOp(
5168
+ input0=temp_map.get(op.input0, op.input0),
5169
+ output=temp_map.get(op.output, op.output),
5170
+ outer=op.outer,
5171
+ axis_size=op.axis_size,
5172
+ inner=op.inner,
5173
+ axis=op.axis,
5174
+ shape=op.shape,
5175
+ dtype=op.dtype,
5176
+ )
4321
5177
  if isinstance(op, NegativeLogLikelihoodLossOp):
4322
5178
  return NegativeLogLikelihoodLossOp(
4323
5179
  input0=temp_map.get(op.input0, op.input0),
@@ -4419,6 +5275,32 @@ class CEmitter:
4419
5275
  dtype=op.dtype,
4420
5276
  indices_dtype=op.indices_dtype,
4421
5277
  )
5278
+ if isinstance(op, GatherNDOp):
5279
+ return GatherNDOp(
5280
+ data=temp_map.get(op.data, op.data),
5281
+ indices=temp_map.get(op.indices, op.indices),
5282
+ output=temp_map.get(op.output, op.output),
5283
+ batch_dims=op.batch_dims,
5284
+ data_shape=op.data_shape,
5285
+ indices_shape=op.indices_shape,
5286
+ output_shape=op.output_shape,
5287
+ dtype=op.dtype,
5288
+ indices_dtype=op.indices_dtype,
5289
+ )
5290
+ if isinstance(op, ScatterNDOp):
5291
+ return ScatterNDOp(
5292
+ data=temp_map.get(op.data, op.data),
5293
+ indices=temp_map.get(op.indices, op.indices),
5294
+ updates=temp_map.get(op.updates, op.updates),
5295
+ output=temp_map.get(op.output, op.output),
5296
+ data_shape=op.data_shape,
5297
+ indices_shape=op.indices_shape,
5298
+ updates_shape=op.updates_shape,
5299
+ output_shape=op.output_shape,
5300
+ reduction=op.reduction,
5301
+ dtype=op.dtype,
5302
+ indices_dtype=op.indices_dtype,
5303
+ )
4422
5304
  if isinstance(op, ConcatOp):
4423
5305
  return ConcatOp(
4424
5306
  inputs=tuple(temp_map.get(name, name) for name in op.inputs),
@@ -4458,6 +5340,15 @@ class CEmitter:
4458
5340
  dtype=op.dtype,
4459
5341
  input_dtype=op.input_dtype,
4460
5342
  )
5343
+ if isinstance(op, NonZeroOp):
5344
+ return NonZeroOp(
5345
+ input0=temp_map.get(op.input0, op.input0),
5346
+ output=temp_map.get(op.output, op.output),
5347
+ input_shape=op.input_shape,
5348
+ output_shape=op.output_shape,
5349
+ dtype=op.dtype,
5350
+ input_dtype=op.input_dtype,
5351
+ )
4461
5352
  if isinstance(op, ExpandOp):
4462
5353
  return ExpandOp(
4463
5354
  input0=temp_map.get(op.input0, op.input0),
@@ -4493,6 +5384,21 @@ class CEmitter:
4493
5384
  dtype=op.dtype,
4494
5385
  input_dtype=op.input_dtype,
4495
5386
  )
5387
+ if isinstance(op, OneHotOp):
5388
+ return OneHotOp(
5389
+ indices=temp_map.get(op.indices, op.indices),
5390
+ depth=temp_map.get(op.depth, op.depth),
5391
+ values=temp_map.get(op.values, op.values),
5392
+ output=temp_map.get(op.output, op.output),
5393
+ axis=op.axis,
5394
+ indices_shape=op.indices_shape,
5395
+ values_shape=op.values_shape,
5396
+ output_shape=op.output_shape,
5397
+ depth_dim=op.depth_dim,
5398
+ dtype=op.dtype,
5399
+ indices_dtype=op.indices_dtype,
5400
+ depth_dtype=op.depth_dtype,
5401
+ )
4496
5402
  if isinstance(op, SplitOp):
4497
5403
  return SplitOp(
4498
5404
  input0=temp_map.get(op.input0, op.input0),
@@ -4542,6 +5448,24 @@ class CEmitter:
4542
5448
  dtype=op.dtype,
4543
5449
  input_dtype=op.input_dtype,
4544
5450
  )
5451
+ if isinstance(op, TriluOp):
5452
+ return TriluOp(
5453
+ input0=temp_map.get(op.input0, op.input0),
5454
+ output=temp_map.get(op.output, op.output),
5455
+ input_shape=op.input_shape,
5456
+ output_shape=op.output_shape,
5457
+ upper=op.upper,
5458
+ k_value=op.k_value,
5459
+ k_input=(
5460
+ temp_map.get(op.k_input, op.k_input)
5461
+ if op.k_input is not None
5462
+ else None
5463
+ ),
5464
+ k_input_shape=op.k_input_shape,
5465
+ k_input_dtype=op.k_input_dtype,
5466
+ dtype=op.dtype,
5467
+ input_dtype=op.input_dtype,
5468
+ )
4545
5469
  if isinstance(op, TileOp):
4546
5470
  return TileOp(
4547
5471
  input0=temp_map.get(op.input0, op.input0),
@@ -4726,6 +5650,21 @@ class CEmitter:
4726
5650
  input_dtype=op.input_dtype,
4727
5651
  output_dtype=op.output_dtype,
4728
5652
  )
5653
+ if isinstance(op, TopKOp):
5654
+ return TopKOp(
5655
+ input0=temp_map.get(op.input0, op.input0),
5656
+ output_values=temp_map.get(op.output_values, op.output_values),
5657
+ output_indices=temp_map.get(op.output_indices, op.output_indices),
5658
+ input_shape=op.input_shape,
5659
+ output_shape=op.output_shape,
5660
+ axis=op.axis,
5661
+ k=op.k,
5662
+ largest=op.largest,
5663
+ sorted=op.sorted,
5664
+ input_dtype=op.input_dtype,
5665
+ output_values_dtype=op.output_values_dtype,
5666
+ output_indices_dtype=op.output_indices_dtype,
5667
+ )
4729
5668
  return UnaryOp(
4730
5669
  input0=temp_map.get(op.input0, op.input0),
4731
5670
  output=temp_map.get(op.output, op.output),
@@ -4743,11 +5682,15 @@ class CEmitter:
4743
5682
  | UnaryOp
4744
5683
  | ClipOp
4745
5684
  | CastOp
5685
+ | QuantizeLinearOp
4746
5686
  | MatMulOp
5687
+ | EinsumOp
4747
5688
  | GemmOp
4748
5689
  | AttentionOp
4749
5690
  | ConvOp
5691
+ | ConvTransposeOp
4750
5692
  | AveragePoolOp
5693
+ | LpPoolOp
4751
5694
  | BatchNormOp
4752
5695
  | LpNormalizationOp
4753
5696
  | InstanceNormalizationOp
@@ -4759,16 +5702,20 @@ class CEmitter:
4759
5702
  | LstmOp
4760
5703
  | SoftmaxOp
4761
5704
  | LogSoftmaxOp
5705
+ | HardmaxOp
4762
5706
  | NegativeLogLikelihoodLossOp
4763
5707
  | SoftmaxCrossEntropyLossOp
4764
5708
  | MaxPoolOp
4765
5709
  | ConcatOp
4766
5710
  | GatherElementsOp
4767
5711
  | GatherOp
5712
+ | GatherNDOp
5713
+ | ScatterNDOp
4768
5714
  | TransposeOp
4769
5715
  | ReshapeOp
4770
5716
  | IdentityOp
4771
5717
  | EyeLikeOp
5718
+ | TriluOp
4772
5719
  | TileOp
4773
5720
  | DepthToSpaceOp
4774
5721
  | SpaceToDepthOp
@@ -4777,12 +5724,15 @@ class CEmitter:
4777
5724
  | GridSampleOp
4778
5725
  | ReduceOp
4779
5726
  | ArgReduceOp
5727
+ | TopKOp
4780
5728
  | ConstantOfShapeOp
4781
5729
  | ShapeOp
4782
5730
  | SizeOp
5731
+ | NonZeroOp
4783
5732
  | ExpandOp
4784
5733
  | CumSumOp
4785
5734
  | RangeOp
5735
+ | OneHotOp
4786
5736
  | SplitOp,
4787
5737
  index: int,
4788
5738
  *,
@@ -4798,11 +5748,15 @@ class CEmitter:
4798
5748
  unary_template,
4799
5749
  clip_template,
4800
5750
  cast_template,
5751
+ quantize_linear_template,
4801
5752
  matmul_template,
5753
+ einsum_template,
4802
5754
  gemm_template,
4803
5755
  attention_template,
4804
5756
  conv_template,
5757
+ conv_transpose_template,
4805
5758
  avg_pool_template,
5759
+ lp_pool_template,
4806
5760
  batch_norm_template,
4807
5761
  lp_norm_template,
4808
5762
  instance_norm_template,
@@ -4814,16 +5768,20 @@ class CEmitter:
4814
5768
  lstm_template,
4815
5769
  softmax_template,
4816
5770
  logsoftmax_template,
5771
+ hardmax_template,
4817
5772
  nllloss_template,
4818
5773
  softmax_cross_entropy_loss_template,
4819
5774
  maxpool_template,
4820
5775
  concat_template,
4821
5776
  gather_elements_template,
4822
5777
  gather_template,
5778
+ gather_nd_template,
5779
+ scatter_nd_template,
4823
5780
  transpose_template,
4824
5781
  reshape_template,
4825
5782
  identity_template,
4826
5783
  eye_like_template,
5784
+ trilu_template,
4827
5785
  tile_template,
4828
5786
  pad_template,
4829
5787
  depth_to_space_template,
@@ -4835,12 +5793,15 @@ class CEmitter:
4835
5793
  reduce_template,
4836
5794
  reduce_dynamic_template,
4837
5795
  arg_reduce_template,
5796
+ topk_template,
4838
5797
  constant_of_shape_template,
4839
5798
  shape_template,
4840
5799
  size_template,
5800
+ nonzero_template,
4841
5801
  expand_template,
4842
5802
  cumsum_template,
4843
5803
  range_template,
5804
+ one_hot_template,
4844
5805
  split_template,
4845
5806
  scalar_registry: ScalarFunctionRegistry | None = None,
4846
5807
  dim_args: str = "",
@@ -4885,21 +5846,27 @@ class CEmitter:
4885
5846
  output_dim_names = _dim_names_for(op.output)
4886
5847
  shape = CEmitter._shape_dim_exprs(op.shape, output_dim_names)
4887
5848
  loop_vars = CEmitter._loop_vars(op.shape)
4888
- array_suffix = self._param_array_suffix(op.shape, output_dim_names)
5849
+ output_suffix = self._param_array_suffix(op.shape, output_dim_names)
5850
+ input0_suffix = self._param_array_suffix(
5851
+ op.input0_shape, _dim_names_for(op.input0)
5852
+ )
5853
+ input1_suffix = self._param_array_suffix(
5854
+ op.input1_shape, _dim_names_for(op.input1)
5855
+ )
4889
5856
  input_c_type = op.input_dtype.c_type
4890
5857
  output_c_type = op.dtype.c_type
4891
5858
  param_decls = self._build_param_decls(
4892
5859
  [
4893
- (params["input0"], input_c_type, array_suffix, True),
4894
- (params["input1"], input_c_type, array_suffix, True),
4895
- (params["output"], output_c_type, array_suffix, False),
5860
+ (params["input0"], input_c_type, input0_suffix, True),
5861
+ (params["input1"], input_c_type, input1_suffix, True),
5862
+ (params["output"], output_c_type, output_suffix, False),
4896
5863
  ]
4897
5864
  )
4898
5865
  common = {
4899
5866
  "model_name": model.name,
4900
5867
  "op_name": op_name,
4901
5868
  "element_count": CEmitter._element_count_expr(shape),
4902
- "array_suffix": array_suffix,
5869
+ "array_suffix": output_suffix,
4903
5870
  "shape": shape,
4904
5871
  "loop_vars": loop_vars,
4905
5872
  "input_c_type": input_c_type,
@@ -4908,11 +5875,17 @@ class CEmitter:
4908
5875
  "dim_args": dim_args,
4909
5876
  "params": param_decls,
4910
5877
  }
4911
- left_expr = f"{params['input0']}" + "".join(
4912
- f"[{var}]" for var in loop_vars
5878
+ left_expr = CEmitter._broadcast_index_expr(
5879
+ params["input0"],
5880
+ op.input0_shape,
5881
+ op.shape,
5882
+ loop_vars,
4913
5883
  )
4914
- right_expr = f"{params['input1']}" + "".join(
4915
- f"[{var}]" for var in loop_vars
5884
+ right_expr = CEmitter._broadcast_index_expr(
5885
+ params["input1"],
5886
+ op.input1_shape,
5887
+ op.shape,
5888
+ loop_vars,
4916
5889
  )
4917
5890
  operator_expr = None
4918
5891
  operator = op_spec.operator
@@ -5177,39 +6150,170 @@ class CEmitter:
5177
6150
  k=op.k,
5178
6151
  ).rstrip()
5179
6152
  return with_node_comment(rendered)
5180
- if isinstance(op, GemmOp):
6153
+ if isinstance(op, EinsumOp):
5181
6154
  params = self._shared_param_map(
5182
6155
  [
5183
- ("input_a", op.input_a),
5184
- ("input_b", op.input_b),
5185
- ("input_c", op.input_c),
6156
+ *(
6157
+ (f"input{idx}", name)
6158
+ for idx, name in enumerate(op.inputs)
6159
+ ),
5186
6160
  ("output", op.output),
5187
6161
  ]
5188
6162
  )
5189
- input_a_shape = (op.k, op.m) if op.trans_a else (op.m, op.k)
5190
- input_b_shape = (op.n, op.k) if op.trans_b else (op.k, op.n)
5191
- input_a_suffix = self._param_array_suffix(input_a_shape)
5192
- input_b_suffix = self._param_array_suffix(input_b_shape)
5193
- output_suffix = self._param_array_suffix((op.m, op.n))
5194
- c_suffix = (
5195
- self._param_array_suffix(op.c_shape)
5196
- if op.c_shape is not None
5197
- else ""
6163
+ output_dim_names = _dim_names_for(op.output)
6164
+ output_shape = CEmitter._shape_dim_exprs(
6165
+ op.output_shape, output_dim_names
5198
6166
  )
6167
+ output_loop_vars = CEmitter._loop_vars(op.output_shape)
6168
+ if output_loop_vars:
6169
+ output_expr = f"{params['output']}" + "".join(
6170
+ f"[{var}]" for var in output_loop_vars
6171
+ )
6172
+ else:
6173
+ output_expr = f"{params['output']}[0]"
6174
+ input_shapes = op.input_shapes
6175
+ input_dim_names = [
6176
+ _dim_names_for(name) for name in op.inputs
6177
+ ]
6178
+ input_suffixes = [
6179
+ self._param_array_suffix(shape, dim_names)
6180
+ for shape, dim_names in zip(input_shapes, input_dim_names)
6181
+ ]
5199
6182
  param_decls = self._build_param_decls(
5200
6183
  [
5201
- (params["input_a"], c_type, input_a_suffix, True),
5202
- (params["input_b"], c_type, input_b_suffix, True),
6184
+ *(
6185
+ (
6186
+ params[f"input{idx}"],
6187
+ op.input_dtype.c_type,
6188
+ input_suffixes[idx],
6189
+ True,
6190
+ )
6191
+ for idx in range(len(op.inputs))
6192
+ ),
5203
6193
  (
5204
- params["input_c"],
5205
- c_type,
5206
- c_suffix,
5207
- True,
5208
- )
5209
- if params["input_c"]
5210
- else (None, "", "", True),
5211
- (params["output"], c_type, output_suffix, False),
5212
- ]
6194
+ params["output"],
6195
+ op.dtype.c_type,
6196
+ self._param_array_suffix(op.output_shape, output_dim_names),
6197
+ False,
6198
+ ),
6199
+ ]
6200
+ )
6201
+ input_loop_vars: tuple[str, ...] = ()
6202
+ input_loop_bounds: tuple[str | int, ...] = ()
6203
+ reduce_loop_var = "k"
6204
+ reduce_loop_bound: str | int | None = None
6205
+ input_expr = None
6206
+ input0_expr = None
6207
+ input1_expr = None
6208
+ if op.kind == EinsumKind.REDUCE_ALL:
6209
+ input_loop_vars = CEmitter._loop_vars(input_shapes[0])
6210
+ input_loop_bounds = tuple(
6211
+ CEmitter._shape_dim_exprs(
6212
+ input_shapes[0], input_dim_names[0]
6213
+ )
6214
+ )
6215
+ if input_loop_vars:
6216
+ input_expr = f"{params['input0']}" + "".join(
6217
+ f"[{var}]" for var in input_loop_vars
6218
+ )
6219
+ else:
6220
+ input_expr = f"{params['input0']}[0]"
6221
+ elif op.kind == EinsumKind.SUM_J:
6222
+ input_shape_exprs = CEmitter._shape_dim_exprs(
6223
+ input_shapes[0], input_dim_names[0]
6224
+ )
6225
+ reduce_loop_bound = input_shape_exprs[1]
6226
+ input_expr = (
6227
+ f"{params['input0']}"
6228
+ f"[{output_loop_vars[0]}][{reduce_loop_var}]"
6229
+ )
6230
+ elif op.kind == EinsumKind.TRANSPOSE:
6231
+ input_expr = (
6232
+ f"{params['input0']}"
6233
+ f"[{output_loop_vars[1]}][{output_loop_vars[0]}]"
6234
+ )
6235
+ elif op.kind == EinsumKind.DOT:
6236
+ input_shape_exprs = CEmitter._shape_dim_exprs(
6237
+ input_shapes[0], input_dim_names[0]
6238
+ )
6239
+ reduce_loop_bound = input_shape_exprs[0]
6240
+ input0_expr = f"{params['input0']}[{reduce_loop_var}]"
6241
+ input1_expr = f"{params['input1']}[{reduce_loop_var}]"
6242
+ elif op.kind == EinsumKind.BATCH_MATMUL:
6243
+ input_shape_exprs = CEmitter._shape_dim_exprs(
6244
+ input_shapes[0], input_dim_names[0]
6245
+ )
6246
+ reduce_loop_bound = input_shape_exprs[2]
6247
+ input0_expr = (
6248
+ f"{params['input0']}"
6249
+ f"[{output_loop_vars[0]}]"
6250
+ f"[{output_loop_vars[1]}][{reduce_loop_var}]"
6251
+ )
6252
+ input1_expr = (
6253
+ f"{params['input1']}"
6254
+ f"[{output_loop_vars[0]}]"
6255
+ f"[{reduce_loop_var}][{output_loop_vars[2]}]"
6256
+ )
6257
+ elif op.kind == EinsumKind.BATCH_DIAGONAL:
6258
+ diag_var = output_loop_vars[-1]
6259
+ prefix_vars = output_loop_vars[:-1]
6260
+ input_expr = f"{params['input0']}" + "".join(
6261
+ f"[{var}]" for var in prefix_vars
6262
+ )
6263
+ input_expr += f"[{diag_var}][{diag_var}]"
6264
+ rendered = einsum_template.render(
6265
+ model_name=model.name,
6266
+ op_name=op_name,
6267
+ params=param_decls,
6268
+ dim_args=dim_args,
6269
+ kind=op.kind.value,
6270
+ output_loop_vars=output_loop_vars,
6271
+ output_loop_bounds=output_shape,
6272
+ output_expr=output_expr,
6273
+ acc_type=op.dtype.c_type,
6274
+ zero_literal=zero_literal,
6275
+ input_loop_vars=input_loop_vars,
6276
+ input_loop_bounds=input_loop_bounds,
6277
+ reduce_loop_var=reduce_loop_var,
6278
+ reduce_loop_bound=reduce_loop_bound,
6279
+ input_expr=input_expr,
6280
+ input0_expr=input0_expr,
6281
+ input1_expr=input1_expr,
6282
+ ).rstrip()
6283
+ return with_node_comment(rendered)
6284
+ if isinstance(op, GemmOp):
6285
+ params = self._shared_param_map(
6286
+ [
6287
+ ("input_a", op.input_a),
6288
+ ("input_b", op.input_b),
6289
+ ("input_c", op.input_c),
6290
+ ("output", op.output),
6291
+ ]
6292
+ )
6293
+ input_a_shape = (op.k, op.m) if op.trans_a else (op.m, op.k)
6294
+ input_b_shape = (op.n, op.k) if op.trans_b else (op.k, op.n)
6295
+ input_a_suffix = self._param_array_suffix(input_a_shape)
6296
+ input_b_suffix = self._param_array_suffix(input_b_shape)
6297
+ output_suffix = self._param_array_suffix((op.m, op.n))
6298
+ c_suffix = (
6299
+ self._param_array_suffix(op.c_shape)
6300
+ if op.c_shape is not None
6301
+ else ""
6302
+ )
6303
+ param_decls = self._build_param_decls(
6304
+ [
6305
+ (params["input_a"], c_type, input_a_suffix, True),
6306
+ (params["input_b"], c_type, input_b_suffix, True),
6307
+ (
6308
+ params["input_c"],
6309
+ c_type,
6310
+ c_suffix,
6311
+ True,
6312
+ )
6313
+ if params["input_c"]
6314
+ else (None, "", "", True),
6315
+ (params["output"], c_type, output_suffix, False),
6316
+ ]
5213
6317
  )
5214
6318
  alpha_literal = CEmitter._format_literal(op.dtype, op.alpha)
5215
6319
  beta_literal = CEmitter._format_literal(op.dtype, op.beta)
@@ -5556,6 +6660,81 @@ class CEmitter:
5556
6660
  in_indices=in_indices,
5557
6661
  ).rstrip()
5558
6662
  return with_node_comment(rendered)
6663
+ if isinstance(op, ConvTransposeOp):
6664
+ params = self._shared_param_map(
6665
+ [
6666
+ ("input0", op.input0),
6667
+ ("weights", op.weights),
6668
+ ("bias", op.bias),
6669
+ ("output", op.output),
6670
+ ]
6671
+ )
6672
+ input_shape = (op.batch, op.in_channels, *op.in_spatial)
6673
+ weight_shape = (
6674
+ op.in_channels,
6675
+ op.out_channels // op.group,
6676
+ *op.kernel_shape,
6677
+ )
6678
+ output_shape = (op.batch, op.out_channels, *op.out_spatial)
6679
+ in_indices = tuple(f"id{dim}" for dim in range(op.spatial_rank))
6680
+ kernel_indices = tuple(
6681
+ f"kd{dim}" for dim in range(op.spatial_rank)
6682
+ )
6683
+ out_indices = tuple(f"od{dim}" for dim in range(op.spatial_rank))
6684
+ pad_begin = op.pads[: op.spatial_rank]
6685
+ group_in_channels = op.in_channels // op.group
6686
+ group_out_channels = op.out_channels // op.group
6687
+ input_suffix = self._param_array_suffix(input_shape)
6688
+ weight_suffix = self._param_array_suffix(weight_shape)
6689
+ bias_suffix = self._param_array_suffix((op.out_channels,))
6690
+ output_suffix = self._param_array_suffix(output_shape)
6691
+ param_decls = self._build_param_decls(
6692
+ [
6693
+ (params["input0"], c_type, input_suffix, True),
6694
+ (params["weights"], c_type, weight_suffix, True),
6695
+ (
6696
+ params["bias"],
6697
+ c_type,
6698
+ bias_suffix,
6699
+ True,
6700
+ )
6701
+ if params["bias"]
6702
+ else (None, "", "", True),
6703
+ (params["output"], c_type, output_suffix, False),
6704
+ ]
6705
+ )
6706
+ rendered = conv_transpose_template.render(
6707
+ model_name=model.name,
6708
+ op_name=op_name,
6709
+ input0=params["input0"],
6710
+ weights=params["weights"],
6711
+ bias=params["bias"],
6712
+ output=params["output"],
6713
+ params=param_decls,
6714
+ c_type=c_type,
6715
+ zero_literal=zero_literal,
6716
+ input_suffix=input_suffix,
6717
+ weight_suffix=weight_suffix,
6718
+ bias_suffix=bias_suffix,
6719
+ output_suffix=output_suffix,
6720
+ batch=op.batch,
6721
+ in_channels=op.in_channels,
6722
+ out_channels=op.out_channels,
6723
+ spatial_rank=op.spatial_rank,
6724
+ in_spatial=op.in_spatial,
6725
+ out_spatial=op.out_spatial,
6726
+ kernel_shape=op.kernel_shape,
6727
+ strides=op.strides,
6728
+ pads_begin=pad_begin,
6729
+ dilations=op.dilations,
6730
+ group=op.group,
6731
+ group_in_channels=group_in_channels,
6732
+ group_out_channels=group_out_channels,
6733
+ in_indices=in_indices,
6734
+ kernel_indices=kernel_indices,
6735
+ out_indices=out_indices,
6736
+ ).rstrip()
6737
+ return with_node_comment(rendered)
5559
6738
  if isinstance(op, AveragePoolOp):
5560
6739
  params = self._shared_param_map(
5561
6740
  [("input0", op.input0), ("output", op.output)]
@@ -5597,6 +6776,49 @@ class CEmitter:
5597
6776
  count_include_pad=int(op.count_include_pad),
5598
6777
  ).rstrip()
5599
6778
  return with_node_comment(rendered)
6779
+ if isinstance(op, LpPoolOp):
6780
+ params = self._shared_param_map(
6781
+ [("input0", op.input0), ("output", op.output)]
6782
+ )
6783
+ input_shape = (op.batch, op.channels, op.in_h, op.in_w)
6784
+ output_shape = (op.batch, op.channels, op.out_h, op.out_w)
6785
+ input_suffix = self._param_array_suffix(input_shape)
6786
+ output_suffix = self._param_array_suffix(output_shape)
6787
+ param_decls = self._build_param_decls(
6788
+ [
6789
+ (params["input0"], c_type, input_suffix, True),
6790
+ (params["output"], c_type, output_suffix, False),
6791
+ ]
6792
+ )
6793
+ rendered = lp_pool_template.render(
6794
+ model_name=model.name,
6795
+ op_name=op_name,
6796
+ input0=params["input0"],
6797
+ output=params["output"],
6798
+ params=param_decls,
6799
+ c_type=c_type,
6800
+ input_suffix=input_suffix,
6801
+ output_suffix=output_suffix,
6802
+ batch=op.batch,
6803
+ channels=op.channels,
6804
+ in_h=op.in_h,
6805
+ in_w=op.in_w,
6806
+ out_h=op.out_h,
6807
+ out_w=op.out_w,
6808
+ kernel_h=op.kernel_h,
6809
+ kernel_w=op.kernel_w,
6810
+ stride_h=op.stride_h,
6811
+ stride_w=op.stride_w,
6812
+ pad_top=op.pad_top,
6813
+ pad_left=op.pad_left,
6814
+ pad_bottom=op.pad_bottom,
6815
+ pad_right=op.pad_right,
6816
+ p=op.p,
6817
+ zero_literal=zero_literal,
6818
+ abs_fn=CEmitter._math_fn(op.dtype, "fabsf", "fabs"),
6819
+ pow_fn=CEmitter._math_fn(op.dtype, "powf", "pow"),
6820
+ ).rstrip()
6821
+ return with_node_comment(rendered)
5600
6822
  if isinstance(op, BatchNormOp):
5601
6823
  params = self._shared_param_map(
5602
6824
  [
@@ -5769,6 +6991,19 @@ class CEmitter:
5769
6991
  ).rstrip()
5770
6992
  return with_node_comment(rendered)
5771
6993
  if isinstance(op, LayerNormalizationOp):
6994
+ acc_dtype = (
6995
+ ScalarType.F32
6996
+ if op.dtype in {ScalarType.F16, ScalarType.F32}
6997
+ else op.dtype
6998
+ )
6999
+ acc_type = acc_dtype.c_type
7000
+ acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
7001
+ acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
7002
+ acc_epsilon_literal = CEmitter._format_floating(
7003
+ op.epsilon, acc_dtype
7004
+ )
7005
+ acc_sqrt_fn = CEmitter._math_fn(acc_dtype, "sqrtf", "sqrt")
7006
+ use_kahan = op.dtype in {ScalarType.F16, ScalarType.F32}
5772
7007
  params = self._shared_param_map(
5773
7008
  [
5774
7009
  ("input0", op.input0),
@@ -5878,8 +7113,12 @@ class CEmitter:
5878
7113
  bias_index_vars=bias_index_vars,
5879
7114
  mean_index_vars=mean_index_vars,
5880
7115
  inner=op.inner,
5881
- epsilon_literal=CEmitter._format_floating(op.epsilon, op.dtype),
5882
- sqrt_fn=CEmitter._math_fn(op.dtype, "sqrtf", "sqrt"),
7116
+ acc_type=acc_type,
7117
+ acc_zero_literal=acc_zero_literal,
7118
+ acc_one_literal=acc_one_literal,
7119
+ acc_epsilon_literal=acc_epsilon_literal,
7120
+ acc_sqrt_fn=acc_sqrt_fn,
7121
+ use_kahan=use_kahan,
5883
7122
  ).rstrip()
5884
7123
  return with_node_comment(rendered)
5885
7124
  if isinstance(op, MeanVarianceNormalizationOp):
@@ -6244,7 +7483,41 @@ class CEmitter:
6244
7483
  log_fn=CEmitter._math_fn(op.dtype, "logf", "log"),
6245
7484
  ).rstrip()
6246
7485
  return with_node_comment(rendered)
7486
+ if isinstance(op, HardmaxOp):
7487
+ params = self._shared_param_map(
7488
+ [("input0", op.input0), ("output", op.output)]
7489
+ )
7490
+ array_suffix = self._param_array_suffix(op.shape)
7491
+ param_decls = self._build_param_decls(
7492
+ [
7493
+ (params["input0"], c_type, array_suffix, True),
7494
+ (params["output"], c_type, array_suffix, False),
7495
+ ]
7496
+ )
7497
+ rendered = hardmax_template.render(
7498
+ model_name=model.name,
7499
+ op_name=op_name,
7500
+ input0=params["input0"],
7501
+ output=params["output"],
7502
+ params=param_decls,
7503
+ c_type=c_type,
7504
+ array_suffix=array_suffix,
7505
+ outer=op.outer,
7506
+ axis_size=op.axis_size,
7507
+ inner=op.inner,
7508
+ zero_literal=zero_literal,
7509
+ one_literal=CEmitter._format_literal(op.dtype, 1),
7510
+ ).rstrip()
7511
+ return with_node_comment(rendered)
6247
7512
  if isinstance(op, NegativeLogLikelihoodLossOp):
7513
+ acc_dtype = (
7514
+ ScalarType.F64
7515
+ if op.dtype in {ScalarType.F16, ScalarType.F32}
7516
+ else op.dtype
7517
+ )
7518
+ acc_type = acc_dtype.c_type
7519
+ acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
7520
+ acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
6248
7521
  params = self._shared_param_map(
6249
7522
  [
6250
7523
  ("input0", op.input0),
@@ -6292,9 +7565,22 @@ class CEmitter:
6292
7565
  ignore_index=op.ignore_index,
6293
7566
  zero_literal=zero_literal,
6294
7567
  one_literal=CEmitter._format_literal(op.dtype, 1),
7568
+ acc_type=acc_type,
7569
+ acc_zero_literal=acc_zero_literal,
7570
+ acc_one_literal=acc_one_literal,
6295
7571
  ).rstrip()
6296
7572
  return with_node_comment(rendered)
6297
7573
  if isinstance(op, SoftmaxCrossEntropyLossOp):
7574
+ acc_dtype = (
7575
+ ScalarType.F64
7576
+ if op.dtype in {ScalarType.F16, ScalarType.F32}
7577
+ else op.dtype
7578
+ )
7579
+ acc_type = acc_dtype.c_type
7580
+ acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
7581
+ acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
7582
+ acc_exp_fn = CEmitter._math_fn(acc_dtype, "expf", "exp")
7583
+ acc_log_fn = CEmitter._math_fn(acc_dtype, "logf", "log")
6298
7584
  params = self._shared_param_map(
6299
7585
  [
6300
7586
  ("input0", op.input0),
@@ -6361,8 +7647,11 @@ class CEmitter:
6361
7647
  ignore_index=ignore_index,
6362
7648
  zero_literal=zero_literal,
6363
7649
  one_literal=CEmitter._format_literal(op.dtype, 1),
6364
- exp_fn=CEmitter._math_fn(op.dtype, "expf", "exp"),
6365
- log_fn=CEmitter._math_fn(op.dtype, "logf", "log"),
7650
+ acc_type=acc_type,
7651
+ acc_zero_literal=acc_zero_literal,
7652
+ acc_one_literal=acc_one_literal,
7653
+ acc_exp_fn=acc_exp_fn,
7654
+ acc_log_fn=acc_log_fn,
6366
7655
  ).rstrip()
6367
7656
  return with_node_comment(rendered)
6368
7657
  if isinstance(op, MaxPoolOp):
@@ -6569,6 +7858,180 @@ class CEmitter:
6569
7858
  axis_dim=op.data_shape[op.axis],
6570
7859
  ).rstrip()
6571
7860
  return with_node_comment(rendered)
7861
+ if isinstance(op, GatherNDOp):
7862
+ params = self._shared_param_map(
7863
+ [
7864
+ ("data", op.data),
7865
+ ("indices", op.indices),
7866
+ ("output", op.output),
7867
+ ]
7868
+ )
7869
+ indices_dim_names = _dim_names_for(op.indices)
7870
+ data_dim_names = _dim_names_for(op.data)
7871
+ data_shape = CEmitter._shape_dim_exprs(op.data_shape, data_dim_names)
7872
+ indices_shape = CEmitter._shape_dim_exprs(
7873
+ op.indices_shape, indices_dim_names
7874
+ )
7875
+ indices_prefix_shape = indices_shape[:-1]
7876
+ indices_prefix_loop_vars = (
7877
+ CEmitter._loop_vars(op.indices_shape[:-1])
7878
+ if op.indices_shape[:-1]
7879
+ else ()
7880
+ )
7881
+ index_depth = op.indices_shape[-1]
7882
+ tail_shape = data_shape[op.batch_dims + index_depth :]
7883
+ tail_loop_vars = (
7884
+ tuple(f"t{index}" for index in range(len(tail_shape)))
7885
+ if tail_shape
7886
+ else ()
7887
+ )
7888
+ output_loop_vars = (*indices_prefix_loop_vars, *tail_loop_vars)
7889
+ if output_loop_vars:
7890
+ output_index_expr = params["output"] + "".join(
7891
+ f"[{var}]" for var in output_loop_vars
7892
+ )
7893
+ else:
7894
+ output_index_expr = f"{params['output']}[0]"
7895
+ data_index_vars = (
7896
+ *indices_prefix_loop_vars[: op.batch_dims],
7897
+ *tuple(f"index{idx}" for idx in range(index_depth)),
7898
+ *tail_loop_vars,
7899
+ )
7900
+ data_index_expr = params["data"] + "".join(
7901
+ f"[{var}]" for var in data_index_vars
7902
+ )
7903
+ data_suffix = self._param_array_suffix(op.data_shape)
7904
+ indices_suffix = self._param_array_suffix(op.indices_shape)
7905
+ output_suffix = self._param_array_suffix(op.output_shape)
7906
+ param_decls = self._build_param_decls(
7907
+ [
7908
+ (params["data"], c_type, data_suffix, True),
7909
+ (
7910
+ params["indices"],
7911
+ op.indices_dtype.c_type,
7912
+ indices_suffix,
7913
+ True,
7914
+ ),
7915
+ (params["output"], c_type, output_suffix, False),
7916
+ ]
7917
+ )
7918
+ rendered = gather_nd_template.render(
7919
+ model_name=model.name,
7920
+ op_name=op_name,
7921
+ data=params["data"],
7922
+ indices=params["indices"],
7923
+ output=params["output"],
7924
+ params=param_decls,
7925
+ c_type=c_type,
7926
+ data_suffix=data_suffix,
7927
+ indices_suffix=indices_suffix,
7928
+ output_suffix=output_suffix,
7929
+ indices_prefix_shape=indices_prefix_shape,
7930
+ indices_prefix_loop_vars=indices_prefix_loop_vars,
7931
+ index_depth=index_depth,
7932
+ tail_shape=tail_shape,
7933
+ tail_loop_vars=tail_loop_vars,
7934
+ output_index_expr=output_index_expr,
7935
+ data_index_expr=data_index_expr,
7936
+ batch_dims=op.batch_dims,
7937
+ data_shape=data_shape,
7938
+ ).rstrip()
7939
+ return with_node_comment(rendered)
7940
+ if isinstance(op, ScatterNDOp):
7941
+ params = self._shared_param_map(
7942
+ [
7943
+ ("data", op.data),
7944
+ ("indices", op.indices),
7945
+ ("updates", op.updates),
7946
+ ("output", op.output),
7947
+ ]
7948
+ )
7949
+ output_dim_names = _dim_names_for(op.output)
7950
+ indices_dim_names = _dim_names_for(op.indices)
7951
+ updates_dim_names = _dim_names_for(op.updates)
7952
+ data_dim_names = _dim_names_for(op.data)
7953
+ output_shape = CEmitter._shape_dim_exprs(
7954
+ op.output_shape, output_dim_names
7955
+ )
7956
+ data_shape = CEmitter._shape_dim_exprs(op.data_shape, data_dim_names)
7957
+ indices_shape = CEmitter._shape_dim_exprs(
7958
+ op.indices_shape, indices_dim_names
7959
+ )
7960
+ output_loop_vars = CEmitter._loop_vars(op.output_shape)
7961
+ indices_prefix_shape = indices_shape[:-1]
7962
+ indices_prefix_loop_vars = (
7963
+ CEmitter._loop_vars(op.indices_shape[:-1])
7964
+ if op.indices_shape[:-1]
7965
+ else ()
7966
+ )
7967
+ index_depth = op.indices_shape[-1]
7968
+ tail_shape = output_shape[index_depth:]
7969
+ tail_loop_vars = (
7970
+ tuple(
7971
+ f"t{index}"
7972
+ for index in range(len(op.output_shape[index_depth:]))
7973
+ )
7974
+ if op.output_shape[index_depth:]
7975
+ else ()
7976
+ )
7977
+ index_vars = tuple(f"index{idx}" for idx in range(index_depth))
7978
+ output_index_expr = f"{params['output']}" + "".join(
7979
+ f"[{var}]" for var in (*index_vars, *tail_loop_vars)
7980
+ )
7981
+ updates_index_vars = (*indices_prefix_loop_vars, *tail_loop_vars)
7982
+ if not op.updates_shape:
7983
+ updates_index_expr = f"{params['updates']}[0]"
7984
+ else:
7985
+ updates_index_expr = f"{params['updates']}" + "".join(
7986
+ f"[{var}]" for var in updates_index_vars
7987
+ )
7988
+ data_suffix = self._param_array_suffix(
7989
+ op.data_shape, data_dim_names
7990
+ )
7991
+ indices_suffix = self._param_array_suffix(
7992
+ op.indices_shape, indices_dim_names
7993
+ )
7994
+ updates_suffix = self._param_array_suffix(
7995
+ op.updates_shape, updates_dim_names
7996
+ )
7997
+ output_suffix = self._param_array_suffix(
7998
+ op.output_shape, output_dim_names
7999
+ )
8000
+ param_decls = self._build_param_decls(
8001
+ [
8002
+ (params["data"], c_type, data_suffix, True),
8003
+ (
8004
+ params["indices"],
8005
+ op.indices_dtype.c_type,
8006
+ indices_suffix,
8007
+ True,
8008
+ ),
8009
+ (params["updates"], c_type, updates_suffix, True),
8010
+ (params["output"], c_type, output_suffix, False),
8011
+ ]
8012
+ )
8013
+ rendered = scatter_nd_template.render(
8014
+ model_name=model.name,
8015
+ op_name=op_name,
8016
+ data=params["data"],
8017
+ indices=params["indices"],
8018
+ updates=params["updates"],
8019
+ output=params["output"],
8020
+ params=param_decls,
8021
+ c_type=c_type,
8022
+ output_shape=output_shape,
8023
+ output_loop_vars=output_loop_vars,
8024
+ indices_prefix_shape=indices_prefix_shape,
8025
+ indices_prefix_loop_vars=indices_prefix_loop_vars,
8026
+ index_depth=index_depth,
8027
+ data_shape=data_shape,
8028
+ tail_shape=tail_shape,
8029
+ tail_loop_vars=tail_loop_vars,
8030
+ output_index_expr=output_index_expr,
8031
+ updates_index_expr=updates_index_expr,
8032
+ reduction=op.reduction,
8033
+ ).rstrip()
8034
+ return with_node_comment(rendered)
6572
8035
  if isinstance(op, TransposeOp):
6573
8036
  params = self._shared_param_map(
6574
8037
  [("input0", op.input0), ("output", op.output)]
@@ -6608,6 +8071,7 @@ class CEmitter:
6608
8071
  [("input0", op.input0), ("output", op.output)]
6609
8072
  )
6610
8073
  input_suffix = self._param_array_suffix(op.input_shape)
8074
+ output_shape = CEmitter._codegen_shape(op.output_shape)
6611
8075
  output_suffix = self._param_array_suffix(op.output_shape)
6612
8076
  param_decls = self._build_param_decls(
6613
8077
  [
@@ -6615,6 +8079,7 @@ class CEmitter:
6615
8079
  (params["output"], c_type, output_suffix, False),
6616
8080
  ]
6617
8081
  )
8082
+ loop_vars = CEmitter._loop_vars(op.output_shape)
6618
8083
  rendered = reshape_template.render(
6619
8084
  model_name=model.name,
6620
8085
  op_name=op_name,
@@ -6625,6 +8090,8 @@ class CEmitter:
6625
8090
  input_suffix=input_suffix,
6626
8091
  output_suffix=output_suffix,
6627
8092
  element_count=CEmitter._element_count(op.output_shape),
8093
+ output_shape=output_shape,
8094
+ loop_vars=loop_vars,
6628
8095
  ).rstrip()
6629
8096
  return with_node_comment(rendered)
6630
8097
  if isinstance(op, IdentityOp):
@@ -6691,6 +8158,61 @@ class CEmitter:
6691
8158
  one_literal=f"(({c_type})1)",
6692
8159
  ).rstrip()
6693
8160
  return with_node_comment(rendered)
8161
+ if isinstance(op, TriluOp):
8162
+ param_specs = [("input0", op.input0), ("output", op.output)]
8163
+ if op.k_input is not None:
8164
+ param_specs.append(("k_input", op.k_input))
8165
+ params = self._shared_param_map(param_specs)
8166
+ output_dim_names = _dim_names_for(op.output)
8167
+ shape = CEmitter._shape_dim_exprs(op.output_shape, output_dim_names)
8168
+ output_suffix = self._param_array_suffix(op.output_shape, output_dim_names)
8169
+ input_suffix = self._param_array_suffix(
8170
+ op.input_shape, _dim_names_for(op.input0)
8171
+ )
8172
+ k_suffix = ""
8173
+ if op.k_input is not None and op.k_input_shape is not None:
8174
+ k_suffix = self._param_array_suffix(
8175
+ op.k_input_shape, _dim_names_for(op.k_input)
8176
+ )
8177
+ batch_dims = op.output_shape[:-2]
8178
+ batch_size = CEmitter._element_count(batch_dims or (1,))
8179
+ param_decls = [
8180
+ (params["input0"], c_type, input_suffix, True),
8181
+ (params["output"], c_type, output_suffix, False),
8182
+ ]
8183
+ if op.k_input is not None and op.k_input_dtype is not None:
8184
+ param_decls.append(
8185
+ (
8186
+ params["k_input"],
8187
+ op.k_input_dtype.c_type,
8188
+ k_suffix,
8189
+ True,
8190
+ )
8191
+ )
8192
+ rendered = trilu_template.render(
8193
+ model_name=model.name,
8194
+ op_name=op_name,
8195
+ input0=params["input0"],
8196
+ output=params["output"],
8197
+ k_input=params.get("k_input"),
8198
+ params=self._build_param_decls(param_decls),
8199
+ c_type=c_type,
8200
+ k_c_type=(
8201
+ op.k_input_dtype.c_type
8202
+ if op.k_input_dtype is not None
8203
+ else ScalarType.I64.c_type
8204
+ ),
8205
+ input_suffix=input_suffix,
8206
+ output_suffix=output_suffix,
8207
+ shape=shape,
8208
+ batch_size=batch_size,
8209
+ rows=op.output_shape[-2],
8210
+ cols=op.output_shape[-1],
8211
+ k_value=op.k_value,
8212
+ upper=op.upper,
8213
+ zero_literal=zero_literal,
8214
+ ).rstrip()
8215
+ return with_node_comment(rendered)
6694
8216
  if isinstance(op, TileOp):
6695
8217
  params = self._shared_param_map(
6696
8218
  [("input0", op.input0), ("output", op.output)]
@@ -7224,17 +8746,19 @@ class CEmitter:
7224
8746
  update_expr = None
7225
8747
  init_literal = None
7226
8748
  final_expr = "acc"
8749
+ use_kahan = False
8750
+ kahan_value_expr = None
7227
8751
  fabs_fn = CEmitter._math_fn(op.dtype, "fabsf", "fabs")
7228
8752
  exp_fn = CEmitter._math_fn(op.dtype, "expf", "exp")
7229
8753
  log_fn = CEmitter._math_fn(op.dtype, "logf", "log")
7230
8754
  sqrt_fn = CEmitter._math_fn(op.dtype, "sqrtf", "sqrt")
7231
- count_literal = CEmitter._format_literal(
7232
- op.dtype, op.reduce_count
7233
- )
7234
8755
  if op.reduce_kind == "sum":
7235
8756
  init_literal = zero_literal
7236
8757
  update_expr = f"acc += {value_expr};"
7237
8758
  elif op.reduce_kind == "mean":
8759
+ count_literal = CEmitter._format_literal(
8760
+ op.dtype, op.reduce_count
8761
+ )
7238
8762
  init_literal = zero_literal
7239
8763
  update_expr = f"acc += {value_expr};"
7240
8764
  final_expr = f"acc / {count_literal}"
@@ -7269,6 +8793,24 @@ class CEmitter:
7269
8793
  raise CodegenError(
7270
8794
  f"Unsupported reduce kind {op.reduce_kind}"
7271
8795
  )
8796
+ if op.dtype in {ScalarType.F16, ScalarType.F32} and op.reduce_kind in {
8797
+ "sum",
8798
+ "mean",
8799
+ "logsum",
8800
+ "logsumexp",
8801
+ "l1",
8802
+ "l2",
8803
+ "sumsquare",
8804
+ }:
8805
+ use_kahan = True
8806
+ if op.reduce_kind == "logsumexp":
8807
+ kahan_value_expr = f"{exp_fn}({value_expr})"
8808
+ elif op.reduce_kind == "l1":
8809
+ kahan_value_expr = f"{fabs_fn}({value_expr})"
8810
+ elif op.reduce_kind in {"l2", "sumsquare"}:
8811
+ kahan_value_expr = f"{value_expr} * {value_expr}"
8812
+ else:
8813
+ kahan_value_expr = value_expr
7272
8814
  input_suffix = self._param_array_suffix(op.input_shape)
7273
8815
  output_suffix = self._param_array_suffix(op.output_shape)
7274
8816
  param_decls = self._build_param_decls(
@@ -7292,8 +8834,11 @@ class CEmitter:
7292
8834
  reduce_dims=reduce_dims,
7293
8835
  output_index_expr=output_index_expr,
7294
8836
  init_literal=init_literal,
8837
+ zero_literal=zero_literal,
7295
8838
  update_expr=update_expr,
7296
8839
  final_expr=final_expr,
8840
+ use_kahan=use_kahan,
8841
+ kahan_value_expr=kahan_value_expr,
7297
8842
  ).rstrip()
7298
8843
  return with_node_comment(rendered)
7299
8844
  if isinstance(op, ArgReduceOp):
@@ -7367,6 +8912,83 @@ class CEmitter:
7367
8912
  dim_args=dim_args,
7368
8913
  ).rstrip()
7369
8914
  return with_node_comment(rendered)
8915
+ if isinstance(op, TopKOp):
8916
+ params = self._shared_param_map(
8917
+ [
8918
+ ("input0", op.input0),
8919
+ ("output_values", op.output_values),
8920
+ ("output_indices", op.output_indices),
8921
+ ]
8922
+ )
8923
+ output_shape = CEmitter._codegen_shape(op.output_shape)
8924
+ outer_shape = tuple(
8925
+ dim for axis, dim in enumerate(output_shape) if axis != op.axis
8926
+ )
8927
+ outer_loop_vars = CEmitter._loop_vars(outer_shape)
8928
+ reduce_var = "r0"
8929
+ k_var = "k0"
8930
+ input_indices: list[str] = []
8931
+ output_indices: list[str] = []
8932
+ outer_index = 0
8933
+ for axis in range(len(op.input_shape)):
8934
+ if axis == op.axis:
8935
+ input_indices.append(reduce_var)
8936
+ output_indices.append(k_var)
8937
+ else:
8938
+ input_indices.append(outer_loop_vars[outer_index])
8939
+ output_indices.append(outer_loop_vars[outer_index])
8940
+ outer_index += 1
8941
+ input_index_expr = "".join(f"[{var}]" for var in input_indices)
8942
+ output_index_expr = "".join(f"[{var}]" for var in output_indices)
8943
+ compare_expr = (
8944
+ "(a > b) || ((a == b) && (ai < bi))"
8945
+ if op.largest
8946
+ else "(a < b) || ((a == b) && (ai < bi))"
8947
+ )
8948
+ input_suffix = self._param_array_suffix(op.input_shape)
8949
+ output_suffix = self._param_array_suffix(op.output_shape)
8950
+ param_decls = self._build_param_decls(
8951
+ [
8952
+ (params["input0"], op.input_dtype.c_type, input_suffix, True),
8953
+ (
8954
+ params["output_values"],
8955
+ op.output_values_dtype.c_type,
8956
+ output_suffix,
8957
+ False,
8958
+ ),
8959
+ (
8960
+ params["output_indices"],
8961
+ op.output_indices_dtype.c_type,
8962
+ output_suffix,
8963
+ False,
8964
+ ),
8965
+ ]
8966
+ )
8967
+ rendered = topk_template.render(
8968
+ model_name=model.name,
8969
+ op_name=op_name,
8970
+ input0=params["input0"],
8971
+ output_values=params["output_values"],
8972
+ output_indices=params["output_indices"],
8973
+ params=param_decls,
8974
+ input_c_type=op.input_dtype.c_type,
8975
+ output_values_c_type=op.output_values_dtype.c_type,
8976
+ output_indices_c_type=op.output_indices_dtype.c_type,
8977
+ input_suffix=input_suffix,
8978
+ output_suffix=output_suffix,
8979
+ output_shape=output_shape,
8980
+ outer_shape=outer_shape,
8981
+ outer_loop_vars=outer_loop_vars,
8982
+ reduce_var=reduce_var,
8983
+ k_var=k_var,
8984
+ axis_dim=op.input_shape[op.axis],
8985
+ k=op.k,
8986
+ input_index_expr=input_index_expr,
8987
+ output_index_expr=output_index_expr,
8988
+ compare_expr=compare_expr,
8989
+ dim_args=dim_args,
8990
+ ).rstrip()
8991
+ return with_node_comment(rendered)
7370
8992
  if isinstance(op, ReduceOp):
7371
8993
  name_params = self._shared_param_map(
7372
8994
  [
@@ -7545,35 +9167,76 @@ class CEmitter:
7545
9167
  c_type=c_type,
7546
9168
  input_suffix=input_suffix,
7547
9169
  output_suffix=output_suffix,
7548
- values=[
7549
- CEmitter._format_literal(op.dtype, value)
7550
- for value in op.values
7551
- ],
9170
+ values=[
9171
+ CEmitter._format_literal(op.dtype, value)
9172
+ for value in op.values
9173
+ ],
9174
+ ).rstrip()
9175
+ return with_node_comment(rendered)
9176
+ if isinstance(op, SizeOp):
9177
+ params = self._shared_param_map(
9178
+ [("input0", op.input0), ("output", op.output)]
9179
+ )
9180
+ input_suffix = self._param_array_suffix(op.input_shape)
9181
+ output_suffix = self._param_array_suffix(op.output_shape)
9182
+ param_decls = self._build_param_decls(
9183
+ [
9184
+ (params["input0"], op.input_dtype.c_type, input_suffix, True),
9185
+ (params["output"], c_type, output_suffix, False),
9186
+ ]
9187
+ )
9188
+ rendered = size_template.render(
9189
+ model_name=model.name,
9190
+ op_name=op_name,
9191
+ input0=params["input0"],
9192
+ output=params["output"],
9193
+ params=param_decls,
9194
+ input_c_type=op.input_dtype.c_type,
9195
+ c_type=c_type,
9196
+ input_suffix=input_suffix,
9197
+ output_suffix=output_suffix,
9198
+ value=CEmitter._format_literal(op.dtype, op.value),
7552
9199
  ).rstrip()
7553
9200
  return with_node_comment(rendered)
7554
- if isinstance(op, SizeOp):
9201
+ if isinstance(op, NonZeroOp):
7555
9202
  params = self._shared_param_map(
7556
9203
  [("input0", op.input0), ("output", op.output)]
7557
9204
  )
7558
- input_suffix = self._param_array_suffix(op.input_shape)
7559
- output_suffix = self._param_array_suffix(op.output_shape)
9205
+ input_dim_names = _dim_names_for(op.input0)
9206
+ output_dim_names = _dim_names_for(op.output)
9207
+ input_shape = CEmitter._shape_dim_exprs(
9208
+ op.input_shape, input_dim_names
9209
+ )
9210
+ loop_vars = CEmitter._loop_vars(op.input_shape)
9211
+ input_suffix = self._param_array_suffix(
9212
+ op.input_shape, input_dim_names
9213
+ )
9214
+ output_suffix = self._param_array_suffix(
9215
+ op.output_shape, output_dim_names
9216
+ )
7560
9217
  param_decls = self._build_param_decls(
7561
9218
  [
7562
9219
  (params["input0"], op.input_dtype.c_type, input_suffix, True),
7563
9220
  (params["output"], c_type, output_suffix, False),
7564
9221
  ]
7565
9222
  )
7566
- rendered = size_template.render(
9223
+ input_expr = f"{params['input0']}" + "".join(
9224
+ f"[{var}]" for var in loop_vars
9225
+ )
9226
+ rendered = nonzero_template.render(
7567
9227
  model_name=model.name,
7568
9228
  op_name=op_name,
7569
9229
  input0=params["input0"],
7570
9230
  output=params["output"],
7571
9231
  params=param_decls,
7572
9232
  input_c_type=op.input_dtype.c_type,
7573
- c_type=c_type,
9233
+ output_c_type=c_type,
7574
9234
  input_suffix=input_suffix,
7575
9235
  output_suffix=output_suffix,
7576
- value=CEmitter._format_literal(op.dtype, op.value),
9236
+ input_shape=input_shape,
9237
+ loop_vars=loop_vars,
9238
+ input_expr=input_expr,
9239
+ zero_literal=op.input_dtype.zero_literal,
7577
9240
  ).rstrip()
7578
9241
  return with_node_comment(rendered)
7579
9242
  if isinstance(op, ExpandOp):
@@ -7692,6 +9355,74 @@ class CEmitter:
7692
9355
  length=op.length,
7693
9356
  ).rstrip()
7694
9357
  return with_node_comment(rendered)
9358
+ if isinstance(op, OneHotOp):
9359
+ params = self._shared_param_map(
9360
+ [
9361
+ ("indices", op.indices),
9362
+ ("depth", op.depth),
9363
+ ("values", op.values),
9364
+ ("output", op.output),
9365
+ ]
9366
+ )
9367
+ output_dim_names = _dim_names_for(op.output)
9368
+ indices_dim_names = _dim_names_for(op.indices)
9369
+ values_dim_names = _dim_names_for(op.values)
9370
+ output_shape = CEmitter._codegen_shape(op.output_shape)
9371
+ loop_vars = CEmitter._loop_vars(output_shape)
9372
+ indices_indices = tuple(
9373
+ var for idx, var in enumerate(loop_vars) if idx != op.axis
9374
+ )
9375
+ if not indices_indices:
9376
+ indices_indices = ("0",)
9377
+ output_suffix = self._param_array_suffix(
9378
+ op.output_shape, output_dim_names
9379
+ )
9380
+ indices_suffix = self._param_array_suffix(
9381
+ op.indices_shape, indices_dim_names
9382
+ )
9383
+ values_suffix = self._param_array_suffix(
9384
+ op.values_shape, values_dim_names
9385
+ )
9386
+ depth_suffix = self._param_array_suffix(())
9387
+ param_decls = self._build_param_decls(
9388
+ [
9389
+ (
9390
+ params["indices"],
9391
+ op.indices_dtype.c_type,
9392
+ indices_suffix,
9393
+ True,
9394
+ ),
9395
+ (
9396
+ params["depth"],
9397
+ op.depth_dtype.c_type,
9398
+ depth_suffix,
9399
+ True,
9400
+ ),
9401
+ (params["values"], c_type, values_suffix, True),
9402
+ (params["output"], c_type, output_suffix, False),
9403
+ ]
9404
+ )
9405
+ rendered = one_hot_template.render(
9406
+ model_name=model.name,
9407
+ op_name=op_name,
9408
+ indices=params["indices"],
9409
+ depth=params["depth"],
9410
+ values=params["values"],
9411
+ output=params["output"],
9412
+ params=param_decls,
9413
+ indices_suffix=indices_suffix,
9414
+ depth_suffix=depth_suffix,
9415
+ values_suffix=values_suffix,
9416
+ output_suffix=output_suffix,
9417
+ output_shape=output_shape,
9418
+ loop_vars=loop_vars,
9419
+ indices_indices=indices_indices,
9420
+ axis_index=loop_vars[op.axis],
9421
+ depth_dim=op.depth_dim,
9422
+ indices_c_type=op.indices_dtype.c_type,
9423
+ c_type=c_type,
9424
+ ).rstrip()
9425
+ return with_node_comment(rendered)
7695
9426
  if isinstance(op, SplitOp):
7696
9427
  output_params = [
7697
9428
  (f"output_{index}", name)
@@ -7772,6 +9503,86 @@ class CEmitter:
7772
9503
  dim_args=dim_args,
7773
9504
  ).rstrip()
7774
9505
  return with_node_comment(rendered)
9506
+ if isinstance(op, QuantizeLinearOp):
9507
+ params = self._shared_param_map(
9508
+ [
9509
+ ("input0", op.input0),
9510
+ ("scale", op.scale),
9511
+ ("zero_point", op.zero_point),
9512
+ ("output", op.output),
9513
+ ]
9514
+ )
9515
+ output_dim_names = _dim_names_for(op.output)
9516
+ shape = CEmitter._shape_dim_exprs(op.input_shape, output_dim_names)
9517
+ loop_vars = CEmitter._loop_vars(op.input_shape)
9518
+ input_suffix = self._param_array_suffix(
9519
+ op.input_shape, _dim_names_for(op.input0)
9520
+ )
9521
+ scale_shape = (
9522
+ ()
9523
+ if op.axis is None
9524
+ else (op.input_shape[op.axis],)
9525
+ )
9526
+ scale_suffix = self._param_array_suffix(
9527
+ scale_shape, _dim_names_for(op.scale)
9528
+ )
9529
+ zero_point_suffix = self._param_array_suffix(
9530
+ scale_shape, _dim_names_for(op.zero_point or "")
9531
+ )
9532
+ param_decls = self._build_param_decls(
9533
+ [
9534
+ (params["input0"], op.input_dtype.c_type, input_suffix, True),
9535
+ (params["scale"], op.scale_dtype.c_type, scale_suffix, True),
9536
+ (
9537
+ params["zero_point"],
9538
+ op.dtype.c_type,
9539
+ zero_point_suffix,
9540
+ True,
9541
+ )
9542
+ if params["zero_point"]
9543
+ else (None, "", "", True),
9544
+ (params["output"], op.dtype.c_type, input_suffix, False),
9545
+ ]
9546
+ )
9547
+ compute_type = "double" if op.input_dtype == ScalarType.F64 else "float"
9548
+ round_fn = CEmitter._math_fn(
9549
+ op.input_dtype, "nearbyintf", "nearbyint"
9550
+ )
9551
+ scale_index = "0" if op.axis is None else loop_vars[op.axis]
9552
+ input_expr = f"{params['input0']}" + "".join(
9553
+ f"[{var}]" for var in loop_vars
9554
+ )
9555
+ output_expr = f"{params['output']}" + "".join(
9556
+ f"[{var}]" for var in loop_vars
9557
+ )
9558
+ scale_expr = f"{params['scale']}[{scale_index}]"
9559
+ if params["zero_point"]:
9560
+ zero_expr = f"{params['zero_point']}[{scale_index}]"
9561
+ else:
9562
+ zero_expr = "0"
9563
+ rendered = quantize_linear_template.render(
9564
+ model_name=model.name,
9565
+ op_name=op_name,
9566
+ input0=params["input0"],
9567
+ scale=params["scale"],
9568
+ zero_point=params["zero_point"],
9569
+ output=params["output"],
9570
+ params=param_decls,
9571
+ compute_type=compute_type,
9572
+ input_c_type=op.input_dtype.c_type,
9573
+ output_c_type=op.dtype.c_type,
9574
+ shape=shape,
9575
+ loop_vars=loop_vars,
9576
+ input_expr=input_expr,
9577
+ scale_expr=scale_expr,
9578
+ zero_expr=zero_expr,
9579
+ output_expr=output_expr,
9580
+ round_fn=round_fn,
9581
+ min_literal=op.dtype.min_literal,
9582
+ max_literal=op.dtype.max_literal,
9583
+ dim_args=dim_args,
9584
+ ).rstrip()
9585
+ return with_node_comment(rendered)
7775
9586
  if isinstance(op, ClipOp):
7776
9587
  params = self._shared_param_map(
7777
9588
  [
@@ -7934,11 +9745,15 @@ class CEmitter:
7934
9745
  | UnaryOp
7935
9746
  | ClipOp
7936
9747
  | CastOp
9748
+ | QuantizeLinearOp
7937
9749
  | MatMulOp
9750
+ | EinsumOp
7938
9751
  | GemmOp
7939
9752
  | AttentionOp
7940
9753
  | ConvOp
9754
+ | ConvTransposeOp
7941
9755
  | AveragePoolOp
9756
+ | LpPoolOp
7942
9757
  | BatchNormOp
7943
9758
  | LpNormalizationOp
7944
9759
  | InstanceNormalizationOp
@@ -7950,16 +9765,20 @@ class CEmitter:
7950
9765
  | LstmOp
7951
9766
  | SoftmaxOp
7952
9767
  | LogSoftmaxOp
9768
+ | HardmaxOp
7953
9769
  | NegativeLogLikelihoodLossOp
7954
9770
  | SoftmaxCrossEntropyLossOp
7955
9771
  | MaxPoolOp
7956
9772
  | ConcatOp
7957
9773
  | GatherElementsOp
7958
9774
  | GatherOp
9775
+ | GatherNDOp
9776
+ | ScatterNDOp
7959
9777
  | TransposeOp
7960
9778
  | ReshapeOp
7961
9779
  | IdentityOp
7962
9780
  | EyeLikeOp
9781
+ | TriluOp
7963
9782
  | TileOp
7964
9783
  | PadOp
7965
9784
  | DepthToSpaceOp
@@ -7968,16 +9787,20 @@ class CEmitter:
7968
9787
  | GridSampleOp
7969
9788
  | ReduceOp
7970
9789
  | ArgReduceOp
9790
+ | TopKOp
7971
9791
  | ConstantOfShapeOp
7972
9792
  | ShapeOp
7973
9793
  | SizeOp
7974
9794
  | ExpandOp
7975
9795
  | CumSumOp
7976
9796
  | RangeOp
9797
+ | OneHotOp
7977
9798
  | SplitOp,
7978
9799
  ) -> str:
7979
9800
  if isinstance(op, SplitOp):
7980
9801
  return op.outputs[0]
9802
+ if isinstance(op, TopKOp):
9803
+ return op.output_values
7981
9804
  return op.output
7982
9805
 
7983
9806
  @staticmethod
@@ -7988,11 +9811,15 @@ class CEmitter:
7988
9811
  | UnaryOp
7989
9812
  | ClipOp
7990
9813
  | CastOp
9814
+ | QuantizeLinearOp
7991
9815
  | MatMulOp
9816
+ | EinsumOp
7992
9817
  | GemmOp
7993
9818
  | AttentionOp
7994
9819
  | ConvOp
9820
+ | ConvTransposeOp
7995
9821
  | AveragePoolOp
9822
+ | LpPoolOp
7996
9823
  | BatchNormOp
7997
9824
  | LpNormalizationOp
7998
9825
  | InstanceNormalizationOp
@@ -8004,16 +9831,20 @@ class CEmitter:
8004
9831
  | LstmOp
8005
9832
  | SoftmaxOp
8006
9833
  | LogSoftmaxOp
9834
+ | HardmaxOp
8007
9835
  | NegativeLogLikelihoodLossOp
8008
9836
  | SoftmaxCrossEntropyLossOp
8009
9837
  | MaxPoolOp
8010
9838
  | ConcatOp
8011
9839
  | GatherElementsOp
8012
9840
  | GatherOp
9841
+ | GatherNDOp
9842
+ | ScatterNDOp
8013
9843
  | TransposeOp
8014
9844
  | ReshapeOp
8015
9845
  | IdentityOp
8016
9846
  | EyeLikeOp
9847
+ | TriluOp
8017
9848
  | TileOp
8018
9849
  | PadOp
8019
9850
  | DepthToSpaceOp
@@ -8022,18 +9853,28 @@ class CEmitter:
8022
9853
  | GridSampleOp
8023
9854
  | ReduceOp
8024
9855
  | ArgReduceOp
9856
+ | TopKOp
8025
9857
  | ConstantOfShapeOp
8026
9858
  | ShapeOp
8027
9859
  | SizeOp
8028
9860
  | ExpandOp
8029
9861
  | CumSumOp
8030
9862
  | RangeOp
9863
+ | OneHotOp
8031
9864
  | SplitOp,
8032
9865
  ) -> tuple[tuple[str, tuple[int, ...]], ...]:
8033
9866
  if isinstance(op, BinaryOp):
8034
- return ((op.input0, op.shape), (op.input1, op.shape))
9867
+ return (
9868
+ (op.input0, op.input0_shape),
9869
+ (op.input1, op.input1_shape),
9870
+ )
8035
9871
  if isinstance(op, MultiInputBinaryOp):
8036
9872
  return tuple((name, op.shape) for name in op.inputs)
9873
+ if isinstance(op, EinsumOp):
9874
+ return tuple(
9875
+ (name, shape)
9876
+ for name, shape in zip(op.inputs, op.input_shapes)
9877
+ )
8037
9878
  if isinstance(op, UnaryOp):
8038
9879
  return ((op.input0, op.shape),)
8039
9880
  if isinstance(op, LpNormalizationOp):
@@ -8068,10 +9909,27 @@ class CEmitter:
8068
9909
  return tuple(inputs)
8069
9910
  if isinstance(op, CastOp):
8070
9911
  return ((op.input0, op.shape),)
9912
+ if isinstance(op, NonZeroOp):
9913
+ return ((op.input0, op.input_shape),)
9914
+ if isinstance(op, QuantizeLinearOp):
9915
+ scale_shape = (
9916
+ ()
9917
+ if op.axis is None
9918
+ else (op.input_shape[op.axis],)
9919
+ )
9920
+ inputs = [(op.input0, op.input_shape), (op.scale, scale_shape)]
9921
+ if op.zero_point is not None:
9922
+ inputs.append((op.zero_point, scale_shape))
9923
+ return tuple(inputs)
8071
9924
  if isinstance(op, IdentityOp):
8072
9925
  return ((op.input0, op.shape),)
8073
9926
  if isinstance(op, EyeLikeOp):
8074
9927
  return ((op.input0, op.output_shape),)
9928
+ if isinstance(op, TriluOp):
9929
+ inputs = [(op.input0, op.input_shape)]
9930
+ if op.k_input is not None and op.k_input_shape is not None:
9931
+ inputs.append((op.k_input, op.k_input_shape))
9932
+ return tuple(inputs)
8075
9933
  if isinstance(op, GridSampleOp):
8076
9934
  return ((op.input0, op.input_shape), (op.grid, op.grid_shape))
8077
9935
  if isinstance(op, PadOp):
@@ -8083,8 +9941,22 @@ class CEmitter:
8083
9941
  if op.value_input is not None and op.value_shape is not None:
8084
9942
  inputs.append((op.value_input, op.value_shape))
8085
9943
  return tuple(inputs)
9944
+ if isinstance(op, ScatterNDOp):
9945
+ return ((op.data, op.data_shape),)
8086
9946
  if isinstance(op, CumSumOp):
8087
9947
  return ((op.input0, op.input_shape),)
9948
+ if isinstance(op, RangeOp):
9949
+ return ((op.start, ()), (op.limit, ()), (op.delta, ()))
9950
+ if isinstance(op, OneHotOp):
9951
+ return (
9952
+ (op.indices, op.indices_shape),
9953
+ (op.depth, ()),
9954
+ (op.values, op.values_shape),
9955
+ )
9956
+ if isinstance(op, SplitOp):
9957
+ return ((op.input0, op.input_shape),)
9958
+ if isinstance(op, TopKOp):
9959
+ return ((op.input0, op.input_shape),)
8088
9960
  return ()
8089
9961
 
8090
9962
  def _propagate_tensor_dim_names(
@@ -8096,11 +9968,15 @@ class CEmitter:
8096
9968
  | UnaryOp
8097
9969
  | ClipOp
8098
9970
  | CastOp
9971
+ | QuantizeLinearOp
8099
9972
  | MatMulOp
9973
+ | EinsumOp
8100
9974
  | GemmOp
8101
9975
  | AttentionOp
8102
9976
  | ConvOp
9977
+ | ConvTransposeOp
8103
9978
  | AveragePoolOp
9979
+ | LpPoolOp
8104
9980
  | BatchNormOp
8105
9981
  | LpNormalizationOp
8106
9982
  | InstanceNormalizationOp
@@ -8112,16 +9988,19 @@ class CEmitter:
8112
9988
  | LstmOp
8113
9989
  | SoftmaxOp
8114
9990
  | LogSoftmaxOp
9991
+ | HardmaxOp
8115
9992
  | NegativeLogLikelihoodLossOp
8116
9993
  | SoftmaxCrossEntropyLossOp
8117
9994
  | MaxPoolOp
8118
9995
  | ConcatOp
8119
9996
  | GatherElementsOp
8120
9997
  | GatherOp
9998
+ | GatherNDOp
8121
9999
  | TransposeOp
8122
10000
  | ReshapeOp
8123
10001
  | IdentityOp
8124
10002
  | EyeLikeOp
10003
+ | TriluOp
8125
10004
  | TileOp
8126
10005
  | PadOp
8127
10006
  | DepthToSpaceOp
@@ -8130,11 +10009,14 @@ class CEmitter:
8130
10009
  | GridSampleOp
8131
10010
  | ReduceOp
8132
10011
  | ArgReduceOp
10012
+ | TopKOp
8133
10013
  | ConstantOfShapeOp
8134
10014
  | ShapeOp
8135
10015
  | SizeOp
10016
+ | NonZeroOp
8136
10017
  | ExpandOp
8137
10018
  | RangeOp
10019
+ | OneHotOp
8138
10020
  | SplitOp
8139
10021
  ],
8140
10022
  tensor_dim_names: dict[str, dict[int, str]],
@@ -8157,11 +10039,15 @@ class CEmitter:
8157
10039
  | UnaryOp
8158
10040
  | ClipOp
8159
10041
  | CastOp
10042
+ | QuantizeLinearOp
8160
10043
  | MatMulOp
10044
+ | EinsumOp
8161
10045
  | GemmOp
8162
10046
  | AttentionOp
8163
10047
  | ConvOp
10048
+ | ConvTransposeOp
8164
10049
  | AveragePoolOp
10050
+ | LpPoolOp
8165
10051
  | BatchNormOp
8166
10052
  | LpNormalizationOp
8167
10053
  | InstanceNormalizationOp
@@ -8173,16 +10059,20 @@ class CEmitter:
8173
10059
  | LstmOp
8174
10060
  | SoftmaxOp
8175
10061
  | LogSoftmaxOp
10062
+ | HardmaxOp
8176
10063
  | NegativeLogLikelihoodLossOp
8177
10064
  | SoftmaxCrossEntropyLossOp
8178
10065
  | MaxPoolOp
8179
10066
  | ConcatOp
8180
10067
  | GatherElementsOp
8181
10068
  | GatherOp
10069
+ | GatherNDOp
10070
+ | ScatterNDOp
8182
10071
  | TransposeOp
8183
10072
  | ReshapeOp
8184
10073
  | IdentityOp
8185
10074
  | EyeLikeOp
10075
+ | TriluOp
8186
10076
  | TileOp
8187
10077
  | PadOp
8188
10078
  | DepthToSpaceOp
@@ -8191,11 +10081,14 @@ class CEmitter:
8191
10081
  | GridSampleOp
8192
10082
  | ReduceOp
8193
10083
  | ArgReduceOp
10084
+ | TopKOp
8194
10085
  | ConstantOfShapeOp
8195
10086
  | ShapeOp
8196
10087
  | SizeOp
10088
+ | NonZeroOp
8197
10089
  | ExpandOp
8198
10090
  | RangeOp
10091
+ | OneHotOp
8199
10092
  | SplitOp,
8200
10093
  ) -> tuple[tuple[str, tuple[int, ...], str], ...]:
8201
10094
  if isinstance(op, AttentionOp):
@@ -8292,6 +10185,19 @@ class CEmitter:
8292
10185
  )
8293
10186
  if isinstance(op, ArgReduceOp):
8294
10187
  return ((op.output, CEmitter._op_output_shape(op), op.output_dtype),)
10188
+ if isinstance(op, TopKOp):
10189
+ return (
10190
+ (
10191
+ op.output_values,
10192
+ CEmitter._op_output_shape(op),
10193
+ op.output_values_dtype,
10194
+ ),
10195
+ (
10196
+ op.output_indices,
10197
+ CEmitter._op_output_shape(op),
10198
+ op.output_indices_dtype,
10199
+ ),
10200
+ )
8295
10201
  return ((op.output, CEmitter._op_output_shape(op), op.dtype),)
8296
10202
 
8297
10203
  @staticmethod
@@ -8303,6 +10209,7 @@ class CEmitter:
8303
10209
  | ClipOp
8304
10210
  | CastOp
8305
10211
  | MatMulOp
10212
+ | EinsumOp
8306
10213
  | GemmOp
8307
10214
  | AttentionOp
8308
10215
  | ConvOp
@@ -8318,25 +10225,34 @@ class CEmitter:
8318
10225
  | LstmOp
8319
10226
  | SoftmaxOp
8320
10227
  | LogSoftmaxOp
10228
+ | HardmaxOp
8321
10229
  | NegativeLogLikelihoodLossOp
8322
10230
  | SoftmaxCrossEntropyLossOp
8323
10231
  | MaxPoolOp
8324
10232
  | ConcatOp
8325
10233
  | GatherElementsOp
8326
10234
  | GatherOp
10235
+ | GatherNDOp
8327
10236
  | TransposeOp
8328
10237
  | ReshapeOp
10238
+ | IdentityOp
10239
+ | EyeLikeOp
10240
+ | TriluOp
10241
+ | TileOp
8329
10242
  | SliceOp
8330
10243
  | ResizeOp
8331
10244
  | GridSampleOp
8332
10245
  | ReduceOp
8333
10246
  | ArgReduceOp
10247
+ | TopKOp
8334
10248
  | ConstantOfShapeOp
8335
10249
  | ShapeOp
8336
10250
  | SizeOp
10251
+ | NonZeroOp
8337
10252
  | ExpandOp
8338
10253
  | CumSumOp
8339
10254
  | RangeOp
10255
+ | OneHotOp
8340
10256
  | SplitOp
8341
10257
  | PadOp,
8342
10258
  ) -> tuple[int, ...]:
@@ -8350,16 +10266,24 @@ class CEmitter:
8350
10266
  return op.shape
8351
10267
  if isinstance(op, ClipOp):
8352
10268
  return op.output_shape
10269
+ if isinstance(op, QuantizeLinearOp):
10270
+ return op.input_shape
8353
10271
  if isinstance(op, CastOp):
8354
10272
  return op.shape
8355
10273
  if isinstance(op, MatMulOp):
8356
- return (op.m, op.n)
10274
+ return op.output_shape
10275
+ if isinstance(op, EinsumOp):
10276
+ return op.output_shape
8357
10277
  if isinstance(op, GemmOp):
8358
10278
  return (op.m, op.n)
8359
10279
  if isinstance(op, ConvOp):
8360
10280
  return (op.batch, op.out_channels, *op.out_spatial)
10281
+ if isinstance(op, ConvTransposeOp):
10282
+ return (op.batch, op.out_channels, *op.out_spatial)
8361
10283
  if isinstance(op, AveragePoolOp):
8362
10284
  return (op.batch, op.channels, op.out_h, op.out_w)
10285
+ if isinstance(op, LpPoolOp):
10286
+ return (op.batch, op.channels, op.out_h, op.out_w)
8363
10287
  if isinstance(op, BatchNormOp):
8364
10288
  return op.shape
8365
10289
  if isinstance(
@@ -8380,6 +10304,8 @@ class CEmitter:
8380
10304
  return op.shape
8381
10305
  if isinstance(op, LogSoftmaxOp):
8382
10306
  return op.shape
10307
+ if isinstance(op, HardmaxOp):
10308
+ return op.shape
8383
10309
  if isinstance(op, NegativeLogLikelihoodLossOp):
8384
10310
  return op.output_shape
8385
10311
  if isinstance(op, SoftmaxCrossEntropyLossOp):
@@ -8392,6 +10318,10 @@ class CEmitter:
8392
10318
  return op.output_shape
8393
10319
  if isinstance(op, GatherOp):
8394
10320
  return op.output_shape
10321
+ if isinstance(op, GatherNDOp):
10322
+ return op.output_shape
10323
+ if isinstance(op, ScatterNDOp):
10324
+ return op.output_shape
8395
10325
  if isinstance(op, TransposeOp):
8396
10326
  return op.output_shape
8397
10327
  if isinstance(op, ReshapeOp):
@@ -8400,6 +10330,8 @@ class CEmitter:
8400
10330
  return op.shape
8401
10331
  if isinstance(op, EyeLikeOp):
8402
10332
  return op.output_shape
10333
+ if isinstance(op, TriluOp):
10334
+ return op.output_shape
8403
10335
  if isinstance(op, TileOp):
8404
10336
  return op.output_shape
8405
10337
  if isinstance(op, PadOp):
@@ -8418,18 +10350,24 @@ class CEmitter:
8418
10350
  return op.output_shape
8419
10351
  if isinstance(op, ArgReduceOp):
8420
10352
  return op.output_shape
10353
+ if isinstance(op, TopKOp):
10354
+ return op.output_shape
8421
10355
  if isinstance(op, ConstantOfShapeOp):
8422
10356
  return op.shape
8423
10357
  if isinstance(op, ShapeOp):
8424
10358
  return op.output_shape
8425
10359
  if isinstance(op, SizeOp):
8426
10360
  return op.output_shape
10361
+ if isinstance(op, NonZeroOp):
10362
+ return op.output_shape
8427
10363
  if isinstance(op, ExpandOp):
8428
10364
  return op.output_shape
8429
10365
  if isinstance(op, CumSumOp):
8430
10366
  return op.input_shape
8431
10367
  if isinstance(op, RangeOp):
8432
10368
  return op.output_shape
10369
+ if isinstance(op, OneHotOp):
10370
+ return op.output_shape
8433
10371
  if op.output_rank == 3:
8434
10372
  return (op.batch, op.q_seq, op.q_heads * op.v_head_size)
8435
10373
  return (op.batch, op.q_heads, op.q_seq, op.v_head_size)
@@ -8441,11 +10379,16 @@ class CEmitter:
8441
10379
  | WhereOp
8442
10380
  | UnaryOp
8443
10381
  | ClipOp
10382
+ | CastOp
10383
+ | QuantizeLinearOp
8444
10384
  | MatMulOp
10385
+ | EinsumOp
8445
10386
  | GemmOp
8446
10387
  | AttentionOp
8447
10388
  | ConvOp
10389
+ | ConvTransposeOp
8448
10390
  | AveragePoolOp
10391
+ | LpPoolOp
8449
10392
  | BatchNormOp
8450
10393
  | LpNormalizationOp
8451
10394
  | InstanceNormalizationOp
@@ -8455,14 +10398,20 @@ class CEmitter:
8455
10398
  | RMSNormalizationOp
8456
10399
  | SoftmaxOp
8457
10400
  | LogSoftmaxOp
10401
+ | HardmaxOp
8458
10402
  | NegativeLogLikelihoodLossOp
8459
10403
  | SoftmaxCrossEntropyLossOp
8460
10404
  | MaxPoolOp
8461
10405
  | ConcatOp
8462
10406
  | GatherElementsOp
8463
10407
  | GatherOp
10408
+ | GatherNDOp
8464
10409
  | TransposeOp
8465
10410
  | ReshapeOp
10411
+ | IdentityOp
10412
+ | EyeLikeOp
10413
+ | TriluOp
10414
+ | TileOp
8466
10415
  | ResizeOp
8467
10416
  | GridSampleOp
8468
10417
  | ReduceOp
@@ -8470,21 +10419,25 @@ class CEmitter:
8470
10419
  | ConstantOfShapeOp
8471
10420
  | ShapeOp
8472
10421
  | SizeOp
10422
+ | NonZeroOp
8473
10423
  | ExpandOp
8474
10424
  | CumSumOp
8475
10425
  | RangeOp
10426
+ | OneHotOp
8476
10427
  | SplitOp
8477
10428
  | PadOp,
8478
10429
  ) -> ScalarType:
8479
10430
  if isinstance(op, ArgReduceOp):
8480
10431
  return op.output_dtype
10432
+ if isinstance(op, TopKOp):
10433
+ return op.output_values_dtype
8481
10434
  return op.dtype
8482
10435
 
8483
10436
  @staticmethod
8484
10437
  def _codegen_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
8485
10438
  if not shape:
8486
10439
  return (1,)
8487
- return shape
10440
+ return tuple(max(1, dim) if isinstance(dim, int) else dim for dim in shape)
8488
10441
 
8489
10442
  @staticmethod
8490
10443
  def _array_suffix(shape: tuple[int, ...]) -> str:
@@ -8623,6 +10576,8 @@ class CEmitter:
8623
10576
  dim_names: Mapping[int, str] | None,
8624
10577
  ) -> tuple[str | int, ...]:
8625
10578
  dim_names = dim_names or {}
10579
+ if not shape:
10580
+ shape = (1,)
8626
10581
  return tuple(
8627
10582
  dim_names.get(index, dim) for index, dim in enumerate(shape)
8628
10583
  )
@@ -8677,7 +10632,8 @@ class CEmitter:
8677
10632
 
8678
10633
  @staticmethod
8679
10634
  def _element_count(shape: tuple[int, ...]) -> int:
8680
- shape = CEmitter._codegen_shape(shape)
10635
+ if not shape:
10636
+ return 1
8681
10637
  count = 1
8682
10638
  for dim in shape:
8683
10639
  if dim < 0:
@@ -8745,6 +10701,7 @@ class CEmitter:
8745
10701
  testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None = None,
8746
10702
  dim_order: Sequence[str],
8747
10703
  dim_values: Mapping[str, int],
10704
+ weight_data_filename: str,
8748
10705
  ) -> str:
8749
10706
  input_counts = tuple(
8750
10707
  self._element_count(shape) for shape in model.input_shapes
@@ -8755,7 +10712,8 @@ class CEmitter:
8755
10712
  model.input_names, model.input_shapes, input_counts, model.input_dtypes
8756
10713
  ):
8757
10714
  codegen_shape = self._codegen_shape(shape)
8758
- loop_vars = self._loop_vars(codegen_shape)
10715
+ loop_shape = (1,) if not shape else shape
10716
+ loop_vars = self._loop_vars(loop_shape)
8759
10717
  if dtype in {ScalarType.F16, ScalarType.F32}:
8760
10718
  random_expr = "rng_next_float()"
8761
10719
  elif dtype == ScalarType.F64:
@@ -8769,20 +10727,26 @@ class CEmitter:
8769
10727
  constant_lines = None
8770
10728
  if constant_values is not None:
8771
10729
  constant_name = f"{name}_testbench_data"
8772
- constant_lines = [
8773
- self._format_value(value, dtype)
8774
- for value in constant_values
8775
- ]
10730
+ if constant_values:
10731
+ constant_lines = [
10732
+ self._format_value(value, dtype)
10733
+ for value in constant_values
10734
+ ]
10735
+ else:
10736
+ constant_lines = [self._format_value(0, dtype)]
8776
10737
  inputs.append(
8777
10738
  {
8778
10739
  "name": name,
8779
- "shape": codegen_shape,
10740
+ "shape": loop_shape,
8780
10741
  "shape_literal": ",".join(str(dim) for dim in shape),
8781
10742
  "count": count,
8782
10743
  "array_suffix": self._array_suffix(codegen_shape),
10744
+ "array_index_expr": "".join(
10745
+ f"[{var}]" for var in loop_vars
10746
+ ),
8783
10747
  "loop_vars": loop_vars,
8784
- "rank": len(codegen_shape),
8785
- "index_expr": self._index_expr(codegen_shape, loop_vars),
10748
+ "rank": len(loop_shape),
10749
+ "index_expr": self._index_expr(loop_shape, loop_vars),
8786
10750
  "dtype": dtype,
8787
10751
  "c_type": dtype.c_type,
8788
10752
  "random_expr": random_expr,
@@ -8797,17 +10761,21 @@ class CEmitter:
8797
10761
  model.output_names, model.output_shapes, model.output_dtypes
8798
10762
  ):
8799
10763
  codegen_shape = self._codegen_shape(shape)
8800
- output_loop_vars = self._loop_vars(codegen_shape)
10764
+ loop_shape = (1,) if not shape else shape
10765
+ output_loop_vars = self._loop_vars(loop_shape)
8801
10766
  outputs.append(
8802
10767
  {
8803
10768
  "name": name,
8804
- "shape": codegen_shape,
10769
+ "shape": loop_shape,
8805
10770
  "shape_literal": ",".join(str(dim) for dim in shape),
8806
- "count": self._element_count(codegen_shape),
10771
+ "count": self._element_count(shape),
8807
10772
  "array_suffix": self._array_suffix(codegen_shape),
10773
+ "array_index_expr": "".join(
10774
+ f"[{var}]" for var in output_loop_vars
10775
+ ),
8808
10776
  "loop_vars": output_loop_vars,
8809
- "rank": len(codegen_shape),
8810
- "index_expr": self._index_expr(codegen_shape, output_loop_vars),
10777
+ "rank": len(loop_shape),
10778
+ "index_expr": self._index_expr(loop_shape, output_loop_vars),
8811
10779
  "dtype": dtype,
8812
10780
  "c_type": dtype.c_type,
8813
10781
  "print_format": self._print_format(dtype),
@@ -8822,9 +10790,87 @@ class CEmitter:
8822
10790
  ],
8823
10791
  inputs=inputs,
8824
10792
  outputs=outputs,
10793
+ weight_data_filename=weight_data_filename,
8825
10794
  ).rstrip()
8826
10795
  return _format_c_indentation(rendered)
8827
10796
 
10797
+ @staticmethod
10798
+ def _testbench_requires_math(
10799
+ model: LoweredModel,
10800
+ testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None,
10801
+ ) -> bool:
10802
+ if not testbench_inputs:
10803
+ return False
10804
+ dtype_map = dict(zip(model.input_names, model.input_dtypes))
10805
+ float_dtypes = {ScalarType.F16, ScalarType.F32, ScalarType.F64}
10806
+ for name, values in testbench_inputs.items():
10807
+ if dtype_map.get(name) not in float_dtypes:
10808
+ continue
10809
+ for value in values:
10810
+ if not math.isfinite(float(value)):
10811
+ return True
10812
+ return False
10813
+
10814
+ def _partition_constants(
10815
+ self, constants: tuple[ConstTensor, ...]
10816
+ ) -> tuple[tuple[ConstTensor, ...], tuple[ConstTensor, ...]]:
10817
+ if self._large_weight_threshold <= 0:
10818
+ return (), constants
10819
+ inline: list[ConstTensor] = []
10820
+ large: list[ConstTensor] = []
10821
+ for const in constants:
10822
+ if self._element_count(const.shape) > self._large_weight_threshold:
10823
+ large.append(const)
10824
+ else:
10825
+ inline.append(const)
10826
+ return tuple(inline), tuple(large)
10827
+
10828
+ @staticmethod
10829
+ def _weight_data_filename(model: LoweredModel) -> str:
10830
+ return f"{model.name}.bin"
10831
+
10832
+ def _emit_weight_loader(
10833
+ self, model: LoweredModel, large_constants: tuple[ConstTensor, ...]
10834
+ ) -> str:
10835
+ lines = [f"_Bool {model.name}_load(const char *path) {{"]
10836
+ if not large_constants:
10837
+ lines.append(" (void)path;")
10838
+ lines.append(" return 1;")
10839
+ lines.append("}")
10840
+ return _format_c_indentation("\n".join(lines))
10841
+ lines.append(" FILE *file = fopen(path, \"rb\");")
10842
+ lines.append(" if (!file) {")
10843
+ lines.append(" return 0;")
10844
+ lines.append(" }")
10845
+ lines.append(
10846
+ f" _Bool ok = {model.name}_load_file(file);"
10847
+ )
10848
+ lines.append(" fclose(file);")
10849
+ lines.append(" return ok;")
10850
+ lines.append("}")
10851
+ lines.append("")
10852
+ lines.append(f"static _Bool {model.name}_load_file(FILE *file) {{")
10853
+ for const in large_constants:
10854
+ shape = self._codegen_shape(const.shape)
10855
+ loop_vars = self._loop_vars(shape)
10856
+ for depth, var in enumerate(loop_vars):
10857
+ lines.append(
10858
+ f" for (idx_t {var} = 0; {var} < {shape[depth]}; ++{var}) {{"
10859
+ )
10860
+ index_expr = "".join(f"[{var}]" for var in loop_vars)
10861
+ zero_index = "[0]" * len(shape)
10862
+ lines.append(
10863
+ f" if (fread(&{const.name}{index_expr}, "
10864
+ f"sizeof({const.name}{zero_index}), 1, file) != 1) {{"
10865
+ )
10866
+ lines.append(" return 0;")
10867
+ lines.append(" }")
10868
+ for _ in loop_vars[::-1]:
10869
+ lines.append(" }")
10870
+ lines.append(" return 1;")
10871
+ lines.append("}")
10872
+ return _format_c_indentation("\n".join(lines))
10873
+
8828
10874
  def _emit_constant_definitions(
8829
10875
  self,
8830
10876
  constants: tuple[ConstTensor, ...],
@@ -8834,26 +10880,31 @@ class CEmitter:
8834
10880
  if not constants:
8835
10881
  return ""
8836
10882
  lines: list[str] = []
8837
- for const in constants:
10883
+ for index, const in enumerate(constants, start=1):
10884
+ lines.append(self._emit_constant_comment(const, index))
8838
10885
  c_type = const.dtype.c_type
8839
- array_suffix = self._array_suffix(const.shape)
10886
+ shape = self._codegen_shape(const.shape)
10887
+ array_suffix = self._array_suffix(shape)
8840
10888
  values = [
8841
- self._format_value(value, const.dtype) for value in const.data
10889
+ self._format_weight_value(value, const.dtype)
10890
+ for value in const.data
8842
10891
  ]
8843
10892
  lines.append(
8844
10893
  f"{storage_prefix} {c_type} {const.name}{array_suffix} = {{"
8845
10894
  )
8846
10895
  if values:
8847
- chunk_size = 8
8848
- chunks = [
8849
- values[index : index + chunk_size]
8850
- for index in range(0, len(values), chunk_size)
8851
- ]
8852
- for chunk_index, chunk in enumerate(chunks):
8853
- line = " " + ", ".join(chunk)
8854
- if chunk_index != len(chunks) - 1:
8855
- line += ","
8856
- lines.append(line)
10896
+ if (
10897
+ self._truncate_weights_after is not None
10898
+ and len(values) > self._truncate_weights_after
10899
+ ):
10900
+ truncated_lines, _, _, _ = (
10901
+ self._emit_initializer_lines_truncated(
10902
+ values, shape, self._truncate_weights_after
10903
+ )
10904
+ )
10905
+ lines.extend(truncated_lines)
10906
+ else:
10907
+ lines.extend(self._emit_initializer_lines(values, shape))
8857
10908
  lines.append("};")
8858
10909
  lines.append("")
8859
10910
  if lines and not lines[-1]:
@@ -8866,12 +10917,44 @@ class CEmitter:
8866
10917
  if not constants:
8867
10918
  return ""
8868
10919
  lines = []
8869
- for const in constants:
10920
+ for index, const in enumerate(constants, start=1):
10921
+ lines.append(self._emit_constant_comment(const, index))
8870
10922
  c_type = const.dtype.c_type
8871
10923
  array_suffix = self._array_suffix(const.shape)
8872
10924
  lines.append(f"extern const {c_type} {const.name}{array_suffix};")
8873
10925
  return "\n".join(lines)
8874
10926
 
10927
+ def _emit_constant_storage_definitions(
10928
+ self,
10929
+ constants: tuple[ConstTensor, ...],
10930
+ *,
10931
+ storage_prefix: str = "static",
10932
+ ) -> str:
10933
+ if not constants:
10934
+ return ""
10935
+ lines: list[str] = []
10936
+ for index, const in enumerate(constants, start=1):
10937
+ lines.append(self._emit_constant_comment(const, index))
10938
+ c_type = const.dtype.c_type
10939
+ array_suffix = self._array_suffix(const.shape)
10940
+ lines.append(f"{storage_prefix} {c_type} {const.name}{array_suffix};")
10941
+ lines.append("")
10942
+ if lines and not lines[-1]:
10943
+ lines.pop()
10944
+ return "\n".join(lines)
10945
+
10946
+ def collect_weight_data(
10947
+ self, constants: tuple[ConstTensor, ...]
10948
+ ) -> bytes | None:
10949
+ _, large_constants = self._partition_constants(constants)
10950
+ if not large_constants:
10951
+ return None
10952
+ chunks: list[bytes] = []
10953
+ for const in large_constants:
10954
+ array = np.asarray(const.data, dtype=const.dtype.np_dtype)
10955
+ chunks.append(array.tobytes(order="C"))
10956
+ return b"".join(chunks)
10957
+
8875
10958
  @staticmethod
8876
10959
  def _index_expr(shape: tuple[int, ...], loop_vars: tuple[str, ...]) -> str:
8877
10960
  shape = CEmitter._codegen_shape(shape)
@@ -8886,6 +10969,10 @@ class CEmitter:
8886
10969
 
8887
10970
  @staticmethod
8888
10971
  def _format_float(value: float) -> str:
10972
+ if math.isnan(value):
10973
+ return "NAN"
10974
+ if math.isinf(value):
10975
+ return "-INFINITY" if value < 0 else "INFINITY"
8889
10976
  formatted = f"{value:.9g}"
8890
10977
  if "e" not in formatted and "E" not in formatted and "." not in formatted:
8891
10978
  formatted = f"{formatted}.0"
@@ -8897,11 +10984,57 @@ class CEmitter:
8897
10984
 
8898
10985
  @staticmethod
8899
10986
  def _format_double(value: float) -> str:
10987
+ if math.isnan(value):
10988
+ return "NAN"
10989
+ if math.isinf(value):
10990
+ return "-INFINITY" if value < 0 else "INFINITY"
8900
10991
  formatted = f"{value:.17g}"
8901
10992
  if "e" not in formatted and "E" not in formatted and "." not in formatted:
8902
10993
  formatted = f"{formatted}.0"
8903
10994
  return formatted
8904
10995
 
10996
+ @staticmethod
10997
+ def _format_float32_hex(value: float) -> str:
10998
+ bits = struct.unpack("<I", struct.pack("<f", float(value)))[0]
10999
+ sign = "-" if (bits >> 31) else ""
11000
+ exponent = (bits >> 23) & 0xFF
11001
+ mantissa = bits & 0x7FFFFF
11002
+ if exponent == 0 and mantissa == 0:
11003
+ return f"{sign}0x0.0p+0"
11004
+ if exponent == 0xFF:
11005
+ if mantissa == 0:
11006
+ return f"{sign}INFINITY"
11007
+ return "NAN"
11008
+ if exponent == 0:
11009
+ shift = mantissa.bit_length() - 1
11010
+ exponent_val = shift - 149
11011
+ fraction = (mantissa - (1 << shift)) << (24 - shift)
11012
+ else:
11013
+ exponent_val = exponent - 127
11014
+ fraction = mantissa << 1
11015
+ return f"{sign}0x1.{fraction:06x}p{exponent_val:+d}"
11016
+
11017
+ @staticmethod
11018
+ def _format_float64_hex(value: float) -> str:
11019
+ bits = struct.unpack("<Q", struct.pack("<d", float(value)))[0]
11020
+ sign = "-" if (bits >> 63) else ""
11021
+ exponent = (bits >> 52) & 0x7FF
11022
+ mantissa = bits & 0xFFFFFFFFFFFFF
11023
+ if exponent == 0 and mantissa == 0:
11024
+ return f"{sign}0x0.0p+0"
11025
+ if exponent == 0x7FF:
11026
+ if mantissa == 0:
11027
+ return f"{sign}INFINITY"
11028
+ return "NAN"
11029
+ if exponent == 0:
11030
+ shift = mantissa.bit_length() - 1
11031
+ exponent_val = shift - 1074
11032
+ fraction = (mantissa - (1 << shift)) << (52 - shift)
11033
+ else:
11034
+ exponent_val = exponent - 1023
11035
+ fraction = mantissa
11036
+ return f"{sign}0x1.{fraction:013x}p{exponent_val:+d}"
11037
+
8905
11038
  @staticmethod
8906
11039
  def _format_floating(value: float, dtype: ScalarType) -> str:
8907
11040
  if dtype == ScalarType.F64:
@@ -8992,14 +11125,139 @@ class CEmitter:
8992
11125
  return self._format_int(int(value), 8, "INT8_MIN")
8993
11126
  raise CodegenError(f"Unsupported dtype {dtype.onnx_name}")
8994
11127
 
11128
+ def _format_weight_value(
11129
+ self, value: float | int | bool, dtype: ScalarType
11130
+ ) -> str:
11131
+ if dtype == ScalarType.F16:
11132
+ formatted = self._format_float32_hex(float(value))
11133
+ if formatted == "NAN" or formatted.endswith("INFINITY"):
11134
+ return f"(_Float16){formatted}"
11135
+ return f"(_Float16){formatted}f"
11136
+ if dtype == ScalarType.F32:
11137
+ formatted = self._format_float32_hex(float(value))
11138
+ if formatted == "NAN" or formatted.endswith("INFINITY"):
11139
+ return formatted
11140
+ return f"{formatted}f"
11141
+ if dtype == ScalarType.F64:
11142
+ return self._format_float64_hex(float(value))
11143
+ if dtype == ScalarType.BOOL:
11144
+ return "true" if bool(value) else "false"
11145
+ if dtype == ScalarType.U64:
11146
+ return self._format_uint(int(value), 64, "UINT64_MAX")
11147
+ if dtype == ScalarType.U32:
11148
+ return self._format_uint(int(value), 32, "UINT32_MAX")
11149
+ if dtype == ScalarType.U16:
11150
+ return self._format_uint(int(value), 16, "UINT16_MAX")
11151
+ if dtype == ScalarType.U8:
11152
+ return self._format_uint(int(value), 8, "UINT8_MAX")
11153
+ if dtype == ScalarType.I64:
11154
+ return self._format_int64(int(value))
11155
+ if dtype == ScalarType.I32:
11156
+ return self._format_int(int(value), 32, "INT32_MIN")
11157
+ if dtype == ScalarType.I16:
11158
+ return self._format_int(int(value), 16, "INT16_MIN")
11159
+ if dtype == ScalarType.I8:
11160
+ return self._format_int(int(value), 8, "INT8_MIN")
11161
+ raise CodegenError(f"Unsupported dtype {dtype.onnx_name}")
11162
+
11163
+ @staticmethod
11164
+ def _emit_initializer_lines(
11165
+ values: Sequence[str],
11166
+ shape: tuple[int, ...],
11167
+ indent: str = " ",
11168
+ per_line: int = 8,
11169
+ ) -> list[str]:
11170
+ if len(shape) == 1:
11171
+ lines: list[str] = []
11172
+ for index in range(0, len(values), per_line):
11173
+ chunk = ", ".join(values[index : index + per_line])
11174
+ lines.append(f"{indent}{chunk},")
11175
+ if lines:
11176
+ lines[-1] = lines[-1].rstrip(",")
11177
+ return lines
11178
+ sub_shape = shape[1:]
11179
+ sub_size = prod(sub_shape)
11180
+ lines = []
11181
+ for index in range(shape[0]):
11182
+ start = index * sub_size
11183
+ end = start + sub_size
11184
+ lines.append(f"{indent}{{")
11185
+ lines.extend(
11186
+ CEmitter._emit_initializer_lines(
11187
+ values[start:end],
11188
+ sub_shape,
11189
+ indent + " ",
11190
+ per_line,
11191
+ )
11192
+ )
11193
+ lines.append(f"{indent}}},")
11194
+ if lines:
11195
+ lines[-1] = lines[-1].rstrip(",")
11196
+ return lines
11197
+
11198
+ @staticmethod
11199
+ def _emit_initializer_lines_truncated(
11200
+ values: Sequence[str],
11201
+ shape: tuple[int, ...],
11202
+ truncate_after: int,
11203
+ indent: str = " ",
11204
+ per_line: int = 8,
11205
+ start_index: int = 0,
11206
+ emitted: int = 0,
11207
+ ) -> tuple[list[str], int, int, bool]:
11208
+ if len(shape) == 1:
11209
+ items: list[str] = []
11210
+ truncated = False
11211
+ index = start_index
11212
+ for _ in range(shape[0]):
11213
+ if emitted >= truncate_after:
11214
+ items.append("...")
11215
+ truncated = True
11216
+ break
11217
+ items.append(values[index])
11218
+ index += 1
11219
+ emitted += 1
11220
+ lines: list[str] = []
11221
+ for item_index in range(0, len(items), per_line):
11222
+ chunk = ", ".join(items[item_index : item_index + per_line])
11223
+ lines.append(f"{indent}{chunk},")
11224
+ if lines:
11225
+ lines[-1] = lines[-1].rstrip(",")
11226
+ return lines, index, emitted, truncated
11227
+ sub_shape = shape[1:]
11228
+ lines: list[str] = []
11229
+ index = start_index
11230
+ truncated = False
11231
+ for _ in range(shape[0]):
11232
+ lines.append(f"{indent}{{")
11233
+ sub_lines, index, emitted, sub_truncated = (
11234
+ CEmitter._emit_initializer_lines_truncated(
11235
+ values,
11236
+ sub_shape,
11237
+ truncate_after,
11238
+ indent + " ",
11239
+ per_line,
11240
+ index,
11241
+ emitted,
11242
+ )
11243
+ )
11244
+ lines.extend(sub_lines)
11245
+ lines.append(f"{indent}}},")
11246
+ if sub_truncated:
11247
+ truncated = True
11248
+ break
11249
+ if lines:
11250
+ lines[-1] = lines[-1].rstrip(",")
11251
+ return lines, index, emitted, truncated
11252
+
8995
11253
  @staticmethod
8996
11254
  def _print_format(dtype: ScalarType) -> str:
8997
11255
  if dtype == ScalarType.F16:
8998
- return "%.8g"
11256
+ return "\\\"%a\\\""
8999
11257
  if dtype == ScalarType.F32:
9000
- return "%.8g"
11258
+ return "\\\"%a\\\""
9001
11259
  if dtype == ScalarType.F64:
9002
- return "%.17g"
11260
+ return "\\\"%a\\\""
9003
11261
  if dtype == ScalarType.BOOL:
9004
11262
  return "%d"
9005
11263
  if dtype == ScalarType.U64: