emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ import re
9
9
  import struct
10
10
  from typing import Mapping, Sequence
11
11
 
12
- from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
12
+ from jinja2 import Environment, FileSystemLoader, PackageLoader, Template, select_autoescape
13
13
  import numpy as np
14
14
 
15
15
  from ..errors import CodegenError
@@ -37,15 +37,18 @@ from ..ir.ops import (
37
37
  AttentionOp,
38
38
  AveragePoolOp,
39
39
  BatchNormOp,
40
+ BernoulliOp,
40
41
  BinaryOp,
41
42
  CastOp,
42
43
  ClipOp,
43
44
  ConcatOp,
44
45
  ConstantOfShapeOp,
45
46
  ConvOp,
47
+ ConvIntegerOp,
46
48
  ConvTransposeOp,
47
49
  CumSumOp,
48
50
  DepthToSpaceOp,
51
+ DequantizeLinearOp,
49
52
  EinsumKind,
50
53
  EinsumOp,
51
54
  ExpandOp,
@@ -54,8 +57,10 @@ from ..ir.ops import (
54
57
  GatherNDOp,
55
58
  GatherOp,
56
59
  GemmOp,
60
+ GruOp,
57
61
  GridSampleOp,
58
62
  GroupNormalizationOp,
63
+ HammingWindowOp,
59
64
  HardmaxOp,
60
65
  IdentityOp,
61
66
  InstanceNormalizationOp,
@@ -73,8 +78,11 @@ from ..ir.ops import (
73
78
  NonMaxSuppressionOp,
74
79
  NonZeroOp,
75
80
  OneHotOp,
81
+ OptionalHasElementOp,
76
82
  PadOp,
77
83
  QuantizeLinearOp,
84
+ PowOp,
85
+ QLinearMulOp,
78
86
  QLinearMatMulOp,
79
87
  RangeOp,
80
88
  ReduceOp,
@@ -91,6 +99,7 @@ from ..ir.ops import (
91
99
  SpaceToDepthOp,
92
100
  SplitOp,
93
101
  TensorScatterOp,
102
+ TfIdfVectorizerOp,
94
103
  TileOp,
95
104
  TopKOp,
96
105
  TransposeOp,
@@ -278,9 +287,11 @@ class ModelHeader:
278
287
  class LoweredModel:
279
288
  name: str
280
289
  input_names: tuple[str, ...]
290
+ input_optional_names: tuple[str | None, ...]
281
291
  input_shapes: tuple[tuple[int, ...], ...]
282
292
  input_dtypes: tuple[ScalarType, ...]
283
293
  output_names: tuple[str, ...]
294
+ output_optional_names: tuple[str | None, ...]
284
295
  output_shapes: tuple[tuple[int, ...], ...]
285
296
  output_dtypes: tuple[ScalarType, ...]
286
297
  constants: tuple[ConstTensor, ...]
@@ -304,20 +315,37 @@ class _EmitState:
304
315
  class CEmitter:
305
316
  def __init__(
306
317
  self,
307
- template_dir: Path,
318
+ template_dir: Path | None,
308
319
  *,
309
320
  restrict_arrays: bool = True,
321
+ fp32_accumulation_strategy: str = "fp64",
322
+ fp16_accumulation_strategy: str = "fp32",
310
323
  truncate_weights_after: int | None = None,
311
324
  large_temp_threshold_bytes: int = 1024,
312
325
  large_weight_threshold: int = 1024,
313
326
  ) -> None:
327
+ loader = (
328
+ FileSystemLoader(str(template_dir))
329
+ if template_dir is not None
330
+ else PackageLoader("emx_onnx_cgen", "templates")
331
+ )
314
332
  self._env = Environment(
315
- loader=FileSystemLoader(str(template_dir)),
333
+ loader=loader,
316
334
  autoescape=select_autoescape(enabled_extensions=()),
317
335
  trim_blocks=True,
318
336
  lstrip_blocks=True,
319
337
  )
320
338
  self._restrict_arrays = restrict_arrays
339
+ if fp32_accumulation_strategy not in {"simple", "fp64"}:
340
+ raise CodegenError(
341
+ "fp32_accumulation_strategy must be 'simple' or 'fp64'"
342
+ )
343
+ self._fp32_accumulation_strategy = fp32_accumulation_strategy
344
+ if fp16_accumulation_strategy not in {"simple", "fp32"}:
345
+ raise CodegenError(
346
+ "fp16_accumulation_strategy must be 'simple' or 'fp32'"
347
+ )
348
+ self._fp16_accumulation_strategy = fp16_accumulation_strategy
321
349
  if truncate_weights_after is not None and truncate_weights_after < 1:
322
350
  raise CodegenError("truncate_weights_after must be >= 1")
323
351
  self._truncate_weights_after = truncate_weights_after
@@ -390,6 +418,21 @@ class CEmitter:
390
418
  mapped[key] = unique
391
419
  return mapped
392
420
 
421
+ def _accumulation_dtype(self, dtype: ScalarType) -> ScalarType:
422
+ if dtype == ScalarType.F32:
423
+ return (
424
+ ScalarType.F32
425
+ if self._fp32_accumulation_strategy == "simple"
426
+ else ScalarType.F64
427
+ )
428
+ if dtype == ScalarType.F16:
429
+ return (
430
+ ScalarType.F16
431
+ if self._fp16_accumulation_strategy == "simple"
432
+ else ScalarType.F32
433
+ )
434
+ return dtype
435
+
393
436
  def _ctx_name(self, name: str) -> str:
394
437
  if self._emit_state is None:
395
438
  raise CodegenError("Emitter state not initialized")
@@ -410,6 +453,12 @@ class CEmitter:
410
453
  raise CodegenError("Emitter state not initialized")
411
454
  return self._emit_state.op_context.require_derived(op, key)
412
455
 
456
+ def _maybe_derived(self, op: OpBase, key: str) -> object | None:
457
+ if self._emit_state is None:
458
+ raise CodegenError("Emitter state not initialized")
459
+ value = self._emit_state.op_context.get_derived(op, key, None)
460
+ return value
461
+
413
462
  @staticmethod
414
463
  def _build_param_decls(
415
464
  specs: Sequence[tuple[str | None, str, str, bool]]
@@ -447,6 +496,8 @@ class CEmitter:
447
496
  | ClipOp
448
497
  | CastOp
449
498
  | QuantizeLinearOp
499
+ | DequantizeLinearOp
500
+ | QLinearMulOp
450
501
  | QLinearMatMulOp
451
502
  | MatMulOp
452
503
  | EinsumOp
@@ -454,6 +505,7 @@ class CEmitter:
454
505
  | AttentionOp
455
506
  | RotaryEmbeddingOp
456
507
  | ConvOp
508
+ | ConvIntegerOp
457
509
  | AveragePoolOp
458
510
  | BatchNormOp
459
511
  | LpNormalizationOp
@@ -463,6 +515,7 @@ class CEmitter:
463
515
  | MeanVarianceNormalizationOp
464
516
  | RMSNormalizationOp
465
517
  | LrnOp
518
+ | GruOp
466
519
  | LstmOp
467
520
  | AdagradOp
468
521
  | SoftmaxOp
@@ -480,6 +533,7 @@ class CEmitter:
480
533
  | TransposeOp
481
534
  | ReshapeOp
482
535
  | IdentityOp
536
+ | BernoulliOp
483
537
  | EyeLikeOp
484
538
  | TriluOp
485
539
  | TileOp
@@ -495,11 +549,13 @@ class CEmitter:
495
549
  | ConstantOfShapeOp
496
550
  | ShapeOp
497
551
  | SizeOp
552
+ | OptionalHasElementOp
498
553
  | NonZeroOp
499
554
  | NonMaxSuppressionOp
500
555
  | ExpandOp
501
556
  | CumSumOp
502
557
  | RangeOp
558
+ | HammingWindowOp
503
559
  | OneHotOp
504
560
  | SplitOp,
505
561
  ) -> tuple[str, ...]:
@@ -527,6 +583,24 @@ class CEmitter:
527
583
  names.append(op.zero_point)
528
584
  names.append(op.output)
529
585
  return tuple(names)
586
+ if isinstance(op, DequantizeLinearOp):
587
+ names = [op.input0, op.scale]
588
+ if op.zero_point is not None:
589
+ names.append(op.zero_point)
590
+ names.append(op.output)
591
+ return tuple(names)
592
+ if isinstance(op, QLinearMulOp):
593
+ return (
594
+ op.input0,
595
+ op.input0_scale,
596
+ op.input0_zero_point,
597
+ op.input1,
598
+ op.input1_scale,
599
+ op.input1_zero_point,
600
+ op.output_scale,
601
+ op.output_zero_point,
602
+ op.output,
603
+ )
530
604
  if isinstance(op, QLinearMatMulOp):
531
605
  return (
532
606
  op.input0,
@@ -579,6 +653,14 @@ class CEmitter:
579
653
  names.append(op.bias)
580
654
  names.append(op.output)
581
655
  return tuple(names)
656
+ if isinstance(op, ConvIntegerOp):
657
+ names = [op.input0, op.weights]
658
+ if op.x_zero_point is not None:
659
+ names.append(op.x_zero_point)
660
+ if op.w_zero_point is not None:
661
+ names.append(op.w_zero_point)
662
+ names.append(op.output)
663
+ return tuple(names)
582
664
  if isinstance(op, ConvTransposeOp):
583
665
  names = [op.input0, op.weights]
584
666
  if op.bias is not None:
@@ -611,6 +693,19 @@ class CEmitter:
611
693
  return (op.input0, op.output)
612
694
  if isinstance(op, RMSNormalizationOp):
613
695
  return (op.input0, op.scale, op.output)
696
+ if isinstance(op, GruOp):
697
+ names = [op.input_x, op.input_w, op.input_r]
698
+ if op.input_b is not None:
699
+ names.append(op.input_b)
700
+ if op.input_sequence_lens is not None:
701
+ names.append(op.input_sequence_lens)
702
+ if op.input_initial_h is not None:
703
+ names.append(op.input_initial_h)
704
+ if op.output_y is not None:
705
+ names.append(op.output_y)
706
+ if op.output_y_h is not None:
707
+ names.append(op.output_y_h)
708
+ return tuple(names)
614
709
  if isinstance(op, LstmOp):
615
710
  names = [op.input_x, op.input_w, op.input_r]
616
711
  if op.input_b is not None:
@@ -705,14 +800,20 @@ class CEmitter:
705
800
  return tuple(names)
706
801
  if isinstance(op, RangeOp):
707
802
  return (op.start, op.limit, op.delta, op.output)
803
+ if isinstance(op, HammingWindowOp):
804
+ return (op.size, op.output)
708
805
  if isinstance(op, OneHotOp):
709
806
  return (op.indices, op.depth, op.values, op.output)
807
+ if isinstance(op, TfIdfVectorizerOp):
808
+ return (op.input0, op.output)
710
809
  if isinstance(op, SplitOp):
711
810
  return (op.input0, *op.outputs)
712
811
  if isinstance(op, ReshapeOp):
713
812
  return (op.input0, op.output)
714
813
  if isinstance(op, IdentityOp):
715
814
  return (op.input0, op.output)
815
+ if isinstance(op, BernoulliOp):
816
+ return (op.input0, op.output)
716
817
  if isinstance(op, EyeLikeOp):
717
818
  return (op.input0, op.output)
718
819
  if isinstance(op, TriluOp):
@@ -761,7 +862,12 @@ class CEmitter:
761
862
  if isinstance(op, GridSampleOp):
762
863
  return (op.input0, op.grid, op.output)
763
864
  if isinstance(op, TopKOp):
764
- return (op.input0, op.output_values, op.output_indices)
865
+ return (
866
+ op.input0,
867
+ op.k_input,
868
+ op.output_values,
869
+ op.output_indices,
870
+ )
765
871
  if isinstance(op, ReduceOp):
766
872
  names = [op.input0]
767
873
  if op.axes_input is not None:
@@ -777,6 +883,12 @@ class CEmitter:
777
883
  names = [model.name]
778
884
  names.extend(model.input_names)
779
885
  names.extend(model.output_names)
886
+ names.extend(
887
+ name for name in model.input_optional_names if name is not None
888
+ )
889
+ names.extend(
890
+ name for name in model.output_optional_names if name is not None
891
+ )
780
892
  for op in model.ops:
781
893
  names.extend(
782
894
  name for name in self._op_names(op) if name not in constant_names
@@ -809,12 +921,15 @@ class CEmitter:
809
921
  def _map_op_names(
810
922
  self,
811
923
  op: BinaryOp
924
+ | PowOp
812
925
  | MultiInputBinaryOp
813
926
  | WhereOp
814
927
  | UnaryOp
815
928
  | ClipOp
816
929
  | CastOp
817
930
  | QuantizeLinearOp
931
+ | DequantizeLinearOp
932
+ | QLinearMulOp
818
933
  | QLinearMatMulOp
819
934
  | MatMulOp
820
935
  | EinsumOp
@@ -822,6 +937,7 @@ class CEmitter:
822
937
  | AttentionOp
823
938
  | RotaryEmbeddingOp
824
939
  | ConvOp
940
+ | ConvIntegerOp
825
941
  | ConvTransposeOp
826
942
  | AveragePoolOp
827
943
  | LpPoolOp
@@ -833,6 +949,7 @@ class CEmitter:
833
949
  | MeanVarianceNormalizationOp
834
950
  | RMSNormalizationOp
835
951
  | LrnOp
952
+ | GruOp
836
953
  | LstmOp
837
954
  | AdagradOp
838
955
  | SoftmaxOp
@@ -865,22 +982,28 @@ class CEmitter:
865
982
  | ConstantOfShapeOp
866
983
  | ShapeOp
867
984
  | SizeOp
985
+ | OptionalHasElementOp
868
986
  | NonZeroOp
869
987
  | NonMaxSuppressionOp
870
988
  | ExpandOp
871
989
  | CumSumOp
872
990
  | RangeOp
991
+ | HammingWindowOp
873
992
  | OneHotOp
993
+ | TfIdfVectorizerOp
874
994
  | SplitOp,
875
995
  name_map: dict[str, str],
876
996
  ) -> (
877
997
  BinaryOp
998
+ | PowOp
878
999
  | MultiInputBinaryOp
879
1000
  | WhereOp
880
1001
  | UnaryOp
881
1002
  | ClipOp
882
1003
  | CastOp
883
1004
  | QuantizeLinearOp
1005
+ | DequantizeLinearOp
1006
+ | QLinearMulOp
884
1007
  | QLinearMatMulOp
885
1008
  | MatMulOp
886
1009
  | EinsumOp
@@ -888,6 +1011,7 @@ class CEmitter:
888
1011
  | AttentionOp
889
1012
  | RotaryEmbeddingOp
890
1013
  | ConvOp
1014
+ | ConvIntegerOp
891
1015
  | ConvTransposeOp
892
1016
  | AveragePoolOp
893
1017
  | LpPoolOp
@@ -899,6 +1023,7 @@ class CEmitter:
899
1023
  | MeanVarianceNormalizationOp
900
1024
  | RMSNormalizationOp
901
1025
  | LrnOp
1026
+ | GruOp
902
1027
  | LstmOp
903
1028
  | AdagradOp
904
1029
  | SoftmaxOp
@@ -931,14 +1056,25 @@ class CEmitter:
931
1056
  | ConstantOfShapeOp
932
1057
  | ShapeOp
933
1058
  | SizeOp
1059
+ | OptionalHasElementOp
934
1060
  | NonZeroOp
935
1061
  | NonMaxSuppressionOp
936
1062
  | ExpandOp
937
1063
  | CumSumOp
938
1064
  | RangeOp
1065
+ | HammingWindowOp
939
1066
  | OneHotOp
940
1067
  | SplitOp
1068
+ | TfIdfVectorizerOp
941
1069
  ):
1070
+ if isinstance(op, PowOp):
1071
+ return PowOp(
1072
+ input0=name_map.get(op.input0, op.input0),
1073
+ input1=name_map.get(op.input1, op.input1),
1074
+ output=name_map.get(op.output, op.output),
1075
+ function=op.function,
1076
+ operator_kind=op.operator_kind,
1077
+ )
942
1078
  if isinstance(op, BinaryOp):
943
1079
  return BinaryOp(
944
1080
  input0=name_map.get(op.input0, op.input0),
@@ -946,11 +1082,6 @@ class CEmitter:
946
1082
  output=name_map.get(op.output, op.output),
947
1083
  function=op.function,
948
1084
  operator_kind=op.operator_kind,
949
- input0_shape=op.input0_shape,
950
- input1_shape=op.input1_shape,
951
- shape=op.shape,
952
- dtype=op.dtype,
953
- input_dtype=op.input_dtype,
954
1085
  )
955
1086
  if isinstance(op, MultiInputBinaryOp):
956
1087
  return MultiInputBinaryOp(
@@ -968,20 +1099,12 @@ class CEmitter:
968
1099
  input_x=name_map.get(op.input_x, op.input_x),
969
1100
  input_y=name_map.get(op.input_y, op.input_y),
970
1101
  output=name_map.get(op.output, op.output),
971
- condition_shape=op.condition_shape,
972
- x_shape=op.x_shape,
973
- y_shape=op.y_shape,
974
- output_shape=op.output_shape,
975
- dtype=op.dtype,
976
1102
  )
977
1103
  if isinstance(op, UnaryOp):
978
1104
  return UnaryOp(
979
1105
  input0=name_map.get(op.input0, op.input0),
980
1106
  output=name_map.get(op.output, op.output),
981
1107
  function=op.function,
982
- shape=op.shape,
983
- dtype=op.dtype,
984
- input_dtype=op.input_dtype,
985
1108
  params=op.params,
986
1109
  )
987
1110
  if isinstance(op, ClipOp):
@@ -990,11 +1113,8 @@ class CEmitter:
990
1113
  input_min=self._map_optional_name(name_map, op.input_min),
991
1114
  input_max=self._map_optional_name(name_map, op.input_max),
992
1115
  output=name_map.get(op.output, op.output),
993
- input_shape=op.input_shape,
994
- min_shape=op.min_shape,
995
- max_shape=op.max_shape,
996
- output_shape=op.output_shape,
997
- dtype=op.dtype,
1116
+ min_value=op.min_value,
1117
+ max_value=op.max_value,
998
1118
  )
999
1119
  if isinstance(op, CastOp):
1000
1120
  return CastOp(
@@ -1016,8 +1136,21 @@ class CEmitter:
1016
1136
  input_dtype=op.input_dtype,
1017
1137
  scale_dtype=op.scale_dtype,
1018
1138
  )
1019
- if isinstance(op, QLinearMatMulOp):
1020
- return QLinearMatMulOp(
1139
+ if isinstance(op, DequantizeLinearOp):
1140
+ return DequantizeLinearOp(
1141
+ input0=name_map.get(op.input0, op.input0),
1142
+ scale=name_map.get(op.scale, op.scale),
1143
+ zero_point=self._map_optional_name(name_map, op.zero_point),
1144
+ output=name_map.get(op.output, op.output),
1145
+ input_shape=op.input_shape,
1146
+ axis=op.axis,
1147
+ block_size=op.block_size,
1148
+ dtype=op.dtype,
1149
+ input_dtype=op.input_dtype,
1150
+ scale_dtype=op.scale_dtype,
1151
+ )
1152
+ if isinstance(op, QLinearMulOp):
1153
+ return QLinearMulOp(
1021
1154
  input0=name_map.get(op.input0, op.input0),
1022
1155
  input0_scale=name_map.get(op.input0_scale, op.input0_scale),
1023
1156
  input0_zero_point=name_map.get(
@@ -1036,14 +1169,6 @@ class CEmitter:
1036
1169
  input0_shape=op.input0_shape,
1037
1170
  input1_shape=op.input1_shape,
1038
1171
  output_shape=op.output_shape,
1039
- batch_shape=op.batch_shape,
1040
- input0_batch_shape=op.input0_batch_shape,
1041
- input1_batch_shape=op.input1_batch_shape,
1042
- m=op.m,
1043
- n=op.n,
1044
- k=op.k,
1045
- left_vector=op.left_vector,
1046
- right_vector=op.right_vector,
1047
1172
  input0_dtype=op.input0_dtype,
1048
1173
  input1_dtype=op.input1_dtype,
1049
1174
  dtype=op.dtype,
@@ -1057,10 +1182,22 @@ class CEmitter:
1057
1182
  input1_zero_shape=op.input1_zero_shape,
1058
1183
  output_zero_shape=op.output_zero_shape,
1059
1184
  )
1060
- if isinstance(op, MatMulOp):
1061
- return MatMulOp(
1185
+ if isinstance(op, QLinearMatMulOp):
1186
+ return QLinearMatMulOp(
1062
1187
  input0=name_map.get(op.input0, op.input0),
1188
+ input0_scale=name_map.get(op.input0_scale, op.input0_scale),
1189
+ input0_zero_point=name_map.get(
1190
+ op.input0_zero_point, op.input0_zero_point
1191
+ ),
1063
1192
  input1=name_map.get(op.input1, op.input1),
1193
+ input1_scale=name_map.get(op.input1_scale, op.input1_scale),
1194
+ input1_zero_point=name_map.get(
1195
+ op.input1_zero_point, op.input1_zero_point
1196
+ ),
1197
+ output_scale=name_map.get(op.output_scale, op.output_scale),
1198
+ output_zero_point=name_map.get(
1199
+ op.output_zero_point, op.output_zero_point
1200
+ ),
1064
1201
  output=name_map.get(op.output, op.output),
1065
1202
  input0_shape=op.input0_shape,
1066
1203
  input1_shape=op.input1_shape,
@@ -1073,7 +1210,24 @@ class CEmitter:
1073
1210
  k=op.k,
1074
1211
  left_vector=op.left_vector,
1075
1212
  right_vector=op.right_vector,
1213
+ input0_dtype=op.input0_dtype,
1214
+ input1_dtype=op.input1_dtype,
1076
1215
  dtype=op.dtype,
1216
+ input0_scale_dtype=op.input0_scale_dtype,
1217
+ input1_scale_dtype=op.input1_scale_dtype,
1218
+ output_scale_dtype=op.output_scale_dtype,
1219
+ input0_scale_shape=op.input0_scale_shape,
1220
+ input1_scale_shape=op.input1_scale_shape,
1221
+ output_scale_shape=op.output_scale_shape,
1222
+ input0_zero_shape=op.input0_zero_shape,
1223
+ input1_zero_shape=op.input1_zero_shape,
1224
+ output_zero_shape=op.output_zero_shape,
1225
+ )
1226
+ if isinstance(op, MatMulOp):
1227
+ return MatMulOp(
1228
+ input0=name_map.get(op.input0, op.input0),
1229
+ input1=name_map.get(op.input1, op.input1),
1230
+ output=name_map.get(op.output, op.output),
1077
1231
  )
1078
1232
  if isinstance(op, EinsumOp):
1079
1233
  return EinsumOp(
@@ -1091,15 +1245,10 @@ class CEmitter:
1091
1245
  input_b=name_map.get(op.input_b, op.input_b),
1092
1246
  input_c=self._map_optional_name(name_map, op.input_c),
1093
1247
  output=name_map.get(op.output, op.output),
1094
- m=op.m,
1095
- n=op.n,
1096
- k=op.k,
1097
1248
  trans_a=op.trans_a,
1098
1249
  trans_b=op.trans_b,
1099
1250
  alpha=op.alpha,
1100
1251
  beta=op.beta,
1101
- c_shape=op.c_shape,
1102
- dtype=op.dtype,
1103
1252
  )
1104
1253
  if isinstance(op, AttentionOp):
1105
1254
  return AttentionOp(
@@ -1202,6 +1351,35 @@ class CEmitter:
1202
1351
  group=op.group,
1203
1352
  dtype=op.dtype,
1204
1353
  )
1354
+ if isinstance(op, ConvIntegerOp):
1355
+ return ConvIntegerOp(
1356
+ input0=name_map.get(op.input0, op.input0),
1357
+ weights=name_map.get(op.weights, op.weights),
1358
+ x_zero_point=self._map_optional_name(
1359
+ name_map, op.x_zero_point
1360
+ ),
1361
+ w_zero_point=self._map_optional_name(
1362
+ name_map, op.w_zero_point
1363
+ ),
1364
+ output=name_map.get(op.output, op.output),
1365
+ batch=op.batch,
1366
+ in_channels=op.in_channels,
1367
+ out_channels=op.out_channels,
1368
+ spatial_rank=op.spatial_rank,
1369
+ in_spatial=op.in_spatial,
1370
+ out_spatial=op.out_spatial,
1371
+ kernel_shape=op.kernel_shape,
1372
+ strides=op.strides,
1373
+ pads=op.pads,
1374
+ dilations=op.dilations,
1375
+ group=op.group,
1376
+ input_dtype=op.input_dtype,
1377
+ weight_dtype=op.weight_dtype,
1378
+ dtype=op.dtype,
1379
+ x_zero_point_shape=op.x_zero_point_shape,
1380
+ w_zero_point_shape=op.w_zero_point_shape,
1381
+ w_zero_point_per_channel=op.w_zero_point_per_channel,
1382
+ )
1205
1383
  if isinstance(op, ConvTransposeOp):
1206
1384
  return ConvTransposeOp(
1207
1385
  input0=name_map.get(op.input0, op.input0),
@@ -1228,16 +1406,26 @@ class CEmitter:
1228
1406
  output=name_map.get(op.output, op.output),
1229
1407
  batch=op.batch,
1230
1408
  channels=op.channels,
1409
+ spatial_rank=op.spatial_rank,
1410
+ in_d=op.in_d,
1231
1411
  in_h=op.in_h,
1232
1412
  in_w=op.in_w,
1413
+ out_d=op.out_d,
1233
1414
  out_h=op.out_h,
1234
1415
  out_w=op.out_w,
1416
+ kernel_d=op.kernel_d,
1235
1417
  kernel_h=op.kernel_h,
1236
1418
  kernel_w=op.kernel_w,
1419
+ dilation_d=op.dilation_d,
1420
+ dilation_h=op.dilation_h,
1421
+ dilation_w=op.dilation_w,
1422
+ stride_d=op.stride_d,
1237
1423
  stride_h=op.stride_h,
1238
1424
  stride_w=op.stride_w,
1425
+ pad_front=op.pad_front,
1239
1426
  pad_top=op.pad_top,
1240
1427
  pad_left=op.pad_left,
1428
+ pad_back=op.pad_back,
1241
1429
  pad_bottom=op.pad_bottom,
1242
1430
  pad_right=op.pad_right,
1243
1431
  count_include_pad=op.count_include_pad,
@@ -1255,6 +1443,8 @@ class CEmitter:
1255
1443
  out_w=op.out_w,
1256
1444
  kernel_h=op.kernel_h,
1257
1445
  kernel_w=op.kernel_w,
1446
+ dilation_h=op.dilation_h,
1447
+ dilation_w=op.dilation_w,
1258
1448
  stride_h=op.stride_h,
1259
1449
  stride_w=op.stride_w,
1260
1450
  pad_top=op.pad_top,
@@ -1371,6 +1561,35 @@ class CEmitter:
1371
1561
  bias=op.bias,
1372
1562
  dtype=op.dtype,
1373
1563
  )
1564
+ if isinstance(op, GruOp):
1565
+ return GruOp(
1566
+ input_x=name_map.get(op.input_x, op.input_x),
1567
+ input_w=name_map.get(op.input_w, op.input_w),
1568
+ input_r=name_map.get(op.input_r, op.input_r),
1569
+ input_b=self._map_optional_name(name_map, op.input_b),
1570
+ input_sequence_lens=self._map_optional_name(
1571
+ name_map, op.input_sequence_lens
1572
+ ),
1573
+ input_initial_h=self._map_optional_name(
1574
+ name_map, op.input_initial_h
1575
+ ),
1576
+ output_y=self._map_optional_name(name_map, op.output_y),
1577
+ output_y_h=self._map_optional_name(name_map, op.output_y_h),
1578
+ seq_length=op.seq_length,
1579
+ batch_size=op.batch_size,
1580
+ input_size=op.input_size,
1581
+ hidden_size=op.hidden_size,
1582
+ num_directions=op.num_directions,
1583
+ direction=op.direction,
1584
+ layout=op.layout,
1585
+ linear_before_reset=op.linear_before_reset,
1586
+ clip=op.clip,
1587
+ activation_kinds=op.activation_kinds,
1588
+ activation_alphas=op.activation_alphas,
1589
+ activation_betas=op.activation_betas,
1590
+ dtype=op.dtype,
1591
+ sequence_lens_dtype=op.sequence_lens_dtype,
1592
+ )
1374
1593
  if isinstance(op, LstmOp):
1375
1594
  return LstmOp(
1376
1595
  input_x=name_map.get(op.input_x, op.input_x),
@@ -1436,34 +1655,19 @@ class CEmitter:
1436
1655
  return SoftmaxOp(
1437
1656
  input0=name_map.get(op.input0, op.input0),
1438
1657
  output=name_map.get(op.output, op.output),
1439
- outer=op.outer,
1440
- axis_size=op.axis_size,
1441
- inner=op.inner,
1442
1658
  axis=op.axis,
1443
- shape=op.shape,
1444
- dtype=op.dtype,
1445
1659
  )
1446
1660
  if isinstance(op, LogSoftmaxOp):
1447
1661
  return LogSoftmaxOp(
1448
1662
  input0=name_map.get(op.input0, op.input0),
1449
1663
  output=name_map.get(op.output, op.output),
1450
- outer=op.outer,
1451
- axis_size=op.axis_size,
1452
- inner=op.inner,
1453
1664
  axis=op.axis,
1454
- shape=op.shape,
1455
- dtype=op.dtype,
1456
1665
  )
1457
1666
  if isinstance(op, HardmaxOp):
1458
1667
  return HardmaxOp(
1459
1668
  input0=name_map.get(op.input0, op.input0),
1460
1669
  output=name_map.get(op.output, op.output),
1461
- outer=op.outer,
1462
- axis_size=op.axis_size,
1463
- inner=op.inner,
1464
1670
  axis=op.axis,
1465
- shape=op.shape,
1466
- dtype=op.dtype,
1467
1671
  )
1468
1672
  if isinstance(op, NegativeLogLikelihoodLossOp):
1469
1673
  return NegativeLogLikelihoodLossOp(
@@ -1624,9 +1828,6 @@ class CEmitter:
1624
1828
  return IdentityOp(
1625
1829
  input0=name_map.get(op.input0, op.input0),
1626
1830
  output=name_map.get(op.output, op.output),
1627
- shape=op.shape,
1628
- dtype=op.dtype,
1629
- input_dtype=op.input_dtype,
1630
1831
  )
1631
1832
  if isinstance(op, EyeLikeOp):
1632
1833
  return EyeLikeOp(
@@ -1782,45 +1983,32 @@ class CEmitter:
1782
1983
  return ReduceOp(
1783
1984
  input0=name_map.get(op.input0, op.input0),
1784
1985
  output=name_map.get(op.output, op.output),
1785
- input_shape=op.input_shape,
1786
- output_shape=op.output_shape,
1787
1986
  axes=op.axes,
1788
1987
  axes_input=self._map_optional_name(name_map, op.axes_input),
1789
- axes_input_shape=op.axes_input_shape,
1790
- axes_input_dtype=op.axes_input_dtype,
1791
1988
  keepdims=op.keepdims,
1792
1989
  noop_with_empty_axes=op.noop_with_empty_axes,
1793
1990
  reduce_kind=op.reduce_kind,
1794
1991
  reduce_count=op.reduce_count,
1795
- dtype=op.dtype,
1796
1992
  )
1797
1993
  if isinstance(op, ArgReduceOp):
1798
1994
  return ArgReduceOp(
1799
1995
  input0=name_map.get(op.input0, op.input0),
1800
1996
  output=name_map.get(op.output, op.output),
1801
- input_shape=op.input_shape,
1802
- output_shape=op.output_shape,
1803
1997
  axis=op.axis,
1804
1998
  keepdims=op.keepdims,
1805
1999
  select_last_index=op.select_last_index,
1806
2000
  reduce_kind=op.reduce_kind,
1807
- input_dtype=op.input_dtype,
1808
- output_dtype=op.output_dtype,
1809
2001
  )
1810
2002
  if isinstance(op, TopKOp):
1811
2003
  return TopKOp(
1812
2004
  input0=name_map.get(op.input0, op.input0),
2005
+ k_input=name_map.get(op.k_input, op.k_input),
1813
2006
  output_values=name_map.get(op.output_values, op.output_values),
1814
2007
  output_indices=name_map.get(op.output_indices, op.output_indices),
1815
- input_shape=op.input_shape,
1816
- output_shape=op.output_shape,
1817
2008
  axis=op.axis,
1818
2009
  k=op.k,
1819
2010
  largest=op.largest,
1820
2011
  sorted=op.sorted,
1821
- input_dtype=op.input_dtype,
1822
- output_values_dtype=op.output_values_dtype,
1823
- output_indices_dtype=op.output_indices_dtype,
1824
2012
  )
1825
2013
  if isinstance(op, ConstantOfShapeOp):
1826
2014
  return ConstantOfShapeOp(
@@ -1852,6 +2040,11 @@ class CEmitter:
1852
2040
  dtype=op.dtype,
1853
2041
  input_dtype=op.input_dtype,
1854
2042
  )
2043
+ if isinstance(op, OptionalHasElementOp):
2044
+ return OptionalHasElementOp(
2045
+ input0=name_map.get(op.input0, op.input0),
2046
+ output=name_map.get(op.output, op.output),
2047
+ )
1855
2048
  if isinstance(op, NonZeroOp):
1856
2049
  return NonZeroOp(
1857
2050
  input0=name_map.get(op.input0, op.input0),
@@ -1918,6 +2111,25 @@ class CEmitter:
1918
2111
  dtype=op.dtype,
1919
2112
  input_dtype=op.input_dtype,
1920
2113
  )
2114
+ if isinstance(op, HammingWindowOp):
2115
+ return HammingWindowOp(
2116
+ size=name_map.get(op.size, op.size),
2117
+ output=name_map.get(op.output, op.output),
2118
+ output_shape=op.output_shape,
2119
+ periodic=op.periodic,
2120
+ dtype=op.dtype,
2121
+ input_dtype=op.input_dtype,
2122
+ )
2123
+ if isinstance(op, BernoulliOp):
2124
+ return BernoulliOp(
2125
+ input0=name_map.get(op.input0, op.input0),
2126
+ output=name_map.get(op.output, op.output),
2127
+ input_shape=op.input_shape,
2128
+ output_shape=op.output_shape,
2129
+ input_dtype=op.input_dtype,
2130
+ dtype=op.dtype,
2131
+ seed=op.seed,
2132
+ )
1921
2133
  if isinstance(op, OneHotOp):
1922
2134
  return OneHotOp(
1923
2135
  indices=name_map.get(op.indices, op.indices),
@@ -1933,6 +2145,23 @@ class CEmitter:
1933
2145
  indices_dtype=op.indices_dtype,
1934
2146
  depth_dtype=op.depth_dtype,
1935
2147
  )
2148
+ if isinstance(op, TfIdfVectorizerOp):
2149
+ return TfIdfVectorizerOp(
2150
+ input0=name_map.get(op.input0, op.input0),
2151
+ output=name_map.get(op.output, op.output),
2152
+ input_shape=op.input_shape,
2153
+ output_shape=op.output_shape,
2154
+ input_dtype=op.input_dtype,
2155
+ output_dtype=op.output_dtype,
2156
+ min_gram_length=op.min_gram_length,
2157
+ max_gram_length=op.max_gram_length,
2158
+ max_skip_count=op.max_skip_count,
2159
+ mode=op.mode,
2160
+ ngram_counts=op.ngram_counts,
2161
+ ngram_indexes=op.ngram_indexes,
2162
+ pool_int64s=op.pool_int64s,
2163
+ weights=op.weights,
2164
+ )
1936
2165
  if isinstance(op, SplitOp):
1937
2166
  return SplitOp(
1938
2167
  input0=name_map.get(op.input0, op.input0),
@@ -1973,11 +2202,19 @@ class CEmitter:
1973
2202
  input_names=tuple(
1974
2203
  name_map.get(name, name) for name in model.input_names
1975
2204
  ),
2205
+ input_optional_names=tuple(
2206
+ name_map.get(name, name) if name is not None else None
2207
+ for name in model.input_optional_names
2208
+ ),
1976
2209
  input_shapes=model.input_shapes,
1977
2210
  input_dtypes=model.input_dtypes,
1978
2211
  output_names=tuple(
1979
2212
  name_map.get(name, name) for name in model.output_names
1980
2213
  ),
2214
+ output_optional_names=tuple(
2215
+ name_map.get(name, name) if name is not None else None
2216
+ for name in model.output_optional_names
2217
+ ),
1981
2218
  output_shapes=model.output_shapes,
1982
2219
  output_dtypes=model.output_dtypes,
1983
2220
  constants=constants,
@@ -2024,6 +2261,18 @@ class CEmitter:
2024
2261
  for name, values in testbench_inputs.items()
2025
2262
  }
2026
2263
 
2264
+ @staticmethod
2265
+ def _sanitize_testbench_optional_inputs(
2266
+ testbench_optional_inputs: Mapping[str, bool] | None,
2267
+ name_map: Mapping[str, str],
2268
+ ) -> Mapping[str, bool] | None:
2269
+ if not testbench_optional_inputs:
2270
+ return None
2271
+ return {
2272
+ name_map.get(name, name): value
2273
+ for name, value in testbench_optional_inputs.items()
2274
+ }
2275
+
2027
2276
  def _load_templates(self, emit_testbench: bool) -> dict[str, Template]:
2028
2277
  try:
2029
2278
  templates = {
@@ -2038,6 +2287,10 @@ class CEmitter:
2038
2287
  "quantize_linear": self._env.get_template(
2039
2288
  "quantize_linear_op.c.j2"
2040
2289
  ),
2290
+ "dequantize_linear": self._env.get_template(
2291
+ "dequantize_linear_op.c.j2"
2292
+ ),
2293
+ "qlinear_mul": self._env.get_template("qlinear_mul_op.c.j2"),
2041
2294
  "qlinear_matmul": self._env.get_template(
2042
2295
  "qlinear_matmul_op.c.j2"
2043
2296
  ),
@@ -2049,6 +2302,7 @@ class CEmitter:
2049
2302
  "rotary_embedding_op.c.j2"
2050
2303
  ),
2051
2304
  "conv": self._env.get_template("conv_op.c.j2"),
2305
+ "conv_integer": self._env.get_template("conv_integer_op.c.j2"),
2052
2306
  "conv_transpose": self._env.get_template(
2053
2307
  "conv_transpose_op.c.j2"
2054
2308
  ),
@@ -2070,6 +2324,7 @@ class CEmitter:
2070
2324
  ),
2071
2325
  "rms_norm": self._env.get_template("rms_normalization_op.c.j2"),
2072
2326
  "lrn": self._env.get_template("lrn_op.c.j2"),
2327
+ "gru": self._env.get_template("gru_op.c.j2"),
2073
2328
  "lstm": self._env.get_template("lstm_op.c.j2"),
2074
2329
  "adagrad": self._env.get_template("adagrad_op.c.j2"),
2075
2330
  "softmax": self._env.get_template("softmax_op.c.j2"),
@@ -2093,6 +2348,7 @@ class CEmitter:
2093
2348
  "transpose": self._env.get_template("transpose_op.c.j2"),
2094
2349
  "reshape": self._env.get_template("reshape_op.c.j2"),
2095
2350
  "identity": self._env.get_template("identity_op.c.j2"),
2351
+ "bernoulli": self._env.get_template("bernoulli_op.c.j2"),
2096
2352
  "eye_like": self._env.get_template("eye_like_op.c.j2"),
2097
2353
  "trilu": self._env.get_template("trilu_op.c.j2"),
2098
2354
  "tile": self._env.get_template("tile_op.c.j2"),
@@ -2116,6 +2372,9 @@ class CEmitter:
2116
2372
  ),
2117
2373
  "shape": self._env.get_template("shape_op.c.j2"),
2118
2374
  "size": self._env.get_template("size_op.c.j2"),
2375
+ "optional_has_element": self._env.get_template(
2376
+ "optional_has_element_op.c.j2"
2377
+ ),
2119
2378
  "nonzero": self._env.get_template("nonzero_op.c.j2"),
2120
2379
  "nonmax_suppression": self._env.get_template(
2121
2380
  "nonmax_suppression_op.c.j2"
@@ -2123,7 +2382,13 @@ class CEmitter:
2123
2382
  "expand": self._env.get_template("expand_op.c.j2"),
2124
2383
  "cumsum": self._env.get_template("cumsum_op.c.j2"),
2125
2384
  "range": self._env.get_template("range_op.c.j2"),
2385
+ "hamming_window": self._env.get_template(
2386
+ "hamming_window_op.c.j2"
2387
+ ),
2126
2388
  "one_hot": self._env.get_template("one_hot_op.c.j2"),
2389
+ "tfidf_vectorizer": self._env.get_template(
2390
+ "tfidf_vectorizer_op.c.j2"
2391
+ ),
2127
2392
  "split": self._env.get_template("split_op.c.j2"),
2128
2393
  }
2129
2394
  if emit_testbench:
@@ -2138,6 +2403,7 @@ class CEmitter:
2138
2403
  *,
2139
2404
  emit_testbench: bool = False,
2140
2405
  testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None = None,
2406
+ testbench_optional_inputs: Mapping[str, bool] | None = None,
2141
2407
  variable_dim_inputs: Mapping[int, Mapping[int, str]] | None = None,
2142
2408
  variable_dim_outputs: Mapping[int, Mapping[int, str]] | None = None,
2143
2409
  ) -> str:
@@ -2147,6 +2413,9 @@ class CEmitter:
2147
2413
  testbench_inputs = self._sanitize_testbench_inputs(
2148
2414
  testbench_inputs, name_map
2149
2415
  )
2416
+ testbench_optional_inputs = self._sanitize_testbench_optional_inputs(
2417
+ testbench_optional_inputs, name_map
2418
+ )
2150
2419
  inline_constants, large_constants = self._partition_constants(
2151
2420
  model.constants
2152
2421
  )
@@ -2184,6 +2453,8 @@ class CEmitter:
2184
2453
  model.name,
2185
2454
  *model.input_names,
2186
2455
  *model.output_names,
2456
+ *(name for name in model.input_optional_names if name is not None),
2457
+ *(name for name in model.output_optional_names if name is not None),
2187
2458
  *(const.name for const in model.constants),
2188
2459
  }
2189
2460
  temp_buffers = self._temp_buffers(model, reserved_names=reserved_names)
@@ -2235,16 +2506,27 @@ class CEmitter:
2235
2506
  *includes,
2236
2507
  "",
2237
2508
  self._emit_index_type_define(),
2509
+ self._emit_unused_define(),
2238
2510
  ]
2239
2511
  if scalar_preamble:
2240
2512
  sections.extend(("", *scalar_preamble))
2241
2513
  sections.append("")
2242
- constants_section = self._emit_constant_definitions(inline_constants)
2514
+ constants_section = self._emit_constant_declarations(inline_constants)
2243
2515
  if constants_section:
2244
2516
  sections.extend((constants_section.rstrip(), ""))
2245
- large_constants_section = self._emit_constant_storage_definitions(
2517
+ storage_declarations = self._emit_constant_storage_declarations(
2246
2518
  large_constants
2247
2519
  )
2520
+ if storage_declarations:
2521
+ sections.extend((storage_declarations.rstrip(), ""))
2522
+ constants_section = self._emit_constant_definitions(
2523
+ inline_constants, storage_prefix="const"
2524
+ )
2525
+ if constants_section:
2526
+ sections.extend((constants_section.rstrip(), ""))
2527
+ large_constants_section = self._emit_constant_storage_definitions(
2528
+ large_constants, storage_prefix=""
2529
+ )
2248
2530
  if large_constants_section:
2249
2531
  sections.extend((large_constants_section.rstrip(), ""))
2250
2532
  if scalar_functions:
@@ -2267,6 +2549,7 @@ class CEmitter:
2267
2549
  model,
2268
2550
  testbench_template,
2269
2551
  testbench_inputs=testbench_inputs,
2552
+ testbench_optional_inputs=testbench_optional_inputs,
2270
2553
  dim_order=dim_order,
2271
2554
  dim_values=dim_values,
2272
2555
  weight_data_filename=self._weight_data_filename(model),
@@ -2285,6 +2568,7 @@ class CEmitter:
2285
2568
  *,
2286
2569
  emit_testbench: bool = False,
2287
2570
  testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None = None,
2571
+ testbench_optional_inputs: Mapping[str, bool] | None = None,
2288
2572
  variable_dim_inputs: Mapping[int, Mapping[int, str]] | None = None,
2289
2573
  variable_dim_outputs: Mapping[int, Mapping[int, str]] | None = None,
2290
2574
  ) -> tuple[str, str]:
@@ -2294,6 +2578,9 @@ class CEmitter:
2294
2578
  testbench_inputs = self._sanitize_testbench_inputs(
2295
2579
  testbench_inputs, name_map
2296
2580
  )
2581
+ testbench_optional_inputs = self._sanitize_testbench_optional_inputs(
2582
+ testbench_optional_inputs, name_map
2583
+ )
2297
2584
  inline_constants, large_constants = self._partition_constants(
2298
2585
  model.constants
2299
2586
  )
@@ -2331,6 +2618,8 @@ class CEmitter:
2331
2618
  model.name,
2332
2619
  *model.input_names,
2333
2620
  *model.output_names,
2621
+ *(name for name in model.input_optional_names if name is not None),
2622
+ *(name for name in model.output_optional_names if name is not None),
2334
2623
  *(const.name for const in model.constants),
2335
2624
  }
2336
2625
  temp_buffers = self._temp_buffers(model, reserved_names=reserved_names)
@@ -2389,9 +2678,14 @@ class CEmitter:
2389
2678
  constants_section = self._emit_constant_declarations(inline_constants)
2390
2679
  if constants_section:
2391
2680
  sections.extend((constants_section.rstrip(), ""))
2392
- large_constants_section = self._emit_constant_storage_definitions(
2681
+ storage_declarations = self._emit_constant_storage_declarations(
2393
2682
  large_constants
2394
2683
  )
2684
+ if storage_declarations:
2685
+ sections.extend((storage_declarations.rstrip(), ""))
2686
+ large_constants_section = self._emit_constant_storage_definitions(
2687
+ large_constants, storage_prefix=""
2688
+ )
2395
2689
  if large_constants_section:
2396
2690
  sections.extend((large_constants_section.rstrip(), ""))
2397
2691
  if scalar_functions:
@@ -2414,6 +2708,7 @@ class CEmitter:
2414
2708
  model,
2415
2709
  testbench_template,
2416
2710
  testbench_inputs=testbench_inputs,
2711
+ testbench_optional_inputs=testbench_optional_inputs,
2417
2712
  dim_order=dim_order,
2418
2713
  dim_values=dim_values,
2419
2714
  weight_data_filename=self._weight_data_filename(model),
@@ -2656,7 +2951,7 @@ class CEmitter:
2656
2951
  except ScalarFunctionError:
2657
2952
  return None
2658
2953
 
2659
- def _lstm_activation_function_name(
2954
+ def _rnn_activation_function_name(
2660
2955
  self,
2661
2956
  kind: int,
2662
2957
  alpha: float,
@@ -2667,7 +2962,7 @@ class CEmitter:
2667
2962
  spec = _LSTM_ACTIVATION_SPECS.get(kind)
2668
2963
  if spec is None:
2669
2964
  raise CodegenError(
2670
- f"Unsupported LSTM activation kind for codegen: {kind}"
2965
+ f"Unsupported RNN activation kind for codegen: {kind}"
2671
2966
  )
2672
2967
  function, param_count = spec
2673
2968
  if param_count == 0:
@@ -2681,7 +2976,7 @@ class CEmitter:
2681
2976
  )
2682
2977
  if name is None:
2683
2978
  raise CodegenError(
2684
- f"Failed to resolve scalar function for LSTM activation kind {kind}"
2979
+ f"Failed to resolve scalar function for RNN activation kind {kind}"
2685
2980
  )
2686
2981
  return name
2687
2982
 
@@ -2695,12 +2990,15 @@ class CEmitter:
2695
2990
  | ClipOp
2696
2991
  | CastOp
2697
2992
  | QuantizeLinearOp
2993
+ | DequantizeLinearOp
2994
+ | QLinearMulOp
2698
2995
  | QLinearMatMulOp
2699
2996
  | MatMulOp
2700
2997
  | EinsumOp
2701
2998
  | GemmOp
2702
2999
  | AttentionOp
2703
3000
  | ConvOp
3001
+ | ConvIntegerOp
2704
3002
  | ConvTransposeOp
2705
3003
  | AveragePoolOp
2706
3004
  | LpPoolOp
@@ -2712,6 +3010,7 @@ class CEmitter:
2712
3010
  | MeanVarianceNormalizationOp
2713
3011
  | RMSNormalizationOp
2714
3012
  | LrnOp
3013
+ | GruOp
2715
3014
  | LstmOp
2716
3015
  | AdagradOp
2717
3016
  | SoftmaxOp
@@ -2743,11 +3042,13 @@ class CEmitter:
2743
3042
  | ConstantOfShapeOp
2744
3043
  | ShapeOp
2745
3044
  | SizeOp
3045
+ | OptionalHasElementOp
2746
3046
  | NonZeroOp
2747
3047
  | NonMaxSuppressionOp
2748
3048
  | ExpandOp
2749
3049
  | CumSumOp
2750
3050
  | RangeOp
3051
+ | HammingWindowOp
2751
3052
  | OneHotOp
2752
3053
  | SplitOp
2753
3054
  ],
@@ -2787,6 +3088,8 @@ class CEmitter:
2787
3088
  return model.op_context.dtype(op.data)
2788
3089
  if isinstance(op, ExpandOp):
2789
3090
  return model.op_context.dtype(op.input0)
3091
+ if hasattr(op, "output") and isinstance(op.output, str):
3092
+ return model.op_context.dtype(op.output)
2790
3093
  return op.dtype
2791
3094
 
2792
3095
  model_dtypes.update(
@@ -2798,7 +3101,10 @@ class CEmitter:
2798
3101
  dtype
2799
3102
  for op in resolved_ops
2800
3103
  if isinstance(op, ArgReduceOp)
2801
- for dtype in (op.input_dtype, op.output_dtype)
3104
+ for dtype in (
3105
+ model.op_context.dtype(op.input0),
3106
+ model.op_context.dtype(op.output),
3107
+ )
2802
3108
  }
2803
3109
  model_dtypes.update(arg_reduce_dtypes)
2804
3110
  topk_dtypes = {
@@ -2806,9 +3112,9 @@ class CEmitter:
2806
3112
  for op in resolved_ops
2807
3113
  if isinstance(op, TopKOp)
2808
3114
  for dtype in (
2809
- op.input_dtype,
2810
- op.output_values_dtype,
2811
- op.output_indices_dtype,
3115
+ model.op_context.dtype(op.input0),
3116
+ model.op_context.dtype(op.output_values),
3117
+ model.op_context.dtype(op.output_indices),
2812
3118
  )
2813
3119
  }
2814
3120
  model_dtypes.update(topk_dtypes)
@@ -2867,15 +3173,18 @@ class CEmitter:
2867
3173
  includes.add("#include <stdbool.h>")
2868
3174
  if any(
2869
3175
  isinstance(op, UnaryOp)
2870
- and unary_op_symbol(op.function, dtype=op.dtype) in {"llabs", "abs"}
3176
+ and unary_op_symbol(
3177
+ op.function, dtype=model.op_context.dtype(op.output)
3178
+ )
3179
+ in {"llabs", "abs"}
2871
3180
  for op in resolved_ops
2872
3181
  ):
2873
3182
  includes.add("#include <stdlib.h>")
2874
3183
  if any(isinstance(op, PadOp) for op in resolved_ops):
2875
3184
  includes.add("#include <stddef.h>")
2876
- if CEmitter._needs_math(resolved_ops):
3185
+ if CEmitter._needs_math(resolved_ops, model.op_context):
2877
3186
  includes.add("#include <math.h>")
2878
- if CEmitter._needs_limits(resolved_ops):
3187
+ if CEmitter._needs_limits(resolved_ops, model.op_context):
2879
3188
  includes.add("#include <limits.h>")
2880
3189
  if any(
2881
3190
  isinstance(op, (ConcatOp, ReshapeOp, SplitOp, IdentityOp))
@@ -2905,6 +3214,20 @@ class CEmitter:
2905
3214
  )
2906
3215
  )
2907
3216
 
3217
+ @staticmethod
3218
+ def _emit_unused_define() -> str:
3219
+ return "\n".join(
3220
+ (
3221
+ "#ifndef EMX_UNUSED",
3222
+ "#if defined(__GNUC__) || defined(__clang__)",
3223
+ "#define EMX_UNUSED __attribute__((unused))",
3224
+ "#else",
3225
+ "#define EMX_UNUSED",
3226
+ "#endif",
3227
+ "#endif",
3228
+ )
3229
+ )
3230
+
2908
3231
  @staticmethod
2909
3232
  def _needs_stdint(
2910
3233
  model_dtypes: set[ScalarType],
@@ -2940,12 +3263,15 @@ class CEmitter:
2940
3263
  | ClipOp
2941
3264
  | CastOp
2942
3265
  | QuantizeLinearOp
3266
+ | DequantizeLinearOp
3267
+ | QLinearMulOp
2943
3268
  | QLinearMatMulOp
2944
3269
  | MatMulOp
2945
3270
  | EinsumOp
2946
3271
  | GemmOp
2947
3272
  | AttentionOp
2948
3273
  | ConvOp
3274
+ | ConvIntegerOp
2949
3275
  | ConvTransposeOp
2950
3276
  | AveragePoolOp
2951
3277
  | LpPoolOp
@@ -2957,6 +3283,7 @@ class CEmitter:
2957
3283
  | MeanVarianceNormalizationOp
2958
3284
  | RMSNormalizationOp
2959
3285
  | LrnOp
3286
+ | GruOp
2960
3287
  | LstmOp
2961
3288
  | AdagradOp
2962
3289
  | SoftmaxOp
@@ -2988,14 +3315,17 @@ class CEmitter:
2988
3315
  | ConstantOfShapeOp
2989
3316
  | ShapeOp
2990
3317
  | SizeOp
3318
+ | OptionalHasElementOp
2991
3319
  | NonZeroOp
2992
3320
  | NonMaxSuppressionOp
2993
3321
  | ExpandOp
2994
3322
  | CumSumOp
2995
3323
  | RangeOp
3324
+ | HammingWindowOp
2996
3325
  | OneHotOp
2997
3326
  | SplitOp
2998
3327
  ],
3328
+ op_context: OpContext,
2999
3329
  ) -> bool:
3000
3330
  math_ops = {
3001
3331
  "atanhf",
@@ -3014,13 +3344,18 @@ class CEmitter:
3014
3344
 
3015
3345
  def is_binary_math_op(op: BinaryOp) -> bool:
3016
3346
  op_spec = binary_op_symbol(
3017
- op.function, dtype=op.input_dtype, validate_attrs=False
3347
+ op.function,
3348
+ dtype=op_context.dtype(op.input0),
3349
+ validate_attrs=False,
3018
3350
  )
3019
3351
  return op_spec is not None and op_spec.operator in binary_math_ops
3020
3352
 
3021
3353
  if any(
3022
3354
  isinstance(op, UnaryOp)
3023
- and unary_op_symbol(op.function, dtype=op.dtype) in math_ops
3355
+ and unary_op_symbol(
3356
+ op.function, dtype=op_context.dtype(op.output)
3357
+ )
3358
+ in math_ops
3024
3359
  for op in resolved_ops
3025
3360
  ):
3026
3361
  return True
@@ -3038,7 +3373,7 @@ class CEmitter:
3038
3373
  return True
3039
3374
  if any(
3040
3375
  isinstance(op, ClipOp)
3041
- and op.dtype.is_float
3376
+ and op_context.dtype(op.output).is_float
3042
3377
  and (op.input_min is None or op.input_max is None)
3043
3378
  for op in resolved_ops
3044
3379
  ):
@@ -3061,6 +3396,7 @@ class CEmitter:
3061
3396
  MeanVarianceNormalizationOp,
3062
3397
  RMSNormalizationOp,
3063
3398
  LrnOp,
3399
+ GruOp,
3064
3400
  LstmOp,
3065
3401
  AdagradOp,
3066
3402
  SoftmaxOp,
@@ -3082,7 +3418,7 @@ class CEmitter:
3082
3418
  if any(
3083
3419
  isinstance(op, ReduceOp)
3084
3420
  and op.reduce_kind in {"min", "max"}
3085
- and op.dtype.is_float
3421
+ and op_context.dtype(op.output).is_float
3086
3422
  for op in resolved_ops
3087
3423
  ):
3088
3424
  return True
@@ -3092,10 +3428,20 @@ class CEmitter:
3092
3428
  ):
3093
3429
  return True
3094
3430
  if any(
3095
- isinstance(op, (LpPoolOp, QuantizeLinearOp, QLinearMatMulOp))
3431
+ isinstance(
3432
+ op,
3433
+ (
3434
+ LpPoolOp,
3435
+ QuantizeLinearOp,
3436
+ QLinearMulOp,
3437
+ QLinearMatMulOp,
3438
+ ),
3439
+ )
3096
3440
  for op in resolved_ops
3097
3441
  ):
3098
3442
  return True
3443
+ if any(isinstance(op, HammingWindowOp) for op in resolved_ops):
3444
+ return True
3099
3445
  return False
3100
3446
 
3101
3447
  @staticmethod
@@ -3106,12 +3452,15 @@ class CEmitter:
3106
3452
  | ClipOp
3107
3453
  | CastOp
3108
3454
  | QuantizeLinearOp
3455
+ | DequantizeLinearOp
3456
+ | QLinearMulOp
3109
3457
  | QLinearMatMulOp
3110
3458
  | MatMulOp
3111
3459
  | EinsumOp
3112
3460
  | GemmOp
3113
3461
  | AttentionOp
3114
3462
  | ConvOp
3463
+ | ConvIntegerOp
3115
3464
  | ConvTransposeOp
3116
3465
  | AveragePoolOp
3117
3466
  | LpPoolOp
@@ -3123,6 +3472,7 @@ class CEmitter:
3123
3472
  | MeanVarianceNormalizationOp
3124
3473
  | RMSNormalizationOp
3125
3474
  | LrnOp
3475
+ | GruOp
3126
3476
  | LstmOp
3127
3477
  | SoftmaxOp
3128
3478
  | LogSoftmaxOp
@@ -3151,19 +3501,23 @@ class CEmitter:
3151
3501
  | ConstantOfShapeOp
3152
3502
  | ShapeOp
3153
3503
  | SizeOp
3504
+ | OptionalHasElementOp
3154
3505
  | NonZeroOp
3155
3506
  | NonMaxSuppressionOp
3156
3507
  | ExpandOp
3157
3508
  | CumSumOp
3158
3509
  | RangeOp
3510
+ | HammingWindowOp
3159
3511
  | OneHotOp
3160
3512
  | SplitOp
3161
3513
  ],
3514
+ op_context: OpContext,
3162
3515
  ) -> bool:
3163
3516
  if any(
3164
3517
  isinstance(op, ReduceOp)
3165
3518
  and op.reduce_kind in {"min", "max"}
3166
- and op.dtype in {
3519
+ and op_context.dtype(op.output)
3520
+ in {
3167
3521
  ScalarType.I64,
3168
3522
  ScalarType.I32,
3169
3523
  ScalarType.I16,
@@ -3174,7 +3528,7 @@ class CEmitter:
3174
3528
  return True
3175
3529
  if any(
3176
3530
  isinstance(op, ClipOp)
3177
- and op.dtype.is_integer
3531
+ and op_context.dtype(op.output).is_integer
3178
3532
  and (op.input_min is None or op.input_max is None)
3179
3533
  for op in resolved_ops
3180
3534
  ):
@@ -3187,7 +3541,7 @@ class CEmitter:
3187
3541
  ):
3188
3542
  return True
3189
3543
  if any(
3190
- isinstance(op, (QuantizeLinearOp, QLinearMatMulOp))
3544
+ isinstance(op, (QuantizeLinearOp, QLinearMulOp, QLinearMatMulOp))
3191
3545
  and op.dtype.is_integer
3192
3546
  for op in resolved_ops
3193
3547
  ):
@@ -3206,12 +3560,15 @@ class CEmitter:
3206
3560
  | ClipOp
3207
3561
  | CastOp
3208
3562
  | QuantizeLinearOp
3563
+ | DequantizeLinearOp
3564
+ | QLinearMulOp
3209
3565
  | QLinearMatMulOp
3210
3566
  | MatMulOp
3211
3567
  | EinsumOp
3212
3568
  | GemmOp
3213
3569
  | AttentionOp
3214
3570
  | ConvOp
3571
+ | ConvIntegerOp
3215
3572
  | ConvTransposeOp
3216
3573
  | AveragePoolOp
3217
3574
  | LpPoolOp
@@ -3223,6 +3580,7 @@ class CEmitter:
3223
3580
  | MeanVarianceNormalizationOp
3224
3581
  | RMSNormalizationOp
3225
3582
  | LrnOp
3583
+ | GruOp
3226
3584
  | LstmOp
3227
3585
  | SoftmaxOp
3228
3586
  | LogSoftmaxOp
@@ -3256,6 +3614,7 @@ class CEmitter:
3256
3614
  | ExpandOp
3257
3615
  | CumSumOp
3258
3616
  | RangeOp
3617
+ | HammingWindowOp
3259
3618
  | OneHotOp
3260
3619
  | SplitOp
3261
3620
  ],
@@ -3266,6 +3625,7 @@ class CEmitter:
3266
3625
  output_dim_names: Mapping[int, Mapping[int, str]],
3267
3626
  ) -> str:
3268
3627
  params = []
3628
+ optional_flags = self._optional_input_flag_map(model)
3269
3629
  if dim_order:
3270
3630
  params.extend(self._format_dim_args(dim_order))
3271
3631
  for index, (name, shape, dtype) in enumerate(
@@ -3273,14 +3633,17 @@ class CEmitter:
3273
3633
  ):
3274
3634
  params.append(
3275
3635
  f"const {dtype.c_type} {name}"
3276
- f"{self._param_array_suffix(shape, input_dim_names.get(index))}"
3636
+ f"{self._param_array_suffix(shape, input_dim_names.get(index), use_restrict=True)}"
3277
3637
  )
3638
+ optional_flag = optional_flags.get(name)
3639
+ if optional_flag is not None:
3640
+ params.append(f"_Bool {optional_flag}")
3278
3641
  for index, (name, shape, dtype) in enumerate(
3279
3642
  zip(model.output_names, model.output_shapes, model.output_dtypes)
3280
3643
  ):
3281
3644
  params.append(
3282
3645
  f"{dtype.c_type} {name}"
3283
- f"{self._param_array_suffix(shape, output_dim_names.get(index))}"
3646
+ f"{self._param_array_suffix(shape, output_dim_names.get(index), use_restrict=True)}"
3284
3647
  )
3285
3648
  signature = ", ".join(params)
3286
3649
  lines = [f"void {model.name}({signature}) {{"]
@@ -3297,7 +3660,7 @@ class CEmitter:
3297
3660
  )
3298
3661
  for index, op in enumerate(resolved_ops):
3299
3662
  op_name = self._op_function_name(model, index)
3300
- call = self._build_op_call(op, dim_order)
3663
+ call = self._build_op_call(op, dim_order, optional_flags)
3301
3664
  lines.append(f" {op_name}({call});")
3302
3665
  lines.append("}")
3303
3666
  return "\n".join(lines)
@@ -3309,14 +3672,16 @@ class CEmitter:
3309
3672
  element_count *= dim
3310
3673
  return element_count * temp.dtype.np_dtype.itemsize
3311
3674
 
3312
- @staticmethod
3313
3675
  def _build_op_call(
3676
+ self,
3314
3677
  op: BinaryOp
3315
3678
  | WhereOp
3316
3679
  | UnaryOp
3317
3680
  | ClipOp
3318
3681
  | CastOp
3319
3682
  | QuantizeLinearOp
3683
+ | DequantizeLinearOp
3684
+ | QLinearMulOp
3320
3685
  | QLinearMatMulOp
3321
3686
  | MatMulOp
3322
3687
  | EinsumOp
@@ -3324,6 +3689,7 @@ class CEmitter:
3324
3689
  | AttentionOp
3325
3690
  | RotaryEmbeddingOp
3326
3691
  | ConvOp
3692
+ | ConvIntegerOp
3327
3693
  | ConvTransposeOp
3328
3694
  | AveragePoolOp
3329
3695
  | LpPoolOp
@@ -3335,6 +3701,7 @@ class CEmitter:
3335
3701
  | MeanVarianceNormalizationOp
3336
3702
  | RMSNormalizationOp
3337
3703
  | LrnOp
3704
+ | GruOp
3338
3705
  | LstmOp
3339
3706
  | AdagradOp
3340
3707
  | SoftmaxOp
@@ -3367,16 +3734,21 @@ class CEmitter:
3367
3734
  | ConstantOfShapeOp
3368
3735
  | ShapeOp
3369
3736
  | SizeOp
3737
+ | OptionalHasElementOp
3370
3738
  | NonZeroOp
3371
3739
  | NonMaxSuppressionOp
3372
3740
  | ExpandOp
3373
3741
  | CumSumOp
3374
3742
  | RangeOp
3743
+ | HammingWindowOp
3375
3744
  | OneHotOp
3376
- | SplitOp,
3745
+ | SplitOp
3746
+ | OptionalHasElementOp,
3377
3747
  dim_order: Sequence[str],
3748
+ optional_flags: Mapping[str, str] | None = None,
3378
3749
  ) -> str:
3379
3750
  args: list[str] = []
3751
+ optional_flags = optional_flags or {}
3380
3752
  if dim_order:
3381
3753
  args.extend(dim_order)
3382
3754
  if isinstance(op, BinaryOp):
@@ -3388,6 +3760,21 @@ class CEmitter:
3388
3760
  if isinstance(op, WhereOp):
3389
3761
  args.extend([op.condition, op.input_x, op.input_y, op.output])
3390
3762
  return ", ".join(args)
3763
+ if isinstance(op, QLinearMulOp):
3764
+ args.extend(
3765
+ [
3766
+ op.input0,
3767
+ op.input0_scale,
3768
+ op.input0_zero_point,
3769
+ op.input1,
3770
+ op.input1_scale,
3771
+ op.input1_zero_point,
3772
+ op.output_scale,
3773
+ op.output_zero_point,
3774
+ op.output,
3775
+ ]
3776
+ )
3777
+ return ", ".join(args)
3391
3778
  if isinstance(op, QLinearMatMulOp):
3392
3779
  args.extend(
3393
3780
  [
@@ -3434,6 +3821,13 @@ class CEmitter:
3434
3821
  call_parts.append(op.output)
3435
3822
  args.extend(call_parts)
3436
3823
  return ", ".join(args)
3824
+ if isinstance(op, DequantizeLinearOp):
3825
+ call_parts = [op.input0, op.scale]
3826
+ if op.zero_point is not None:
3827
+ call_parts.append(op.zero_point)
3828
+ call_parts.append(op.output)
3829
+ args.extend(call_parts)
3830
+ return ", ".join(args)
3437
3831
  if isinstance(op, AttentionOp):
3438
3832
  call_parts = [op.input_q, op.input_k, op.input_v]
3439
3833
  if op.input_attn_mask is not None:
@@ -3453,12 +3847,27 @@ class CEmitter:
3453
3847
  call_parts.append(op.output_qk_matmul)
3454
3848
  args.extend(call_parts)
3455
3849
  return ", ".join(args)
3850
+ if isinstance(op, RotaryEmbeddingOp):
3851
+ call_parts = [op.input0, op.cos_cache, op.sin_cache]
3852
+ if op.position_ids is not None:
3853
+ call_parts.append(op.position_ids)
3854
+ call_parts.append(op.output)
3855
+ args.extend(call_parts)
3856
+ return ", ".join(args)
3456
3857
  if isinstance(op, ConvOp):
3457
3858
  if op.bias is None:
3458
3859
  args.extend([op.input0, op.weights, op.output])
3459
3860
  return ", ".join(args)
3460
3861
  args.extend([op.input0, op.weights, op.bias, op.output])
3461
3862
  return ", ".join(args)
3863
+ if isinstance(op, ConvIntegerOp):
3864
+ args.extend([op.input0, op.weights])
3865
+ if op.x_zero_point is not None:
3866
+ args.append(op.x_zero_point)
3867
+ if op.w_zero_point is not None:
3868
+ args.append(op.w_zero_point)
3869
+ args.append(op.output)
3870
+ return ", ".join(args)
3462
3871
  if isinstance(op, ConvTransposeOp):
3463
3872
  if op.bias is None:
3464
3873
  args.extend([op.input0, op.weights, op.output])
@@ -3502,6 +3911,20 @@ class CEmitter:
3502
3911
  if isinstance(op, RMSNormalizationOp):
3503
3912
  args.extend([op.input0, op.scale, op.output])
3504
3913
  return ", ".join(args)
3914
+ if isinstance(op, GruOp):
3915
+ call_parts = [op.input_x, op.input_w, op.input_r]
3916
+ if op.input_b is not None:
3917
+ call_parts.append(op.input_b)
3918
+ if op.input_sequence_lens is not None:
3919
+ call_parts.append(op.input_sequence_lens)
3920
+ if op.input_initial_h is not None:
3921
+ call_parts.append(op.input_initial_h)
3922
+ if op.output_y is not None:
3923
+ call_parts.append(op.output_y)
3924
+ if op.output_y_h is not None:
3925
+ call_parts.append(op.output_y_h)
3926
+ args.extend(call_parts)
3927
+ return ", ".join(args)
3505
3928
  if isinstance(op, LstmOp):
3506
3929
  call_parts = [op.input_x, op.input_w, op.input_r]
3507
3930
  if op.input_b is not None:
@@ -3523,17 +3946,18 @@ class CEmitter:
3523
3946
  args.extend(call_parts)
3524
3947
  return ", ".join(args)
3525
3948
  if isinstance(op, AdagradOp):
3526
- args.extend(
3527
- [
3528
- op.rate,
3529
- op.timestep,
3530
- *op.inputs,
3531
- *op.gradients,
3532
- *op.accumulators,
3533
- *op.outputs,
3534
- *op.accumulator_outputs,
3535
- ]
3536
- )
3949
+ args.append(op.rate)
3950
+ args.append(op.timestep)
3951
+ for index in range(len(op.inputs)):
3952
+ args.extend(
3953
+ [
3954
+ op.inputs[index],
3955
+ op.gradients[index],
3956
+ op.accumulators[index],
3957
+ op.outputs[index],
3958
+ op.accumulator_outputs[index],
3959
+ ]
3960
+ )
3537
3961
  return ", ".join(args)
3538
3962
  if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
3539
3963
  args.extend([op.input0, op.output])
@@ -3590,6 +4014,14 @@ class CEmitter:
3590
4014
  if isinstance(op, SizeOp):
3591
4015
  args.extend([op.input0, op.output])
3592
4016
  return ", ".join(args)
4017
+ if isinstance(op, OptionalHasElementOp):
4018
+ input_flag = optional_flags.get(op.input0)
4019
+ if input_flag is None:
4020
+ raise CodegenError(
4021
+ "OptionalHasElement expects an optional input flag."
4022
+ )
4023
+ args.extend([op.input0, input_flag, op.output])
4024
+ return ", ".join(args)
3593
4025
  if isinstance(op, NonZeroOp):
3594
4026
  args.extend([op.input0, op.output])
3595
4027
  return ", ".join(args)
@@ -3622,6 +4054,9 @@ class CEmitter:
3622
4054
  if isinstance(op, RangeOp):
3623
4055
  args.extend([op.start, op.limit, op.delta, op.output])
3624
4056
  return ", ".join(args)
4057
+ if isinstance(op, HammingWindowOp):
4058
+ args.extend([op.size, op.output])
4059
+ return ", ".join(args)
3625
4060
  if isinstance(op, OneHotOp):
3626
4061
  args.extend([op.indices, op.depth, op.values, op.output])
3627
4062
  return ", ".join(args)
@@ -3732,12 +4167,15 @@ class CEmitter:
3732
4167
  @staticmethod
3733
4168
  def _resolve_op(
3734
4169
  op: BinaryOp
4170
+ | PowOp
3735
4171
  | MultiInputBinaryOp
3736
4172
  | WhereOp
3737
4173
  | UnaryOp
3738
4174
  | ClipOp
3739
4175
  | CastOp
3740
4176
  | QuantizeLinearOp
4177
+ | DequantizeLinearOp
4178
+ | QLinearMulOp
3741
4179
  | QLinearMatMulOp
3742
4180
  | MatMulOp
3743
4181
  | EinsumOp
@@ -3745,6 +4183,7 @@ class CEmitter:
3745
4183
  | AttentionOp
3746
4184
  | RotaryEmbeddingOp
3747
4185
  | ConvOp
4186
+ | ConvIntegerOp
3748
4187
  | ConvTransposeOp
3749
4188
  | AveragePoolOp
3750
4189
  | LpPoolOp
@@ -3756,6 +4195,7 @@ class CEmitter:
3756
4195
  | MeanVarianceNormalizationOp
3757
4196
  | RMSNormalizationOp
3758
4197
  | LrnOp
4198
+ | GruOp
3759
4199
  | LstmOp
3760
4200
  | AdagradOp
3761
4201
  | SoftmaxOp
@@ -3792,23 +4232,29 @@ class CEmitter:
3792
4232
  | ExpandOp
3793
4233
  | CumSumOp
3794
4234
  | RangeOp
4235
+ | HammingWindowOp
3795
4236
  | OneHotOp
4237
+ | TfIdfVectorizerOp
3796
4238
  | SplitOp,
3797
4239
  temp_map: dict[str, str],
3798
4240
  ) -> (
3799
4241
  BinaryOp
4242
+ | PowOp
3800
4243
  | MultiInputBinaryOp
3801
4244
  | WhereOp
3802
4245
  | UnaryOp
3803
4246
  | ClipOp
3804
4247
  | CastOp
3805
4248
  | QuantizeLinearOp
4249
+ | DequantizeLinearOp
4250
+ | QLinearMulOp
3806
4251
  | QLinearMatMulOp
3807
4252
  | MatMulOp
3808
4253
  | EinsumOp
3809
4254
  | GemmOp
3810
4255
  | AttentionOp
3811
4256
  | ConvOp
4257
+ | ConvIntegerOp
3812
4258
  | ConvTransposeOp
3813
4259
  | AveragePoolOp
3814
4260
  | LpPoolOp
@@ -3820,6 +4266,7 @@ class CEmitter:
3820
4266
  | MeanVarianceNormalizationOp
3821
4267
  | RMSNormalizationOp
3822
4268
  | LrnOp
4269
+ | GruOp
3823
4270
  | LstmOp
3824
4271
  | AdagradOp
3825
4272
  | SoftmaxOp
@@ -3856,9 +4303,19 @@ class CEmitter:
3856
4303
  | ExpandOp
3857
4304
  | CumSumOp
3858
4305
  | RangeOp
4306
+ | HammingWindowOp
3859
4307
  | OneHotOp
3860
4308
  | SplitOp
4309
+ | TfIdfVectorizerOp
3861
4310
  ):
4311
+ if isinstance(op, PowOp):
4312
+ return PowOp(
4313
+ input0=temp_map.get(op.input0, op.input0),
4314
+ input1=temp_map.get(op.input1, op.input1),
4315
+ output=temp_map.get(op.output, op.output),
4316
+ function=op.function,
4317
+ operator_kind=op.operator_kind,
4318
+ )
3862
4319
  if isinstance(op, BinaryOp):
3863
4320
  return BinaryOp(
3864
4321
  input0=temp_map.get(op.input0, op.input0),
@@ -3866,11 +4323,6 @@ class CEmitter:
3866
4323
  output=temp_map.get(op.output, op.output),
3867
4324
  function=op.function,
3868
4325
  operator_kind=op.operator_kind,
3869
- input0_shape=op.input0_shape,
3870
- input1_shape=op.input1_shape,
3871
- shape=op.shape,
3872
- dtype=op.dtype,
3873
- input_dtype=op.input_dtype,
3874
4326
  )
3875
4327
  if isinstance(op, MultiInputBinaryOp):
3876
4328
  return MultiInputBinaryOp(
@@ -3888,20 +4340,12 @@ class CEmitter:
3888
4340
  input_x=temp_map.get(op.input_x, op.input_x),
3889
4341
  input_y=temp_map.get(op.input_y, op.input_y),
3890
4342
  output=temp_map.get(op.output, op.output),
3891
- condition_shape=op.condition_shape,
3892
- x_shape=op.x_shape,
3893
- y_shape=op.y_shape,
3894
- output_shape=op.output_shape,
3895
- dtype=op.dtype,
3896
4343
  )
3897
4344
  if isinstance(op, UnaryOp):
3898
4345
  return UnaryOp(
3899
4346
  input0=temp_map.get(op.input0, op.input0),
3900
4347
  output=temp_map.get(op.output, op.output),
3901
4348
  function=op.function,
3902
- shape=op.shape,
3903
- dtype=op.dtype,
3904
- input_dtype=op.input_dtype,
3905
4349
  params=op.params,
3906
4350
  )
3907
4351
  if isinstance(op, ClipOp):
@@ -3914,29 +4358,14 @@ class CEmitter:
3914
4358
  if op.input_max is not None
3915
4359
  else None,
3916
4360
  output=temp_map.get(op.output, op.output),
3917
- input_shape=op.input_shape,
3918
- min_shape=op.min_shape,
3919
- max_shape=op.max_shape,
3920
- output_shape=op.output_shape,
3921
- dtype=op.dtype,
4361
+ min_value=op.min_value,
4362
+ max_value=op.max_value,
3922
4363
  )
3923
4364
  if isinstance(op, MatMulOp):
3924
4365
  return MatMulOp(
3925
4366
  input0=temp_map.get(op.input0, op.input0),
3926
4367
  input1=temp_map.get(op.input1, op.input1),
3927
4368
  output=temp_map.get(op.output, op.output),
3928
- input0_shape=op.input0_shape,
3929
- input1_shape=op.input1_shape,
3930
- output_shape=op.output_shape,
3931
- batch_shape=op.batch_shape,
3932
- input0_batch_shape=op.input0_batch_shape,
3933
- input1_batch_shape=op.input1_batch_shape,
3934
- m=op.m,
3935
- n=op.n,
3936
- k=op.k,
3937
- left_vector=op.left_vector,
3938
- right_vector=op.right_vector,
3939
- dtype=op.dtype,
3940
4369
  )
3941
4370
  if isinstance(op, EinsumOp):
3942
4371
  return EinsumOp(
@@ -3972,6 +4401,56 @@ class CEmitter:
3972
4401
  input_dtype=op.input_dtype,
3973
4402
  scale_dtype=op.scale_dtype,
3974
4403
  )
4404
+ if isinstance(op, DequantizeLinearOp):
4405
+ return DequantizeLinearOp(
4406
+ input0=temp_map.get(op.input0, op.input0),
4407
+ scale=temp_map.get(op.scale, op.scale),
4408
+ zero_point=(
4409
+ temp_map.get(op.zero_point, op.zero_point)
4410
+ if op.zero_point is not None
4411
+ else None
4412
+ ),
4413
+ output=temp_map.get(op.output, op.output),
4414
+ input_shape=op.input_shape,
4415
+ axis=op.axis,
4416
+ block_size=op.block_size,
4417
+ dtype=op.dtype,
4418
+ input_dtype=op.input_dtype,
4419
+ scale_dtype=op.scale_dtype,
4420
+ )
4421
+ if isinstance(op, QLinearMulOp):
4422
+ return QLinearMulOp(
4423
+ input0=temp_map.get(op.input0, op.input0),
4424
+ input0_scale=temp_map.get(op.input0_scale, op.input0_scale),
4425
+ input0_zero_point=temp_map.get(
4426
+ op.input0_zero_point, op.input0_zero_point
4427
+ ),
4428
+ input1=temp_map.get(op.input1, op.input1),
4429
+ input1_scale=temp_map.get(op.input1_scale, op.input1_scale),
4430
+ input1_zero_point=temp_map.get(
4431
+ op.input1_zero_point, op.input1_zero_point
4432
+ ),
4433
+ output_scale=temp_map.get(op.output_scale, op.output_scale),
4434
+ output_zero_point=temp_map.get(
4435
+ op.output_zero_point, op.output_zero_point
4436
+ ),
4437
+ output=temp_map.get(op.output, op.output),
4438
+ input0_shape=op.input0_shape,
4439
+ input1_shape=op.input1_shape,
4440
+ output_shape=op.output_shape,
4441
+ input0_dtype=op.input0_dtype,
4442
+ input1_dtype=op.input1_dtype,
4443
+ dtype=op.dtype,
4444
+ input0_scale_dtype=op.input0_scale_dtype,
4445
+ input1_scale_dtype=op.input1_scale_dtype,
4446
+ output_scale_dtype=op.output_scale_dtype,
4447
+ input0_scale_shape=op.input0_scale_shape,
4448
+ input1_scale_shape=op.input1_scale_shape,
4449
+ output_scale_shape=op.output_scale_shape,
4450
+ input0_zero_shape=op.input0_zero_shape,
4451
+ input1_zero_shape=op.input1_zero_shape,
4452
+ output_zero_shape=op.output_zero_shape,
4453
+ )
3975
4454
  if isinstance(op, QLinearMatMulOp):
3976
4455
  return QLinearMatMulOp(
3977
4456
  input0=temp_map.get(op.input0, op.input0),
@@ -4023,15 +4502,10 @@ class CEmitter:
4023
4502
  else None
4024
4503
  ),
4025
4504
  output=temp_map.get(op.output, op.output),
4026
- m=op.m,
4027
- n=op.n,
4028
- k=op.k,
4029
4505
  trans_a=op.trans_a,
4030
4506
  trans_b=op.trans_b,
4031
4507
  alpha=op.alpha,
4032
4508
  beta=op.beta,
4033
- c_shape=op.c_shape,
4034
- dtype=op.dtype,
4035
4509
  )
4036
4510
  if isinstance(op, AttentionOp):
4037
4511
  return AttentionOp(
@@ -4133,6 +4607,51 @@ class CEmitter:
4133
4607
  input_rank=op.input_rank,
4134
4608
  interleaved=op.interleaved,
4135
4609
  )
4610
+ if isinstance(op, GruOp):
4611
+ return GruOp(
4612
+ input_x=temp_map.get(op.input_x, op.input_x),
4613
+ input_w=temp_map.get(op.input_w, op.input_w),
4614
+ input_r=temp_map.get(op.input_r, op.input_r),
4615
+ input_b=(
4616
+ temp_map.get(op.input_b, op.input_b)
4617
+ if op.input_b is not None
4618
+ else None
4619
+ ),
4620
+ input_sequence_lens=(
4621
+ temp_map.get(op.input_sequence_lens, op.input_sequence_lens)
4622
+ if op.input_sequence_lens is not None
4623
+ else None
4624
+ ),
4625
+ input_initial_h=(
4626
+ temp_map.get(op.input_initial_h, op.input_initial_h)
4627
+ if op.input_initial_h is not None
4628
+ else None
4629
+ ),
4630
+ output_y=(
4631
+ temp_map.get(op.output_y, op.output_y)
4632
+ if op.output_y is not None
4633
+ else None
4634
+ ),
4635
+ output_y_h=(
4636
+ temp_map.get(op.output_y_h, op.output_y_h)
4637
+ if op.output_y_h is not None
4638
+ else None
4639
+ ),
4640
+ seq_length=op.seq_length,
4641
+ batch_size=op.batch_size,
4642
+ input_size=op.input_size,
4643
+ hidden_size=op.hidden_size,
4644
+ num_directions=op.num_directions,
4645
+ direction=op.direction,
4646
+ layout=op.layout,
4647
+ linear_before_reset=op.linear_before_reset,
4648
+ clip=op.clip,
4649
+ activation_kinds=op.activation_kinds,
4650
+ activation_alphas=op.activation_alphas,
4651
+ activation_betas=op.activation_betas,
4652
+ dtype=op.dtype,
4653
+ sequence_lens_dtype=op.sequence_lens_dtype,
4654
+ )
4136
4655
  if isinstance(op, LstmOp):
4137
4656
  return LstmOp(
4138
4657
  input_x=temp_map.get(op.input_x, op.input_x),
@@ -4239,6 +4758,35 @@ class CEmitter:
4239
4758
  group=op.group,
4240
4759
  dtype=op.dtype,
4241
4760
  )
4761
+ if isinstance(op, ConvIntegerOp):
4762
+ return ConvIntegerOp(
4763
+ input0=temp_map.get(op.input0, op.input0),
4764
+ weights=temp_map.get(op.weights, op.weights),
4765
+ x_zero_point=temp_map.get(op.x_zero_point, op.x_zero_point)
4766
+ if op.x_zero_point
4767
+ else None,
4768
+ w_zero_point=temp_map.get(op.w_zero_point, op.w_zero_point)
4769
+ if op.w_zero_point
4770
+ else None,
4771
+ output=temp_map.get(op.output, op.output),
4772
+ batch=op.batch,
4773
+ in_channels=op.in_channels,
4774
+ out_channels=op.out_channels,
4775
+ spatial_rank=op.spatial_rank,
4776
+ in_spatial=op.in_spatial,
4777
+ out_spatial=op.out_spatial,
4778
+ kernel_shape=op.kernel_shape,
4779
+ strides=op.strides,
4780
+ pads=op.pads,
4781
+ dilations=op.dilations,
4782
+ group=op.group,
4783
+ input_dtype=op.input_dtype,
4784
+ weight_dtype=op.weight_dtype,
4785
+ dtype=op.dtype,
4786
+ x_zero_point_shape=op.x_zero_point_shape,
4787
+ w_zero_point_shape=op.w_zero_point_shape,
4788
+ w_zero_point_per_channel=op.w_zero_point_per_channel,
4789
+ )
4242
4790
  if isinstance(op, ConvTransposeOp):
4243
4791
  return ConvTransposeOp(
4244
4792
  input0=temp_map.get(op.input0, op.input0),
@@ -4265,16 +4813,26 @@ class CEmitter:
4265
4813
  output=temp_map.get(op.output, op.output),
4266
4814
  batch=op.batch,
4267
4815
  channels=op.channels,
4816
+ spatial_rank=op.spatial_rank,
4817
+ in_d=op.in_d,
4268
4818
  in_h=op.in_h,
4269
4819
  in_w=op.in_w,
4820
+ out_d=op.out_d,
4270
4821
  out_h=op.out_h,
4271
4822
  out_w=op.out_w,
4823
+ kernel_d=op.kernel_d,
4272
4824
  kernel_h=op.kernel_h,
4273
4825
  kernel_w=op.kernel_w,
4826
+ dilation_d=op.dilation_d,
4827
+ dilation_h=op.dilation_h,
4828
+ dilation_w=op.dilation_w,
4829
+ stride_d=op.stride_d,
4274
4830
  stride_h=op.stride_h,
4275
4831
  stride_w=op.stride_w,
4832
+ pad_front=op.pad_front,
4276
4833
  pad_top=op.pad_top,
4277
4834
  pad_left=op.pad_left,
4835
+ pad_back=op.pad_back,
4278
4836
  pad_bottom=op.pad_bottom,
4279
4837
  pad_right=op.pad_right,
4280
4838
  count_include_pad=op.count_include_pad,
@@ -4292,6 +4850,8 @@ class CEmitter:
4292
4850
  out_w=op.out_w,
4293
4851
  kernel_h=op.kernel_h,
4294
4852
  kernel_w=op.kernel_w,
4853
+ dilation_h=op.dilation_h,
4854
+ dilation_w=op.dilation_w,
4295
4855
  stride_h=op.stride_h,
4296
4856
  stride_w=op.stride_w,
4297
4857
  pad_top=op.pad_top,
@@ -4420,34 +4980,19 @@ class CEmitter:
4420
4980
  return SoftmaxOp(
4421
4981
  input0=temp_map.get(op.input0, op.input0),
4422
4982
  output=temp_map.get(op.output, op.output),
4423
- outer=op.outer,
4424
- axis_size=op.axis_size,
4425
- inner=op.inner,
4426
4983
  axis=op.axis,
4427
- shape=op.shape,
4428
- dtype=op.dtype,
4429
4984
  )
4430
4985
  if isinstance(op, LogSoftmaxOp):
4431
4986
  return LogSoftmaxOp(
4432
4987
  input0=temp_map.get(op.input0, op.input0),
4433
4988
  output=temp_map.get(op.output, op.output),
4434
- outer=op.outer,
4435
- axis_size=op.axis_size,
4436
- inner=op.inner,
4437
4989
  axis=op.axis,
4438
- shape=op.shape,
4439
- dtype=op.dtype,
4440
4990
  )
4441
4991
  if isinstance(op, HardmaxOp):
4442
4992
  return HardmaxOp(
4443
4993
  input0=temp_map.get(op.input0, op.input0),
4444
4994
  output=temp_map.get(op.output, op.output),
4445
- outer=op.outer,
4446
- axis_size=op.axis_size,
4447
- inner=op.inner,
4448
4995
  axis=op.axis,
4449
- shape=op.shape,
4450
- dtype=op.dtype,
4451
4996
  )
4452
4997
  if isinstance(op, NegativeLogLikelihoodLossOp):
4453
4998
  return NegativeLogLikelihoodLossOp(
@@ -4629,6 +5174,11 @@ class CEmitter:
4629
5174
  dtype=op.dtype,
4630
5175
  input_dtype=op.input_dtype,
4631
5176
  )
5177
+ if isinstance(op, OptionalHasElementOp):
5178
+ return OptionalHasElementOp(
5179
+ input0=temp_map.get(op.input0, op.input0),
5180
+ output=temp_map.get(op.output, op.output),
5181
+ )
4632
5182
  if isinstance(op, NonZeroOp):
4633
5183
  return NonZeroOp(
4634
5184
  input0=temp_map.get(op.input0, op.input0),
@@ -4695,6 +5245,15 @@ class CEmitter:
4695
5245
  dtype=op.dtype,
4696
5246
  input_dtype=op.input_dtype,
4697
5247
  )
5248
+ if isinstance(op, HammingWindowOp):
5249
+ return HammingWindowOp(
5250
+ size=temp_map.get(op.size, op.size),
5251
+ output=temp_map.get(op.output, op.output),
5252
+ output_shape=op.output_shape,
5253
+ periodic=op.periodic,
5254
+ dtype=op.dtype,
5255
+ input_dtype=op.input_dtype,
5256
+ )
4698
5257
  if isinstance(op, OneHotOp):
4699
5258
  return OneHotOp(
4700
5259
  indices=temp_map.get(op.indices, op.indices),
@@ -4710,6 +5269,23 @@ class CEmitter:
4710
5269
  indices_dtype=op.indices_dtype,
4711
5270
  depth_dtype=op.depth_dtype,
4712
5271
  )
5272
+ if isinstance(op, TfIdfVectorizerOp):
5273
+ return TfIdfVectorizerOp(
5274
+ input0=temp_map.get(op.input0, op.input0),
5275
+ output=temp_map.get(op.output, op.output),
5276
+ input_shape=op.input_shape,
5277
+ output_shape=op.output_shape,
5278
+ input_dtype=op.input_dtype,
5279
+ output_dtype=op.output_dtype,
5280
+ min_gram_length=op.min_gram_length,
5281
+ max_gram_length=op.max_gram_length,
5282
+ max_skip_count=op.max_skip_count,
5283
+ mode=op.mode,
5284
+ ngram_counts=op.ngram_counts,
5285
+ ngram_indexes=op.ngram_indexes,
5286
+ pool_int64s=op.pool_int64s,
5287
+ weights=op.weights,
5288
+ )
4713
5289
  if isinstance(op, SplitOp):
4714
5290
  return SplitOp(
4715
5291
  input0=temp_map.get(op.input0, op.input0),
@@ -4746,9 +5322,6 @@ class CEmitter:
4746
5322
  return IdentityOp(
4747
5323
  input0=temp_map.get(op.input0, op.input0),
4748
5324
  output=temp_map.get(op.output, op.output),
4749
- shape=op.shape,
4750
- dtype=op.dtype,
4751
- input_dtype=op.input_dtype,
4752
5325
  )
4753
5326
  if isinstance(op, EyeLikeOp):
4754
5327
  return EyeLikeOp(
@@ -4934,54 +5507,50 @@ class CEmitter:
4934
5507
  return ReduceOp(
4935
5508
  input0=temp_map.get(op.input0, op.input0),
4936
5509
  output=temp_map.get(op.output, op.output),
4937
- input_shape=op.input_shape,
4938
- output_shape=op.output_shape,
4939
5510
  axes=op.axes,
4940
5511
  axes_input=temp_map.get(op.axes_input, op.axes_input)
4941
5512
  if op.axes_input
4942
5513
  else None,
4943
- axes_input_shape=op.axes_input_shape,
4944
- axes_input_dtype=op.axes_input_dtype,
4945
5514
  keepdims=op.keepdims,
4946
5515
  noop_with_empty_axes=op.noop_with_empty_axes,
4947
5516
  reduce_kind=op.reduce_kind,
4948
5517
  reduce_count=op.reduce_count,
4949
- dtype=op.dtype,
4950
5518
  )
4951
5519
  if isinstance(op, ArgReduceOp):
4952
5520
  return ArgReduceOp(
4953
5521
  input0=temp_map.get(op.input0, op.input0),
4954
5522
  output=temp_map.get(op.output, op.output),
4955
- input_shape=op.input_shape,
4956
- output_shape=op.output_shape,
4957
5523
  axis=op.axis,
4958
5524
  keepdims=op.keepdims,
4959
5525
  select_last_index=op.select_last_index,
4960
5526
  reduce_kind=op.reduce_kind,
4961
- input_dtype=op.input_dtype,
4962
- output_dtype=op.output_dtype,
4963
5527
  )
4964
5528
  if isinstance(op, TopKOp):
4965
5529
  return TopKOp(
4966
5530
  input0=temp_map.get(op.input0, op.input0),
5531
+ k_input=temp_map.get(op.k_input, op.k_input),
4967
5532
  output_values=temp_map.get(op.output_values, op.output_values),
4968
5533
  output_indices=temp_map.get(op.output_indices, op.output_indices),
4969
- input_shape=op.input_shape,
4970
- output_shape=op.output_shape,
4971
5534
  axis=op.axis,
4972
5535
  k=op.k,
4973
5536
  largest=op.largest,
4974
5537
  sorted=op.sorted,
5538
+ )
5539
+ if isinstance(op, BernoulliOp):
5540
+ return BernoulliOp(
5541
+ input0=temp_map.get(op.input0, op.input0),
5542
+ output=temp_map.get(op.output, op.output),
5543
+ input_shape=op.input_shape,
5544
+ output_shape=op.output_shape,
4975
5545
  input_dtype=op.input_dtype,
4976
- output_values_dtype=op.output_values_dtype,
4977
- output_indices_dtype=op.output_indices_dtype,
5546
+ dtype=op.dtype,
5547
+ seed=op.seed,
4978
5548
  )
4979
5549
  return UnaryOp(
4980
5550
  input0=temp_map.get(op.input0, op.input0),
4981
5551
  output=temp_map.get(op.output, op.output),
4982
5552
  function=op.function,
4983
- shape=op.shape,
4984
- dtype=op.dtype,
5553
+ params=op.params,
4985
5554
  )
4986
5555
 
4987
5556
  def render_op(self, op: OpBase, ctx: EmitContext) -> str:
@@ -5007,6 +5576,8 @@ class CEmitter:
5007
5576
  clip_template=templates["clip"],
5008
5577
  cast_template=templates["cast"],
5009
5578
  quantize_linear_template=templates["quantize_linear"],
5579
+ dequantize_linear_template=templates["dequantize_linear"],
5580
+ qlinear_mul_template=templates["qlinear_mul"],
5010
5581
  qlinear_matmul_template=templates["qlinear_matmul"],
5011
5582
  matmul_template=templates["matmul"],
5012
5583
  einsum_template=templates["einsum"],
@@ -5014,6 +5585,7 @@ class CEmitter:
5014
5585
  attention_template=templates["attention"],
5015
5586
  rotary_embedding_template=templates["rotary_embedding"],
5016
5587
  conv_template=templates["conv"],
5588
+ conv_integer_template=templates["conv_integer"],
5017
5589
  conv_transpose_template=templates["conv_transpose"],
5018
5590
  avg_pool_template=templates["avg_pool"],
5019
5591
  lp_pool_template=templates["lp_pool"],
@@ -5025,6 +5597,7 @@ class CEmitter:
5025
5597
  mean_variance_norm_template=templates["mean_variance_norm"],
5026
5598
  rms_norm_template=templates["rms_norm"],
5027
5599
  lrn_template=templates["lrn"],
5600
+ gru_template=templates["gru"],
5028
5601
  lstm_template=templates["lstm"],
5029
5602
  adagrad_template=templates["adagrad"],
5030
5603
  softmax_template=templates["softmax"],
@@ -5043,6 +5616,7 @@ class CEmitter:
5043
5616
  transpose_template=templates["transpose"],
5044
5617
  reshape_template=templates["reshape"],
5045
5618
  identity_template=templates["identity"],
5619
+ bernoulli_template=templates["bernoulli"],
5046
5620
  eye_like_template=templates["eye_like"],
5047
5621
  trilu_template=templates["trilu"],
5048
5622
  tile_template=templates["tile"],
@@ -5060,12 +5634,15 @@ class CEmitter:
5060
5634
  constant_of_shape_template=templates["constant_of_shape"],
5061
5635
  shape_template=templates["shape"],
5062
5636
  size_template=templates["size"],
5637
+ optional_has_element_template=templates["optional_has_element"],
5063
5638
  nonzero_template=templates["nonzero"],
5064
5639
  nonmax_suppression_template=templates["nonmax_suppression"],
5065
5640
  expand_template=templates["expand"],
5066
5641
  cumsum_template=templates["cumsum"],
5067
5642
  range_template=templates["range"],
5643
+ hamming_window_template=templates["hamming_window"],
5068
5644
  one_hot_template=templates["one_hot"],
5645
+ tfidf_vectorizer_template=templates["tfidf_vectorizer"],
5069
5646
  split_template=templates["split"],
5070
5647
  scalar_registry=state.scalar_registry,
5071
5648
  dim_args=state.dim_args,
@@ -5091,6 +5668,8 @@ class CEmitter:
5091
5668
  clip_template,
5092
5669
  cast_template,
5093
5670
  quantize_linear_template,
5671
+ dequantize_linear_template,
5672
+ qlinear_mul_template,
5094
5673
  qlinear_matmul_template,
5095
5674
  matmul_template,
5096
5675
  einsum_template,
@@ -5098,6 +5677,7 @@ class CEmitter:
5098
5677
  attention_template,
5099
5678
  rotary_embedding_template,
5100
5679
  conv_template,
5680
+ conv_integer_template,
5101
5681
  conv_transpose_template,
5102
5682
  avg_pool_template,
5103
5683
  lp_pool_template,
@@ -5109,6 +5689,7 @@ class CEmitter:
5109
5689
  mean_variance_norm_template,
5110
5690
  rms_norm_template,
5111
5691
  lrn_template,
5692
+ gru_template,
5112
5693
  lstm_template,
5113
5694
  adagrad_template,
5114
5695
  softmax_template,
@@ -5125,6 +5706,7 @@ class CEmitter:
5125
5706
  transpose_template,
5126
5707
  reshape_template,
5127
5708
  identity_template,
5709
+ bernoulli_template,
5128
5710
  eye_like_template,
5129
5711
  trilu_template,
5130
5712
  tile_template,
@@ -5142,12 +5724,15 @@ class CEmitter:
5142
5724
  constant_of_shape_template,
5143
5725
  shape_template,
5144
5726
  size_template,
5727
+ optional_has_element_template,
5145
5728
  nonzero_template,
5146
5729
  nonmax_suppression_template,
5147
5730
  expand_template,
5148
5731
  cumsum_template,
5149
5732
  range_template,
5733
+ hamming_window_template,
5150
5734
  one_hot_template,
5735
+ tfidf_vectorizer_template,
5151
5736
  split_template,
5152
5737
  scalar_registry: ScalarFunctionRegistry | None = None,
5153
5738
  dim_args: str = "",
@@ -5169,6 +5754,11 @@ class CEmitter:
5169
5754
  input1_shape = self._ctx_shape(op.input1)
5170
5755
  output_shape = self._ctx_shape(op.output)
5171
5756
  input_dtype = self._ctx_dtype(op.input0)
5757
+ input1_dtype = (
5758
+ self._ctx_dtype(op.input1)
5759
+ if isinstance(op, PowOp)
5760
+ else input_dtype
5761
+ )
5172
5762
  output_dtype = self._ctx_dtype(op.output)
5173
5763
  params = self._shared_param_map(
5174
5764
  [
@@ -5207,11 +5797,12 @@ class CEmitter:
5207
5797
  input1_shape, _dim_names_for(op.input1)
5208
5798
  )
5209
5799
  input_c_type = input_dtype.c_type
5800
+ input1_c_type = input1_dtype.c_type
5210
5801
  output_c_type = output_dtype.c_type
5211
5802
  param_decls = self._build_param_decls(
5212
5803
  [
5213
5804
  (params["input0"], input_c_type, input0_suffix, True),
5214
- (params["input1"], input_c_type, input1_suffix, True),
5805
+ (params["input1"], input1_c_type, input1_suffix, True),
5215
5806
  (params["output"], output_c_type, output_suffix, False),
5216
5807
  ]
5217
5808
  )
@@ -5234,12 +5825,20 @@ class CEmitter:
5234
5825
  output_shape,
5235
5826
  loop_vars,
5236
5827
  )
5237
- right_expr = CEmitter._broadcast_index_expr(
5238
- params["input1"],
5239
- input1_shape,
5240
- output_shape,
5241
- loop_vars,
5242
- )
5828
+ prelu_axis = None
5829
+ if op.function == ScalarFunction.PRELU:
5830
+ derived_axis = self._maybe_derived(op, "prelu_slope_axis")
5831
+ if isinstance(derived_axis, int):
5832
+ prelu_axis = derived_axis
5833
+ if prelu_axis is None:
5834
+ right_expr = CEmitter._broadcast_index_expr(
5835
+ params["input1"],
5836
+ input1_shape,
5837
+ output_shape,
5838
+ loop_vars,
5839
+ )
5840
+ else:
5841
+ right_expr = f"{params['input1']}[{loop_vars[prelu_axis]}]"
5243
5842
  operator_expr = None
5244
5843
  operator = op_spec.operator
5245
5844
  operator_kind = op.operator_kind
@@ -5263,7 +5862,7 @@ class CEmitter:
5263
5862
  ).rstrip()
5264
5863
  return with_node_comment(rendered)
5265
5864
  if isinstance(op, MultiInputBinaryOp):
5266
- output_shape = self._ctx_shape(op.output)
5865
+ output_shape_raw = self._ctx_shape(op.output)
5267
5866
  input_dtype = self._ctx_dtype(op.inputs[0])
5268
5867
  output_dtype = self._ctx_dtype(op.output)
5269
5868
  params = self._shared_param_map(
@@ -5292,27 +5891,47 @@ class CEmitter:
5292
5891
  f"{op.function.value}"
5293
5892
  )
5294
5893
  output_dim_names = _dim_names_for(op.output)
5295
- shape = CEmitter._shape_dim_exprs(output_shape, output_dim_names)
5296
- loop_vars = CEmitter._loop_vars(output_shape)
5297
- array_suffix = self._param_array_suffix(
5298
- output_shape, output_dim_names
5894
+ shape = CEmitter._shape_dim_exprs(
5895
+ output_shape_raw, output_dim_names
5896
+ )
5897
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
5898
+ output_array_suffix = self._param_array_suffix(
5899
+ output_shape_raw, output_dim_names
5299
5900
  )
5300
5901
  input_c_type = input_dtype.c_type
5301
5902
  output_c_type = output_dtype.c_type
5302
5903
  input_names = [
5303
5904
  params[f"input{idx}"] for idx in range(len(op.inputs))
5304
5905
  ]
5906
+ input_shapes = [self._ctx_shape(name) for name in op.inputs]
5907
+ input_dim_names = [
5908
+ _dim_names_for(name) for name in op.inputs
5909
+ ]
5910
+ input_array_suffixes = [
5911
+ self._param_array_suffix(shape, dim_names)
5912
+ for shape, dim_names in zip(input_shapes, input_dim_names)
5913
+ ]
5305
5914
  param_decls = self._build_param_decls(
5306
5915
  [
5307
- *( (name, input_c_type, array_suffix, True) for name in input_names ),
5308
- (params["output"], output_c_type, array_suffix, False),
5916
+ *(
5917
+ (name, input_c_type, array_suffix, True)
5918
+ for name, array_suffix in zip(
5919
+ input_names, input_array_suffixes
5920
+ )
5921
+ ),
5922
+ (
5923
+ params["output"],
5924
+ output_c_type,
5925
+ output_array_suffix,
5926
+ False,
5927
+ ),
5309
5928
  ]
5310
5929
  )
5311
5930
  common = {
5312
5931
  "model_name": model.name,
5313
5932
  "op_name": op_name,
5314
5933
  "element_count": CEmitter._element_count_expr(shape),
5315
- "array_suffix": array_suffix,
5934
+ "array_suffix": output_array_suffix,
5316
5935
  "shape": shape,
5317
5936
  "loop_vars": loop_vars,
5318
5937
  "input_c_type": input_c_type,
@@ -5322,8 +5941,10 @@ class CEmitter:
5322
5941
  "params": param_decls,
5323
5942
  }
5324
5943
  input_exprs = [
5325
- f"{name}" + "".join(f"[{var}]" for var in loop_vars)
5326
- for name in input_names
5944
+ CEmitter._broadcast_index_expr(
5945
+ name, shape, output_shape_raw, loop_vars
5946
+ )
5947
+ for name, shape in zip(input_names, input_shapes)
5327
5948
  ]
5328
5949
  output_expr = f"{params['output']}" + "".join(
5329
5950
  f"[{var}]" for var in loop_vars
@@ -5452,37 +6073,51 @@ class CEmitter:
5452
6073
  ("output", op.output),
5453
6074
  ]
5454
6075
  )
5455
- output_shape = CEmitter._codegen_shape(op.output_shape)
6076
+ output_shape = CEmitter._codegen_shape(self._ctx_shape(op.output))
5456
6077
  output_loop_vars = CEmitter._loop_vars(output_shape)
5457
6078
  output_index_expr = f"{params['output']}" + "".join(
5458
6079
  f"[{var}]" for var in output_loop_vars
5459
6080
  )
5460
- batch_rank = len(op.batch_shape)
6081
+ batch_shape = self._derived(op, "batch_shape")
6082
+ batch_rank = len(batch_shape)
5461
6083
  batch_vars = output_loop_vars[:batch_rank]
5462
- if op.left_vector and op.right_vector:
6084
+ left_vector = bool(self._derived(op, "left_vector"))
6085
+ right_vector = bool(self._derived(op, "right_vector"))
6086
+ if left_vector and right_vector:
5463
6087
  row_var = None
5464
6088
  col_var = None
5465
- elif op.left_vector:
6089
+ elif left_vector:
5466
6090
  row_var = None
5467
6091
  col_var = output_loop_vars[-1]
5468
- elif op.right_vector:
6092
+ elif right_vector:
5469
6093
  row_var = output_loop_vars[-1]
5470
6094
  col_var = None
5471
6095
  else:
5472
6096
  row_var = output_loop_vars[-2]
5473
6097
  col_var = output_loop_vars[-1]
6098
+ input0_shape = self._ctx_shape(op.input0)
6099
+ input1_shape = self._ctx_shape(op.input1)
6100
+ input0_batch_shape = self._derived(op, "input0_batch_shape")
6101
+ input1_batch_shape = self._derived(op, "input1_batch_shape")
5474
6102
  input0_index_expr, input1_index_expr = CEmitter._matmul_index_exprs(
5475
- op,
5476
6103
  batch_vars,
5477
6104
  row_var,
5478
6105
  col_var,
5479
6106
  batch_rank,
5480
6107
  input0=params["input0"],
5481
6108
  input1=params["input1"],
5482
- )
5483
- input0_suffix = self._param_array_suffix(op.input0_shape)
5484
- input1_suffix = self._param_array_suffix(op.input1_shape)
5485
- output_suffix = self._param_array_suffix(op.output_shape)
6109
+ left_vector=left_vector,
6110
+ right_vector=right_vector,
6111
+ input0_shape=input0_shape,
6112
+ input1_shape=input1_shape,
6113
+ input0_batch_shape=input0_batch_shape,
6114
+ input1_batch_shape=input1_batch_shape,
6115
+ )
6116
+ input0_suffix = self._param_array_suffix(input0_shape)
6117
+ input1_suffix = self._param_array_suffix(input1_shape)
6118
+ output_suffix = self._param_array_suffix(self._ctx_shape(op.output))
6119
+ acc_dtype = self._accumulation_dtype(self._ctx_dtype(op.output))
6120
+ acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
5486
6121
  param_decls = self._build_param_decls(
5487
6122
  [
5488
6123
  (params["input0"], c_type, input0_suffix, True),
@@ -5490,6 +6125,9 @@ class CEmitter:
5490
6125
  (params["output"], c_type, output_suffix, False),
5491
6126
  ]
5492
6127
  )
6128
+ m = int(self._derived(op, "m"))
6129
+ n = int(self._derived(op, "n"))
6130
+ k = int(self._derived(op, "k"))
5493
6131
  rendered = matmul_template.render(
5494
6132
  model_name=model.name,
5495
6133
  op_name=op_name,
@@ -5498,8 +6136,8 @@ class CEmitter:
5498
6136
  output=params["output"],
5499
6137
  params=param_decls,
5500
6138
  c_type=c_type,
5501
- acc_type=c_type,
5502
- zero_literal=zero_literal,
6139
+ acc_type=acc_dtype.c_type,
6140
+ zero_literal=acc_zero_literal,
5503
6141
  input0_suffix=input0_suffix,
5504
6142
  input1_suffix=input1_suffix,
5505
6143
  output_suffix=output_suffix,
@@ -5508,9 +6146,9 @@ class CEmitter:
5508
6146
  output_index_expr=output_index_expr,
5509
6147
  input0_index_expr=input0_index_expr,
5510
6148
  input1_index_expr=input1_index_expr,
5511
- m=op.m,
5512
- n=op.n,
5513
- k=op.k,
6149
+ m=m,
6150
+ n=n,
6151
+ k=k,
5514
6152
  ).rstrip()
5515
6153
  return with_node_comment(rendered)
5516
6154
  if isinstance(op, EinsumOp):
@@ -5561,6 +6199,8 @@ class CEmitter:
5561
6199
  ),
5562
6200
  ]
5563
6201
  )
6202
+ acc_dtype = self._accumulation_dtype(self._ctx_dtype(op.output))
6203
+ acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
5564
6204
  input_loop_vars: tuple[str, ...] = ()
5565
6205
  input_loop_bounds: tuple[str | int, ...] = ()
5566
6206
  reduce_loop_var = "k"
@@ -5633,8 +6273,8 @@ class CEmitter:
5633
6273
  output_loop_vars=output_loop_vars,
5634
6274
  output_loop_bounds=output_shape,
5635
6275
  output_expr=output_expr,
5636
- acc_type=op.dtype.c_type,
5637
- zero_literal=zero_literal,
6276
+ acc_type=acc_dtype.c_type,
6277
+ zero_literal=acc_zero_literal,
5638
6278
  input_loop_vars=input_loop_vars,
5639
6279
  input_loop_bounds=input_loop_bounds,
5640
6280
  reduce_loop_var=reduce_loop_var,
@@ -5653,14 +6293,20 @@ class CEmitter:
5653
6293
  ("output", op.output),
5654
6294
  ]
5655
6295
  )
5656
- input_a_shape = (op.k, op.m) if op.trans_a else (op.m, op.k)
5657
- input_b_shape = (op.n, op.k) if op.trans_b else (op.k, op.n)
6296
+ m = int(self._derived(op, "m"))
6297
+ n = int(self._derived(op, "n"))
6298
+ k = int(self._derived(op, "k"))
6299
+ trans_a = bool(self._derived(op, "trans_a"))
6300
+ trans_b = bool(self._derived(op, "trans_b"))
6301
+ c_shape = self._derived(op, "c_shape")
6302
+ input_a_shape = (k, m) if trans_a else (m, k)
6303
+ input_b_shape = (n, k) if trans_b else (k, n)
5658
6304
  input_a_suffix = self._param_array_suffix(input_a_shape)
5659
6305
  input_b_suffix = self._param_array_suffix(input_b_shape)
5660
- output_suffix = self._param_array_suffix((op.m, op.n))
6306
+ output_suffix = self._param_array_suffix((m, n))
5661
6307
  c_suffix = (
5662
- self._param_array_suffix(op.c_shape)
5663
- if op.c_shape is not None
6308
+ self._param_array_suffix(c_shape)
6309
+ if c_shape is not None
5664
6310
  else ""
5665
6311
  )
5666
6312
  param_decls = self._build_param_decls(
@@ -5678,24 +6324,31 @@ class CEmitter:
5678
6324
  (params["output"], c_type, output_suffix, False),
5679
6325
  ]
5680
6326
  )
5681
- alpha_literal = CEmitter._format_literal(op.dtype, op.alpha)
5682
- beta_literal = CEmitter._format_literal(op.dtype, op.beta)
5683
- if op.c_shape is None:
6327
+ dtype = self._ctx_dtype(op.output)
6328
+ alpha_literal = CEmitter._format_literal(
6329
+ dtype, self._derived(op, "alpha")
6330
+ )
6331
+ beta_literal = CEmitter._format_literal(
6332
+ dtype, self._derived(op, "beta")
6333
+ )
6334
+ acc_dtype = self._accumulation_dtype(dtype)
6335
+ acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
6336
+ if c_shape is None:
5684
6337
  c_rank = 0
5685
6338
  c_dim0 = 0
5686
6339
  c_dim1 = 0
5687
- elif len(op.c_shape) == 0:
6340
+ elif len(c_shape) == 0:
5688
6341
  c_rank = 0
5689
6342
  c_dim0 = 0
5690
6343
  c_dim1 = 0
5691
- elif len(op.c_shape) == 1:
6344
+ elif len(c_shape) == 1:
5692
6345
  c_rank = 1
5693
6346
  c_dim0 = 1
5694
- c_dim1 = op.c_shape[0]
6347
+ c_dim1 = c_shape[0]
5695
6348
  else:
5696
6349
  c_rank = 2
5697
- c_dim0 = op.c_shape[0]
5698
- c_dim1 = op.c_shape[1]
6350
+ c_dim0 = c_shape[0]
6351
+ c_dim1 = c_shape[1]
5699
6352
  rendered = gemm_template.render(
5700
6353
  model_name=model.name,
5701
6354
  op_name=op_name,
@@ -5704,21 +6357,21 @@ class CEmitter:
5704
6357
  input_c=params["input_c"],
5705
6358
  output=params["output"],
5706
6359
  params=param_decls,
5707
- c_type=c_type,
5708
- acc_type=c_type,
5709
- zero_literal=zero_literal,
6360
+ c_type=dtype.c_type,
6361
+ acc_type=acc_dtype.c_type,
6362
+ zero_literal=acc_zero_literal,
5710
6363
  alpha_literal=alpha_literal,
5711
6364
  beta_literal=beta_literal,
5712
- trans_a=int(op.trans_a),
5713
- trans_b=int(op.trans_b),
5714
- m=op.m,
5715
- n=op.n,
5716
- k=op.k,
6365
+ trans_a=int(trans_a),
6366
+ trans_b=int(trans_b),
6367
+ m=m,
6368
+ n=n,
6369
+ k=k,
5717
6370
  input_a_suffix=input_a_suffix,
5718
6371
  input_b_suffix=input_b_suffix,
5719
6372
  output_suffix=output_suffix,
5720
6373
  c_suffix=(
5721
- c_suffix if op.c_shape is not None else None
6374
+ c_suffix if c_shape is not None else None
5722
6375
  ),
5723
6376
  c_rank=c_rank,
5724
6377
  c_dim0=c_dim0,
@@ -6034,6 +6687,9 @@ class CEmitter:
6034
6687
  ("output", op.output),
6035
6688
  ]
6036
6689
  )
6690
+ acc_dtype = self._accumulation_dtype(op.dtype)
6691
+ acc_type = acc_dtype.c_type
6692
+ acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
6037
6693
  input_shape = (op.batch, op.in_channels, *op.in_spatial)
6038
6694
  weight_shape = (
6039
6695
  op.out_channels,
@@ -6077,6 +6733,8 @@ class CEmitter:
6077
6733
  output=params["output"],
6078
6734
  params=param_decls,
6079
6735
  c_type=c_type,
6736
+ acc_type=acc_type,
6737
+ acc_zero_literal=acc_zero_literal,
6080
6738
  zero_literal=zero_literal,
6081
6739
  input_suffix=input_suffix,
6082
6740
  weight_suffix=weight_suffix,
@@ -6100,6 +6758,129 @@ class CEmitter:
6100
6758
  in_indices=in_indices,
6101
6759
  ).rstrip()
6102
6760
  return with_node_comment(rendered)
6761
+ if isinstance(op, ConvIntegerOp):
6762
+ params = self._shared_param_map(
6763
+ [
6764
+ ("input0", op.input0),
6765
+ ("weights", op.weights),
6766
+ ("x_zero_point", op.x_zero_point),
6767
+ ("w_zero_point", op.w_zero_point),
6768
+ ("output", op.output),
6769
+ ]
6770
+ )
6771
+ acc_dtype = op.dtype
6772
+ acc_type = acc_dtype.c_type
6773
+ acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
6774
+ input_shape = (op.batch, op.in_channels, *op.in_spatial)
6775
+ weight_shape = (
6776
+ op.out_channels,
6777
+ op.in_channels // op.group,
6778
+ *op.kernel_shape,
6779
+ )
6780
+ output_shape = (op.batch, op.out_channels, *op.out_spatial)
6781
+ out_indices = tuple(f"od{dim}" for dim in range(op.spatial_rank))
6782
+ kernel_indices = tuple(
6783
+ f"kd{dim}" for dim in range(op.spatial_rank)
6784
+ )
6785
+ in_indices = tuple(f"id{dim}" for dim in range(op.spatial_rank))
6786
+ pad_begin = op.pads[: op.spatial_rank]
6787
+ group_in_channels = op.in_channels // op.group
6788
+ group_out_channels = op.out_channels // op.group
6789
+ input_suffix = self._param_array_suffix(input_shape)
6790
+ weight_suffix = self._param_array_suffix(weight_shape)
6791
+ output_suffix = self._param_array_suffix(output_shape)
6792
+ x_zero_suffix = (
6793
+ self._param_array_suffix(op.x_zero_point_shape)
6794
+ if op.x_zero_point_shape is not None
6795
+ else ""
6796
+ )
6797
+ w_zero_suffix = (
6798
+ self._param_array_suffix(op.w_zero_point_shape)
6799
+ if op.w_zero_point_shape is not None
6800
+ else ""
6801
+ )
6802
+ param_decls = self._build_param_decls(
6803
+ [
6804
+ (
6805
+ params["input0"],
6806
+ op.input_dtype.c_type,
6807
+ input_suffix,
6808
+ True,
6809
+ ),
6810
+ (
6811
+ params["weights"],
6812
+ op.weight_dtype.c_type,
6813
+ weight_suffix,
6814
+ True,
6815
+ ),
6816
+ (
6817
+ params["x_zero_point"],
6818
+ op.input_dtype.c_type,
6819
+ x_zero_suffix,
6820
+ True,
6821
+ )
6822
+ if params["x_zero_point"]
6823
+ else (None, "", "", True),
6824
+ (
6825
+ params["w_zero_point"],
6826
+ op.weight_dtype.c_type,
6827
+ w_zero_suffix,
6828
+ True,
6829
+ )
6830
+ if params["w_zero_point"]
6831
+ else (None, "", "", True),
6832
+ (params["output"], c_type, output_suffix, False),
6833
+ ]
6834
+ )
6835
+ x_zero_expr = (
6836
+ f"{params['x_zero_point']}[0]"
6837
+ if params["x_zero_point"]
6838
+ else "0"
6839
+ )
6840
+ if params["w_zero_point"]:
6841
+ if op.w_zero_point_per_channel:
6842
+ w_zero_expr = f"{params['w_zero_point']}[oc_global]"
6843
+ else:
6844
+ w_zero_expr = f"{params['w_zero_point']}[0]"
6845
+ else:
6846
+ w_zero_expr = "0"
6847
+ rendered = conv_integer_template.render(
6848
+ model_name=model.name,
6849
+ op_name=op_name,
6850
+ input0=params["input0"],
6851
+ weights=params["weights"],
6852
+ x_zero_point=params["x_zero_point"],
6853
+ w_zero_point=params["w_zero_point"],
6854
+ output=params["output"],
6855
+ params=param_decls,
6856
+ c_type=c_type,
6857
+ acc_type=acc_type,
6858
+ acc_zero_literal=acc_zero_literal,
6859
+ input_suffix=input_suffix,
6860
+ weight_suffix=weight_suffix,
6861
+ x_zero_suffix=x_zero_suffix,
6862
+ w_zero_suffix=w_zero_suffix,
6863
+ output_suffix=output_suffix,
6864
+ batch=op.batch,
6865
+ in_channels=op.in_channels,
6866
+ out_channels=op.out_channels,
6867
+ spatial_rank=op.spatial_rank,
6868
+ in_spatial=op.in_spatial,
6869
+ out_spatial=op.out_spatial,
6870
+ kernel_shape=op.kernel_shape,
6871
+ strides=op.strides,
6872
+ pads_begin=pad_begin,
6873
+ dilations=op.dilations,
6874
+ group=op.group,
6875
+ group_in_channels=group_in_channels,
6876
+ group_out_channels=group_out_channels,
6877
+ out_indices=out_indices,
6878
+ kernel_indices=kernel_indices,
6879
+ in_indices=in_indices,
6880
+ x_zero_expr=x_zero_expr,
6881
+ w_zero_expr=w_zero_expr,
6882
+ ).rstrip()
6883
+ return with_node_comment(rendered)
6103
6884
  if isinstance(op, ConvTransposeOp):
6104
6885
  params = self._shared_param_map(
6105
6886
  [
@@ -6179,8 +6960,27 @@ class CEmitter:
6179
6960
  params = self._shared_param_map(
6180
6961
  [("input0", op.input0), ("output", op.output)]
6181
6962
  )
6182
- input_shape = (op.batch, op.channels, op.in_h, op.in_w)
6183
- output_shape = (op.batch, op.channels, op.out_h, op.out_w)
6963
+ if op.spatial_rank == 3:
6964
+ input_shape = (
6965
+ op.batch,
6966
+ op.channels,
6967
+ op.in_d,
6968
+ op.in_h,
6969
+ op.in_w,
6970
+ )
6971
+ output_shape = (
6972
+ op.batch,
6973
+ op.channels,
6974
+ op.out_d,
6975
+ op.out_h,
6976
+ op.out_w,
6977
+ )
6978
+ elif op.spatial_rank == 1:
6979
+ input_shape = (op.batch, op.channels, op.in_w)
6980
+ output_shape = (op.batch, op.channels, op.out_w)
6981
+ else:
6982
+ input_shape = (op.batch, op.channels, op.in_h, op.in_w)
6983
+ output_shape = (op.batch, op.channels, op.out_h, op.out_w)
6184
6984
  input_suffix = self._param_array_suffix(input_shape)
6185
6985
  output_suffix = self._param_array_suffix(output_shape)
6186
6986
  param_decls = self._build_param_decls(
@@ -6201,16 +7001,26 @@ class CEmitter:
6201
7001
  output_suffix=output_suffix,
6202
7002
  batch=op.batch,
6203
7003
  channels=op.channels,
7004
+ spatial_rank=op.spatial_rank,
7005
+ in_d=op.in_d,
6204
7006
  in_h=op.in_h,
6205
7007
  in_w=op.in_w,
7008
+ out_d=op.out_d,
6206
7009
  out_h=op.out_h,
6207
7010
  out_w=op.out_w,
7011
+ kernel_d=op.kernel_d,
6208
7012
  kernel_h=op.kernel_h,
6209
7013
  kernel_w=op.kernel_w,
7014
+ dilation_d=op.dilation_d,
7015
+ dilation_h=op.dilation_h,
7016
+ dilation_w=op.dilation_w,
7017
+ stride_d=op.stride_d,
6210
7018
  stride_h=op.stride_h,
6211
7019
  stride_w=op.stride_w,
7020
+ pad_front=op.pad_front,
6212
7021
  pad_top=op.pad_top,
6213
7022
  pad_left=op.pad_left,
7023
+ pad_back=op.pad_back,
6214
7024
  pad_bottom=op.pad_bottom,
6215
7025
  pad_right=op.pad_right,
6216
7026
  count_include_pad=int(op.count_include_pad),
@@ -6247,6 +7057,8 @@ class CEmitter:
6247
7057
  out_w=op.out_w,
6248
7058
  kernel_h=op.kernel_h,
6249
7059
  kernel_w=op.kernel_w,
7060
+ dilation_h=op.dilation_h,
7061
+ dilation_w=op.dilation_w,
6250
7062
  stride_h=op.stride_h,
6251
7063
  stride_w=op.stride_w,
6252
7064
  pad_top=op.pad_top,
@@ -6431,11 +7243,7 @@ class CEmitter:
6431
7243
  ).rstrip()
6432
7244
  return with_node_comment(rendered)
6433
7245
  if isinstance(op, LayerNormalizationOp):
6434
- acc_dtype = (
6435
- ScalarType.F32
6436
- if op.dtype in {ScalarType.F16, ScalarType.F32}
6437
- else op.dtype
6438
- )
7246
+ acc_dtype = self._accumulation_dtype(op.dtype)
6439
7247
  acc_type = acc_dtype.c_type
6440
7248
  acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
6441
7249
  acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
@@ -6443,7 +7251,7 @@ class CEmitter:
6443
7251
  op.epsilon, acc_dtype
6444
7252
  )
6445
7253
  acc_sqrt_fn = CEmitter._math_fn(acc_dtype, "sqrtf", "sqrt")
6446
- use_kahan = op.dtype in {ScalarType.F16, ScalarType.F32}
7254
+ use_kahan = False
6447
7255
  params = self._shared_param_map(
6448
7256
  [
6449
7257
  ("input0", op.input0),
@@ -6678,7 +7486,7 @@ class CEmitter:
6678
7486
  pow_fn=CEmitter._math_fn(op.dtype, "powf", "pow"),
6679
7487
  ).rstrip()
6680
7488
  return with_node_comment(rendered)
6681
- if isinstance(op, LstmOp):
7489
+ if isinstance(op, GruOp):
6682
7490
  params = self._shared_param_map(
6683
7491
  [
6684
7492
  ("input_x", op.input_x),
@@ -6687,11 +7495,8 @@ class CEmitter:
6687
7495
  ("input_b", op.input_b),
6688
7496
  ("input_sequence_lens", op.input_sequence_lens),
6689
7497
  ("input_initial_h", op.input_initial_h),
6690
- ("input_initial_c", op.input_initial_c),
6691
- ("input_p", op.input_p),
6692
7498
  ("output_y", op.output_y),
6693
7499
  ("output_y_h", op.output_y_h),
6694
- ("output_y_c", op.output_y_c),
6695
7500
  ]
6696
7501
  )
6697
7502
  input_x_shape = (
@@ -6699,14 +7504,16 @@ class CEmitter:
6699
7504
  if op.layout == 0
6700
7505
  else (op.batch_size, op.seq_length, op.input_size)
6701
7506
  )
6702
- w_shape = (op.num_directions, 4 * op.hidden_size, op.input_size)
6703
- r_shape = (op.num_directions, 4 * op.hidden_size, op.hidden_size)
7507
+ w_shape = (op.num_directions, 3 * op.hidden_size, op.input_size)
7508
+ r_shape = (op.num_directions, 3 * op.hidden_size, op.hidden_size)
6704
7509
  b_shape = (
6705
- (op.num_directions, 8 * op.hidden_size)
7510
+ (op.num_directions, 6 * op.hidden_size)
6706
7511
  if op.input_b is not None
6707
7512
  else None
6708
7513
  )
6709
- seq_shape = (op.batch_size,) if op.input_sequence_lens is not None else None
7514
+ seq_shape = (
7515
+ (op.batch_size,) if op.input_sequence_lens is not None else None
7516
+ )
6710
7517
  state_shape = (
6711
7518
  (op.num_directions, op.batch_size, op.hidden_size)
6712
7519
  if op.layout == 0
@@ -6717,16 +7524,6 @@ class CEmitter:
6717
7524
  if op.input_initial_h is not None or op.output_y_h is not None
6718
7525
  else None
6719
7526
  )
6720
- c_shape = (
6721
- state_shape
6722
- if op.input_initial_c is not None or op.output_y_c is not None
6723
- else None
6724
- )
6725
- p_shape = (
6726
- (op.num_directions, 3 * op.hidden_size)
6727
- if op.input_p is not None
6728
- else None
6729
- )
6730
7527
  y_shape = (
6731
7528
  (op.seq_length, op.num_directions, op.batch_size, op.hidden_size)
6732
7529
  if op.layout == 0
@@ -6776,22 +7573,6 @@ class CEmitter:
6776
7573
  )
6777
7574
  if params["input_initial_h"]
6778
7575
  else (None, "", "", True),
6779
- (
6780
- params["input_initial_c"],
6781
- c_type,
6782
- self._param_array_suffix(c_shape),
6783
- True,
6784
- )
6785
- if params["input_initial_c"]
6786
- else (None, "", "", True),
6787
- (
6788
- params["input_p"],
6789
- c_type,
6790
- self._param_array_suffix(p_shape),
6791
- True,
6792
- )
6793
- if params["input_p"]
6794
- else (None, "", "", True),
6795
7576
  (
6796
7577
  params["output_y"],
6797
7578
  c_type,
@@ -6808,22 +7589,14 @@ class CEmitter:
6808
7589
  )
6809
7590
  if params["output_y_h"]
6810
7591
  else (None, "", "", False),
6811
- (
6812
- params["output_y_c"],
6813
- c_type,
6814
- self._param_array_suffix(c_shape),
6815
- False,
6816
- )
6817
- if params["output_y_c"]
6818
- else (None, "", "", False),
6819
7592
  ]
6820
7593
  )
6821
7594
  if scalar_registry is None:
6822
7595
  raise CodegenError(
6823
- "Scalar function registry is required for LSTM codegen."
7596
+ "Scalar function registry is required for GRU codegen."
6824
7597
  )
6825
7598
  activation_functions = tuple(
6826
- self._lstm_activation_function_name(
7599
+ self._rnn_activation_function_name(
6827
7600
  kind,
6828
7601
  alpha,
6829
7602
  beta,
@@ -6836,7 +7609,7 @@ class CEmitter:
6836
7609
  op.activation_betas,
6837
7610
  )
6838
7611
  )
6839
- rendered = lstm_template.render(
7612
+ rendered = gru_template.render(
6840
7613
  model_name=model.name,
6841
7614
  op_name=op_name,
6842
7615
  input_x=params["input_x"],
@@ -6845,11 +7618,8 @@ class CEmitter:
6845
7618
  input_b=params["input_b"],
6846
7619
  input_sequence_lens=params["input_sequence_lens"],
6847
7620
  input_initial_h=params["input_initial_h"],
6848
- input_initial_c=params["input_initial_c"],
6849
- input_p=params["input_p"],
6850
7621
  output_y=params["output_y"],
6851
7622
  output_y_h=params["output_y_h"],
6852
- output_y_c=params["output_y_c"],
6853
7623
  params=param_decls,
6854
7624
  c_type=c_type,
6855
7625
  seq_c_type=(op.sequence_lens_dtype or ScalarType.I64).c_type,
@@ -6868,38 +7638,232 @@ class CEmitter:
6868
7638
  num_directions=op.num_directions,
6869
7639
  layout=op.layout,
6870
7640
  direction=op.direction,
6871
- input_forget=op.input_forget,
7641
+ linear_before_reset=op.linear_before_reset,
6872
7642
  activation_functions=activation_functions,
6873
7643
  ).rstrip()
6874
7644
  return with_node_comment(rendered)
6875
- if isinstance(op, AdagradOp):
7645
+ if isinstance(op, LstmOp):
6876
7646
  params = self._shared_param_map(
6877
7647
  [
6878
- ("rate", op.rate),
6879
- ("timestep", op.timestep),
6880
- *(
6881
- (f"input{idx}", name)
6882
- for idx, name in enumerate(op.inputs)
6883
- ),
6884
- *(
6885
- (f"grad{idx}", name)
6886
- for idx, name in enumerate(op.gradients)
6887
- ),
6888
- *(
6889
- (f"acc{idx}", name)
6890
- for idx, name in enumerate(op.accumulators)
6891
- ),
6892
- *(
6893
- (f"output{idx}", name)
6894
- for idx, name in enumerate(op.outputs)
6895
- ),
6896
- *(
6897
- (f"acc_output{idx}", name)
6898
- for idx, name in enumerate(op.accumulator_outputs)
6899
- ),
6900
- ]
6901
- )
6902
- rate_suffix = self._param_array_suffix(
7648
+ ("input_x", op.input_x),
7649
+ ("input_w", op.input_w),
7650
+ ("input_r", op.input_r),
7651
+ ("input_b", op.input_b),
7652
+ ("input_sequence_lens", op.input_sequence_lens),
7653
+ ("input_initial_h", op.input_initial_h),
7654
+ ("input_initial_c", op.input_initial_c),
7655
+ ("input_p", op.input_p),
7656
+ ("output_y", op.output_y),
7657
+ ("output_y_h", op.output_y_h),
7658
+ ("output_y_c", op.output_y_c),
7659
+ ]
7660
+ )
7661
+ input_x_shape = (
7662
+ (op.seq_length, op.batch_size, op.input_size)
7663
+ if op.layout == 0
7664
+ else (op.batch_size, op.seq_length, op.input_size)
7665
+ )
7666
+ w_shape = (op.num_directions, 4 * op.hidden_size, op.input_size)
7667
+ r_shape = (op.num_directions, 4 * op.hidden_size, op.hidden_size)
7668
+ b_shape = (
7669
+ (op.num_directions, 8 * op.hidden_size)
7670
+ if op.input_b is not None
7671
+ else None
7672
+ )
7673
+ seq_shape = (op.batch_size,) if op.input_sequence_lens is not None else None
7674
+ state_shape = (
7675
+ (op.num_directions, op.batch_size, op.hidden_size)
7676
+ if op.layout == 0
7677
+ else (op.batch_size, op.num_directions, op.hidden_size)
7678
+ )
7679
+ h_shape = (
7680
+ state_shape
7681
+ if op.input_initial_h is not None or op.output_y_h is not None
7682
+ else None
7683
+ )
7684
+ c_shape = (
7685
+ state_shape
7686
+ if op.input_initial_c is not None or op.output_y_c is not None
7687
+ else None
7688
+ )
7689
+ p_shape = (
7690
+ (op.num_directions, 3 * op.hidden_size)
7691
+ if op.input_p is not None
7692
+ else None
7693
+ )
7694
+ y_shape = (
7695
+ (op.seq_length, op.num_directions, op.batch_size, op.hidden_size)
7696
+ if op.layout == 0
7697
+ else (op.batch_size, op.seq_length, op.num_directions, op.hidden_size)
7698
+ )
7699
+ param_decls = self._build_param_decls(
7700
+ [
7701
+ (
7702
+ params["input_x"],
7703
+ c_type,
7704
+ self._param_array_suffix(input_x_shape),
7705
+ True,
7706
+ ),
7707
+ (
7708
+ params["input_w"],
7709
+ c_type,
7710
+ self._param_array_suffix(w_shape),
7711
+ True,
7712
+ ),
7713
+ (
7714
+ params["input_r"],
7715
+ c_type,
7716
+ self._param_array_suffix(r_shape),
7717
+ True,
7718
+ ),
7719
+ (
7720
+ params["input_b"],
7721
+ c_type,
7722
+ self._param_array_suffix(b_shape),
7723
+ True,
7724
+ )
7725
+ if params["input_b"]
7726
+ else (None, "", "", True),
7727
+ (
7728
+ params["input_sequence_lens"],
7729
+ (op.sequence_lens_dtype or ScalarType.I64).c_type,
7730
+ self._param_array_suffix(seq_shape),
7731
+ True,
7732
+ )
7733
+ if params["input_sequence_lens"]
7734
+ else (None, "", "", True),
7735
+ (
7736
+ params["input_initial_h"],
7737
+ c_type,
7738
+ self._param_array_suffix(h_shape),
7739
+ True,
7740
+ )
7741
+ if params["input_initial_h"]
7742
+ else (None, "", "", True),
7743
+ (
7744
+ params["input_initial_c"],
7745
+ c_type,
7746
+ self._param_array_suffix(c_shape),
7747
+ True,
7748
+ )
7749
+ if params["input_initial_c"]
7750
+ else (None, "", "", True),
7751
+ (
7752
+ params["input_p"],
7753
+ c_type,
7754
+ self._param_array_suffix(p_shape),
7755
+ True,
7756
+ )
7757
+ if params["input_p"]
7758
+ else (None, "", "", True),
7759
+ (
7760
+ params["output_y"],
7761
+ c_type,
7762
+ self._param_array_suffix(y_shape),
7763
+ False,
7764
+ )
7765
+ if params["output_y"]
7766
+ else (None, "", "", False),
7767
+ (
7768
+ params["output_y_h"],
7769
+ c_type,
7770
+ self._param_array_suffix(h_shape),
7771
+ False,
7772
+ )
7773
+ if params["output_y_h"]
7774
+ else (None, "", "", False),
7775
+ (
7776
+ params["output_y_c"],
7777
+ c_type,
7778
+ self._param_array_suffix(c_shape),
7779
+ False,
7780
+ )
7781
+ if params["output_y_c"]
7782
+ else (None, "", "", False),
7783
+ ]
7784
+ )
7785
+ if scalar_registry is None:
7786
+ raise CodegenError(
7787
+ "Scalar function registry is required for LSTM codegen."
7788
+ )
7789
+ activation_functions = tuple(
7790
+ self._rnn_activation_function_name(
7791
+ kind,
7792
+ alpha,
7793
+ beta,
7794
+ op.dtype,
7795
+ scalar_registry,
7796
+ )
7797
+ for kind, alpha, beta in zip(
7798
+ op.activation_kinds,
7799
+ op.activation_alphas,
7800
+ op.activation_betas,
7801
+ )
7802
+ )
7803
+ rendered = lstm_template.render(
7804
+ model_name=model.name,
7805
+ op_name=op_name,
7806
+ input_x=params["input_x"],
7807
+ input_w=params["input_w"],
7808
+ input_r=params["input_r"],
7809
+ input_b=params["input_b"],
7810
+ input_sequence_lens=params["input_sequence_lens"],
7811
+ input_initial_h=params["input_initial_h"],
7812
+ input_initial_c=params["input_initial_c"],
7813
+ input_p=params["input_p"],
7814
+ output_y=params["output_y"],
7815
+ output_y_h=params["output_y_h"],
7816
+ output_y_c=params["output_y_c"],
7817
+ params=param_decls,
7818
+ c_type=c_type,
7819
+ seq_c_type=(op.sequence_lens_dtype or ScalarType.I64).c_type,
7820
+ zero_literal=zero_literal,
7821
+ one_literal=CEmitter._format_literal(op.dtype, 1),
7822
+ clip_literal=(
7823
+ CEmitter._format_floating(op.clip, op.dtype)
7824
+ if op.clip is not None
7825
+ else CEmitter._format_literal(op.dtype, 0)
7826
+ ),
7827
+ use_clip=int(op.clip is not None and op.clip > 0),
7828
+ seq_length=op.seq_length,
7829
+ batch_size=op.batch_size,
7830
+ input_size=op.input_size,
7831
+ hidden_size=op.hidden_size,
7832
+ num_directions=op.num_directions,
7833
+ layout=op.layout,
7834
+ direction=op.direction,
7835
+ input_forget=op.input_forget,
7836
+ activation_functions=activation_functions,
7837
+ ).rstrip()
7838
+ return with_node_comment(rendered)
7839
+ if isinstance(op, AdagradOp):
7840
+ params = self._shared_param_map(
7841
+ [
7842
+ ("rate", op.rate),
7843
+ ("timestep", op.timestep),
7844
+ *(
7845
+ (f"input{idx}", name)
7846
+ for idx, name in enumerate(op.inputs)
7847
+ ),
7848
+ *(
7849
+ (f"grad{idx}", name)
7850
+ for idx, name in enumerate(op.gradients)
7851
+ ),
7852
+ *(
7853
+ (f"acc{idx}", name)
7854
+ for idx, name in enumerate(op.accumulators)
7855
+ ),
7856
+ *(
7857
+ (f"output{idx}", name)
7858
+ for idx, name in enumerate(op.outputs)
7859
+ ),
7860
+ *(
7861
+ (f"acc_output{idx}", name)
7862
+ for idx, name in enumerate(op.accumulator_outputs)
7863
+ ),
7864
+ ]
7865
+ )
7866
+ rate_suffix = self._param_array_suffix(
6903
7867
  op.rate_shape, _dim_names_for(op.rate)
6904
7868
  )
6905
7869
  timestep_suffix = self._param_array_suffix(
@@ -7112,11 +8076,7 @@ class CEmitter:
7112
8076
  ).rstrip()
7113
8077
  return with_node_comment(rendered)
7114
8078
  if isinstance(op, NegativeLogLikelihoodLossOp):
7115
- acc_dtype = (
7116
- ScalarType.F64
7117
- if op.dtype in {ScalarType.F16, ScalarType.F32}
7118
- else op.dtype
7119
- )
8079
+ acc_dtype = self._accumulation_dtype(op.dtype)
7120
8080
  acc_type = acc_dtype.c_type
7121
8081
  acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
7122
8082
  acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
@@ -7173,11 +8133,7 @@ class CEmitter:
7173
8133
  ).rstrip()
7174
8134
  return with_node_comment(rendered)
7175
8135
  if isinstance(op, SoftmaxCrossEntropyLossOp):
7176
- acc_dtype = (
7177
- ScalarType.F64
7178
- if op.dtype in {ScalarType.F16, ScalarType.F32}
7179
- else op.dtype
7180
- )
8136
+ acc_dtype = self._accumulation_dtype(op.dtype)
7181
8137
  if scalar_registry is None:
7182
8138
  raise CodegenError(
7183
8139
  "Scalar function registry is required for SoftmaxCrossEntropyLoss."
@@ -7873,7 +8829,58 @@ class CEmitter:
7873
8829
  loop_vars=loop_vars,
7874
8830
  ).rstrip()
7875
8831
  return with_node_comment(rendered)
8832
+ if isinstance(op, BernoulliOp):
8833
+ output_dim_names = _dim_names_for(op.output)
8834
+ shape = CEmitter._shape_dim_exprs(op.output_shape, output_dim_names)
8835
+ loop_vars = CEmitter._loop_vars(op.output_shape)
8836
+ output_suffix = self._param_array_suffix(
8837
+ op.output_shape, output_dim_names
8838
+ )
8839
+ input_suffix = self._param_array_suffix(
8840
+ op.input_shape, _dim_names_for(op.input0)
8841
+ )
8842
+ params = self._shared_param_map(
8843
+ [("input0", op.input0), ("output", op.output)]
8844
+ )
8845
+ output_dtype = op.dtype
8846
+ param_decls = self._build_param_decls(
8847
+ [
8848
+ (params["input0"], op.input_dtype.c_type, input_suffix, True),
8849
+ (params["output"], output_dtype.c_type, output_suffix, False),
8850
+ ]
8851
+ )
8852
+ one_literal = (
8853
+ "true"
8854
+ if output_dtype == ScalarType.BOOL
8855
+ else f"({output_dtype.c_type})1"
8856
+ )
8857
+ zero_literal = (
8858
+ "false"
8859
+ if output_dtype == ScalarType.BOOL
8860
+ else output_dtype.zero_literal
8861
+ )
8862
+ rendered = bernoulli_template.render(
8863
+ model_name=model.name,
8864
+ op_name=op_name,
8865
+ input0=params["input0"],
8866
+ output=params["output"],
8867
+ input_index_expr="".join(
8868
+ f"[{var}]" for var in loop_vars
8869
+ ),
8870
+ output_index_expr="".join(
8871
+ f"[{var}]" for var in loop_vars
8872
+ ),
8873
+ shape=shape,
8874
+ loop_vars=loop_vars,
8875
+ seed=op.seed if op.seed is not None else 0,
8876
+ one_literal=one_literal,
8877
+ zero_literal=zero_literal,
8878
+ dim_args=dim_args,
8879
+ params=param_decls,
8880
+ ).rstrip()
8881
+ return with_node_comment(rendered)
7876
8882
  if isinstance(op, EyeLikeOp):
8883
+ input_c_type = op.input_dtype.c_type
7877
8884
  params = self._shared_param_map(
7878
8885
  [("input0", op.input0), ("output", op.output)]
7879
8886
  )
@@ -7887,7 +8894,7 @@ class CEmitter:
7887
8894
  batch_size = CEmitter._element_count(batch_dims or (1,))
7888
8895
  param_decls = self._build_param_decls(
7889
8896
  [
7890
- (params["input0"], c_type, input_suffix, True),
8897
+ (params["input0"], input_c_type, input_suffix, True),
7891
8898
  (params["output"], c_type, output_suffix, False),
7892
8899
  ]
7893
8900
  )
@@ -8499,8 +9506,6 @@ class CEmitter:
8499
9506
  update_expr = None
8500
9507
  init_literal = None
8501
9508
  final_expr = "acc"
8502
- use_kahan = False
8503
- kahan_value_expr = None
8504
9509
  fabs_fn = CEmitter._math_fn(output_dtype, "fabsf", "fabs")
8505
9510
  exp_fn = CEmitter._math_fn(output_dtype, "expf", "exp")
8506
9511
  log_fn = CEmitter._math_fn(output_dtype, "logf", "log")
@@ -8546,24 +9551,6 @@ class CEmitter:
8546
9551
  raise CodegenError(
8547
9552
  f"Unsupported reduce kind {op.reduce_kind}"
8548
9553
  )
8549
- if output_dtype in {ScalarType.F16, ScalarType.F32} and op.reduce_kind in {
8550
- "sum",
8551
- "mean",
8552
- "logsum",
8553
- "logsumexp",
8554
- "l1",
8555
- "l2",
8556
- "sumsquare",
8557
- }:
8558
- use_kahan = True
8559
- if op.reduce_kind == "logsumexp":
8560
- kahan_value_expr = f"{exp_fn}({value_expr})"
8561
- elif op.reduce_kind == "l1":
8562
- kahan_value_expr = f"{fabs_fn}({value_expr})"
8563
- elif op.reduce_kind in {"l2", "sumsquare"}:
8564
- kahan_value_expr = f"{value_expr} * {value_expr}"
8565
- else:
8566
- kahan_value_expr = value_expr
8567
9554
  input_suffix = self._param_array_suffix(input_shape)
8568
9555
  output_suffix = self._param_array_suffix(output_shape_raw)
8569
9556
  param_decls = self._build_param_decls(
@@ -8590,8 +9577,8 @@ class CEmitter:
8590
9577
  zero_literal=zero_literal,
8591
9578
  update_expr=update_expr,
8592
9579
  final_expr=final_expr,
8593
- use_kahan=use_kahan,
8594
- kahan_value_expr=kahan_value_expr,
9580
+ use_kahan=False,
9581
+ kahan_value_expr=None,
8595
9582
  ).rstrip()
8596
9583
  return with_node_comment(rendered)
8597
9584
  if isinstance(op, ArgReduceOp):
@@ -8736,9 +9723,9 @@ class CEmitter:
8736
9723
  output_values=params["output_values"],
8737
9724
  output_indices=params["output_indices"],
8738
9725
  params=param_decls,
8739
- input_c_type=op.input_dtype.c_type,
8740
- output_values_c_type=op.output_values_dtype.c_type,
8741
- output_indices_c_type=op.output_indices_dtype.c_type,
9726
+ input_c_type=input_dtype.c_type,
9727
+ output_values_c_type=output_values_dtype.c_type,
9728
+ output_indices_c_type=output_indices_dtype.c_type,
8742
9729
  input_suffix=input_suffix,
8743
9730
  output_suffix=output_suffix,
8744
9731
  output_shape=output_shape,
@@ -8746,7 +9733,7 @@ class CEmitter:
8746
9733
  outer_loop_vars=outer_loop_vars,
8747
9734
  reduce_var=reduce_var,
8748
9735
  k_var=k_var,
8749
- axis_dim=op.input_shape[op.axis],
9736
+ axis_dim=input_shape[op.axis],
8750
9737
  k=op.k,
8751
9738
  input_index_expr=input_index_expr,
8752
9739
  output_index_expr=output_index_expr,
@@ -8762,11 +9749,15 @@ class CEmitter:
8762
9749
  ("output", op.output),
8763
9750
  ]
8764
9751
  )
8765
- output_shape = CEmitter._codegen_shape(op.output_shape)
9752
+ input_shape_raw = self._ctx_shape(op.input0)
9753
+ output_shape_raw = self._ctx_shape(op.output)
9754
+ output_shape = CEmitter._codegen_shape(output_shape_raw)
8766
9755
  output_loop_vars = CEmitter._loop_vars(output_shape)
8767
- input_shape = CEmitter._codegen_shape(op.input_shape)
9756
+ input_shape = CEmitter._codegen_shape(input_shape_raw)
8768
9757
  input_loop_vars = CEmitter._loop_vars(input_shape)
8769
- axes_shape = op.axes_input_shape or ()
9758
+ axes_shape = (
9759
+ self._ctx_shape(op.axes_input) if op.axes_input is not None else ()
9760
+ )
8770
9761
  axes_count = 1
8771
9762
  for dim in axes_shape:
8772
9763
  if dim == 0:
@@ -8774,8 +9765,8 @@ class CEmitter:
8774
9765
  break
8775
9766
  axes_count *= dim
8776
9767
  axes_c_type = (
8777
- op.axes_input_dtype.c_type
8778
- if op.axes_input_dtype
9768
+ self._ctx_dtype(op.axes_input).c_type
9769
+ if op.axes_input is not None
8779
9770
  else ScalarType.I64.c_type
8780
9771
  )
8781
9772
  input_indices = "".join(f"[{var}]" for var in input_loop_vars)
@@ -8789,10 +9780,11 @@ class CEmitter:
8789
9780
  update_expr = None
8790
9781
  init_literal = None
8791
9782
  post_expr = None
8792
- fabs_fn = CEmitter._math_fn(op.dtype, "fabsf", "fabs")
8793
- exp_fn = CEmitter._math_fn(op.dtype, "expf", "exp")
8794
- log_fn = CEmitter._math_fn(op.dtype, "logf", "log")
8795
- sqrt_fn = CEmitter._math_fn(op.dtype, "sqrtf", "sqrt")
9783
+ reduce_dtype = self._ctx_dtype(op.output)
9784
+ fabs_fn = CEmitter._math_fn(reduce_dtype, "fabsf", "fabs")
9785
+ exp_fn = CEmitter._math_fn(reduce_dtype, "expf", "exp")
9786
+ log_fn = CEmitter._math_fn(reduce_dtype, "logf", "log")
9787
+ sqrt_fn = CEmitter._math_fn(reduce_dtype, "sqrtf", "sqrt")
8796
9788
  if op.reduce_kind == "sum":
8797
9789
  init_literal = zero_literal
8798
9790
  update_expr = f"*out_ptr += {value_expr};"
@@ -8807,7 +9799,7 @@ class CEmitter:
8807
9799
  init_literal = max_literal
8808
9800
  update_expr = f"if ({value_expr} < *out_ptr) *out_ptr = {value_expr};"
8809
9801
  elif op.reduce_kind == "prod":
8810
- init_literal = CEmitter._format_literal(op.dtype, 1)
9802
+ init_literal = CEmitter._format_literal(reduce_dtype, 1)
8811
9803
  update_expr = f"*out_ptr *= {value_expr};"
8812
9804
  elif op.reduce_kind == "l1":
8813
9805
  init_literal = zero_literal
@@ -8831,11 +9823,11 @@ class CEmitter:
8831
9823
  raise CodegenError(
8832
9824
  f"Unsupported reduce kind {op.reduce_kind}"
8833
9825
  )
8834
- input_suffix = self._param_array_suffix(op.input_shape)
8835
- output_suffix = self._param_array_suffix(op.output_shape)
9826
+ input_suffix = self._param_array_suffix(input_shape_raw)
9827
+ output_suffix = self._param_array_suffix(output_shape_raw)
8836
9828
  axes_suffix = (
8837
- self._param_array_suffix(op.axes_input_shape)
8838
- if op.axes_input_shape
9829
+ self._param_array_suffix(axes_shape)
9830
+ if axes_shape
8839
9831
  else ""
8840
9832
  )
8841
9833
  params = self._build_param_decls(
@@ -8963,6 +9955,44 @@ class CEmitter:
8963
9955
  value=CEmitter._format_literal(op.dtype, op.value),
8964
9956
  ).rstrip()
8965
9957
  return with_node_comment(rendered)
9958
+ if isinstance(op, OptionalHasElementOp):
9959
+ params = self._shared_param_map(
9960
+ [("input0", op.input0), ("output", op.output)]
9961
+ )
9962
+ input_shape = self._ctx_shape(op.input0)
9963
+ output_shape = self._ctx_shape(op.output)
9964
+ input_dim_names = _dim_names_for(op.input0)
9965
+ output_dim_names = _dim_names_for(op.output)
9966
+ input_suffix = self._param_array_suffix(input_shape, input_dim_names)
9967
+ output_suffix = self._param_array_suffix(output_shape, output_dim_names)
9968
+ input_dtype = self._ctx_dtype(op.input0)
9969
+ output_dtype = self._ctx_dtype(op.output)
9970
+ optional_flags = self._optional_input_flag_map(model)
9971
+ input_flag = optional_flags.get(op.input0)
9972
+ if input_flag is None:
9973
+ raise CodegenError(
9974
+ "OptionalHasElement expects an optional input flag."
9975
+ )
9976
+ param_decls = self._build_param_decls(
9977
+ [
9978
+ (params["input0"], input_dtype.c_type, input_suffix, True),
9979
+ (input_flag, "_Bool", "", True),
9980
+ (params["output"], output_dtype.c_type, output_suffix, False),
9981
+ ]
9982
+ )
9983
+ rendered = optional_has_element_template.render(
9984
+ model_name=model.name,
9985
+ op_name=op_name,
9986
+ input0=params["input0"],
9987
+ input_present=input_flag,
9988
+ output=params["output"],
9989
+ params=param_decls,
9990
+ input_c_type=input_dtype.c_type,
9991
+ output_c_type=output_dtype.c_type,
9992
+ input_suffix=input_suffix,
9993
+ output_suffix=output_suffix,
9994
+ ).rstrip()
9995
+ return with_node_comment(rendered)
8966
9996
  if isinstance(op, NonZeroOp):
8967
9997
  params = self._shared_param_map(
8968
9998
  [("input0", op.input0), ("output", op.output)]
@@ -9247,6 +10277,38 @@ class CEmitter:
9247
10277
  length=op.length,
9248
10278
  ).rstrip()
9249
10279
  return with_node_comment(rendered)
10280
+ if isinstance(op, HammingWindowOp):
10281
+ params = self._shared_param_map(
10282
+ [
10283
+ ("size", op.size),
10284
+ ("output", op.output),
10285
+ ]
10286
+ )
10287
+ scalar_suffix = self._param_array_suffix(())
10288
+ output_suffix = self._param_array_suffix(op.output_shape)
10289
+ param_decls = self._build_param_decls(
10290
+ [
10291
+ (
10292
+ params["size"],
10293
+ op.input_dtype.c_type,
10294
+ scalar_suffix,
10295
+ True,
10296
+ ),
10297
+ (params["output"], c_type, output_suffix, False),
10298
+ ]
10299
+ )
10300
+ rendered = hamming_window_template.render(
10301
+ model_name=model.name,
10302
+ op_name=op_name,
10303
+ size=params["size"],
10304
+ output=params["output"],
10305
+ params=param_decls,
10306
+ c_type=c_type,
10307
+ output_suffix=output_suffix,
10308
+ length=op.output_shape[0],
10309
+ periodic_literal="1" if op.periodic else "0",
10310
+ ).rstrip()
10311
+ return with_node_comment(rendered)
9250
10312
  if isinstance(op, OneHotOp):
9251
10313
  params = self._shared_param_map(
9252
10314
  [
@@ -9315,6 +10377,85 @@ class CEmitter:
9315
10377
  c_type=c_type,
9316
10378
  ).rstrip()
9317
10379
  return with_node_comment(rendered)
10380
+ if isinstance(op, TfIdfVectorizerOp):
10381
+ params = self._shared_param_map(
10382
+ [("input0", op.input0), ("output", op.output)]
10383
+ )
10384
+ input_dim_names = _dim_names_for(op.input0)
10385
+ output_dim_names = _dim_names_for(op.output)
10386
+ input_suffix = self._param_array_suffix(
10387
+ op.input_shape, input_dim_names
10388
+ )
10389
+ output_suffix = self._param_array_suffix(
10390
+ op.output_shape, output_dim_names
10391
+ )
10392
+ param_decls = self._build_param_decls(
10393
+ [
10394
+ (
10395
+ params["input0"],
10396
+ op.input_dtype.c_type,
10397
+ input_suffix,
10398
+ True,
10399
+ ),
10400
+ (
10401
+ params["output"],
10402
+ op.output_dtype.c_type,
10403
+ output_suffix,
10404
+ False,
10405
+ ),
10406
+ ]
10407
+ )
10408
+ output_dim = op.output_shape[-1] if op.output_shape else 0
10409
+ mode_id = {"TF": 0, "IDF": 1, "TFIDF": 2}[op.mode]
10410
+ pool_values = [
10411
+ CEmitter._format_literal(ScalarType.I64, value)
10412
+ for value in op.pool_int64s
10413
+ ]
10414
+ ngram_counts_values = [
10415
+ CEmitter._format_literal(ScalarType.I64, value)
10416
+ for value in op.ngram_counts
10417
+ ]
10418
+ ngram_indexes_values = [
10419
+ CEmitter._format_literal(ScalarType.I64, value)
10420
+ for value in op.ngram_indexes
10421
+ ]
10422
+ weights_values = (
10423
+ [
10424
+ CEmitter._format_literal(op.output_dtype, value)
10425
+ for value in op.weights
10426
+ ]
10427
+ if op.weights is not None
10428
+ else None
10429
+ )
10430
+ rendered = tfidf_vectorizer_template.render(
10431
+ model_name=model.name,
10432
+ op_name=op_name,
10433
+ input0=params["input0"],
10434
+ output=params["output"],
10435
+ params=param_decls,
10436
+ input_suffix=input_suffix,
10437
+ output_suffix=output_suffix,
10438
+ input_shape=op.input_shape,
10439
+ output_shape=op.output_shape,
10440
+ input_rank=len(op.input_shape),
10441
+ output_dim=output_dim,
10442
+ min_gram_length=op.min_gram_length,
10443
+ max_gram_length=op.max_gram_length,
10444
+ max_skip_count=op.max_skip_count,
10445
+ mode_id=mode_id,
10446
+ ngram_counts_len=len(op.ngram_counts),
10447
+ pool_size=len(op.pool_int64s),
10448
+ ngram_index_len=len(op.ngram_indexes),
10449
+ pool_values=pool_values,
10450
+ ngram_counts_values=ngram_counts_values,
10451
+ ngram_indexes_values=ngram_indexes_values,
10452
+ weights_values=weights_values,
10453
+ zero_literal=op.output_dtype.zero_literal,
10454
+ one_literal=CEmitter._format_literal(op.output_dtype, 1.0),
10455
+ c_type=op.output_dtype.c_type,
10456
+ input_c_type=op.input_dtype.c_type,
10457
+ ).rstrip()
10458
+ return with_node_comment(rendered)
9318
10459
  if isinstance(op, SplitOp):
9319
10460
  output_params = [
9320
10461
  (f"output_{index}", name)
@@ -9429,30 +10570,302 @@ class CEmitter:
9429
10570
  scale_suffix = self._param_array_suffix(
9430
10571
  scale_shape, _dim_names_for(op.scale)
9431
10572
  )
9432
- zero_point_suffix = self._param_array_suffix(
9433
- scale_shape, _dim_names_for(op.zero_point or "")
10573
+ zero_point_suffix = self._param_array_suffix(
10574
+ scale_shape, _dim_names_for(op.zero_point or "")
10575
+ )
10576
+ param_decls = self._build_param_decls(
10577
+ [
10578
+ (params["input0"], op.input_dtype.c_type, input_suffix, True),
10579
+ (params["scale"], op.scale_dtype.c_type, scale_suffix, True),
10580
+ (
10581
+ params["zero_point"],
10582
+ op.dtype.c_type,
10583
+ zero_point_suffix,
10584
+ True,
10585
+ )
10586
+ if params["zero_point"]
10587
+ else (None, "", "", True),
10588
+ (params["output"], op.dtype.c_type, input_suffix, False),
10589
+ ]
10590
+ )
10591
+ compute_type = "double" if op.input_dtype == ScalarType.F64 else "float"
10592
+ compute_dtype = (
10593
+ ScalarType.F64
10594
+ if compute_type == "double"
10595
+ else ScalarType.F32
10596
+ )
10597
+ max_fn = self._scalar_function_name(
10598
+ ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
10599
+ )
10600
+ min_fn = self._scalar_function_name(
10601
+ ScalarFunction.MINIMUM, compute_dtype, scalar_registry
10602
+ )
10603
+ if max_fn is None or min_fn is None:
10604
+ raise CodegenError(
10605
+ "Failed to resolve scalar min/max functions for QuantizeLinear."
10606
+ )
10607
+ round_fn = CEmitter._math_fn(
10608
+ op.input_dtype, "nearbyintf", "nearbyint"
10609
+ )
10610
+ scale_index = "0" if op.axis is None else loop_vars[op.axis]
10611
+ input_expr = f"{params['input0']}" + "".join(
10612
+ f"[{var}]" for var in loop_vars
10613
+ )
10614
+ output_expr = f"{params['output']}" + "".join(
10615
+ f"[{var}]" for var in loop_vars
10616
+ )
10617
+ scale_expr = f"{params['scale']}[{scale_index}]"
10618
+ if params["zero_point"]:
10619
+ zero_expr = f"{params['zero_point']}[{scale_index}]"
10620
+ else:
10621
+ zero_expr = "0"
10622
+ rendered = quantize_linear_template.render(
10623
+ model_name=model.name,
10624
+ op_name=op_name,
10625
+ input0=params["input0"],
10626
+ scale=params["scale"],
10627
+ zero_point=params["zero_point"],
10628
+ output=params["output"],
10629
+ params=param_decls,
10630
+ compute_type=compute_type,
10631
+ input_c_type=op.input_dtype.c_type,
10632
+ output_c_type=op.dtype.c_type,
10633
+ shape=shape,
10634
+ loop_vars=loop_vars,
10635
+ input_expr=input_expr,
10636
+ scale_expr=scale_expr,
10637
+ zero_expr=zero_expr,
10638
+ output_expr=output_expr,
10639
+ round_fn=round_fn,
10640
+ min_literal=op.dtype.min_literal,
10641
+ max_literal=op.dtype.max_literal,
10642
+ min_fn=min_fn,
10643
+ max_fn=max_fn,
10644
+ dim_args=dim_args,
10645
+ ).rstrip()
10646
+ return with_node_comment(rendered)
10647
+ if isinstance(op, DequantizeLinearOp):
10648
+ params = self._shared_param_map(
10649
+ [
10650
+ ("input0", op.input0),
10651
+ ("scale", op.scale),
10652
+ ("zero_point", op.zero_point),
10653
+ ("output", op.output),
10654
+ ]
10655
+ )
10656
+ output_dim_names = _dim_names_for(op.output)
10657
+ shape = CEmitter._shape_dim_exprs(op.input_shape, output_dim_names)
10658
+ loop_vars = CEmitter._loop_vars(op.input_shape)
10659
+ input_suffix = self._param_array_suffix(
10660
+ op.input_shape, _dim_names_for(op.input0)
10661
+ )
10662
+ if op.axis is None:
10663
+ scale_shape = ()
10664
+ elif op.block_size:
10665
+ scale_shape_list = list(op.input_shape)
10666
+ scale_shape_list[op.axis] = (
10667
+ op.input_shape[op.axis] // op.block_size
10668
+ )
10669
+ scale_shape = tuple(scale_shape_list)
10670
+ else:
10671
+ scale_shape = (op.input_shape[op.axis],)
10672
+ scale_suffix = self._param_array_suffix(
10673
+ scale_shape, _dim_names_for(op.scale)
10674
+ )
10675
+ zero_point_suffix = self._param_array_suffix(
10676
+ scale_shape, _dim_names_for(op.zero_point or "")
10677
+ )
10678
+ param_decls = self._build_param_decls(
10679
+ [
10680
+ (params["input0"], op.input_dtype.c_type, input_suffix, True),
10681
+ (params["scale"], op.scale_dtype.c_type, scale_suffix, True),
10682
+ (
10683
+ params["zero_point"],
10684
+ op.input_dtype.c_type,
10685
+ zero_point_suffix,
10686
+ True,
10687
+ )
10688
+ if params["zero_point"]
10689
+ else (None, "", "", True),
10690
+ (params["output"], op.dtype.c_type, input_suffix, False),
10691
+ ]
10692
+ )
10693
+ compute_type = "double" if op.dtype == ScalarType.F64 else "float"
10694
+ input_expr = f"{params['input0']}" + "".join(
10695
+ f"[{var}]" for var in loop_vars
10696
+ )
10697
+ output_expr = f"{params['output']}" + "".join(
10698
+ f"[{var}]" for var in loop_vars
10699
+ )
10700
+ if op.axis is None:
10701
+ scale_expr = f"{params['scale']}[0]"
10702
+ elif op.block_size:
10703
+ scale_indices = list(loop_vars)
10704
+ scale_indices[op.axis] = (
10705
+ f"({loop_vars[op.axis]}) / {op.block_size}"
10706
+ )
10707
+ scale_expr = f"{params['scale']}" + "".join(
10708
+ f"[{index}]" for index in scale_indices
10709
+ )
10710
+ else:
10711
+ scale_index = loop_vars[op.axis]
10712
+ scale_expr = f"{params['scale']}[{scale_index}]"
10713
+ if params["zero_point"]:
10714
+ if op.axis is None:
10715
+ zero_expr = f"{params['zero_point']}[0]"
10716
+ elif op.block_size:
10717
+ scale_indices = list(loop_vars)
10718
+ scale_indices[op.axis] = (
10719
+ f"({loop_vars[op.axis]}) / {op.block_size}"
10720
+ )
10721
+ zero_expr = f"{params['zero_point']}" + "".join(
10722
+ f"[{index}]" for index in scale_indices
10723
+ )
10724
+ else:
10725
+ zero_expr = f"{params['zero_point']}[{scale_index}]"
10726
+ else:
10727
+ zero_expr = "0"
10728
+ rendered = dequantize_linear_template.render(
10729
+ model_name=model.name,
10730
+ op_name=op_name,
10731
+ input0=params["input0"],
10732
+ scale=params["scale"],
10733
+ zero_point=params["zero_point"],
10734
+ output=params["output"],
10735
+ params=param_decls,
10736
+ compute_type=compute_type,
10737
+ input_c_type=op.input_dtype.c_type,
10738
+ output_c_type=op.dtype.c_type,
10739
+ shape=shape,
10740
+ loop_vars=loop_vars,
10741
+ input_expr=input_expr,
10742
+ scale_expr=scale_expr,
10743
+ zero_expr=zero_expr,
10744
+ output_expr=output_expr,
10745
+ dim_args=dim_args,
10746
+ ).rstrip()
10747
+ return with_node_comment(rendered)
10748
+ if isinstance(op, QLinearMulOp):
10749
+ if scalar_registry is None:
10750
+ raise CodegenError(
10751
+ "Scalar function registry is required for QLinearMul."
10752
+ )
10753
+ params = self._shared_param_map(
10754
+ [
10755
+ ("input0", op.input0),
10756
+ ("input0_scale", op.input0_scale),
10757
+ ("input0_zero_point", op.input0_zero_point),
10758
+ ("input1", op.input1),
10759
+ ("input1_scale", op.input1_scale),
10760
+ ("input1_zero_point", op.input1_zero_point),
10761
+ ("output_scale", op.output_scale),
10762
+ ("output_zero_point", op.output_zero_point),
10763
+ ("output", op.output),
10764
+ ]
10765
+ )
10766
+ output_shape = CEmitter._codegen_shape(op.output_shape)
10767
+ output_loop_vars = CEmitter._loop_vars(op.output_shape)
10768
+ output_index_expr = f"{params['output']}" + "".join(
10769
+ f"[{var}]" for var in output_loop_vars
10770
+ )
10771
+ input0_index_expr = CEmitter._broadcast_index_expr(
10772
+ params["input0"],
10773
+ op.input0_shape,
10774
+ op.output_shape,
10775
+ output_loop_vars,
10776
+ )
10777
+ input1_index_expr = CEmitter._broadcast_index_expr(
10778
+ params["input1"],
10779
+ op.input1_shape,
10780
+ op.output_shape,
10781
+ output_loop_vars,
10782
+ )
10783
+ input0_suffix = self._param_array_suffix(op.input0_shape)
10784
+ input1_suffix = self._param_array_suffix(op.input1_shape)
10785
+ input0_scale_suffix = self._param_array_suffix(
10786
+ op.input0_scale_shape
10787
+ )
10788
+ input1_scale_suffix = self._param_array_suffix(
10789
+ op.input1_scale_shape
9434
10790
  )
10791
+ output_scale_suffix = self._param_array_suffix(
10792
+ op.output_scale_shape
10793
+ )
10794
+ input0_zero_suffix = self._param_array_suffix(op.input0_zero_shape)
10795
+ input1_zero_suffix = self._param_array_suffix(op.input1_zero_shape)
10796
+ output_zero_suffix = self._param_array_suffix(op.output_zero_shape)
10797
+ output_suffix = self._param_array_suffix(op.output_shape)
9435
10798
  param_decls = self._build_param_decls(
9436
10799
  [
9437
- (params["input0"], op.input_dtype.c_type, input_suffix, True),
9438
- (params["scale"], op.scale_dtype.c_type, scale_suffix, True),
9439
10800
  (
9440
- params["zero_point"],
10801
+ params["input0"],
10802
+ op.input0_dtype.c_type,
10803
+ input0_suffix,
10804
+ True,
10805
+ ),
10806
+ (
10807
+ params["input0_scale"],
10808
+ op.input0_scale_dtype.c_type,
10809
+ input0_scale_suffix,
10810
+ True,
10811
+ ),
10812
+ (
10813
+ params["input0_zero_point"],
10814
+ op.input0_dtype.c_type,
10815
+ input0_zero_suffix,
10816
+ True,
10817
+ ),
10818
+ (
10819
+ params["input1"],
10820
+ op.input1_dtype.c_type,
10821
+ input1_suffix,
10822
+ True,
10823
+ ),
10824
+ (
10825
+ params["input1_scale"],
10826
+ op.input1_scale_dtype.c_type,
10827
+ input1_scale_suffix,
10828
+ True,
10829
+ ),
10830
+ (
10831
+ params["input1_zero_point"],
10832
+ op.input1_dtype.c_type,
10833
+ input1_zero_suffix,
10834
+ True,
10835
+ ),
10836
+ (
10837
+ params["output_scale"],
10838
+ op.output_scale_dtype.c_type,
10839
+ output_scale_suffix,
10840
+ True,
10841
+ ),
10842
+ (
10843
+ params["output_zero_point"],
9441
10844
  op.dtype.c_type,
9442
- zero_point_suffix,
10845
+ output_zero_suffix,
9443
10846
  True,
9444
- )
9445
- if params["zero_point"]
9446
- else (None, "", "", True),
9447
- (params["output"], op.dtype.c_type, input_suffix, False),
10847
+ ),
10848
+ (
10849
+ params["output"],
10850
+ op.dtype.c_type,
10851
+ output_suffix,
10852
+ False,
10853
+ ),
9448
10854
  ]
9449
10855
  )
9450
- compute_type = "double" if op.input_dtype == ScalarType.F64 else "float"
9451
10856
  compute_dtype = (
9452
10857
  ScalarType.F64
9453
- if compute_type == "double"
10858
+ if ScalarType.F64
10859
+ in {
10860
+ op.input0_scale_dtype,
10861
+ op.input1_scale_dtype,
10862
+ op.output_scale_dtype,
10863
+ }
9454
10864
  else ScalarType.F32
9455
10865
  )
10866
+ compute_type = (
10867
+ "double" if compute_dtype == ScalarType.F64 else "float"
10868
+ )
9456
10869
  max_fn = self._scalar_function_name(
9457
10870
  ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
9458
10871
  )
@@ -9461,40 +10874,38 @@ class CEmitter:
9461
10874
  )
9462
10875
  if max_fn is None or min_fn is None:
9463
10876
  raise CodegenError(
9464
- "Failed to resolve scalar min/max functions for QuantizeLinear."
10877
+ "Failed to resolve scalar min/max functions for QLinearMul."
9465
10878
  )
9466
10879
  round_fn = CEmitter._math_fn(
9467
- op.input_dtype, "nearbyintf", "nearbyint"
9468
- )
9469
- scale_index = "0" if op.axis is None else loop_vars[op.axis]
9470
- input_expr = f"{params['input0']}" + "".join(
9471
- f"[{var}]" for var in loop_vars
9472
- )
9473
- output_expr = f"{params['output']}" + "".join(
9474
- f"[{var}]" for var in loop_vars
10880
+ compute_dtype, "nearbyintf", "nearbyint"
9475
10881
  )
9476
- scale_expr = f"{params['scale']}[{scale_index}]"
9477
- if params["zero_point"]:
9478
- zero_expr = f"{params['zero_point']}[{scale_index}]"
9479
- else:
9480
- zero_expr = "0"
9481
- rendered = quantize_linear_template.render(
10882
+ scale_index = "0"
10883
+ rendered = qlinear_mul_template.render(
9482
10884
  model_name=model.name,
9483
10885
  op_name=op_name,
9484
10886
  input0=params["input0"],
9485
- scale=params["scale"],
9486
- zero_point=params["zero_point"],
10887
+ input1=params["input1"],
10888
+ input0_scale=params["input0_scale"],
10889
+ input0_zero_point=params["input0_zero_point"],
10890
+ input1_scale=params["input1_scale"],
10891
+ input1_zero_point=params["input1_zero_point"],
10892
+ output_scale=params["output_scale"],
10893
+ output_zero_point=params["output_zero_point"],
9487
10894
  output=params["output"],
9488
10895
  params=param_decls,
9489
10896
  compute_type=compute_type,
9490
- input_c_type=op.input_dtype.c_type,
9491
10897
  output_c_type=op.dtype.c_type,
9492
- shape=shape,
9493
- loop_vars=loop_vars,
9494
- input_expr=input_expr,
9495
- scale_expr=scale_expr,
9496
- zero_expr=zero_expr,
9497
- output_expr=output_expr,
10898
+ input0_index_expr=input0_index_expr,
10899
+ input1_index_expr=input1_index_expr,
10900
+ input0_scale_expr=f"{params['input0_scale']}[{scale_index}]",
10901
+ input1_scale_expr=f"{params['input1_scale']}[{scale_index}]",
10902
+ output_scale_expr=f"{params['output_scale']}[{scale_index}]",
10903
+ input0_zero_expr=f"{params['input0_zero_point']}[{scale_index}]",
10904
+ input1_zero_expr=f"{params['input1_zero_point']}[{scale_index}]",
10905
+ output_zero_expr=f"{params['output_zero_point']}[{scale_index}]",
10906
+ output_loop_vars=output_loop_vars,
10907
+ output_loop_bounds=output_shape,
10908
+ output_index_expr=output_index_expr,
9498
10909
  round_fn=round_fn,
9499
10910
  min_literal=op.dtype.min_literal,
9500
10911
  max_literal=op.dtype.max_literal,
@@ -9504,10 +10915,6 @@ class CEmitter:
9504
10915
  ).rstrip()
9505
10916
  return with_node_comment(rendered)
9506
10917
  if isinstance(op, QLinearMatMulOp):
9507
- if scalar_registry is None:
9508
- raise CodegenError(
9509
- "Scalar function registry is required for QLinearMatMul."
9510
- )
9511
10918
  params = self._shared_param_map(
9512
10919
  [
9513
10920
  ("input0", op.input0),
@@ -9541,13 +10948,18 @@ class CEmitter:
9541
10948
  row_var = output_loop_vars[-2]
9542
10949
  col_var = output_loop_vars[-1]
9543
10950
  input0_index_expr, input1_index_expr = CEmitter._matmul_index_exprs(
9544
- op,
9545
10951
  batch_vars,
9546
10952
  row_var,
9547
10953
  col_var,
9548
10954
  batch_rank,
9549
10955
  input0=params["input0"],
9550
10956
  input1=params["input1"],
10957
+ left_vector=op.left_vector,
10958
+ right_vector=op.right_vector,
10959
+ input0_shape=op.input0_shape,
10960
+ input1_shape=op.input1_shape,
10961
+ input0_batch_shape=op.input0_batch_shape,
10962
+ input1_batch_shape=op.input1_batch_shape,
9551
10963
  )
9552
10964
  input0_suffix = self._param_array_suffix(op.input0_shape)
9553
10965
  input1_suffix = self._param_array_suffix(op.input1_shape)
@@ -9622,32 +11034,28 @@ class CEmitter:
9622
11034
  ),
9623
11035
  ]
9624
11036
  )
9625
- compute_dtype = (
9626
- ScalarType.F64
9627
- if ScalarType.F64
9628
- in {
9629
- op.input0_scale_dtype,
9630
- op.input1_scale_dtype,
9631
- op.output_scale_dtype,
9632
- }
9633
- else ScalarType.F32
9634
- )
11037
+ if ScalarType.F64 in {
11038
+ op.input0_scale_dtype,
11039
+ op.input1_scale_dtype,
11040
+ op.output_scale_dtype,
11041
+ }:
11042
+ scale_dtype = ScalarType.F64
11043
+ elif ScalarType.F32 in {
11044
+ op.input0_scale_dtype,
11045
+ op.input1_scale_dtype,
11046
+ op.output_scale_dtype,
11047
+ }:
11048
+ scale_dtype = ScalarType.F32
11049
+ else:
11050
+ scale_dtype = ScalarType.F16
11051
+ compute_dtype = ScalarType.F64
9635
11052
  compute_type = (
9636
11053
  "double" if compute_dtype == ScalarType.F64 else "float"
9637
11054
  )
9638
- max_fn = self._scalar_function_name(
9639
- ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
9640
- )
9641
- min_fn = self._scalar_function_name(
9642
- ScalarFunction.MINIMUM, compute_dtype, scalar_registry
9643
- )
9644
- if max_fn is None or min_fn is None:
9645
- raise CodegenError(
9646
- "Failed to resolve scalar min/max functions for QLinearMatMul."
9647
- )
9648
11055
  round_fn = CEmitter._math_fn(
9649
11056
  compute_dtype, "nearbyintf", "nearbyint"
9650
11057
  )
11058
+ mod_fn = CEmitter._math_fn(compute_dtype, "fmodf", "fmod")
9651
11059
  scale_index = "0"
9652
11060
  rendered = qlinear_matmul_template.render(
9653
11061
  model_name=model.name,
@@ -9662,6 +11070,8 @@ class CEmitter:
9662
11070
  output_zero_point=params["output_zero_point"],
9663
11071
  output=params["output"],
9664
11072
  params=param_decls,
11073
+ scale_type=scale_dtype.c_type,
11074
+ scale_is_float16=scale_dtype == ScalarType.F16,
9665
11075
  compute_type=compute_type,
9666
11076
  output_c_type=op.dtype.c_type,
9667
11077
  input0_index_expr=input0_index_expr,
@@ -9677,10 +11087,8 @@ class CEmitter:
9677
11087
  output_index_expr=output_index_expr,
9678
11088
  k=op.k,
9679
11089
  round_fn=round_fn,
9680
- min_literal=op.dtype.min_literal,
9681
- max_literal=op.dtype.max_literal,
9682
- min_fn=min_fn,
9683
- max_fn=max_fn,
11090
+ mod_fn=mod_fn,
11091
+ output_is_signed=op.dtype.is_signed,
9684
11092
  dim_args=dim_args,
9685
11093
  ).rstrip()
9686
11094
  return with_node_comment(rendered)
@@ -9740,7 +11148,11 @@ class CEmitter:
9740
11148
  loop_vars,
9741
11149
  )
9742
11150
  if op.input_min is not None
9743
- else output_dtype.min_literal
11151
+ else (
11152
+ CEmitter._format_literal(output_dtype, op.min_value)
11153
+ if op.min_value is not None
11154
+ else output_dtype.min_literal
11155
+ )
9744
11156
  )
9745
11157
  max_expr = (
9746
11158
  CEmitter._broadcast_index_expr(
@@ -9750,7 +11162,11 @@ class CEmitter:
9750
11162
  loop_vars,
9751
11163
  )
9752
11164
  if op.input_max is not None
9753
- else output_dtype.max_literal
11165
+ else (
11166
+ CEmitter._format_literal(output_dtype, op.max_value)
11167
+ if op.max_value is not None
11168
+ else output_dtype.max_literal
11169
+ )
9754
11170
  )
9755
11171
  input_suffix = self._param_array_suffix(
9756
11172
  input_shape, _dim_names_for(op.input0)
@@ -9896,11 +11312,14 @@ class CEmitter:
9896
11312
  | ClipOp
9897
11313
  | CastOp
9898
11314
  | QuantizeLinearOp
11315
+ | QLinearMulOp
11316
+ | QLinearMatMulOp
9899
11317
  | MatMulOp
9900
11318
  | EinsumOp
9901
11319
  | GemmOp
9902
11320
  | AttentionOp
9903
11321
  | ConvOp
11322
+ | ConvIntegerOp
9904
11323
  | ConvTransposeOp
9905
11324
  | AveragePoolOp
9906
11325
  | LpPoolOp
@@ -9912,6 +11331,7 @@ class CEmitter:
9912
11331
  | MeanVarianceNormalizationOp
9913
11332
  | RMSNormalizationOp
9914
11333
  | LrnOp
11334
+ | GruOp
9915
11335
  | LstmOp
9916
11336
  | SoftmaxOp
9917
11337
  | LogSoftmaxOp
@@ -9942,9 +11362,11 @@ class CEmitter:
9942
11362
  | ConstantOfShapeOp
9943
11363
  | ShapeOp
9944
11364
  | SizeOp
11365
+ | OptionalHasElementOp
9945
11366
  | ExpandOp
9946
11367
  | CumSumOp
9947
11368
  | RangeOp
11369
+ | HammingWindowOp
9948
11370
  | OneHotOp
9949
11371
  | SplitOp,
9950
11372
  ) -> str:
@@ -9963,11 +11385,13 @@ class CEmitter:
9963
11385
  | ClipOp
9964
11386
  | CastOp
9965
11387
  | QuantizeLinearOp
11388
+ | DequantizeLinearOp
9966
11389
  | MatMulOp
9967
11390
  | EinsumOp
9968
11391
  | GemmOp
9969
11392
  | AttentionOp
9970
11393
  | ConvOp
11394
+ | ConvIntegerOp
9971
11395
  | ConvTransposeOp
9972
11396
  | AveragePoolOp
9973
11397
  | LpPoolOp
@@ -9979,6 +11403,7 @@ class CEmitter:
9979
11403
  | MeanVarianceNormalizationOp
9980
11404
  | RMSNormalizationOp
9981
11405
  | LrnOp
11406
+ | GruOp
9982
11407
  | LstmOp
9983
11408
  | SoftmaxOp
9984
11409
  | LogSoftmaxOp
@@ -10009,9 +11434,11 @@ class CEmitter:
10009
11434
  | ConstantOfShapeOp
10010
11435
  | ShapeOp
10011
11436
  | SizeOp
11437
+ | OptionalHasElementOp
10012
11438
  | ExpandOp
10013
11439
  | CumSumOp
10014
11440
  | RangeOp
11441
+ | HammingWindowOp
10015
11442
  | OneHotOp
10016
11443
  | SplitOp,
10017
11444
  ) -> tuple[tuple[str, tuple[int, ...]], ...]:
@@ -10069,6 +11496,8 @@ class CEmitter:
10069
11496
  return ((op.input0, self._ctx_shape(op.input0)),)
10070
11497
  if isinstance(op, NonZeroOp):
10071
11498
  return ((op.input0, op.input_shape),)
11499
+ if isinstance(op, OptionalHasElementOp):
11500
+ return ((op.input0, self._ctx_shape(op.input0)),)
10072
11501
  if isinstance(op, NonMaxSuppressionOp):
10073
11502
  inputs = [
10074
11503
  (op.boxes, op.boxes_shape),
@@ -10104,6 +11533,20 @@ class CEmitter:
10104
11533
  if op.zero_point is not None:
10105
11534
  inputs.append((op.zero_point, scale_shape))
10106
11535
  return tuple(inputs)
11536
+ if isinstance(op, DequantizeLinearOp):
11537
+ if op.axis is None:
11538
+ scale_shape = ()
11539
+ elif op.block_size:
11540
+ input_shape = self._ctx_shape(op.input0)
11541
+ scale_shape_list = list(input_shape)
11542
+ scale_shape_list[op.axis] = input_shape[op.axis] // op.block_size
11543
+ scale_shape = tuple(scale_shape_list)
11544
+ else:
11545
+ scale_shape = (self._ctx_shape(op.input0)[op.axis],)
11546
+ inputs = [(op.input0, self._ctx_shape(op.input0)), (op.scale, scale_shape)]
11547
+ if op.zero_point is not None:
11548
+ inputs.append((op.zero_point, scale_shape))
11549
+ return tuple(inputs)
10107
11550
  if isinstance(op, IdentityOp):
10108
11551
  return ((op.input0, self._ctx_shape(op.input0)),)
10109
11552
  if isinstance(op, EyeLikeOp):
@@ -10138,6 +11581,8 @@ class CEmitter:
10138
11581
  return ((op.input0, op.input_shape),)
10139
11582
  if isinstance(op, RangeOp):
10140
11583
  return ((op.start, ()), (op.limit, ()), (op.delta, ()))
11584
+ if isinstance(op, HammingWindowOp):
11585
+ return ((op.size, ()),)
10141
11586
  if isinstance(op, OneHotOp):
10142
11587
  return (
10143
11588
  (op.indices, op.indices_shape),
@@ -10147,7 +11592,10 @@ class CEmitter:
10147
11592
  if isinstance(op, SplitOp):
10148
11593
  return ((op.input0, op.input_shape),)
10149
11594
  if isinstance(op, TopKOp):
10150
- return ((op.input0, self._ctx_shape(op.input0)),)
11595
+ return (
11596
+ (op.input0, self._ctx_shape(op.input0)),
11597
+ (op.k_input, self._ctx_shape(op.k_input)),
11598
+ )
10151
11599
  if isinstance(op, (TransposeOp, ReshapeOp, ReduceOp, ArgReduceOp)):
10152
11600
  return ((op.input0, self._ctx_shape(op.input0)),)
10153
11601
  return ()
@@ -10162,6 +11610,7 @@ class CEmitter:
10162
11610
  | ClipOp
10163
11611
  | CastOp
10164
11612
  | QuantizeLinearOp
11613
+ | DequantizeLinearOp
10165
11614
  | MatMulOp
10166
11615
  | EinsumOp
10167
11616
  | GemmOp
@@ -10178,6 +11627,7 @@ class CEmitter:
10178
11627
  | MeanVarianceNormalizationOp
10179
11628
  | RMSNormalizationOp
10180
11629
  | LrnOp
11630
+ | GruOp
10181
11631
  | LstmOp
10182
11632
  | SoftmaxOp
10183
11633
  | LogSoftmaxOp
@@ -10210,6 +11660,7 @@ class CEmitter:
10210
11660
  | NonMaxSuppressionOp
10211
11661
  | ExpandOp
10212
11662
  | RangeOp
11663
+ | HammingWindowOp
10213
11664
  | OneHotOp
10214
11665
  | SplitOp
10215
11666
  ],
@@ -10234,11 +11685,13 @@ class CEmitter:
10234
11685
  | ClipOp
10235
11686
  | CastOp
10236
11687
  | QuantizeLinearOp
11688
+ | DequantizeLinearOp
10237
11689
  | MatMulOp
10238
11690
  | EinsumOp
10239
11691
  | GemmOp
10240
11692
  | AttentionOp
10241
11693
  | ConvOp
11694
+ | ConvIntegerOp
10242
11695
  | ConvTransposeOp
10243
11696
  | AveragePoolOp
10244
11697
  | LpPoolOp
@@ -10250,6 +11703,7 @@ class CEmitter:
10250
11703
  | MeanVarianceNormalizationOp
10251
11704
  | RMSNormalizationOp
10252
11705
  | LrnOp
11706
+ | GruOp
10253
11707
  | LstmOp
10254
11708
  | SoftmaxOp
10255
11709
  | LogSoftmaxOp
@@ -10284,9 +11738,18 @@ class CEmitter:
10284
11738
  | NonMaxSuppressionOp
10285
11739
  | ExpandOp
10286
11740
  | RangeOp
11741
+ | HammingWindowOp
10287
11742
  | OneHotOp
10288
11743
  | SplitOp,
10289
11744
  ) -> tuple[tuple[str, tuple[int, ...], ScalarType], ...]:
11745
+ if isinstance(op, OptionalHasElementOp):
11746
+ return (
11747
+ (
11748
+ op.output,
11749
+ self._op_output_shape(op),
11750
+ self._op_output_dtype(op),
11751
+ ),
11752
+ )
10290
11753
  if isinstance(
10291
11754
  op,
10292
11755
  (
@@ -10341,6 +11804,39 @@ class CEmitter:
10341
11804
  )
10342
11805
  )
10343
11806
  return tuple(outputs)
11807
+ if isinstance(op, GruOp):
11808
+ outputs: list[tuple[str, tuple[int, ...], ScalarType]] = []
11809
+ if op.output_y is not None:
11810
+ if op.layout == 0:
11811
+ y_shape = (
11812
+ op.seq_length,
11813
+ op.num_directions,
11814
+ op.batch_size,
11815
+ op.hidden_size,
11816
+ )
11817
+ else:
11818
+ y_shape = (
11819
+ op.batch_size,
11820
+ op.seq_length,
11821
+ op.num_directions,
11822
+ op.hidden_size,
11823
+ )
11824
+ outputs.append((op.output_y, y_shape, op.dtype))
11825
+ if op.output_y_h is not None:
11826
+ if op.layout == 0:
11827
+ state_shape = (
11828
+ op.num_directions,
11829
+ op.batch_size,
11830
+ op.hidden_size,
11831
+ )
11832
+ else:
11833
+ state_shape = (
11834
+ op.batch_size,
11835
+ op.num_directions,
11836
+ op.hidden_size,
11837
+ )
11838
+ outputs.append((op.output_y_h, state_shape, op.dtype))
11839
+ return tuple(outputs)
10344
11840
  if isinstance(op, LstmOp):
10345
11841
  outputs: list[tuple[str, tuple[int, ...], ScalarType]] = []
10346
11842
  if op.output_y is not None:
@@ -10456,12 +11952,14 @@ class CEmitter:
10456
11952
  | ClipOp
10457
11953
  | CastOp
10458
11954
  | QuantizeLinearOp
11955
+ | DequantizeLinearOp
10459
11956
  | QLinearMatMulOp
10460
11957
  | MatMulOp
10461
11958
  | EinsumOp
10462
11959
  | GemmOp
10463
11960
  | AttentionOp
10464
11961
  | ConvOp
11962
+ | ConvIntegerOp
10465
11963
  | AveragePoolOp
10466
11964
  | BatchNormOp
10467
11965
  | LpNormalizationOp
@@ -10471,6 +11969,7 @@ class CEmitter:
10471
11969
  | MeanVarianceNormalizationOp
10472
11970
  | RMSNormalizationOp
10473
11971
  | LrnOp
11972
+ | GruOp
10474
11973
  | LstmOp
10475
11974
  | SoftmaxOp
10476
11975
  | LogSoftmaxOp
@@ -10485,6 +11984,7 @@ class CEmitter:
10485
11984
  | TransposeOp
10486
11985
  | ReshapeOp
10487
11986
  | IdentityOp
11987
+ | BernoulliOp
10488
11988
  | EyeLikeOp
10489
11989
  | TriluOp
10490
11990
  | TileOp
@@ -10502,7 +12002,10 @@ class CEmitter:
10502
12002
  | ExpandOp
10503
12003
  | CumSumOp
10504
12004
  | RangeOp
12005
+ | HammingWindowOp
10505
12006
  | OneHotOp
12007
+ | TfIdfVectorizerOp
12008
+ | RotaryEmbeddingOp
10506
12009
  | SplitOp
10507
12010
  | PadOp,
10508
12011
  ) -> tuple[int, ...]:
@@ -10518,21 +12021,29 @@ class CEmitter:
10518
12021
  return self._ctx_shape(op.output)
10519
12022
  if isinstance(op, QuantizeLinearOp):
10520
12023
  return op.input_shape
12024
+ if isinstance(op, DequantizeLinearOp):
12025
+ return op.input_shape
10521
12026
  if isinstance(op, CastOp):
10522
12027
  return self._ctx_shape(op.output)
12028
+ if isinstance(op, QLinearMulOp):
12029
+ return op.output_shape
10523
12030
  if isinstance(op, QLinearMatMulOp):
10524
12031
  return op.output_shape
10525
12032
  if isinstance(op, MatMulOp):
10526
- return op.output_shape
12033
+ return self._ctx_shape(op.output)
10527
12034
  if isinstance(op, EinsumOp):
10528
12035
  return op.output_shape
10529
12036
  if isinstance(op, GemmOp):
10530
- return (op.m, op.n)
12037
+ return self._ctx_shape(op.output)
10531
12038
  if isinstance(op, ConvOp):
10532
12039
  return (op.batch, op.out_channels, *op.out_spatial)
12040
+ if isinstance(op, ConvIntegerOp):
12041
+ return (op.batch, op.out_channels, *op.out_spatial)
10533
12042
  if isinstance(op, ConvTransposeOp):
10534
12043
  return (op.batch, op.out_channels, *op.out_spatial)
10535
12044
  if isinstance(op, AveragePoolOp):
12045
+ if op.spatial_rank == 3:
12046
+ return (op.batch, op.channels, op.out_d, op.out_h, op.out_w)
10536
12047
  return (op.batch, op.channels, op.out_h, op.out_w)
10537
12048
  if isinstance(op, LpPoolOp):
10538
12049
  return (op.batch, op.channels, op.out_h, op.out_w)
@@ -10582,6 +12093,8 @@ class CEmitter:
10582
12093
  return self._ctx_shape(op.output)
10583
12094
  if isinstance(op, IdentityOp):
10584
12095
  return self._ctx_shape(op.output)
12096
+ if isinstance(op, BernoulliOp):
12097
+ return op.output_shape
10585
12098
  if isinstance(op, EyeLikeOp):
10586
12099
  return op.output_shape
10587
12100
  if isinstance(op, TriluOp):
@@ -10612,6 +12125,8 @@ class CEmitter:
10612
12125
  return op.output_shape
10613
12126
  if isinstance(op, SizeOp):
10614
12127
  return op.output_shape
12128
+ if isinstance(op, OptionalHasElementOp):
12129
+ return self._ctx_shape(op.output)
10615
12130
  if isinstance(op, NonZeroOp):
10616
12131
  return op.output_shape
10617
12132
  if isinstance(op, NonMaxSuppressionOp):
@@ -10622,8 +12137,14 @@ class CEmitter:
10622
12137
  return op.input_shape
10623
12138
  if isinstance(op, RangeOp):
10624
12139
  return op.output_shape
12140
+ if isinstance(op, HammingWindowOp):
12141
+ return op.output_shape
10625
12142
  if isinstance(op, OneHotOp):
10626
12143
  return op.output_shape
12144
+ if isinstance(op, TfIdfVectorizerOp):
12145
+ return op.output_shape
12146
+ if isinstance(op, RotaryEmbeddingOp):
12147
+ return op.input_shape
10627
12148
  if op.output_rank == 3:
10628
12149
  return (op.batch, op.q_seq, op.q_heads * op.v_head_size)
10629
12150
  return (op.batch, op.q_heads, op.q_seq, op.v_head_size)
@@ -10637,11 +12158,13 @@ class CEmitter:
10637
12158
  | ClipOp
10638
12159
  | CastOp
10639
12160
  | QuantizeLinearOp
12161
+ | DequantizeLinearOp
10640
12162
  | MatMulOp
10641
12163
  | EinsumOp
10642
12164
  | GemmOp
10643
12165
  | AttentionOp
10644
12166
  | ConvOp
12167
+ | ConvIntegerOp
10645
12168
  | ConvTransposeOp
10646
12169
  | AveragePoolOp
10647
12170
  | LpPoolOp
@@ -10666,6 +12189,7 @@ class CEmitter:
10666
12189
  | TransposeOp
10667
12190
  | ReshapeOp
10668
12191
  | IdentityOp
12192
+ | BernoulliOp
10669
12193
  | EyeLikeOp
10670
12194
  | TriluOp
10671
12195
  | TileOp
@@ -10681,7 +12205,9 @@ class CEmitter:
10681
12205
  | ExpandOp
10682
12206
  | CumSumOp
10683
12207
  | RangeOp
12208
+ | HammingWindowOp
10684
12209
  | OneHotOp
12210
+ | TfIdfVectorizerOp
10685
12211
  | SplitOp
10686
12212
  | PadOp,
10687
12213
  ) -> ScalarType:
@@ -10689,8 +12215,12 @@ class CEmitter:
10689
12215
  return self._ctx_dtype(op.output)
10690
12216
  if isinstance(op, TopKOp):
10691
12217
  return self._ctx_dtype(op.output_values)
12218
+ if isinstance(op, OptionalHasElementOp):
12219
+ return self._ctx_dtype(op.output)
10692
12220
  if isinstance(op, NonMaxSuppressionOp):
10693
12221
  return op.output_dtype
12222
+ if isinstance(op, TfIdfVectorizerOp):
12223
+ return op.output_dtype
10694
12224
  if isinstance(
10695
12225
  op,
10696
12226
  (
@@ -10703,6 +12233,8 @@ class CEmitter:
10703
12233
  SoftmaxOp,
10704
12234
  LogSoftmaxOp,
10705
12235
  HardmaxOp,
12236
+ MatMulOp,
12237
+ GemmOp,
10706
12238
  GatherOp,
10707
12239
  TransposeOp,
10708
12240
  ReshapeOp,
@@ -10729,10 +12261,12 @@ class CEmitter:
10729
12261
  self,
10730
12262
  shape: tuple[int, ...],
10731
12263
  dim_names: Mapping[int, str] | None = None,
12264
+ *,
12265
+ use_restrict: bool = False,
10732
12266
  ) -> str:
10733
12267
  shape = CEmitter._codegen_shape(shape)
10734
12268
  dim_names = dim_names or {}
10735
- if not self._restrict_arrays:
12269
+ if not (self._restrict_arrays and use_restrict):
10736
12270
  return "".join(
10737
12271
  f"[{dim_names.get(index, dim)}]"
10738
12272
  for index, dim in enumerate(shape)
@@ -10755,6 +12289,16 @@ class CEmitter:
10755
12289
  return ""
10756
12290
  return ", ".join(f"int {dim_name}" for dim_name in dim_order) + ", "
10757
12291
 
12292
+ @staticmethod
12293
+ def _optional_input_flag_map(model: LoweredModel) -> dict[str, str]:
12294
+ return {
12295
+ name: flag
12296
+ for name, flag in zip(
12297
+ model.input_names, model.input_optional_names
12298
+ )
12299
+ if flag is not None
12300
+ }
12301
+
10758
12302
  def _build_variable_dim_names(
10759
12303
  self,
10760
12304
  model: LoweredModel,
@@ -10772,6 +12316,12 @@ class CEmitter:
10772
12316
  dim_vars: dict[tuple[str, int, int], str] = {}
10773
12317
  dim_values: dict[str, int] = {}
10774
12318
  reserved_names = set(model.input_names) | set(model.output_names)
12319
+ reserved_names.update(
12320
+ name for name in model.input_optional_names if name is not None
12321
+ )
12322
+ reserved_names.update(
12323
+ name for name in model.output_optional_names if name is not None
12324
+ )
10775
12325
  used_names = set(reserved_names)
10776
12326
  dim_aliases: dict[str, str] = {}
10777
12327
 
@@ -10926,14 +12476,19 @@ class CEmitter:
10926
12476
 
10927
12477
  @staticmethod
10928
12478
  def _matmul_index_exprs(
10929
- op: MatMulOp,
10930
12479
  batch_vars: tuple[str, ...],
10931
12480
  row_var: str | None,
10932
12481
  col_var: str | None,
10933
12482
  batch_rank: int,
10934
12483
  *,
10935
- input0: str | None = None,
10936
- input1: str | None = None,
12484
+ input0: str,
12485
+ input1: str,
12486
+ left_vector: bool,
12487
+ right_vector: bool,
12488
+ input0_shape: tuple[int, ...],
12489
+ input1_shape: tuple[int, ...],
12490
+ input0_batch_shape: tuple[int, ...],
12491
+ input1_batch_shape: tuple[int, ...],
10937
12492
  ) -> tuple[str, str]:
10938
12493
  def batch_indices(
10939
12494
  batch_shape: tuple[int, ...], actual_rank: int
@@ -10948,28 +12503,28 @@ class CEmitter:
10948
12503
  indices.append("0" if dim == 1 else var)
10949
12504
  return indices
10950
12505
 
10951
- if op.left_vector:
12506
+ if left_vector:
10952
12507
  input0_indices = ["k"]
10953
12508
  else:
10954
- input0_batch_rank = len(op.input0_shape) - 2
12509
+ input0_batch_rank = len(input0_shape) - 2
10955
12510
  input0_indices = batch_indices(
10956
- op.input0_batch_shape, input0_batch_rank
12511
+ input0_batch_shape, input0_batch_rank
10957
12512
  )
10958
12513
  input0_indices.append(row_var if row_var is not None else "0")
10959
12514
  input0_indices.append("k")
10960
- if op.right_vector:
12515
+ if right_vector:
10961
12516
  input1_indices = ["k"]
10962
12517
  else:
10963
- input1_batch_rank = len(op.input1_shape) - 2
12518
+ input1_batch_rank = len(input1_shape) - 2
10964
12519
  input1_indices = batch_indices(
10965
- op.input1_batch_shape, input1_batch_rank
12520
+ input1_batch_shape, input1_batch_rank
10966
12521
  )
10967
12522
  input1_indices.append("k")
10968
12523
  input1_indices.append(col_var if col_var is not None else "0")
10969
- input0_index_expr = f"{input0 or op.input0}" + "".join(
12524
+ input0_index_expr = f"{input0}" + "".join(
10970
12525
  f"[{index}]" for index in input0_indices
10971
12526
  )
10972
- input1_index_expr = f"{input1 or op.input1}" + "".join(
12527
+ input1_index_expr = f"{input1}" + "".join(
10973
12528
  f"[{index}]" for index in input1_indices
10974
12529
  )
10975
12530
  return input0_index_expr, input1_index_expr
@@ -10980,6 +12535,7 @@ class CEmitter:
10980
12535
  testbench_template,
10981
12536
  *,
10982
12537
  testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None = None,
12538
+ testbench_optional_inputs: Mapping[str, bool] | None = None,
10983
12539
  dim_order: Sequence[str],
10984
12540
  dim_values: Mapping[str, int],
10985
12541
  weight_data_filename: str,
@@ -10988,13 +12544,34 @@ class CEmitter:
10988
12544
  self._element_count(shape) for shape in model.input_shapes
10989
12545
  )
10990
12546
  testbench_inputs = testbench_inputs or {}
12547
+ testbench_optional_inputs = testbench_optional_inputs or {}
12548
+ rng_requires_u64 = False
12549
+ rng_requires_float = False
12550
+ rng_requires_double = False
12551
+ rng_requires_i64 = False
10991
12552
  inputs = []
10992
- for name, shape, count, dtype in zip(
10993
- model.input_names, model.input_shapes, input_counts, model.input_dtypes
12553
+ for name, shape, count, dtype, optional_flag in zip(
12554
+ model.input_names,
12555
+ model.input_shapes,
12556
+ input_counts,
12557
+ model.input_dtypes,
12558
+ model.input_optional_names,
10994
12559
  ):
12560
+ json_name = self._ctx_name(name)
10995
12561
  codegen_shape = self._codegen_shape(shape)
10996
12562
  loop_shape = (1,) if not shape else shape
10997
12563
  loop_vars = self._loop_vars(loop_shape)
12564
+ constant_values = testbench_inputs.get(name)
12565
+ if constant_values is None:
12566
+ rng_requires_u64 = True
12567
+ if dtype in {ScalarType.F16, ScalarType.F32}:
12568
+ rng_requires_float = True
12569
+ elif dtype == ScalarType.F64:
12570
+ rng_requires_double = True
12571
+ elif dtype == ScalarType.BOOL:
12572
+ pass
12573
+ else:
12574
+ rng_requires_i64 = True
10998
12575
  if dtype in {ScalarType.F16, ScalarType.F32}:
10999
12576
  random_expr = "rng_next_float()"
11000
12577
  elif dtype == ScalarType.F64:
@@ -11003,7 +12580,6 @@ class CEmitter:
11003
12580
  random_expr = "((rng_next_u64() & 1ull) != 0)"
11004
12581
  else:
11005
12582
  random_expr = f"({dtype.c_type})rng_next_i64()"
11006
- constant_values = testbench_inputs.get(name)
11007
12583
  constant_name = None
11008
12584
  constant_lines = None
11009
12585
  if constant_values is not None:
@@ -11015,6 +12591,11 @@ class CEmitter:
11015
12591
  ]
11016
12592
  else:
11017
12593
  constant_lines = [self._format_value(0, dtype)]
12594
+ optional_present = (
12595
+ testbench_optional_inputs.get(name, True)
12596
+ if optional_flag is not None
12597
+ else None
12598
+ )
11018
12599
  inputs.append(
11019
12600
  {
11020
12601
  "name": name,
@@ -11035,12 +12616,16 @@ class CEmitter:
11035
12616
  "print_cast": self._print_cast(dtype),
11036
12617
  "constant_name": constant_name,
11037
12618
  "constant_lines": constant_lines,
12619
+ "json_name": json_name,
12620
+ "optional_flag_name": optional_flag,
12621
+ "optional_present": optional_present,
11038
12622
  }
11039
12623
  )
11040
12624
  outputs = []
11041
12625
  for name, shape, dtype in zip(
11042
12626
  model.output_names, model.output_shapes, model.output_dtypes
11043
12627
  ):
12628
+ json_name = self._ctx_name(name)
11044
12629
  codegen_shape = self._codegen_shape(shape)
11045
12630
  loop_shape = (1,) if not shape else shape
11046
12631
  output_loop_vars = self._loop_vars(loop_shape)
@@ -11061,10 +12646,15 @@ class CEmitter:
11061
12646
  "c_type": dtype.c_type,
11062
12647
  "print_format": self._print_format(dtype),
11063
12648
  "print_cast": self._print_cast(dtype),
12649
+ "json_name": json_name,
11064
12650
  }
11065
12651
  )
11066
12652
  rendered = testbench_template.render(
11067
12653
  model_name=model.name,
12654
+ rng_requires_u64=rng_requires_u64,
12655
+ rng_requires_float=rng_requires_float,
12656
+ rng_requires_double=rng_requires_double,
12657
+ rng_requires_i64=rng_requires_i64,
11068
12658
  dim_args=[
11069
12659
  {"name": dim_name, "value": dim_values[dim_name]}
11070
12660
  for dim_name in dim_order
@@ -11097,13 +12687,30 @@ class CEmitter:
11097
12687
  ) -> tuple[tuple[ConstTensor, ...], tuple[ConstTensor, ...]]:
11098
12688
  if self._large_weight_threshold <= 0:
11099
12689
  return constants, ()
12690
+ sorted_constants = sorted(
12691
+ enumerate(constants),
12692
+ key=lambda item: (
12693
+ self._element_count(item[1].shape)
12694
+ * item[1].dtype.np_dtype.itemsize,
12695
+ item[0],
12696
+ ),
12697
+ )
12698
+ inline_set: set[ConstTensor] = set()
12699
+ total_bytes = 0
12700
+ for _, const in sorted_constants:
12701
+ const_bytes = (
12702
+ self._element_count(const.shape) * const.dtype.np_dtype.itemsize
12703
+ )
12704
+ if total_bytes + const_bytes <= self._large_weight_threshold:
12705
+ inline_set.add(const)
12706
+ total_bytes += const_bytes
11100
12707
  inline: list[ConstTensor] = []
11101
12708
  large: list[ConstTensor] = []
11102
12709
  for const in constants:
11103
- if self._element_count(const.shape) > self._large_weight_threshold:
11104
- large.append(const)
11105
- else:
12710
+ if const in inline_set:
11106
12711
  inline.append(const)
12712
+ else:
12713
+ large.append(const)
11107
12714
  return tuple(inline), tuple(large)
11108
12715
 
11109
12716
  @staticmethod
@@ -11113,12 +12720,16 @@ class CEmitter:
11113
12720
  def _emit_weight_loader(
11114
12721
  self, model: LoweredModel, large_constants: tuple[ConstTensor, ...]
11115
12722
  ) -> str:
11116
- lines = [f"_Bool {model.name}_load(const char *path) {{"]
12723
+ lines = []
11117
12724
  if not large_constants:
12725
+ lines.append(f"_Bool {model.name}_load(const char *path) {{")
11118
12726
  lines.append(" (void)path;")
11119
12727
  lines.append(" return 1;")
11120
12728
  lines.append("}")
11121
12729
  return _format_c_indentation("\n".join(lines))
12730
+ lines.append(f"static _Bool {model.name}_load_file(FILE *file);")
12731
+ lines.append("")
12732
+ lines.append(f"_Bool {model.name}_load(const char *path) {{")
11122
12733
  lines.append(" FILE *file = fopen(path, \"rb\");")
11123
12734
  lines.append(" if (!file) {")
11124
12735
  lines.append(" return 0;")
@@ -11171,7 +12782,7 @@ class CEmitter:
11171
12782
  for value in const.data
11172
12783
  ]
11173
12784
  lines.append(
11174
- f"{storage_prefix} {c_type} {const.name}{array_suffix} = {{"
12785
+ f"{storage_prefix} EMX_UNUSED {c_type} {const.name}{array_suffix} = {{"
11175
12786
  )
11176
12787
  if values:
11177
12788
  if (
@@ -11199,12 +12810,23 @@ class CEmitter:
11199
12810
  return ""
11200
12811
  lines = []
11201
12812
  for index, const in enumerate(constants, start=1):
11202
- lines.append(self._emit_constant_comment(const, index))
11203
12813
  c_type = const.dtype.c_type
11204
12814
  array_suffix = self._array_suffix(const.shape)
11205
12815
  lines.append(f"extern const {c_type} {const.name}{array_suffix};")
11206
12816
  return "\n".join(lines)
11207
12817
 
12818
+ def _emit_constant_storage_declarations(
12819
+ self, constants: tuple[ConstTensor, ...]
12820
+ ) -> str:
12821
+ if not constants:
12822
+ return ""
12823
+ lines = []
12824
+ for index, const in enumerate(constants, start=1):
12825
+ c_type = const.dtype.c_type
12826
+ array_suffix = self._array_suffix(const.shape)
12827
+ lines.append(f"extern {c_type} {const.name}{array_suffix};")
12828
+ return "\n".join(lines)
12829
+
11208
12830
  def _emit_constant_storage_definitions(
11209
12831
  self,
11210
12832
  constants: tuple[ConstTensor, ...],
@@ -11214,11 +12836,12 @@ class CEmitter:
11214
12836
  if not constants:
11215
12837
  return ""
11216
12838
  lines: list[str] = []
12839
+ prefix = f"{storage_prefix} " if storage_prefix else ""
11217
12840
  for index, const in enumerate(constants, start=1):
11218
12841
  lines.append(self._emit_constant_comment(const, index))
11219
12842
  c_type = const.dtype.c_type
11220
12843
  array_suffix = self._array_suffix(const.shape)
11221
- lines.append(f"{storage_prefix} {c_type} {const.name}{array_suffix};")
12844
+ lines.append(f"{prefix}{c_type} {const.name}{array_suffix};")
11222
12845
  lines.append("")
11223
12846
  if lines and not lines[-1]:
11224
12847
  lines.pop()