emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from enum import Enum
5
4
  import itertools
6
5
  import math
7
6
  from math import prod
@@ -20,6 +19,85 @@ from ..ops import (
20
19
  binary_op_symbol,
21
20
  unary_op_symbol,
22
21
  )
22
+ from ..ir.op_base import (
23
+ BroadcastingOpBase,
24
+ ConvLikeOpBase,
25
+ ElementwiseOpBase,
26
+ GemmLikeOpBase,
27
+ MatMulLikeOpBase,
28
+ ReduceOpBase,
29
+ RenderableOpBase,
30
+ OpBase,
31
+ EmitContext,
32
+ )
33
+ from ..ir.op_context import OpContext
34
+ from ..ir.ops import (
35
+ AdagradOp,
36
+ ArgReduceOp,
37
+ AttentionOp,
38
+ AveragePoolOp,
39
+ BatchNormOp,
40
+ BinaryOp,
41
+ CastOp,
42
+ ClipOp,
43
+ ConcatOp,
44
+ ConstantOfShapeOp,
45
+ ConvOp,
46
+ ConvTransposeOp,
47
+ CumSumOp,
48
+ DepthToSpaceOp,
49
+ EinsumKind,
50
+ EinsumOp,
51
+ ExpandOp,
52
+ EyeLikeOp,
53
+ GatherElementsOp,
54
+ GatherNDOp,
55
+ GatherOp,
56
+ GemmOp,
57
+ GridSampleOp,
58
+ GroupNormalizationOp,
59
+ HardmaxOp,
60
+ IdentityOp,
61
+ InstanceNormalizationOp,
62
+ LayerNormalizationOp,
63
+ LogSoftmaxOp,
64
+ LpNormalizationOp,
65
+ LpPoolOp,
66
+ LrnOp,
67
+ LstmOp,
68
+ MatMulOp,
69
+ MaxPoolOp,
70
+ MeanVarianceNormalizationOp,
71
+ MultiInputBinaryOp,
72
+ NegativeLogLikelihoodLossOp,
73
+ NonMaxSuppressionOp,
74
+ NonZeroOp,
75
+ OneHotOp,
76
+ PadOp,
77
+ QuantizeLinearOp,
78
+ QLinearMatMulOp,
79
+ RangeOp,
80
+ ReduceOp,
81
+ ReshapeOp,
82
+ ResizeOp,
83
+ RMSNormalizationOp,
84
+ RotaryEmbeddingOp,
85
+ ScatterNDOp,
86
+ ShapeOp,
87
+ SizeOp,
88
+ SliceOp,
89
+ SoftmaxCrossEntropyLossOp,
90
+ SoftmaxOp,
91
+ SpaceToDepthOp,
92
+ SplitOp,
93
+ TensorScatterOp,
94
+ TileOp,
95
+ TopKOp,
96
+ TransposeOp,
97
+ TriluOp,
98
+ UnaryOp,
99
+ WhereOp,
100
+ )
23
101
  from shared.scalar_functions import (
24
102
  ScalarFunction,
25
103
  ScalarFunctionKey,
@@ -150,44 +228,6 @@ _C_KEYWORDS = {
150
228
  "while",
151
229
  }
152
230
 
153
- @dataclass(frozen=True)
154
- class BinaryOp:
155
- input0: str
156
- input1: str
157
- output: str
158
- function: ScalarFunction
159
- operator_kind: OperatorKind
160
- input0_shape: tuple[int, ...]
161
- input1_shape: tuple[int, ...]
162
- shape: tuple[int, ...]
163
- dtype: ScalarType
164
- input_dtype: ScalarType
165
-
166
-
167
- @dataclass(frozen=True)
168
- class MultiInputBinaryOp:
169
- inputs: tuple[str, ...]
170
- output: str
171
- function: ScalarFunction
172
- operator_kind: OperatorKind
173
- shape: tuple[int, ...]
174
- dtype: ScalarType
175
- input_dtype: ScalarType
176
-
177
-
178
- @dataclass(frozen=True)
179
- class WhereOp:
180
- condition: str
181
- input_x: str
182
- input_y: str
183
- output: str
184
- condition_shape: tuple[int, ...]
185
- x_shape: tuple[int, ...]
186
- y_shape: tuple[int, ...]
187
- output_shape: tuple[int, ...]
188
- dtype: ScalarType
189
-
190
-
191
231
  @dataclass(frozen=True)
192
232
  class NodeInfo:
193
233
  op_type: str
@@ -197,905 +237,6 @@ class NodeInfo:
197
237
  attrs: dict[str, object]
198
238
 
199
239
 
200
- @dataclass(frozen=True)
201
- class UnaryOp:
202
- input0: str
203
- output: str
204
- function: ScalarFunction
205
- shape: tuple[int, ...]
206
- dtype: ScalarType
207
- input_dtype: ScalarType
208
- params: tuple[float, ...] = ()
209
-
210
-
211
- @dataclass(frozen=True)
212
- class ClipOp:
213
- input0: str
214
- input_min: str | None
215
- input_max: str | None
216
- output: str
217
- input_shape: tuple[int, ...]
218
- min_shape: tuple[int, ...] | None
219
- max_shape: tuple[int, ...] | None
220
- output_shape: tuple[int, ...]
221
- dtype: ScalarType
222
-
223
-
224
-
225
-
226
- @dataclass(frozen=True)
227
- class CastOp:
228
- input0: str
229
- output: str
230
- shape: tuple[int, ...]
231
- input_dtype: ScalarType
232
- dtype: ScalarType
233
-
234
-
235
- @dataclass(frozen=True)
236
- class MatMulOp:
237
- input0: str
238
- input1: str
239
- output: str
240
- input0_shape: tuple[int, ...]
241
- input1_shape: tuple[int, ...]
242
- output_shape: tuple[int, ...]
243
- batch_shape: tuple[int, ...]
244
- input0_batch_shape: tuple[int, ...]
245
- input1_batch_shape: tuple[int, ...]
246
- m: int
247
- n: int
248
- k: int
249
- left_vector: bool
250
- right_vector: bool
251
- dtype: ScalarType
252
-
253
-
254
- class EinsumKind(str, Enum):
255
- REDUCE_ALL = "reduce_all"
256
- SUM_J = "sum_j"
257
- TRANSPOSE = "transpose"
258
- DOT = "dot"
259
- BATCH_MATMUL = "batch_matmul"
260
- BATCH_DIAGONAL = "batch_diagonal"
261
-
262
-
263
- @dataclass(frozen=True)
264
- class EinsumOp:
265
- inputs: tuple[str, ...]
266
- output: str
267
- kind: EinsumKind
268
- input_shapes: tuple[tuple[int, ...], ...]
269
- output_shape: tuple[int, ...]
270
- dtype: ScalarType
271
- input_dtype: ScalarType
272
-
273
-
274
- @dataclass(frozen=True)
275
- class GemmOp:
276
- input_a: str
277
- input_b: str
278
- input_c: str | None
279
- output: str
280
- m: int
281
- n: int
282
- k: int
283
- trans_a: bool
284
- trans_b: bool
285
- alpha: float | int
286
- beta: float | int
287
- c_shape: tuple[int, ...] | None
288
- dtype: ScalarType
289
-
290
-
291
- @dataclass(frozen=True)
292
- class AttentionOp:
293
- input_q: str
294
- input_k: str
295
- input_v: str
296
- input_attn_mask: str | None
297
- input_past_key: str | None
298
- input_past_value: str | None
299
- input_nonpad_kv_seqlen: str | None
300
- output: str
301
- output_present_key: str | None
302
- output_present_value: str | None
303
- output_qk_matmul: str | None
304
- batch: int
305
- q_heads: int
306
- kv_heads: int
307
- q_seq: int
308
- kv_seq: int
309
- total_seq: int
310
- past_seq: int
311
- qk_head_size: int
312
- v_head_size: int
313
- q_hidden_size: int | None
314
- k_hidden_size: int | None
315
- v_hidden_size: int | None
316
- scale: float
317
- is_causal: bool
318
- softcap: float
319
- qk_matmul_output_mode: int
320
- q_rank: int
321
- k_rank: int
322
- v_rank: int
323
- output_rank: int
324
- mask_shape: tuple[int, ...] | None
325
- mask_is_bool: bool
326
- mask_rank: int | None
327
- mask_broadcast_batch: bool
328
- mask_broadcast_heads: bool
329
- mask_broadcast_q_seq: bool
330
- mask_q_seq: int | None
331
- mask_kv_seq: int | None
332
- head_group_size: int
333
- dtype: ScalarType
334
-
335
-
336
- @dataclass(frozen=True)
337
- class ConvOp:
338
- input0: str
339
- weights: str
340
- bias: str | None
341
- output: str
342
- batch: int
343
- in_channels: int
344
- out_channels: int
345
- spatial_rank: int
346
- in_spatial: tuple[int, ...]
347
- out_spatial: tuple[int, ...]
348
- kernel_shape: tuple[int, ...]
349
- strides: tuple[int, ...]
350
- pads: tuple[int, ...]
351
- dilations: tuple[int, ...]
352
- group: int
353
- dtype: ScalarType
354
-
355
- @property
356
- def out_h(self) -> int:
357
- if self.spatial_rank < 1:
358
- raise ValueError("Conv output height is undefined for spatial_rank < 1")
359
- return self.out_spatial[0]
360
-
361
- @property
362
- def out_w(self) -> int:
363
- if self.spatial_rank < 2:
364
- raise ValueError("Conv output width is undefined for spatial_rank < 2")
365
- return self.out_spatial[1]
366
-
367
-
368
- @dataclass(frozen=True)
369
- class ConvTransposeOp:
370
- input0: str
371
- weights: str
372
- bias: str | None
373
- output: str
374
- batch: int
375
- in_channels: int
376
- out_channels: int
377
- spatial_rank: int
378
- in_spatial: tuple[int, ...]
379
- out_spatial: tuple[int, ...]
380
- kernel_shape: tuple[int, ...]
381
- strides: tuple[int, ...]
382
- pads: tuple[int, ...]
383
- dilations: tuple[int, ...]
384
- output_padding: tuple[int, ...]
385
- group: int
386
- dtype: ScalarType
387
-
388
-
389
- @dataclass(frozen=True)
390
- class AveragePoolOp:
391
- input0: str
392
- output: str
393
- batch: int
394
- channels: int
395
- in_h: int
396
- in_w: int
397
- out_h: int
398
- out_w: int
399
- kernel_h: int
400
- kernel_w: int
401
- stride_h: int
402
- stride_w: int
403
- pad_top: int
404
- pad_left: int
405
- pad_bottom: int
406
- pad_right: int
407
- count_include_pad: bool
408
- dtype: ScalarType
409
-
410
-
411
- @dataclass(frozen=True)
412
- class LpPoolOp:
413
- input0: str
414
- output: str
415
- batch: int
416
- channels: int
417
- in_h: int
418
- in_w: int
419
- out_h: int
420
- out_w: int
421
- kernel_h: int
422
- kernel_w: int
423
- stride_h: int
424
- stride_w: int
425
- pad_top: int
426
- pad_left: int
427
- pad_bottom: int
428
- pad_right: int
429
- p: int
430
- dtype: ScalarType
431
-
432
-
433
- @dataclass(frozen=True)
434
- class QuantizeLinearOp:
435
- input0: str
436
- scale: str
437
- zero_point: str | None
438
- output: str
439
- input_shape: tuple[int, ...]
440
- axis: int | None
441
- dtype: ScalarType
442
- input_dtype: ScalarType
443
- scale_dtype: ScalarType
444
-
445
-
446
- @dataclass(frozen=True)
447
- class SoftmaxOp:
448
- input0: str
449
- output: str
450
- outer: int
451
- axis_size: int
452
- inner: int
453
- axis: int
454
- shape: tuple[int, ...]
455
- dtype: ScalarType
456
-
457
-
458
- @dataclass(frozen=True)
459
- class LogSoftmaxOp:
460
- input0: str
461
- output: str
462
- outer: int
463
- axis_size: int
464
- inner: int
465
- axis: int
466
- shape: tuple[int, ...]
467
- dtype: ScalarType
468
-
469
-
470
- @dataclass(frozen=True)
471
- class HardmaxOp:
472
- input0: str
473
- output: str
474
- outer: int
475
- axis_size: int
476
- inner: int
477
- axis: int
478
- shape: tuple[int, ...]
479
- dtype: ScalarType
480
-
481
-
482
- @dataclass(frozen=True)
483
- class NegativeLogLikelihoodLossOp:
484
- input0: str
485
- target: str
486
- weight: str | None
487
- output: str
488
- input_shape: tuple[int, ...]
489
- target_shape: tuple[int, ...]
490
- output_shape: tuple[int, ...]
491
- n: int
492
- c: int
493
- d: int
494
- reduction: str
495
- ignore_index: int
496
- input_dtype: ScalarType
497
- weight_dtype: ScalarType | None
498
- weight_shape: tuple[int, ...] | None
499
- dtype: ScalarType
500
- target_dtype: ScalarType
501
-
502
-
503
- @dataclass(frozen=True)
504
- class SoftmaxCrossEntropyLossOp:
505
- input0: str
506
- target: str
507
- weight: str | None
508
- output: str
509
- log_prob: str | None
510
- input_shape: tuple[int, ...]
511
- target_shape: tuple[int, ...]
512
- output_shape: tuple[int, ...]
513
- log_prob_shape: tuple[int, ...] | None
514
- n: int
515
- c: int
516
- d: int
517
- reduction: str
518
- ignore_index: int | None
519
- input_dtype: ScalarType
520
- weight_dtype: ScalarType | None
521
- weight_shape: tuple[int, ...] | None
522
- dtype: ScalarType
523
- target_dtype: ScalarType
524
-
525
-
526
- @dataclass(frozen=True)
527
- class BatchNormOp:
528
- input0: str
529
- scale: str
530
- bias: str
531
- mean: str
532
- variance: str
533
- output: str
534
- shape: tuple[int, ...]
535
- channels: int
536
- epsilon: float
537
- dtype: ScalarType
538
-
539
-
540
- @dataclass(frozen=True)
541
- class LpNormalizationOp:
542
- input0: str
543
- output: str
544
- shape: tuple[int, ...]
545
- axis: int
546
- p: int
547
- outer: int
548
- axis_size: int
549
- inner: int
550
- dtype: ScalarType
551
-
552
-
553
- @dataclass(frozen=True)
554
- class InstanceNormalizationOp:
555
- input0: str
556
- scale: str
557
- bias: str
558
- output: str
559
- shape: tuple[int, ...]
560
- channels: int
561
- spatial_size: int
562
- epsilon: float
563
- dtype: ScalarType
564
-
565
-
566
- @dataclass(frozen=True)
567
- class GroupNormalizationOp:
568
- input0: str
569
- scale: str
570
- bias: str
571
- output: str
572
- shape: tuple[int, ...]
573
- channels: int
574
- num_groups: int
575
- group_size: int
576
- spatial_size: int
577
- epsilon: float
578
- dtype: ScalarType
579
-
580
-
581
- @dataclass(frozen=True)
582
- class LayerNormalizationOp:
583
- input0: str
584
- scale: str
585
- bias: str | None
586
- output: str
587
- mean_output: str | None
588
- invstd_output: str | None
589
- shape: tuple[int, ...]
590
- normalized_shape: tuple[int, ...]
591
- scale_shape: tuple[int, ...]
592
- bias_shape: tuple[int, ...] | None
593
- outer: int
594
- inner: int
595
- axis: int
596
- epsilon: float
597
- dtype: ScalarType
598
-
599
-
600
- @dataclass(frozen=True)
601
- class MeanVarianceNormalizationOp:
602
- input0: str
603
- output: str
604
- shape: tuple[int, ...]
605
- axes: tuple[int, ...]
606
- non_axes: tuple[int, ...]
607
- reduce_count: int
608
- epsilon: float
609
- dtype: ScalarType
610
-
611
-
612
- @dataclass(frozen=True)
613
- class RMSNormalizationOp:
614
- input0: str
615
- scale: str
616
- output: str
617
- shape: tuple[int, ...]
618
- normalized_shape: tuple[int, ...]
619
- scale_shape: tuple[int, ...]
620
- outer: int
621
- inner: int
622
- axis: int
623
- epsilon: float
624
- dtype: ScalarType
625
-
626
-
627
- @dataclass(frozen=True)
628
- class LrnOp:
629
- input0: str
630
- output: str
631
- shape: tuple[int, ...]
632
- channels: int
633
- size: int
634
- half: int
635
- alpha: float
636
- beta: float
637
- bias: float
638
- dtype: ScalarType
639
-
640
-
641
- @dataclass(frozen=True)
642
- class LstmOp:
643
- input_x: str
644
- input_w: str
645
- input_r: str
646
- input_b: str | None
647
- input_sequence_lens: str | None
648
- input_initial_h: str | None
649
- input_initial_c: str | None
650
- input_p: str | None
651
- output_y: str | None
652
- output_y_h: str | None
653
- output_y_c: str | None
654
- seq_length: int
655
- batch_size: int
656
- input_size: int
657
- hidden_size: int
658
- num_directions: int
659
- direction: str
660
- layout: int
661
- input_forget: int
662
- clip: float | None
663
- activation_kinds: tuple[int, ...]
664
- activation_alphas: tuple[float, ...]
665
- activation_betas: tuple[float, ...]
666
- dtype: ScalarType
667
- sequence_lens_dtype: ScalarType | None
668
-
669
-
670
- @dataclass(frozen=True)
671
- class MaxPoolOp:
672
- input0: str
673
- output: str
674
- indices: str | None
675
- batch: int
676
- channels: int
677
- spatial_rank: int
678
- in_spatial: tuple[int, ...]
679
- out_spatial: tuple[int, ...]
680
- kernel_shape: tuple[int, ...]
681
- strides: tuple[int, ...]
682
- pads: tuple[int, ...]
683
- dilations: tuple[int, ...]
684
- ceil_mode: bool
685
- storage_order: int
686
- dtype: ScalarType
687
- indices_dtype: ScalarType | None
688
-
689
-
690
- @dataclass(frozen=True)
691
- class ConcatOp:
692
- inputs: tuple[str, ...]
693
- output: str
694
- axis: int
695
- input_shapes: tuple[tuple[int, ...], ...]
696
- output_shape: tuple[int, ...]
697
- dtype: ScalarType
698
-
699
-
700
- @dataclass(frozen=True)
701
- class GatherElementsOp:
702
- data: str
703
- indices: str
704
- output: str
705
- axis: int
706
- data_shape: tuple[int, ...]
707
- indices_shape: tuple[int, ...]
708
- output_shape: tuple[int, ...]
709
- dtype: ScalarType
710
- indices_dtype: ScalarType
711
-
712
-
713
- @dataclass(frozen=True)
714
- class GatherOp:
715
- data: str
716
- indices: str
717
- output: str
718
- axis: int
719
- data_shape: tuple[int, ...]
720
- indices_shape: tuple[int, ...]
721
- output_shape: tuple[int, ...]
722
- dtype: ScalarType
723
- indices_dtype: ScalarType
724
-
725
-
726
- @dataclass(frozen=True)
727
- class GatherNDOp:
728
- data: str
729
- indices: str
730
- output: str
731
- batch_dims: int
732
- data_shape: tuple[int, ...]
733
- indices_shape: tuple[int, ...]
734
- output_shape: tuple[int, ...]
735
- dtype: ScalarType
736
- indices_dtype: ScalarType
737
-
738
-
739
- @dataclass(frozen=True)
740
- class ScatterNDOp:
741
- data: str
742
- indices: str
743
- updates: str
744
- output: str
745
- data_shape: tuple[int, ...]
746
- indices_shape: tuple[int, ...]
747
- updates_shape: tuple[int, ...]
748
- output_shape: tuple[int, ...]
749
- reduction: str
750
- dtype: ScalarType
751
- indices_dtype: ScalarType
752
-
753
-
754
- @dataclass(frozen=True)
755
- class TransposeOp:
756
- input0: str
757
- output: str
758
- perm: tuple[int, ...]
759
- input_shape: tuple[int, ...]
760
- output_shape: tuple[int, ...]
761
- dtype: ScalarType
762
- input_dtype: ScalarType
763
-
764
-
765
- @dataclass(frozen=True)
766
- class ReshapeOp:
767
- input0: str
768
- output: str
769
- input_shape: tuple[int, ...]
770
- output_shape: tuple[int, ...]
771
- dtype: ScalarType
772
- input_dtype: ScalarType
773
-
774
-
775
- @dataclass(frozen=True)
776
- class IdentityOp:
777
- input0: str
778
- output: str
779
- shape: tuple[int, ...]
780
- dtype: ScalarType
781
- input_dtype: ScalarType
782
-
783
-
784
- @dataclass(frozen=True)
785
- class EyeLikeOp:
786
- input0: str
787
- output: str
788
- output_shape: tuple[int, ...]
789
- k: int
790
- dtype: ScalarType
791
- input_dtype: ScalarType
792
-
793
-
794
- @dataclass(frozen=True)
795
- class TriluOp:
796
- input0: str
797
- output: str
798
- input_shape: tuple[int, ...]
799
- output_shape: tuple[int, ...]
800
- upper: bool
801
- k_value: int
802
- k_input: str | None
803
- k_input_shape: tuple[int, ...] | None
804
- k_input_dtype: ScalarType | None
805
- dtype: ScalarType
806
- input_dtype: ScalarType
807
-
808
-
809
- @dataclass(frozen=True)
810
- class TileOp:
811
- input0: str
812
- output: str
813
- input_shape: tuple[int, ...]
814
- output_shape: tuple[int, ...]
815
- repeats: tuple[int, ...]
816
- input_strides: tuple[int, ...]
817
- dtype: ScalarType
818
- input_dtype: ScalarType
819
-
820
-
821
- @dataclass(frozen=True)
822
- class PadOp:
823
- input0: str
824
- output: str
825
- input_shape: tuple[int, ...]
826
- output_shape: tuple[int, ...]
827
- pads_begin: tuple[int, ...] | None
828
- pads_end: tuple[int, ...] | None
829
- pads_input: str | None
830
- pads_shape: tuple[int, ...] | None
831
- pads_dtype: ScalarType | None
832
- pads_axis_map: tuple[int | None, ...] | None
833
- pads_values: tuple[int, ...] | None
834
- axes_input: str | None
835
- axes_shape: tuple[int, ...] | None
836
- axes_dtype: ScalarType | None
837
- mode: str
838
- value: float | int | bool
839
- value_input: str | None
840
- value_shape: tuple[int, ...] | None
841
- dtype: ScalarType
842
- input_dtype: ScalarType
843
- input_strides: tuple[int, ...]
844
-
845
-
846
- @dataclass(frozen=True)
847
- class DepthToSpaceOp:
848
- input0: str
849
- output: str
850
- input_shape: tuple[int, ...]
851
- output_shape: tuple[int, ...]
852
- blocksize: int
853
- mode: str
854
- dtype: ScalarType
855
- input_dtype: ScalarType
856
-
857
-
858
- @dataclass(frozen=True)
859
- class SpaceToDepthOp:
860
- input0: str
861
- output: str
862
- input_shape: tuple[int, ...]
863
- output_shape: tuple[int, ...]
864
- blocksize: int
865
- dtype: ScalarType
866
- input_dtype: ScalarType
867
-
868
-
869
- @dataclass(frozen=True)
870
- class SliceOp:
871
- input0: str
872
- output: str
873
- input_shape: tuple[int, ...]
874
- output_shape: tuple[int, ...]
875
- starts: tuple[int, ...] | None
876
- steps: tuple[int, ...] | None
877
- axes: tuple[int, ...] | None
878
- starts_input: str | None
879
- ends_input: str | None
880
- axes_input: str | None
881
- steps_input: str | None
882
- starts_shape: tuple[int, ...] | None
883
- ends_shape: tuple[int, ...] | None
884
- axes_shape: tuple[int, ...] | None
885
- steps_shape: tuple[int, ...] | None
886
- starts_dtype: ScalarType | None
887
- ends_dtype: ScalarType | None
888
- axes_dtype: ScalarType | None
889
- steps_dtype: ScalarType | None
890
- dtype: ScalarType
891
- input_dtype: ScalarType
892
-
893
-
894
- @dataclass(frozen=True)
895
- class ResizeOp:
896
- input0: str
897
- output: str
898
- input_shape: tuple[int, ...]
899
- output_shape: tuple[int, ...]
900
- scales: tuple[float, ...]
901
- scales_input: str | None
902
- sizes_input: str | None
903
- roi_input: str | None
904
- axes: tuple[int, ...]
905
- scales_shape: tuple[int, ...] | None
906
- sizes_shape: tuple[int, ...] | None
907
- roi_shape: tuple[int, ...] | None
908
- scales_dtype: ScalarType | None
909
- sizes_dtype: ScalarType | None
910
- roi_dtype: ScalarType | None
911
- scales_axes: tuple[int, ...] | None
912
- sizes_axes: tuple[int, ...] | None
913
- roi_axes: tuple[int, ...] | None
914
- mode: str
915
- coordinate_transformation_mode: str
916
- nearest_mode: str
917
- cubic_coeff_a: float
918
- exclude_outside: bool
919
- extrapolation_value: float
920
- antialias: bool
921
- keep_aspect_ratio_policy: str
922
- dtype: ScalarType
923
-
924
-
925
- @dataclass(frozen=True)
926
- class GridSampleOp:
927
- input0: str
928
- grid: str
929
- output: str
930
- input_shape: tuple[int, ...]
931
- grid_shape: tuple[int, ...]
932
- output_shape: tuple[int, ...]
933
- spatial_rank: int
934
- input_spatial: tuple[int, ...]
935
- output_spatial: tuple[int, ...]
936
- mode: str
937
- padding_mode: str
938
- align_corners: bool
939
- dtype: ScalarType
940
- grid_dtype: ScalarType
941
-
942
-
943
- @dataclass(frozen=True)
944
- class ReduceOp:
945
- input0: str
946
- output: str
947
- input_shape: tuple[int, ...]
948
- output_shape: tuple[int, ...]
949
- axes: tuple[int, ...]
950
- axes_input: str | None
951
- axes_input_shape: tuple[int, ...] | None
952
- axes_input_dtype: ScalarType | None
953
- keepdims: bool
954
- noop_with_empty_axes: bool
955
- reduce_kind: str
956
- reduce_count: int | None
957
- dtype: ScalarType
958
-
959
-
960
- @dataclass(frozen=True)
961
- class ArgReduceOp:
962
- input0: str
963
- output: str
964
- input_shape: tuple[int, ...]
965
- output_shape: tuple[int, ...]
966
- axis: int
967
- keepdims: bool
968
- select_last_index: bool
969
- reduce_kind: str
970
- input_dtype: ScalarType
971
- output_dtype: ScalarType
972
-
973
-
974
- @dataclass(frozen=True)
975
- class TopKOp:
976
- input0: str
977
- output_values: str
978
- output_indices: str
979
- input_shape: tuple[int, ...]
980
- output_shape: tuple[int, ...]
981
- axis: int
982
- k: int
983
- largest: bool
984
- sorted: bool
985
- input_dtype: ScalarType
986
- output_values_dtype: ScalarType
987
- output_indices_dtype: ScalarType
988
-
989
-
990
- @dataclass(frozen=True)
991
- class ConstantOfShapeOp:
992
- input0: str
993
- output: str
994
- input_shape: tuple[int, ...]
995
- shape: tuple[int, ...]
996
- value: float | int | bool
997
- dtype: ScalarType
998
- input_dtype: ScalarType
999
-
1000
-
1001
- @dataclass(frozen=True)
1002
- class ShapeOp:
1003
- input0: str
1004
- output: str
1005
- input_shape: tuple[int, ...]
1006
- output_shape: tuple[int, ...]
1007
- values: tuple[int, ...]
1008
- dtype: ScalarType
1009
- input_dtype: ScalarType
1010
-
1011
-
1012
- @dataclass(frozen=True)
1013
- class SizeOp:
1014
- input0: str
1015
- output: str
1016
- input_shape: tuple[int, ...]
1017
- output_shape: tuple[int, ...]
1018
- value: int
1019
- dtype: ScalarType
1020
- input_dtype: ScalarType
1021
-
1022
-
1023
- @dataclass(frozen=True)
1024
- class NonZeroOp:
1025
- input0: str
1026
- output: str
1027
- input_shape: tuple[int, ...]
1028
- output_shape: tuple[int, ...]
1029
- dtype: ScalarType
1030
- input_dtype: ScalarType
1031
-
1032
-
1033
- @dataclass(frozen=True)
1034
- class ExpandOp:
1035
- input0: str
1036
- output: str
1037
- input_shape: tuple[int, ...]
1038
- output_shape: tuple[int, ...]
1039
- input_shape_padded: tuple[int, ...]
1040
- input_strides: tuple[int, ...]
1041
- dtype: ScalarType
1042
- input_dtype: ScalarType
1043
-
1044
-
1045
- @dataclass(frozen=True)
1046
- class CumSumOp:
1047
- input0: str
1048
- axis_input: str | None
1049
- axis_input_dtype: ScalarType | None
1050
- axis: int | None
1051
- output: str
1052
- input_shape: tuple[int, ...]
1053
- dtype: ScalarType
1054
- input_dtype: ScalarType
1055
- exclusive: bool
1056
- reverse: bool
1057
-
1058
-
1059
- @dataclass(frozen=True)
1060
- class RangeOp:
1061
- start: str
1062
- limit: str
1063
- delta: str
1064
- output: str
1065
- output_shape: tuple[int, ...]
1066
- length: int
1067
- dtype: ScalarType
1068
- input_dtype: ScalarType
1069
-
1070
-
1071
- @dataclass(frozen=True)
1072
- class OneHotOp:
1073
- indices: str
1074
- depth: str
1075
- values: str
1076
- output: str
1077
- axis: int
1078
- indices_shape: tuple[int, ...]
1079
- values_shape: tuple[int, ...]
1080
- output_shape: tuple[int, ...]
1081
- depth_dim: int
1082
- dtype: ScalarType
1083
- indices_dtype: ScalarType
1084
- depth_dtype: ScalarType
1085
-
1086
-
1087
- @dataclass(frozen=True)
1088
- class SplitOp:
1089
- input0: str
1090
- outputs: tuple[str, ...]
1091
- input_shape: tuple[int, ...]
1092
- output_shapes: tuple[tuple[int, ...], ...]
1093
- axis: int
1094
- split_sizes: tuple[int, ...]
1095
- dtype: ScalarType
1096
- input_dtype: ScalarType
1097
-
1098
-
1099
240
  @dataclass(frozen=True)
1100
241
  class ConstTensor:
1101
242
  name: str
@@ -1135,78 +276,29 @@ class ModelHeader:
1135
276
 
1136
277
  @dataclass(frozen=True)
1137
278
  class LoweredModel:
1138
- name: str
1139
- input_names: tuple[str, ...]
1140
- input_shapes: tuple[tuple[int, ...], ...]
1141
- input_dtypes: tuple[ScalarType, ...]
1142
- output_names: tuple[str, ...]
1143
- output_shapes: tuple[tuple[int, ...], ...]
1144
- output_dtypes: tuple[ScalarType, ...]
1145
- constants: tuple[ConstTensor, ...]
1146
- ops: tuple[
1147
- BinaryOp
1148
- | MultiInputBinaryOp
1149
- | WhereOp
1150
- | UnaryOp
1151
- | ClipOp
1152
- | CastOp
1153
- | QuantizeLinearOp
1154
- | MatMulOp
1155
- | EinsumOp
1156
- | GemmOp
1157
- | AttentionOp
1158
- | ConvOp
1159
- | ConvTransposeOp
1160
- | AveragePoolOp
1161
- | LpPoolOp
1162
- | BatchNormOp
1163
- | LpNormalizationOp
1164
- | InstanceNormalizationOp
1165
- | GroupNormalizationOp
1166
- | LayerNormalizationOp
1167
- | MeanVarianceNormalizationOp
1168
- | RMSNormalizationOp
1169
- | LrnOp
1170
- | LstmOp
1171
- | SoftmaxOp
1172
- | LogSoftmaxOp
1173
- | HardmaxOp
1174
- | NegativeLogLikelihoodLossOp
1175
- | SoftmaxCrossEntropyLossOp
1176
- | MaxPoolOp
1177
- | ConcatOp
1178
- | GatherElementsOp
1179
- | GatherOp
1180
- | GatherNDOp
1181
- | ScatterNDOp
1182
- | TransposeOp
1183
- | ReshapeOp
1184
- | IdentityOp
1185
- | EyeLikeOp
1186
- | TriluOp
1187
- | TileOp
1188
- | PadOp
1189
- | DepthToSpaceOp
1190
- | SpaceToDepthOp
1191
- | SliceOp
1192
- | ResizeOp
1193
- | GridSampleOp
1194
- | ReduceOp
1195
- | ArgReduceOp
1196
- | TopKOp
1197
- | ConstantOfShapeOp
1198
- | ShapeOp
1199
- | SizeOp
1200
- | NonZeroOp
1201
- | ExpandOp
1202
- | CumSumOp
1203
- | RangeOp
1204
- | OneHotOp
1205
- | SplitOp,
1206
- ...,
1207
- ]
279
+ name: str
280
+ input_names: tuple[str, ...]
281
+ input_shapes: tuple[tuple[int, ...], ...]
282
+ input_dtypes: tuple[ScalarType, ...]
283
+ output_names: tuple[str, ...]
284
+ output_shapes: tuple[tuple[int, ...], ...]
285
+ output_dtypes: tuple[ScalarType, ...]
286
+ constants: tuple[ConstTensor, ...]
287
+ ops: tuple[OpBase, ...]
1208
288
  node_infos: tuple[NodeInfo, ...]
1209
289
  header: ModelHeader
290
+ op_context: OpContext
291
+
292
+
293
+ @dataclass
294
+ class _EmitState:
295
+ model: LoweredModel
296
+ templates: dict[str, Template]
297
+ scalar_registry: ScalarFunctionRegistry
298
+ dim_args: str
299
+ tensor_dim_names: Mapping[str, Mapping[int, str]]
300
+ op_context: OpContext
301
+ value_name_map: Mapping[str, str]
1210
302
 
1211
303
 
1212
304
  class CEmitter:
@@ -1235,6 +327,7 @@ class CEmitter:
1235
327
  if large_weight_threshold < 0:
1236
328
  raise CodegenError("large_weight_threshold must be >= 0")
1237
329
  self._large_weight_threshold = large_weight_threshold
330
+ self._emit_state: _EmitState | None = None
1238
331
 
1239
332
  @staticmethod
1240
333
  def _sanitize_identifier(name: str) -> str:
@@ -1297,6 +390,26 @@ class CEmitter:
1297
390
  mapped[key] = unique
1298
391
  return mapped
1299
392
 
393
+ def _ctx_name(self, name: str) -> str:
394
+ if self._emit_state is None:
395
+ raise CodegenError("Emitter state not initialized")
396
+ return self._emit_state.value_name_map.get(name, name)
397
+
398
+ def _ctx_shape(self, name: str) -> tuple[int, ...]:
399
+ if self._emit_state is None:
400
+ raise CodegenError("Emitter state not initialized")
401
+ return self._emit_state.op_context.shape(self._ctx_name(name))
402
+
403
+ def _ctx_dtype(self, name: str) -> ScalarType:
404
+ if self._emit_state is None:
405
+ raise CodegenError("Emitter state not initialized")
406
+ return self._emit_state.op_context.dtype(self._ctx_name(name))
407
+
408
+ def _derived(self, op: OpBase, key: str) -> object:
409
+ if self._emit_state is None:
410
+ raise CodegenError("Emitter state not initialized")
411
+ return self._emit_state.op_context.require_derived(op, key)
412
+
1300
413
  @staticmethod
1301
414
  def _build_param_decls(
1302
415
  specs: Sequence[tuple[str | None, str, str, bool]]
@@ -1334,10 +447,12 @@ class CEmitter:
1334
447
  | ClipOp
1335
448
  | CastOp
1336
449
  | QuantizeLinearOp
450
+ | QLinearMatMulOp
1337
451
  | MatMulOp
1338
452
  | EinsumOp
1339
453
  | GemmOp
1340
454
  | AttentionOp
455
+ | RotaryEmbeddingOp
1341
456
  | ConvOp
1342
457
  | AveragePoolOp
1343
458
  | BatchNormOp
@@ -1349,6 +464,7 @@ class CEmitter:
1349
464
  | RMSNormalizationOp
1350
465
  | LrnOp
1351
466
  | LstmOp
467
+ | AdagradOp
1352
468
  | SoftmaxOp
1353
469
  | LogSoftmaxOp
1354
470
  | HardmaxOp
@@ -1360,6 +476,7 @@ class CEmitter:
1360
476
  | GatherOp
1361
477
  | GatherNDOp
1362
478
  | ScatterNDOp
479
+ | TensorScatterOp
1363
480
  | TransposeOp
1364
481
  | ReshapeOp
1365
482
  | IdentityOp
@@ -1379,6 +496,7 @@ class CEmitter:
1379
496
  | ShapeOp
1380
497
  | SizeOp
1381
498
  | NonZeroOp
499
+ | NonMaxSuppressionOp
1382
500
  | ExpandOp
1383
501
  | CumSumOp
1384
502
  | RangeOp
@@ -1409,6 +527,18 @@ class CEmitter:
1409
527
  names.append(op.zero_point)
1410
528
  names.append(op.output)
1411
529
  return tuple(names)
530
+ if isinstance(op, QLinearMatMulOp):
531
+ return (
532
+ op.input0,
533
+ op.input0_scale,
534
+ op.input0_zero_point,
535
+ op.input1,
536
+ op.input1_scale,
537
+ op.input1_zero_point,
538
+ op.output_scale,
539
+ op.output_zero_point,
540
+ op.output,
541
+ )
1412
542
  if isinstance(op, MatMulOp):
1413
543
  return (op.input0, op.input1, op.output)
1414
544
  if isinstance(op, EinsumOp):
@@ -1437,6 +567,12 @@ class CEmitter:
1437
567
  if op.output_qk_matmul is not None:
1438
568
  names.append(op.output_qk_matmul)
1439
569
  return tuple(names)
570
+ if isinstance(op, RotaryEmbeddingOp):
571
+ names = [op.input0, op.cos_cache, op.sin_cache]
572
+ if op.position_ids is not None:
573
+ names.append(op.position_ids)
574
+ names.append(op.output)
575
+ return tuple(names)
1440
576
  if isinstance(op, ConvOp):
1441
577
  names = [op.input0, op.weights]
1442
578
  if op.bias is not None:
@@ -1494,6 +630,16 @@ class CEmitter:
1494
630
  if op.output_y_c is not None:
1495
631
  names.append(op.output_y_c)
1496
632
  return tuple(names)
633
+ if isinstance(op, AdagradOp):
634
+ return (
635
+ op.rate,
636
+ op.timestep,
637
+ *op.inputs,
638
+ *op.gradients,
639
+ *op.accumulators,
640
+ *op.outputs,
641
+ *op.accumulator_outputs,
642
+ )
1497
643
  if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
1498
644
  return (op.input0, op.output)
1499
645
  if isinstance(op, NegativeLogLikelihoodLossOp):
@@ -1523,6 +669,12 @@ class CEmitter:
1523
669
  return (op.data, op.indices, op.output)
1524
670
  if isinstance(op, ScatterNDOp):
1525
671
  return (op.data, op.indices, op.updates, op.output)
672
+ if isinstance(op, TensorScatterOp):
673
+ names = [op.past_cache, op.update]
674
+ if op.write_indices is not None:
675
+ names.append(op.write_indices)
676
+ names.append(op.output)
677
+ return tuple(names)
1526
678
  if isinstance(op, ConcatOp):
1527
679
  return (*op.inputs, op.output)
1528
680
  if isinstance(op, ConstantOfShapeOp):
@@ -1533,6 +685,16 @@ class CEmitter:
1533
685
  return (op.input0, op.output)
1534
686
  if isinstance(op, NonZeroOp):
1535
687
  return (op.input0, op.output)
688
+ if isinstance(op, NonMaxSuppressionOp):
689
+ names = [op.boxes, op.scores]
690
+ if op.max_output_boxes_per_class is not None:
691
+ names.append(op.max_output_boxes_per_class)
692
+ if op.iou_threshold is not None:
693
+ names.append(op.iou_threshold)
694
+ if op.score_threshold is not None:
695
+ names.append(op.score_threshold)
696
+ names.append(op.output)
697
+ return tuple(names)
1536
698
  if isinstance(op, ExpandOp):
1537
699
  return (op.input0, op.output)
1538
700
  if isinstance(op, CumSumOp):
@@ -1653,10 +815,12 @@ class CEmitter:
1653
815
  | ClipOp
1654
816
  | CastOp
1655
817
  | QuantizeLinearOp
818
+ | QLinearMatMulOp
1656
819
  | MatMulOp
1657
820
  | EinsumOp
1658
821
  | GemmOp
1659
822
  | AttentionOp
823
+ | RotaryEmbeddingOp
1660
824
  | ConvOp
1661
825
  | ConvTransposeOp
1662
826
  | AveragePoolOp
@@ -1670,6 +834,7 @@ class CEmitter:
1670
834
  | RMSNormalizationOp
1671
835
  | LrnOp
1672
836
  | LstmOp
837
+ | AdagradOp
1673
838
  | SoftmaxOp
1674
839
  | LogSoftmaxOp
1675
840
  | HardmaxOp
@@ -1681,6 +846,7 @@ class CEmitter:
1681
846
  | GatherOp
1682
847
  | GatherNDOp
1683
848
  | ScatterNDOp
849
+ | TensorScatterOp
1684
850
  | TransposeOp
1685
851
  | ReshapeOp
1686
852
  | IdentityOp
@@ -1700,6 +866,7 @@ class CEmitter:
1700
866
  | ShapeOp
1701
867
  | SizeOp
1702
868
  | NonZeroOp
869
+ | NonMaxSuppressionOp
1703
870
  | ExpandOp
1704
871
  | CumSumOp
1705
872
  | RangeOp
@@ -1714,10 +881,12 @@ class CEmitter:
1714
881
  | ClipOp
1715
882
  | CastOp
1716
883
  | QuantizeLinearOp
884
+ | QLinearMatMulOp
1717
885
  | MatMulOp
1718
886
  | EinsumOp
1719
887
  | GemmOp
1720
888
  | AttentionOp
889
+ | RotaryEmbeddingOp
1721
890
  | ConvOp
1722
891
  | ConvTransposeOp
1723
892
  | AveragePoolOp
@@ -1731,6 +900,7 @@ class CEmitter:
1731
900
  | RMSNormalizationOp
1732
901
  | LrnOp
1733
902
  | LstmOp
903
+ | AdagradOp
1734
904
  | SoftmaxOp
1735
905
  | LogSoftmaxOp
1736
906
  | HardmaxOp
@@ -1742,6 +912,7 @@ class CEmitter:
1742
912
  | GatherOp
1743
913
  | GatherNDOp
1744
914
  | ScatterNDOp
915
+ | TensorScatterOp
1745
916
  | TransposeOp
1746
917
  | ReshapeOp
1747
918
  | IdentityOp
@@ -1761,6 +932,7 @@ class CEmitter:
1761
932
  | ShapeOp
1762
933
  | SizeOp
1763
934
  | NonZeroOp
935
+ | NonMaxSuppressionOp
1764
936
  | ExpandOp
1765
937
  | CumSumOp
1766
938
  | RangeOp
@@ -1844,6 +1016,47 @@ class CEmitter:
1844
1016
  input_dtype=op.input_dtype,
1845
1017
  scale_dtype=op.scale_dtype,
1846
1018
  )
1019
+ if isinstance(op, QLinearMatMulOp):
1020
+ return QLinearMatMulOp(
1021
+ input0=name_map.get(op.input0, op.input0),
1022
+ input0_scale=name_map.get(op.input0_scale, op.input0_scale),
1023
+ input0_zero_point=name_map.get(
1024
+ op.input0_zero_point, op.input0_zero_point
1025
+ ),
1026
+ input1=name_map.get(op.input1, op.input1),
1027
+ input1_scale=name_map.get(op.input1_scale, op.input1_scale),
1028
+ input1_zero_point=name_map.get(
1029
+ op.input1_zero_point, op.input1_zero_point
1030
+ ),
1031
+ output_scale=name_map.get(op.output_scale, op.output_scale),
1032
+ output_zero_point=name_map.get(
1033
+ op.output_zero_point, op.output_zero_point
1034
+ ),
1035
+ output=name_map.get(op.output, op.output),
1036
+ input0_shape=op.input0_shape,
1037
+ input1_shape=op.input1_shape,
1038
+ output_shape=op.output_shape,
1039
+ batch_shape=op.batch_shape,
1040
+ input0_batch_shape=op.input0_batch_shape,
1041
+ input1_batch_shape=op.input1_batch_shape,
1042
+ m=op.m,
1043
+ n=op.n,
1044
+ k=op.k,
1045
+ left_vector=op.left_vector,
1046
+ right_vector=op.right_vector,
1047
+ input0_dtype=op.input0_dtype,
1048
+ input1_dtype=op.input1_dtype,
1049
+ dtype=op.dtype,
1050
+ input0_scale_dtype=op.input0_scale_dtype,
1051
+ input1_scale_dtype=op.input1_scale_dtype,
1052
+ output_scale_dtype=op.output_scale_dtype,
1053
+ input0_scale_shape=op.input0_scale_shape,
1054
+ input1_scale_shape=op.input1_scale_shape,
1055
+ output_scale_shape=op.output_scale_shape,
1056
+ input0_zero_shape=op.input0_zero_shape,
1057
+ input1_zero_shape=op.input1_zero_shape,
1058
+ output_zero_shape=op.output_zero_shape,
1059
+ )
1847
1060
  if isinstance(op, MatMulOp):
1848
1061
  return MatMulOp(
1849
1062
  input0=name_map.get(op.input0, op.input0),
@@ -1946,6 +1159,30 @@ class CEmitter:
1946
1159
  head_group_size=op.head_group_size,
1947
1160
  dtype=op.dtype,
1948
1161
  )
1162
+ if isinstance(op, RotaryEmbeddingOp):
1163
+ return RotaryEmbeddingOp(
1164
+ input0=name_map.get(op.input0, op.input0),
1165
+ cos_cache=name_map.get(op.cos_cache, op.cos_cache),
1166
+ sin_cache=name_map.get(op.sin_cache, op.sin_cache),
1167
+ position_ids=self._map_optional_name(
1168
+ name_map, op.position_ids
1169
+ ),
1170
+ output=name_map.get(op.output, op.output),
1171
+ input_shape=op.input_shape,
1172
+ cos_shape=op.cos_shape,
1173
+ sin_shape=op.sin_shape,
1174
+ position_ids_shape=op.position_ids_shape,
1175
+ dtype=op.dtype,
1176
+ position_ids_dtype=op.position_ids_dtype,
1177
+ rotary_dim=op.rotary_dim,
1178
+ rotary_dim_half=op.rotary_dim_half,
1179
+ head_size=op.head_size,
1180
+ num_heads=op.num_heads,
1181
+ seq_len=op.seq_len,
1182
+ batch=op.batch,
1183
+ input_rank=op.input_rank,
1184
+ interleaved=op.interleaved,
1185
+ )
1949
1186
  if isinstance(op, ConvOp):
1950
1187
  return ConvOp(
1951
1188
  input0=name_map.get(op.input0, op.input0),
@@ -2168,6 +1405,33 @@ class CEmitter:
2168
1405
  dtype=op.dtype,
2169
1406
  sequence_lens_dtype=op.sequence_lens_dtype,
2170
1407
  )
1408
+ if isinstance(op, AdagradOp):
1409
+ return AdagradOp(
1410
+ rate=name_map.get(op.rate, op.rate),
1411
+ timestep=name_map.get(op.timestep, op.timestep),
1412
+ inputs=tuple(name_map.get(name, name) for name in op.inputs),
1413
+ gradients=tuple(
1414
+ name_map.get(name, name) for name in op.gradients
1415
+ ),
1416
+ accumulators=tuple(
1417
+ name_map.get(name, name) for name in op.accumulators
1418
+ ),
1419
+ outputs=tuple(name_map.get(name, name) for name in op.outputs),
1420
+ accumulator_outputs=tuple(
1421
+ name_map.get(name, name)
1422
+ for name in op.accumulator_outputs
1423
+ ),
1424
+ rate_shape=op.rate_shape,
1425
+ timestep_shape=op.timestep_shape,
1426
+ tensor_shapes=op.tensor_shapes,
1427
+ output_shapes=op.output_shapes,
1428
+ dtype=op.dtype,
1429
+ rate_dtype=op.rate_dtype,
1430
+ timestep_dtype=op.timestep_dtype,
1431
+ norm_coefficient=op.norm_coefficient,
1432
+ epsilon=op.epsilon,
1433
+ decay_factor=op.decay_factor,
1434
+ )
2171
1435
  if isinstance(op, SoftmaxOp):
2172
1436
  return SoftmaxOp(
2173
1437
  input0=name_map.get(op.input0, op.input0),
@@ -2323,6 +1587,25 @@ class CEmitter:
2323
1587
  dtype=op.dtype,
2324
1588
  indices_dtype=op.indices_dtype,
2325
1589
  )
1590
+ if isinstance(op, TensorScatterOp):
1591
+ return TensorScatterOp(
1592
+ past_cache=name_map.get(op.past_cache, op.past_cache),
1593
+ update=name_map.get(op.update, op.update),
1594
+ write_indices=(
1595
+ name_map.get(op.write_indices, op.write_indices)
1596
+ if op.write_indices is not None
1597
+ else None
1598
+ ),
1599
+ output=name_map.get(op.output, op.output),
1600
+ past_cache_shape=op.past_cache_shape,
1601
+ update_shape=op.update_shape,
1602
+ output_shape=op.output_shape,
1603
+ write_indices_shape=op.write_indices_shape,
1604
+ axis=op.axis,
1605
+ mode=op.mode,
1606
+ dtype=op.dtype,
1607
+ write_indices_dtype=op.write_indices_dtype,
1608
+ )
2326
1609
  if isinstance(op, TransposeOp):
2327
1610
  return TransposeOp(
2328
1611
  input0=name_map.get(op.input0, op.input0),
@@ -2583,6 +1866,33 @@ class CEmitter:
2583
1866
  dtype=op.dtype,
2584
1867
  input_dtype=op.input_dtype,
2585
1868
  )
1869
+ if isinstance(op, NonMaxSuppressionOp):
1870
+ return NonMaxSuppressionOp(
1871
+ boxes=name_map.get(op.boxes, op.boxes),
1872
+ scores=name_map.get(op.scores, op.scores),
1873
+ max_output_boxes_per_class=self._map_optional_name(
1874
+ name_map, op.max_output_boxes_per_class
1875
+ ),
1876
+ iou_threshold=self._map_optional_name(
1877
+ name_map, op.iou_threshold
1878
+ ),
1879
+ score_threshold=self._map_optional_name(
1880
+ name_map, op.score_threshold
1881
+ ),
1882
+ output=name_map.get(op.output, op.output),
1883
+ boxes_shape=op.boxes_shape,
1884
+ scores_shape=op.scores_shape,
1885
+ output_shape=op.output_shape,
1886
+ center_point_box=op.center_point_box,
1887
+ boxes_dtype=op.boxes_dtype,
1888
+ output_dtype=op.output_dtype,
1889
+ max_output_dtype=op.max_output_dtype,
1890
+ max_output_shape=op.max_output_shape,
1891
+ iou_threshold_dtype=op.iou_threshold_dtype,
1892
+ iou_threshold_shape=op.iou_threshold_shape,
1893
+ score_threshold_dtype=op.score_threshold_dtype,
1894
+ score_threshold_shape=op.score_threshold_shape,
1895
+ )
2586
1896
  if isinstance(op, ExpandOp):
2587
1897
  return ExpandOp(
2588
1898
  input0=name_map.get(op.input0, op.input0),
@@ -2684,12 +1994,34 @@ class CEmitter:
2684
1994
  ops=ops,
2685
1995
  node_infos=model.node_infos,
2686
1996
  header=model.header,
1997
+ op_context=model.op_context,
2687
1998
  )
2688
1999
  return sanitized, name_map
2689
2000
 
2690
2001
  def _sanitize_model_names(self, model: LoweredModel) -> LoweredModel:
2691
2002
  return self._sanitize_model_names_with_map(model)[0]
2692
2003
 
2004
+ @staticmethod
2005
+ def _copy_derived(
2006
+ op_context: OpContext,
2007
+ source_ops: Sequence[OpBase],
2008
+ target_ops: Sequence[OpBase],
2009
+ ) -> None:
2010
+ for source_op, target_op in zip(source_ops, target_ops):
2011
+ op_context.copy_derived(source_op, target_op)
2012
+
2013
+ @staticmethod
2014
+ def _build_value_name_map(
2015
+ name_map: Mapping[str, str],
2016
+ temp_name_map: Mapping[str, str],
2017
+ ) -> dict[str, str]:
2018
+ reverse_name_map = {sanitized: original for original, sanitized in name_map.items()}
2019
+ value_name_map = dict(reverse_name_map)
2020
+ for sanitized_name, temp_name in temp_name_map.items():
2021
+ original_name = reverse_name_map.get(sanitized_name, sanitized_name)
2022
+ value_name_map[temp_name] = original_name
2023
+ return value_name_map
2024
+
2693
2025
  @staticmethod
2694
2026
  def _sanitize_testbench_inputs(
2695
2027
  testbench_inputs: Mapping[str, tuple[float | int | bool, ...]] | None,
@@ -2716,10 +2048,16 @@ class CEmitter:
2716
2048
  "quantize_linear": self._env.get_template(
2717
2049
  "quantize_linear_op.c.j2"
2718
2050
  ),
2051
+ "qlinear_matmul": self._env.get_template(
2052
+ "qlinear_matmul_op.c.j2"
2053
+ ),
2719
2054
  "matmul": self._env.get_template("matmul_op.c.j2"),
2720
2055
  "einsum": self._env.get_template("einsum_op.c.j2"),
2721
2056
  "gemm": self._env.get_template("gemm_op.c.j2"),
2722
2057
  "attention": self._env.get_template("attention_op.c.j2"),
2058
+ "rotary_embedding": self._env.get_template(
2059
+ "rotary_embedding_op.c.j2"
2060
+ ),
2723
2061
  "conv": self._env.get_template("conv_op.c.j2"),
2724
2062
  "conv_transpose": self._env.get_template(
2725
2063
  "conv_transpose_op.c.j2"
@@ -2743,6 +2081,7 @@ class CEmitter:
2743
2081
  "rms_norm": self._env.get_template("rms_normalization_op.c.j2"),
2744
2082
  "lrn": self._env.get_template("lrn_op.c.j2"),
2745
2083
  "lstm": self._env.get_template("lstm_op.c.j2"),
2084
+ "adagrad": self._env.get_template("adagrad_op.c.j2"),
2746
2085
  "softmax": self._env.get_template("softmax_op.c.j2"),
2747
2086
  "logsoftmax": self._env.get_template("logsoftmax_op.c.j2"),
2748
2087
  "hardmax": self._env.get_template("hardmax_op.c.j2"),
@@ -2758,6 +2097,9 @@ class CEmitter:
2758
2097
  "gather": self._env.get_template("gather_op.c.j2"),
2759
2098
  "gather_nd": self._env.get_template("gather_nd_op.c.j2"),
2760
2099
  "scatter_nd": self._env.get_template("scatter_nd_op.c.j2"),
2100
+ "tensor_scatter": self._env.get_template(
2101
+ "tensor_scatter_op.c.j2"
2102
+ ),
2761
2103
  "transpose": self._env.get_template("transpose_op.c.j2"),
2762
2104
  "reshape": self._env.get_template("reshape_op.c.j2"),
2763
2105
  "identity": self._env.get_template("identity_op.c.j2"),
@@ -2785,6 +2127,9 @@ class CEmitter:
2785
2127
  "shape": self._env.get_template("shape_op.c.j2"),
2786
2128
  "size": self._env.get_template("size_op.c.j2"),
2787
2129
  "nonzero": self._env.get_template("nonzero_op.c.j2"),
2130
+ "nonmax_suppression": self._env.get_template(
2131
+ "nonmax_suppression_op.c.j2"
2132
+ ),
2788
2133
  "expand": self._env.get_template("expand_op.c.j2"),
2789
2134
  "cumsum": self._env.get_template("cumsum_op.c.j2"),
2790
2135
  "range": self._env.get_template("range_op.c.j2"),
@@ -2806,7 +2151,9 @@ class CEmitter:
2806
2151
  variable_dim_inputs: Mapping[int, Mapping[int, str]] | None = None,
2807
2152
  variable_dim_outputs: Mapping[int, Mapping[int, str]] | None = None,
2808
2153
  ) -> str:
2154
+ original_model = model
2809
2155
  model, name_map = self._sanitize_model_names_with_map(model)
2156
+ self._copy_derived(model.op_context, original_model.ops, model.ops)
2810
2157
  testbench_inputs = self._sanitize_testbench_inputs(
2811
2158
  testbench_inputs, name_map
2812
2159
  )
@@ -2832,68 +2179,17 @@ class CEmitter:
2832
2179
  self._env.globals["dim_args"] = dim_args
2833
2180
  templates = self._load_templates(emit_testbench)
2834
2181
  scalar_registry = ScalarFunctionRegistry()
2835
- binary_template = templates["binary"]
2836
- multi_input_template = templates["multi_input"]
2837
- where_template = templates["where"]
2838
- unary_template = templates["unary"]
2839
- clip_template = templates["clip"]
2840
- cast_template = templates["cast"]
2841
- quantize_linear_template = templates["quantize_linear"]
2842
- matmul_template = templates["matmul"]
2843
- einsum_template = templates["einsum"]
2844
- gemm_template = templates["gemm"]
2845
- attention_template = templates["attention"]
2846
- conv_template = templates["conv"]
2847
- conv_transpose_template = templates["conv_transpose"]
2848
- avg_pool_template = templates["avg_pool"]
2849
- lp_pool_template = templates["lp_pool"]
2850
- batch_norm_template = templates["batch_norm"]
2851
- lp_norm_template = templates["lp_norm"]
2852
- instance_norm_template = templates["instance_norm"]
2853
- group_norm_template = templates["group_norm"]
2854
- layer_norm_template = templates["layer_norm"]
2855
- mean_variance_norm_template = templates["mean_variance_norm"]
2856
- rms_norm_template = templates["rms_norm"]
2857
- lrn_template = templates["lrn"]
2858
- lstm_template = templates["lstm"]
2859
- softmax_template = templates["softmax"]
2860
- logsoftmax_template = templates["logsoftmax"]
2861
- hardmax_template = templates["hardmax"]
2862
- nllloss_template = templates["nllloss"]
2863
- softmax_cross_entropy_loss_template = templates["softmax_cross_entropy_loss"]
2864
- maxpool_template = templates["maxpool"]
2865
- concat_template = templates["concat"]
2866
- gather_elements_template = templates["gather_elements"]
2867
- gather_template = templates["gather"]
2868
- gather_nd_template = templates["gather_nd"]
2869
- scatter_nd_template = templates["scatter_nd"]
2870
- transpose_template = templates["transpose"]
2871
- reshape_template = templates["reshape"]
2872
- identity_template = templates["identity"]
2873
- eye_like_template = templates["eye_like"]
2874
- trilu_template = templates["trilu"]
2875
- tile_template = templates["tile"]
2876
- pad_template = templates["pad"]
2877
- depth_to_space_template = templates["depth_to_space"]
2878
- space_to_depth_template = templates["space_to_depth"]
2879
- slice_template = templates["slice"]
2880
- slice_dynamic_template = templates["slice_dynamic"]
2881
- resize_template = templates["resize"]
2882
- grid_sample_template = templates["grid_sample"]
2883
- reduce_template = templates["reduce"]
2884
- reduce_dynamic_template = templates["reduce_dynamic"]
2885
- arg_reduce_template = templates["arg_reduce"]
2886
- topk_template = templates["topk"]
2887
- constant_of_shape_template = templates["constant_of_shape"]
2888
- shape_template = templates["shape"]
2889
- size_template = templates["size"]
2890
- nonzero_template = templates["nonzero"]
2891
- expand_template = templates["expand"]
2892
- cumsum_template = templates["cumsum"]
2893
- range_template = templates["range"]
2894
- one_hot_template = templates["one_hot"]
2895
- split_template = templates["split"]
2896
2182
  testbench_template = templates.get("testbench")
2183
+ initial_name_map = self._build_value_name_map(name_map, {})
2184
+ self._emit_state = _EmitState(
2185
+ model=model,
2186
+ templates=templates,
2187
+ scalar_registry=scalar_registry,
2188
+ dim_args=dim_args,
2189
+ tensor_dim_names=tensor_dim_names,
2190
+ op_context=model.op_context,
2191
+ value_name_map=initial_name_map,
2192
+ )
2897
2193
  reserved_names = {
2898
2194
  model.name,
2899
2195
  *model.input_names,
@@ -2905,83 +2201,12 @@ class CEmitter:
2905
2201
  original: buffer.name for original, buffer in temp_buffers.items()
2906
2202
  }
2907
2203
  resolved_ops = [self._resolve_op(op, temp_name_map) for op in model.ops]
2204
+ self._copy_derived(model.op_context, model.ops, resolved_ops)
2205
+ value_name_map = self._build_value_name_map(name_map, temp_name_map)
2206
+ self._emit_state.value_name_map = value_name_map
2908
2207
  self._propagate_tensor_dim_names(resolved_ops, tensor_dim_names)
2909
2208
  operator_fns = "\n\n".join(
2910
- self._render_op(
2911
- model,
2912
- op,
2913
- index,
2914
- array_suffix="",
2915
- loop_vars=(),
2916
- c_type=self._op_output_dtype(op).c_type,
2917
- zero_literal=self._op_output_dtype(op).zero_literal,
2918
- min_literal=self._op_output_dtype(op).min_literal,
2919
- max_literal=self._op_output_dtype(op).max_literal,
2920
- binary_template=binary_template,
2921
- multi_input_template=multi_input_template,
2922
- where_template=where_template,
2923
- unary_template=unary_template,
2924
- clip_template=clip_template,
2925
- cast_template=cast_template,
2926
- quantize_linear_template=quantize_linear_template,
2927
- matmul_template=matmul_template,
2928
- einsum_template=einsum_template,
2929
- gemm_template=gemm_template,
2930
- attention_template=attention_template,
2931
- conv_template=conv_template,
2932
- conv_transpose_template=conv_transpose_template,
2933
- avg_pool_template=avg_pool_template,
2934
- lp_pool_template=lp_pool_template,
2935
- batch_norm_template=batch_norm_template,
2936
- lp_norm_template=lp_norm_template,
2937
- instance_norm_template=instance_norm_template,
2938
- group_norm_template=group_norm_template,
2939
- layer_norm_template=layer_norm_template,
2940
- mean_variance_norm_template=mean_variance_norm_template,
2941
- rms_norm_template=rms_norm_template,
2942
- lrn_template=lrn_template,
2943
- lstm_template=lstm_template,
2944
- softmax_template=softmax_template,
2945
- logsoftmax_template=logsoftmax_template,
2946
- hardmax_template=hardmax_template,
2947
- nllloss_template=nllloss_template,
2948
- softmax_cross_entropy_loss_template=softmax_cross_entropy_loss_template,
2949
- maxpool_template=maxpool_template,
2950
- concat_template=concat_template,
2951
- gather_elements_template=gather_elements_template,
2952
- gather_template=gather_template,
2953
- gather_nd_template=gather_nd_template,
2954
- scatter_nd_template=scatter_nd_template,
2955
- transpose_template=transpose_template,
2956
- reshape_template=reshape_template,
2957
- identity_template=identity_template,
2958
- eye_like_template=eye_like_template,
2959
- trilu_template=trilu_template,
2960
- tile_template=tile_template,
2961
- pad_template=pad_template,
2962
- depth_to_space_template=depth_to_space_template,
2963
- space_to_depth_template=space_to_depth_template,
2964
- slice_template=slice_template,
2965
- slice_dynamic_template=slice_dynamic_template,
2966
- resize_template=resize_template,
2967
- grid_sample_template=grid_sample_template,
2968
- reduce_template=reduce_template,
2969
- reduce_dynamic_template=reduce_dynamic_template,
2970
- arg_reduce_template=arg_reduce_template,
2971
- topk_template=topk_template,
2972
- constant_of_shape_template=constant_of_shape_template,
2973
- shape_template=shape_template,
2974
- size_template=size_template,
2975
- nonzero_template=nonzero_template,
2976
- expand_template=expand_template,
2977
- cumsum_template=cumsum_template,
2978
- range_template=range_template,
2979
- one_hot_template=one_hot_template,
2980
- split_template=split_template,
2981
- scalar_registry=scalar_registry,
2982
- dim_args=dim_args,
2983
- tensor_dim_names=tensor_dim_names,
2984
- )
2209
+ op.emit(self, EmitContext(op_index=index))
2985
2210
  for index, op in enumerate(resolved_ops)
2986
2211
  )
2987
2212
  wrapper_fn = self._emit_model_wrapper(
@@ -3073,7 +2298,9 @@ class CEmitter:
3073
2298
  variable_dim_inputs: Mapping[int, Mapping[int, str]] | None = None,
3074
2299
  variable_dim_outputs: Mapping[int, Mapping[int, str]] | None = None,
3075
2300
  ) -> tuple[str, str]:
2301
+ original_model = model
3076
2302
  model, name_map = self._sanitize_model_names_with_map(model)
2303
+ self._copy_derived(model.op_context, original_model.ops, model.ops)
3077
2304
  testbench_inputs = self._sanitize_testbench_inputs(
3078
2305
  testbench_inputs, name_map
3079
2306
  )
@@ -3099,68 +2326,17 @@ class CEmitter:
3099
2326
  self._env.globals["dim_args"] = dim_args
3100
2327
  templates = self._load_templates(emit_testbench)
3101
2328
  scalar_registry = ScalarFunctionRegistry()
3102
- binary_template = templates["binary"]
3103
- multi_input_template = templates["multi_input"]
3104
- where_template = templates["where"]
3105
- unary_template = templates["unary"]
3106
- clip_template = templates["clip"]
3107
- cast_template = templates["cast"]
3108
- quantize_linear_template = templates["quantize_linear"]
3109
- matmul_template = templates["matmul"]
3110
- einsum_template = templates["einsum"]
3111
- gemm_template = templates["gemm"]
3112
- attention_template = templates["attention"]
3113
- conv_template = templates["conv"]
3114
- conv_transpose_template = templates["conv_transpose"]
3115
- avg_pool_template = templates["avg_pool"]
3116
- lp_pool_template = templates["lp_pool"]
3117
- batch_norm_template = templates["batch_norm"]
3118
- lp_norm_template = templates["lp_norm"]
3119
- instance_norm_template = templates["instance_norm"]
3120
- group_norm_template = templates["group_norm"]
3121
- layer_norm_template = templates["layer_norm"]
3122
- mean_variance_norm_template = templates["mean_variance_norm"]
3123
- rms_norm_template = templates["rms_norm"]
3124
- lrn_template = templates["lrn"]
3125
- lstm_template = templates["lstm"]
3126
- softmax_template = templates["softmax"]
3127
- logsoftmax_template = templates["logsoftmax"]
3128
- hardmax_template = templates["hardmax"]
3129
- nllloss_template = templates["nllloss"]
3130
- softmax_cross_entropy_loss_template = templates["softmax_cross_entropy_loss"]
3131
- maxpool_template = templates["maxpool"]
3132
- concat_template = templates["concat"]
3133
- gather_elements_template = templates["gather_elements"]
3134
- gather_template = templates["gather"]
3135
- gather_nd_template = templates["gather_nd"]
3136
- scatter_nd_template = templates["scatter_nd"]
3137
- transpose_template = templates["transpose"]
3138
- reshape_template = templates["reshape"]
3139
- identity_template = templates["identity"]
3140
- eye_like_template = templates["eye_like"]
3141
- trilu_template = templates["trilu"]
3142
- tile_template = templates["tile"]
3143
- pad_template = templates["pad"]
3144
- depth_to_space_template = templates["depth_to_space"]
3145
- space_to_depth_template = templates["space_to_depth"]
3146
- slice_template = templates["slice"]
3147
- slice_dynamic_template = templates["slice_dynamic"]
3148
- resize_template = templates["resize"]
3149
- grid_sample_template = templates["grid_sample"]
3150
- reduce_template = templates["reduce"]
3151
- reduce_dynamic_template = templates["reduce_dynamic"]
3152
- arg_reduce_template = templates["arg_reduce"]
3153
- topk_template = templates["topk"]
3154
- constant_of_shape_template = templates["constant_of_shape"]
3155
- shape_template = templates["shape"]
3156
- size_template = templates["size"]
3157
- nonzero_template = templates["nonzero"]
3158
- expand_template = templates["expand"]
3159
- cumsum_template = templates["cumsum"]
3160
- range_template = templates["range"]
3161
- one_hot_template = templates["one_hot"]
3162
- split_template = templates["split"]
3163
2329
  testbench_template = templates.get("testbench")
2330
+ initial_name_map = self._build_value_name_map(name_map, {})
2331
+ self._emit_state = _EmitState(
2332
+ model=model,
2333
+ templates=templates,
2334
+ scalar_registry=scalar_registry,
2335
+ dim_args=dim_args,
2336
+ tensor_dim_names=tensor_dim_names,
2337
+ op_context=model.op_context,
2338
+ value_name_map=initial_name_map,
2339
+ )
3164
2340
  reserved_names = {
3165
2341
  model.name,
3166
2342
  *model.input_names,
@@ -3172,83 +2348,12 @@ class CEmitter:
3172
2348
  original: buffer.name for original, buffer in temp_buffers.items()
3173
2349
  }
3174
2350
  resolved_ops = [self._resolve_op(op, temp_name_map) for op in model.ops]
2351
+ self._copy_derived(model.op_context, model.ops, resolved_ops)
2352
+ value_name_map = self._build_value_name_map(name_map, temp_name_map)
2353
+ self._emit_state.value_name_map = value_name_map
3175
2354
  self._propagate_tensor_dim_names(resolved_ops, tensor_dim_names)
3176
2355
  operator_fns = "\n\n".join(
3177
- self._render_op(
3178
- model,
3179
- op,
3180
- index,
3181
- array_suffix="",
3182
- loop_vars=(),
3183
- c_type=self._op_output_dtype(op).c_type,
3184
- zero_literal=self._op_output_dtype(op).zero_literal,
3185
- min_literal=self._op_output_dtype(op).min_literal,
3186
- max_literal=self._op_output_dtype(op).max_literal,
3187
- binary_template=binary_template,
3188
- multi_input_template=multi_input_template,
3189
- where_template=where_template,
3190
- unary_template=unary_template,
3191
- clip_template=clip_template,
3192
- cast_template=cast_template,
3193
- quantize_linear_template=quantize_linear_template,
3194
- matmul_template=matmul_template,
3195
- einsum_template=einsum_template,
3196
- gemm_template=gemm_template,
3197
- attention_template=attention_template,
3198
- conv_template=conv_template,
3199
- conv_transpose_template=conv_transpose_template,
3200
- avg_pool_template=avg_pool_template,
3201
- lp_pool_template=lp_pool_template,
3202
- batch_norm_template=batch_norm_template,
3203
- lp_norm_template=lp_norm_template,
3204
- instance_norm_template=instance_norm_template,
3205
- group_norm_template=group_norm_template,
3206
- layer_norm_template=layer_norm_template,
3207
- mean_variance_norm_template=mean_variance_norm_template,
3208
- rms_norm_template=rms_norm_template,
3209
- lrn_template=lrn_template,
3210
- lstm_template=lstm_template,
3211
- softmax_template=softmax_template,
3212
- logsoftmax_template=logsoftmax_template,
3213
- hardmax_template=hardmax_template,
3214
- nllloss_template=nllloss_template,
3215
- softmax_cross_entropy_loss_template=softmax_cross_entropy_loss_template,
3216
- maxpool_template=maxpool_template,
3217
- concat_template=concat_template,
3218
- gather_elements_template=gather_elements_template,
3219
- gather_template=gather_template,
3220
- gather_nd_template=gather_nd_template,
3221
- scatter_nd_template=scatter_nd_template,
3222
- transpose_template=transpose_template,
3223
- reshape_template=reshape_template,
3224
- identity_template=identity_template,
3225
- eye_like_template=eye_like_template,
3226
- trilu_template=trilu_template,
3227
- tile_template=tile_template,
3228
- pad_template=pad_template,
3229
- depth_to_space_template=depth_to_space_template,
3230
- space_to_depth_template=space_to_depth_template,
3231
- slice_template=slice_template,
3232
- slice_dynamic_template=slice_dynamic_template,
3233
- resize_template=resize_template,
3234
- grid_sample_template=grid_sample_template,
3235
- reduce_template=reduce_template,
3236
- reduce_dynamic_template=reduce_dynamic_template,
3237
- arg_reduce_template=arg_reduce_template,
3238
- topk_template=topk_template,
3239
- constant_of_shape_template=constant_of_shape_template,
3240
- shape_template=shape_template,
3241
- size_template=size_template,
3242
- nonzero_template=nonzero_template,
3243
- expand_template=expand_template,
3244
- cumsum_template=cumsum_template,
3245
- range_template=range_template,
3246
- one_hot_template=one_hot_template,
3247
- split_template=split_template,
3248
- scalar_registry=scalar_registry,
3249
- dim_args=dim_args,
3250
- tensor_dim_names=tensor_dim_names,
3251
- )
2356
+ op.emit(self, EmitContext(op_index=index))
3252
2357
  for index, op in enumerate(resolved_ops)
3253
2358
  )
3254
2359
  wrapper_fn = self._emit_model_wrapper(
@@ -3536,6 +2641,8 @@ class CEmitter:
3536
2641
  ScalarFunction.SCALED_TANH,
3537
2642
  ScalarFunction.THRESHOLDED_RELU,
3538
2643
  ScalarFunction.LOGICAL_XOR,
2644
+ ScalarFunction.ISNEGINF,
2645
+ ScalarFunction.ISPOSINF,
3539
2646
  }
3540
2647
  if function in {ScalarFunction.MAXIMUM, ScalarFunction.MINIMUM}:
3541
2648
  if dtype in {ScalarType.F32, ScalarType.F64}:
@@ -3598,6 +2705,7 @@ class CEmitter:
3598
2705
  | ClipOp
3599
2706
  | CastOp
3600
2707
  | QuantizeLinearOp
2708
+ | QLinearMatMulOp
3601
2709
  | MatMulOp
3602
2710
  | EinsumOp
3603
2711
  | GemmOp
@@ -3615,6 +2723,7 @@ class CEmitter:
3615
2723
  | RMSNormalizationOp
3616
2724
  | LrnOp
3617
2725
  | LstmOp
2726
+ | AdagradOp
3618
2727
  | SoftmaxOp
3619
2728
  | LogSoftmaxOp
3620
2729
  | HardmaxOp
@@ -3626,6 +2735,7 @@ class CEmitter:
3626
2735
  | GatherOp
3627
2736
  | GatherNDOp
3628
2737
  | ScatterNDOp
2738
+ | TensorScatterOp
3629
2739
  | TransposeOp
3630
2740
  | ReshapeOp
3631
2741
  | IdentityOp
@@ -3644,6 +2754,7 @@ class CEmitter:
3644
2754
  | ShapeOp
3645
2755
  | SizeOp
3646
2756
  | NonZeroOp
2757
+ | NonMaxSuppressionOp
3647
2758
  | ExpandOp
3648
2759
  | CumSumOp
3649
2760
  | RangeOp
@@ -3830,6 +2941,7 @@ class CEmitter:
3830
2941
  | ClipOp
3831
2942
  | CastOp
3832
2943
  | QuantizeLinearOp
2944
+ | QLinearMatMulOp
3833
2945
  | MatMulOp
3834
2946
  | EinsumOp
3835
2947
  | GemmOp
@@ -3847,6 +2959,7 @@ class CEmitter:
3847
2959
  | RMSNormalizationOp
3848
2960
  | LrnOp
3849
2961
  | LstmOp
2962
+ | AdagradOp
3850
2963
  | SoftmaxOp
3851
2964
  | LogSoftmaxOp
3852
2965
  | HardmaxOp
@@ -3858,6 +2971,7 @@ class CEmitter:
3858
2971
  | GatherOp
3859
2972
  | GatherNDOp
3860
2973
  | ScatterNDOp
2974
+ | TensorScatterOp
3861
2975
  | TransposeOp
3862
2976
  | ReshapeOp
3863
2977
  | IdentityOp
@@ -3876,6 +2990,7 @@ class CEmitter:
3876
2990
  | ShapeOp
3877
2991
  | SizeOp
3878
2992
  | NonZeroOp
2993
+ | NonMaxSuppressionOp
3879
2994
  | ExpandOp
3880
2995
  | CumSumOp
3881
2996
  | RangeOp
@@ -3948,6 +3063,7 @@ class CEmitter:
3948
3063
  RMSNormalizationOp,
3949
3064
  LrnOp,
3950
3065
  LstmOp,
3066
+ AdagradOp,
3951
3067
  SoftmaxOp,
3952
3068
  LogSoftmaxOp,
3953
3069
  SoftmaxCrossEntropyLossOp,
@@ -3977,7 +3093,7 @@ class CEmitter:
3977
3093
  ):
3978
3094
  return True
3979
3095
  if any(
3980
- isinstance(op, (LpPoolOp, QuantizeLinearOp))
3096
+ isinstance(op, (LpPoolOp, QuantizeLinearOp, QLinearMatMulOp))
3981
3097
  for op in resolved_ops
3982
3098
  ):
3983
3099
  return True
@@ -3991,6 +3107,7 @@ class CEmitter:
3991
3107
  | ClipOp
3992
3108
  | CastOp
3993
3109
  | QuantizeLinearOp
3110
+ | QLinearMatMulOp
3994
3111
  | MatMulOp
3995
3112
  | EinsumOp
3996
3113
  | GemmOp
@@ -4036,6 +3153,7 @@ class CEmitter:
4036
3153
  | ShapeOp
4037
3154
  | SizeOp
4038
3155
  | NonZeroOp
3156
+ | NonMaxSuppressionOp
4039
3157
  | ExpandOp
4040
3158
  | CumSumOp
4041
3159
  | RangeOp
@@ -4070,10 +3188,13 @@ class CEmitter:
4070
3188
  ):
4071
3189
  return True
4072
3190
  if any(
4073
- isinstance(op, QuantizeLinearOp) and op.dtype.is_integer
3191
+ isinstance(op, (QuantizeLinearOp, QLinearMatMulOp))
3192
+ and op.dtype.is_integer
4074
3193
  for op in resolved_ops
4075
3194
  ):
4076
3195
  return True
3196
+ if any(isinstance(op, NonMaxSuppressionOp) for op in resolved_ops):
3197
+ return True
4077
3198
  return False
4078
3199
 
4079
3200
  def _emit_model_wrapper(
@@ -4086,6 +3207,7 @@ class CEmitter:
4086
3207
  | ClipOp
4087
3208
  | CastOp
4088
3209
  | QuantizeLinearOp
3210
+ | QLinearMatMulOp
4089
3211
  | MatMulOp
4090
3212
  | EinsumOp
4091
3213
  | GemmOp
@@ -4131,6 +3253,7 @@ class CEmitter:
4131
3253
  | ShapeOp
4132
3254
  | SizeOp
4133
3255
  | NonZeroOp
3256
+ | NonMaxSuppressionOp
4134
3257
  | ExpandOp
4135
3258
  | CumSumOp
4136
3259
  | RangeOp
@@ -4195,10 +3318,12 @@ class CEmitter:
4195
3318
  | ClipOp
4196
3319
  | CastOp
4197
3320
  | QuantizeLinearOp
3321
+ | QLinearMatMulOp
4198
3322
  | MatMulOp
4199
3323
  | EinsumOp
4200
3324
  | GemmOp
4201
3325
  | AttentionOp
3326
+ | RotaryEmbeddingOp
4202
3327
  | ConvOp
4203
3328
  | ConvTransposeOp
4204
3329
  | AveragePoolOp
@@ -4212,6 +3337,7 @@ class CEmitter:
4212
3337
  | RMSNormalizationOp
4213
3338
  | LrnOp
4214
3339
  | LstmOp
3340
+ | AdagradOp
4215
3341
  | SoftmaxOp
4216
3342
  | LogSoftmaxOp
4217
3343
  | HardmaxOp
@@ -4223,6 +3349,7 @@ class CEmitter:
4223
3349
  | GatherOp
4224
3350
  | GatherNDOp
4225
3351
  | ScatterNDOp
3352
+ | TensorScatterOp
4226
3353
  | TransposeOp
4227
3354
  | ReshapeOp
4228
3355
  | IdentityOp
@@ -4242,6 +3369,7 @@ class CEmitter:
4242
3369
  | ShapeOp
4243
3370
  | SizeOp
4244
3371
  | NonZeroOp
3372
+ | NonMaxSuppressionOp
4245
3373
  | ExpandOp
4246
3374
  | CumSumOp
4247
3375
  | RangeOp
@@ -4261,6 +3389,21 @@ class CEmitter:
4261
3389
  if isinstance(op, WhereOp):
4262
3390
  args.extend([op.condition, op.input_x, op.input_y, op.output])
4263
3391
  return ", ".join(args)
3392
+ if isinstance(op, QLinearMatMulOp):
3393
+ args.extend(
3394
+ [
3395
+ op.input0,
3396
+ op.input0_scale,
3397
+ op.input0_zero_point,
3398
+ op.input1,
3399
+ op.input1_scale,
3400
+ op.input1_zero_point,
3401
+ op.output_scale,
3402
+ op.output_zero_point,
3403
+ op.output,
3404
+ ]
3405
+ )
3406
+ return ", ".join(args)
4264
3407
  if isinstance(op, MatMulOp):
4265
3408
  args.extend([op.input0, op.input1, op.output])
4266
3409
  return ", ".join(args)
@@ -4380,6 +3523,19 @@ class CEmitter:
4380
3523
  call_parts.append(op.output_y_c)
4381
3524
  args.extend(call_parts)
4382
3525
  return ", ".join(args)
3526
+ if isinstance(op, AdagradOp):
3527
+ args.extend(
3528
+ [
3529
+ op.rate,
3530
+ op.timestep,
3531
+ *op.inputs,
3532
+ *op.gradients,
3533
+ *op.accumulators,
3534
+ *op.outputs,
3535
+ *op.accumulator_outputs,
3536
+ ]
3537
+ )
3538
+ return ", ".join(args)
4383
3539
  if isinstance(op, (SoftmaxOp, LogSoftmaxOp, HardmaxOp)):
4384
3540
  args.extend([op.input0, op.output])
4385
3541
  return ", ".join(args)
@@ -4417,6 +3573,12 @@ class CEmitter:
4417
3573
  if isinstance(op, ScatterNDOp):
4418
3574
  args.extend([op.data, op.indices, op.updates, op.output])
4419
3575
  return ", ".join(args)
3576
+ if isinstance(op, TensorScatterOp):
3577
+ args.extend([op.past_cache, op.update])
3578
+ if op.write_indices is not None:
3579
+ args.append(op.write_indices)
3580
+ args.append(op.output)
3581
+ return ", ".join(args)
4420
3582
  if isinstance(op, ConcatOp):
4421
3583
  args.extend([*op.inputs, op.output])
4422
3584
  return ", ".join(args)
@@ -4432,6 +3594,17 @@ class CEmitter:
4432
3594
  if isinstance(op, NonZeroOp):
4433
3595
  args.extend([op.input0, op.output])
4434
3596
  return ", ".join(args)
3597
+ if isinstance(op, NonMaxSuppressionOp):
3598
+ call_parts = [op.boxes, op.scores]
3599
+ if op.max_output_boxes_per_class is not None:
3600
+ call_parts.append(op.max_output_boxes_per_class)
3601
+ if op.iou_threshold is not None:
3602
+ call_parts.append(op.iou_threshold)
3603
+ if op.score_threshold is not None:
3604
+ call_parts.append(op.score_threshold)
3605
+ call_parts.append(op.output)
3606
+ args.extend(call_parts)
3607
+ return ", ".join(args)
4435
3608
  if isinstance(op, ExpandOp):
4436
3609
  args.extend([op.input0, op.output])
4437
3610
  return ", ".join(args)
@@ -4566,10 +3739,12 @@ class CEmitter:
4566
3739
  | ClipOp
4567
3740
  | CastOp
4568
3741
  | QuantizeLinearOp
3742
+ | QLinearMatMulOp
4569
3743
  | MatMulOp
4570
3744
  | EinsumOp
4571
3745
  | GemmOp
4572
3746
  | AttentionOp
3747
+ | RotaryEmbeddingOp
4573
3748
  | ConvOp
4574
3749
  | ConvTransposeOp
4575
3750
  | AveragePoolOp
@@ -4583,6 +3758,7 @@ class CEmitter:
4583
3758
  | RMSNormalizationOp
4584
3759
  | LrnOp
4585
3760
  | LstmOp
3761
+ | AdagradOp
4586
3762
  | SoftmaxOp
4587
3763
  | LogSoftmaxOp
4588
3764
  | HardmaxOp
@@ -4594,6 +3770,7 @@ class CEmitter:
4594
3770
  | GatherOp
4595
3771
  | GatherNDOp
4596
3772
  | ScatterNDOp
3773
+ | TensorScatterOp
4597
3774
  | TransposeOp
4598
3775
  | ReshapeOp
4599
3776
  | IdentityOp
@@ -4612,6 +3789,7 @@ class CEmitter:
4612
3789
  | ShapeOp
4613
3790
  | SizeOp
4614
3791
  | NonZeroOp
3792
+ | NonMaxSuppressionOp
4615
3793
  | ExpandOp
4616
3794
  | CumSumOp
4617
3795
  | RangeOp
@@ -4626,6 +3804,7 @@ class CEmitter:
4626
3804
  | ClipOp
4627
3805
  | CastOp
4628
3806
  | QuantizeLinearOp
3807
+ | QLinearMatMulOp
4629
3808
  | MatMulOp
4630
3809
  | EinsumOp
4631
3810
  | GemmOp
@@ -4643,6 +3822,7 @@ class CEmitter:
4643
3822
  | RMSNormalizationOp
4644
3823
  | LrnOp
4645
3824
  | LstmOp
3825
+ | AdagradOp
4646
3826
  | SoftmaxOp
4647
3827
  | LogSoftmaxOp
4648
3828
  | HardmaxOp
@@ -4654,6 +3834,7 @@ class CEmitter:
4654
3834
  | GatherOp
4655
3835
  | GatherNDOp
4656
3836
  | ScatterNDOp
3837
+ | TensorScatterOp
4657
3838
  | TransposeOp
4658
3839
  | ReshapeOp
4659
3840
  | IdentityOp
@@ -4672,6 +3853,7 @@ class CEmitter:
4672
3853
  | ShapeOp
4673
3854
  | SizeOp
4674
3855
  | NonZeroOp
3856
+ | NonMaxSuppressionOp
4675
3857
  | ExpandOp
4676
3858
  | CumSumOp
4677
3859
  | RangeOp
@@ -4791,6 +3973,47 @@ class CEmitter:
4791
3973
  input_dtype=op.input_dtype,
4792
3974
  scale_dtype=op.scale_dtype,
4793
3975
  )
3976
+ if isinstance(op, QLinearMatMulOp):
3977
+ return QLinearMatMulOp(
3978
+ input0=temp_map.get(op.input0, op.input0),
3979
+ input0_scale=temp_map.get(op.input0_scale, op.input0_scale),
3980
+ input0_zero_point=temp_map.get(
3981
+ op.input0_zero_point, op.input0_zero_point
3982
+ ),
3983
+ input1=temp_map.get(op.input1, op.input1),
3984
+ input1_scale=temp_map.get(op.input1_scale, op.input1_scale),
3985
+ input1_zero_point=temp_map.get(
3986
+ op.input1_zero_point, op.input1_zero_point
3987
+ ),
3988
+ output_scale=temp_map.get(op.output_scale, op.output_scale),
3989
+ output_zero_point=temp_map.get(
3990
+ op.output_zero_point, op.output_zero_point
3991
+ ),
3992
+ output=temp_map.get(op.output, op.output),
3993
+ input0_shape=op.input0_shape,
3994
+ input1_shape=op.input1_shape,
3995
+ output_shape=op.output_shape,
3996
+ batch_shape=op.batch_shape,
3997
+ input0_batch_shape=op.input0_batch_shape,
3998
+ input1_batch_shape=op.input1_batch_shape,
3999
+ m=op.m,
4000
+ n=op.n,
4001
+ k=op.k,
4002
+ left_vector=op.left_vector,
4003
+ right_vector=op.right_vector,
4004
+ input0_dtype=op.input0_dtype,
4005
+ input1_dtype=op.input1_dtype,
4006
+ dtype=op.dtype,
4007
+ input0_scale_dtype=op.input0_scale_dtype,
4008
+ input1_scale_dtype=op.input1_scale_dtype,
4009
+ output_scale_dtype=op.output_scale_dtype,
4010
+ input0_scale_shape=op.input0_scale_shape,
4011
+ input1_scale_shape=op.input1_scale_shape,
4012
+ output_scale_shape=op.output_scale_shape,
4013
+ input0_zero_shape=op.input0_zero_shape,
4014
+ input1_zero_shape=op.input1_zero_shape,
4015
+ output_zero_shape=op.output_zero_shape,
4016
+ )
4794
4017
  if isinstance(op, GemmOp):
4795
4018
  return GemmOp(
4796
4019
  input_a=temp_map.get(op.input_a, op.input_a),
@@ -4885,6 +4108,32 @@ class CEmitter:
4885
4108
  head_group_size=op.head_group_size,
4886
4109
  dtype=op.dtype,
4887
4110
  )
4111
+ if isinstance(op, RotaryEmbeddingOp):
4112
+ return RotaryEmbeddingOp(
4113
+ input0=temp_map.get(op.input0, op.input0),
4114
+ cos_cache=temp_map.get(op.cos_cache, op.cos_cache),
4115
+ sin_cache=temp_map.get(op.sin_cache, op.sin_cache),
4116
+ position_ids=(
4117
+ temp_map.get(op.position_ids, op.position_ids)
4118
+ if op.position_ids is not None
4119
+ else None
4120
+ ),
4121
+ output=temp_map.get(op.output, op.output),
4122
+ input_shape=op.input_shape,
4123
+ cos_shape=op.cos_shape,
4124
+ sin_shape=op.sin_shape,
4125
+ position_ids_shape=op.position_ids_shape,
4126
+ dtype=op.dtype,
4127
+ position_ids_dtype=op.position_ids_dtype,
4128
+ rotary_dim=op.rotary_dim,
4129
+ rotary_dim_half=op.rotary_dim_half,
4130
+ head_size=op.head_size,
4131
+ num_heads=op.num_heads,
4132
+ seq_len=op.seq_len,
4133
+ batch=op.batch,
4134
+ input_rank=op.input_rank,
4135
+ interleaved=op.interleaved,
4136
+ )
4888
4137
  if isinstance(op, LstmOp):
4889
4138
  return LstmOp(
4890
4139
  input_x=temp_map.get(op.input_x, op.input_x),
@@ -4945,6 +4194,33 @@ class CEmitter:
4945
4194
  dtype=op.dtype,
4946
4195
  sequence_lens_dtype=op.sequence_lens_dtype,
4947
4196
  )
4197
+ if isinstance(op, AdagradOp):
4198
+ return AdagradOp(
4199
+ rate=temp_map.get(op.rate, op.rate),
4200
+ timestep=temp_map.get(op.timestep, op.timestep),
4201
+ inputs=tuple(temp_map.get(name, name) for name in op.inputs),
4202
+ gradients=tuple(
4203
+ temp_map.get(name, name) for name in op.gradients
4204
+ ),
4205
+ accumulators=tuple(
4206
+ temp_map.get(name, name) for name in op.accumulators
4207
+ ),
4208
+ outputs=tuple(temp_map.get(name, name) for name in op.outputs),
4209
+ accumulator_outputs=tuple(
4210
+ temp_map.get(name, name)
4211
+ for name in op.accumulator_outputs
4212
+ ),
4213
+ rate_shape=op.rate_shape,
4214
+ timestep_shape=op.timestep_shape,
4215
+ tensor_shapes=op.tensor_shapes,
4216
+ output_shapes=op.output_shapes,
4217
+ dtype=op.dtype,
4218
+ rate_dtype=op.rate_dtype,
4219
+ timestep_dtype=op.timestep_dtype,
4220
+ norm_coefficient=op.norm_coefficient,
4221
+ epsilon=op.epsilon,
4222
+ decay_factor=op.decay_factor,
4223
+ )
4948
4224
  if isinstance(op, ConvOp):
4949
4225
  return ConvOp(
4950
4226
  input0=temp_map.get(op.input0, op.input0),
@@ -5301,6 +4577,25 @@ class CEmitter:
5301
4577
  dtype=op.dtype,
5302
4578
  indices_dtype=op.indices_dtype,
5303
4579
  )
4580
+ if isinstance(op, TensorScatterOp):
4581
+ return TensorScatterOp(
4582
+ past_cache=temp_map.get(op.past_cache, op.past_cache),
4583
+ update=temp_map.get(op.update, op.update),
4584
+ write_indices=(
4585
+ temp_map.get(op.write_indices, op.write_indices)
4586
+ if op.write_indices is not None
4587
+ else None
4588
+ ),
4589
+ output=temp_map.get(op.output, op.output),
4590
+ past_cache_shape=op.past_cache_shape,
4591
+ update_shape=op.update_shape,
4592
+ output_shape=op.output_shape,
4593
+ write_indices_shape=op.write_indices_shape,
4594
+ axis=op.axis,
4595
+ mode=op.mode,
4596
+ dtype=op.dtype,
4597
+ write_indices_dtype=op.write_indices_dtype,
4598
+ )
5304
4599
  if isinstance(op, ConcatOp):
5305
4600
  return ConcatOp(
5306
4601
  inputs=tuple(temp_map.get(name, name) for name in op.inputs),
@@ -5349,6 +4644,33 @@ class CEmitter:
5349
4644
  dtype=op.dtype,
5350
4645
  input_dtype=op.input_dtype,
5351
4646
  )
4647
+ if isinstance(op, NonMaxSuppressionOp):
4648
+ return NonMaxSuppressionOp(
4649
+ boxes=temp_map.get(op.boxes, op.boxes),
4650
+ scores=temp_map.get(op.scores, op.scores),
4651
+ max_output_boxes_per_class=CEmitter._map_optional_name(
4652
+ temp_map, op.max_output_boxes_per_class
4653
+ ),
4654
+ iou_threshold=CEmitter._map_optional_name(
4655
+ temp_map, op.iou_threshold
4656
+ ),
4657
+ score_threshold=CEmitter._map_optional_name(
4658
+ temp_map, op.score_threshold
4659
+ ),
4660
+ output=temp_map.get(op.output, op.output),
4661
+ boxes_shape=op.boxes_shape,
4662
+ scores_shape=op.scores_shape,
4663
+ output_shape=op.output_shape,
4664
+ center_point_box=op.center_point_box,
4665
+ boxes_dtype=op.boxes_dtype,
4666
+ output_dtype=op.output_dtype,
4667
+ max_output_dtype=op.max_output_dtype,
4668
+ max_output_shape=op.max_output_shape,
4669
+ iou_threshold_dtype=op.iou_threshold_dtype,
4670
+ iou_threshold_shape=op.iou_threshold_shape,
4671
+ score_threshold_dtype=op.score_threshold_dtype,
4672
+ score_threshold_shape=op.score_threshold_shape,
4673
+ )
5352
4674
  if isinstance(op, ExpandOp):
5353
4675
  return ExpandOp(
5354
4676
  input0=temp_map.get(op.input0, op.input0),
@@ -5673,67 +4995,98 @@ class CEmitter:
5673
4995
  dtype=op.dtype,
5674
4996
  )
5675
4997
 
4998
+ def render_op(self, op: OpBase, ctx: EmitContext) -> str:
4999
+ if self._emit_state is None:
5000
+ raise CodegenError("Emitter state not initialized")
5001
+ state = self._emit_state
5002
+ dtype = self._op_output_dtype(op)
5003
+ templates = state.templates
5004
+ return self._render_op(
5005
+ state.model,
5006
+ op,
5007
+ ctx.op_index,
5008
+ array_suffix="",
5009
+ loop_vars=(),
5010
+ c_type=dtype.c_type,
5011
+ zero_literal=dtype.zero_literal,
5012
+ min_literal=dtype.min_literal,
5013
+ max_literal=dtype.max_literal,
5014
+ binary_template=templates["binary"],
5015
+ multi_input_template=templates["multi_input"],
5016
+ where_template=templates["where"],
5017
+ unary_template=templates["unary"],
5018
+ clip_template=templates["clip"],
5019
+ cast_template=templates["cast"],
5020
+ quantize_linear_template=templates["quantize_linear"],
5021
+ qlinear_matmul_template=templates["qlinear_matmul"],
5022
+ matmul_template=templates["matmul"],
5023
+ einsum_template=templates["einsum"],
5024
+ gemm_template=templates["gemm"],
5025
+ attention_template=templates["attention"],
5026
+ rotary_embedding_template=templates["rotary_embedding"],
5027
+ conv_template=templates["conv"],
5028
+ conv_transpose_template=templates["conv_transpose"],
5029
+ avg_pool_template=templates["avg_pool"],
5030
+ lp_pool_template=templates["lp_pool"],
5031
+ batch_norm_template=templates["batch_norm"],
5032
+ lp_norm_template=templates["lp_norm"],
5033
+ instance_norm_template=templates["instance_norm"],
5034
+ group_norm_template=templates["group_norm"],
5035
+ layer_norm_template=templates["layer_norm"],
5036
+ mean_variance_norm_template=templates["mean_variance_norm"],
5037
+ rms_norm_template=templates["rms_norm"],
5038
+ lrn_template=templates["lrn"],
5039
+ lstm_template=templates["lstm"],
5040
+ adagrad_template=templates["adagrad"],
5041
+ softmax_template=templates["softmax"],
5042
+ logsoftmax_template=templates["logsoftmax"],
5043
+ hardmax_template=templates["hardmax"],
5044
+ nllloss_template=templates["nllloss"],
5045
+ softmax_cross_entropy_loss_template=templates[
5046
+ "softmax_cross_entropy_loss"
5047
+ ],
5048
+ maxpool_template=templates["maxpool"],
5049
+ concat_template=templates["concat"],
5050
+ gather_elements_template=templates["gather_elements"],
5051
+ gather_template=templates["gather"],
5052
+ gather_nd_template=templates["gather_nd"],
5053
+ scatter_nd_template=templates["scatter_nd"],
5054
+ transpose_template=templates["transpose"],
5055
+ reshape_template=templates["reshape"],
5056
+ identity_template=templates["identity"],
5057
+ eye_like_template=templates["eye_like"],
5058
+ trilu_template=templates["trilu"],
5059
+ tile_template=templates["tile"],
5060
+ pad_template=templates["pad"],
5061
+ depth_to_space_template=templates["depth_to_space"],
5062
+ space_to_depth_template=templates["space_to_depth"],
5063
+ slice_template=templates["slice"],
5064
+ slice_dynamic_template=templates["slice_dynamic"],
5065
+ resize_template=templates["resize"],
5066
+ grid_sample_template=templates["grid_sample"],
5067
+ reduce_template=templates["reduce"],
5068
+ reduce_dynamic_template=templates["reduce_dynamic"],
5069
+ arg_reduce_template=templates["arg_reduce"],
5070
+ topk_template=templates["topk"],
5071
+ constant_of_shape_template=templates["constant_of_shape"],
5072
+ shape_template=templates["shape"],
5073
+ size_template=templates["size"],
5074
+ nonzero_template=templates["nonzero"],
5075
+ nonmax_suppression_template=templates["nonmax_suppression"],
5076
+ expand_template=templates["expand"],
5077
+ cumsum_template=templates["cumsum"],
5078
+ range_template=templates["range"],
5079
+ one_hot_template=templates["one_hot"],
5080
+ split_template=templates["split"],
5081
+ scalar_registry=state.scalar_registry,
5082
+ dim_args=state.dim_args,
5083
+ tensor_dim_names=state.tensor_dim_names,
5084
+ )
5085
+
5676
5086
  def _render_op(
5677
5087
  self,
5678
5088
  model: LoweredModel,
5679
- op: BinaryOp
5680
- | MultiInputBinaryOp
5681
- | WhereOp
5682
- | UnaryOp
5683
- | ClipOp
5684
- | CastOp
5685
- | QuantizeLinearOp
5686
- | MatMulOp
5687
- | EinsumOp
5688
- | GemmOp
5689
- | AttentionOp
5690
- | ConvOp
5691
- | ConvTransposeOp
5692
- | AveragePoolOp
5693
- | LpPoolOp
5694
- | BatchNormOp
5695
- | LpNormalizationOp
5696
- | InstanceNormalizationOp
5697
- | GroupNormalizationOp
5698
- | LayerNormalizationOp
5699
- | MeanVarianceNormalizationOp
5700
- | RMSNormalizationOp
5701
- | LrnOp
5702
- | LstmOp
5703
- | SoftmaxOp
5704
- | LogSoftmaxOp
5705
- | HardmaxOp
5706
- | NegativeLogLikelihoodLossOp
5707
- | SoftmaxCrossEntropyLossOp
5708
- | MaxPoolOp
5709
- | ConcatOp
5710
- | GatherElementsOp
5711
- | GatherOp
5712
- | GatherNDOp
5713
- | ScatterNDOp
5714
- | TransposeOp
5715
- | ReshapeOp
5716
- | IdentityOp
5717
- | EyeLikeOp
5718
- | TriluOp
5719
- | TileOp
5720
- | DepthToSpaceOp
5721
- | SpaceToDepthOp
5722
- | SliceOp
5723
- | ResizeOp
5724
- | GridSampleOp
5725
- | ReduceOp
5726
- | ArgReduceOp
5727
- | TopKOp
5728
- | ConstantOfShapeOp
5729
- | ShapeOp
5730
- | SizeOp
5731
- | NonZeroOp
5732
- | ExpandOp
5733
- | CumSumOp
5734
- | RangeOp
5735
- | OneHotOp
5736
- | SplitOp,
5089
+ op: OpBase,
5737
5090
  index: int,
5738
5091
  *,
5739
5092
  array_suffix: str,
@@ -5749,10 +5102,12 @@ class CEmitter:
5749
5102
  clip_template,
5750
5103
  cast_template,
5751
5104
  quantize_linear_template,
5105
+ qlinear_matmul_template,
5752
5106
  matmul_template,
5753
5107
  einsum_template,
5754
5108
  gemm_template,
5755
5109
  attention_template,
5110
+ rotary_embedding_template,
5756
5111
  conv_template,
5757
5112
  conv_transpose_template,
5758
5113
  avg_pool_template,
@@ -5766,6 +5121,7 @@ class CEmitter:
5766
5121
  rms_norm_template,
5767
5122
  lrn_template,
5768
5123
  lstm_template,
5124
+ adagrad_template,
5769
5125
  softmax_template,
5770
5126
  logsoftmax_template,
5771
5127
  hardmax_template,
@@ -5798,6 +5154,7 @@ class CEmitter:
5798
5154
  shape_template,
5799
5155
  size_template,
5800
5156
  nonzero_template,
5157
+ nonmax_suppression_template,
5801
5158
  expand_template,
5802
5159
  cumsum_template,
5803
5160
  range_template,
@@ -5819,6 +5176,11 @@ class CEmitter:
5819
5176
  return f"{node_comment}\n{_format_c_indentation(rendered)}"
5820
5177
 
5821
5178
  if isinstance(op, BinaryOp):
5179
+ input0_shape = self._ctx_shape(op.input0)
5180
+ input1_shape = self._ctx_shape(op.input1)
5181
+ output_shape = self._ctx_shape(op.output)
5182
+ input_dtype = self._ctx_dtype(op.input0)
5183
+ output_dtype = self._ctx_dtype(op.output)
5822
5184
  params = self._shared_param_map(
5823
5185
  [
5824
5186
  ("input0", op.input0),
@@ -5832,11 +5194,11 @@ class CEmitter:
5832
5194
  and op.function not in COMPARE_FUNCTIONS
5833
5195
  ):
5834
5196
  scalar_operator = self._scalar_function_name(
5835
- op.function, op.input_dtype, scalar_registry
5197
+ op.function, input_dtype, scalar_registry
5836
5198
  )
5837
5199
  op_spec = binary_op_symbol(
5838
5200
  op.function,
5839
- dtype=op.input_dtype,
5201
+ dtype=input_dtype,
5840
5202
  validate_attrs=False,
5841
5203
  )
5842
5204
  if op_spec is None:
@@ -5844,17 +5206,19 @@ class CEmitter:
5844
5206
  f"Unsupported binary operator for rendering: {op.function.value}"
5845
5207
  )
5846
5208
  output_dim_names = _dim_names_for(op.output)
5847
- shape = CEmitter._shape_dim_exprs(op.shape, output_dim_names)
5848
- loop_vars = CEmitter._loop_vars(op.shape)
5849
- output_suffix = self._param_array_suffix(op.shape, output_dim_names)
5209
+ shape = CEmitter._shape_dim_exprs(output_shape, output_dim_names)
5210
+ loop_vars = CEmitter._loop_vars(output_shape)
5211
+ output_suffix = self._param_array_suffix(
5212
+ output_shape, output_dim_names
5213
+ )
5850
5214
  input0_suffix = self._param_array_suffix(
5851
- op.input0_shape, _dim_names_for(op.input0)
5215
+ input0_shape, _dim_names_for(op.input0)
5852
5216
  )
5853
5217
  input1_suffix = self._param_array_suffix(
5854
- op.input1_shape, _dim_names_for(op.input1)
5218
+ input1_shape, _dim_names_for(op.input1)
5855
5219
  )
5856
- input_c_type = op.input_dtype.c_type
5857
- output_c_type = op.dtype.c_type
5220
+ input_c_type = input_dtype.c_type
5221
+ output_c_type = output_dtype.c_type
5858
5222
  param_decls = self._build_param_decls(
5859
5223
  [
5860
5224
  (params["input0"], input_c_type, input0_suffix, True),
@@ -5877,14 +5241,14 @@ class CEmitter:
5877
5241
  }
5878
5242
  left_expr = CEmitter._broadcast_index_expr(
5879
5243
  params["input0"],
5880
- op.input0_shape,
5881
- op.shape,
5244
+ input0_shape,
5245
+ output_shape,
5882
5246
  loop_vars,
5883
5247
  )
5884
5248
  right_expr = CEmitter._broadcast_index_expr(
5885
5249
  params["input1"],
5886
- op.input1_shape,
5887
- op.shape,
5250
+ input1_shape,
5251
+ output_shape,
5888
5252
  loop_vars,
5889
5253
  )
5890
5254
  operator_expr = None
@@ -5910,6 +5274,9 @@ class CEmitter:
5910
5274
  ).rstrip()
5911
5275
  return with_node_comment(rendered)
5912
5276
  if isinstance(op, MultiInputBinaryOp):
5277
+ output_shape = self._ctx_shape(op.output)
5278
+ input_dtype = self._ctx_dtype(op.inputs[0])
5279
+ output_dtype = self._ctx_dtype(op.output)
5913
5280
  params = self._shared_param_map(
5914
5281
  [
5915
5282
  *( (f"input{idx}", name) for idx, name in enumerate(op.inputs) ),
@@ -5923,11 +5290,11 @@ class CEmitter:
5923
5290
  and op.function != ScalarFunction.MEAN
5924
5291
  ):
5925
5292
  scalar_operator = self._scalar_function_name(
5926
- op.function, op.input_dtype, scalar_registry
5293
+ op.function, input_dtype, scalar_registry
5927
5294
  )
5928
5295
  op_spec = binary_op_symbol(
5929
5296
  op.function,
5930
- dtype=op.input_dtype,
5297
+ dtype=input_dtype,
5931
5298
  validate_attrs=False,
5932
5299
  )
5933
5300
  if op_spec is None:
@@ -5936,11 +5303,13 @@ class CEmitter:
5936
5303
  f"{op.function.value}"
5937
5304
  )
5938
5305
  output_dim_names = _dim_names_for(op.output)
5939
- shape = CEmitter._shape_dim_exprs(op.shape, output_dim_names)
5940
- loop_vars = CEmitter._loop_vars(op.shape)
5941
- array_suffix = self._param_array_suffix(op.shape, output_dim_names)
5942
- input_c_type = op.input_dtype.c_type
5943
- output_c_type = op.dtype.c_type
5306
+ shape = CEmitter._shape_dim_exprs(output_shape, output_dim_names)
5307
+ loop_vars = CEmitter._loop_vars(output_shape)
5308
+ array_suffix = self._param_array_suffix(
5309
+ output_shape, output_dim_names
5310
+ )
5311
+ input_c_type = input_dtype.c_type
5312
+ output_c_type = output_dtype.c_type
5944
5313
  input_names = [
5945
5314
  params[f"input{idx}"] for idx in range(len(op.inputs))
5946
5315
  ]
@@ -5999,6 +5368,11 @@ class CEmitter:
5999
5368
  ).rstrip()
6000
5369
  return with_node_comment(rendered)
6001
5370
  if isinstance(op, WhereOp):
5371
+ output_shape_raw = self._ctx_shape(op.output)
5372
+ condition_shape = self._ctx_shape(op.condition)
5373
+ x_shape = self._ctx_shape(op.input_x)
5374
+ y_shape = self._ctx_shape(op.input_y)
5375
+ output_dtype = self._ctx_dtype(op.output)
6002
5376
  params = self._shared_param_map(
6003
5377
  [
6004
5378
  ("condition", op.condition),
@@ -6009,32 +5383,32 @@ class CEmitter:
6009
5383
  )
6010
5384
  output_dim_names = _dim_names_for(op.output)
6011
5385
  output_shape = CEmitter._shape_dim_exprs(
6012
- op.output_shape, output_dim_names
5386
+ output_shape_raw, output_dim_names
6013
5387
  )
6014
- loop_vars = CEmitter._loop_vars(op.output_shape)
5388
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
6015
5389
  output_array_suffix = self._param_array_suffix(
6016
- op.output_shape, output_dim_names
5390
+ output_shape_raw, output_dim_names
6017
5391
  )
6018
5392
  condition_array_suffix = self._param_array_suffix(
6019
- op.condition_shape, _dim_names_for(op.condition)
5393
+ condition_shape, _dim_names_for(op.condition)
6020
5394
  )
6021
5395
  x_array_suffix = self._param_array_suffix(
6022
- op.x_shape, _dim_names_for(op.input_x)
5396
+ x_shape, _dim_names_for(op.input_x)
6023
5397
  )
6024
5398
  y_array_suffix = self._param_array_suffix(
6025
- op.y_shape, _dim_names_for(op.input_y)
5399
+ y_shape, _dim_names_for(op.input_y)
6026
5400
  )
6027
5401
  condition_expr = CEmitter._broadcast_index_expr(
6028
5402
  params["condition"],
6029
- op.condition_shape,
6030
- op.output_shape,
5403
+ condition_shape,
5404
+ output_shape_raw,
6031
5405
  loop_vars,
6032
5406
  )
6033
5407
  x_expr = CEmitter._broadcast_index_expr(
6034
- params["input_x"], op.x_shape, op.output_shape, loop_vars
5408
+ params["input_x"], x_shape, output_shape_raw, loop_vars
6035
5409
  )
6036
5410
  y_expr = CEmitter._broadcast_index_expr(
6037
- params["input_y"], op.y_shape, op.output_shape, loop_vars
5411
+ params["input_y"], y_shape, output_shape_raw, loop_vars
6038
5412
  )
6039
5413
  output_expr = f"{params['output']}" + "".join(
6040
5414
  f"[{var}]" for var in loop_vars
@@ -6047,11 +5421,11 @@ class CEmitter:
6047
5421
  condition_array_suffix,
6048
5422
  True,
6049
5423
  ),
6050
- (params["input_x"], op.dtype.c_type, x_array_suffix, True),
6051
- (params["input_y"], op.dtype.c_type, y_array_suffix, True),
5424
+ (params["input_x"], output_dtype.c_type, x_array_suffix, True),
5425
+ (params["input_y"], output_dtype.c_type, y_array_suffix, True),
6052
5426
  (
6053
5427
  params["output"],
6054
- op.dtype.c_type,
5428
+ output_dtype.c_type,
6055
5429
  output_array_suffix,
6056
5430
  False,
6057
5431
  ),
@@ -6074,8 +5448,8 @@ class CEmitter:
6074
5448
  x_expr=x_expr,
6075
5449
  y_expr=y_expr,
6076
5450
  output_expr=output_expr,
6077
- input_c_type=op.dtype.c_type,
6078
- output_c_type=op.dtype.c_type,
5451
+ input_c_type=output_dtype.c_type,
5452
+ output_c_type=output_dtype.c_type,
6079
5453
  condition_c_type=ScalarType.BOOL.c_type,
6080
5454
  dim_args=dim_args,
6081
5455
  params=param_decls,
@@ -6363,6 +5737,17 @@ class CEmitter:
6363
5737
  ).rstrip()
6364
5738
  return with_node_comment(rendered)
6365
5739
  if isinstance(op, AttentionOp):
5740
+ if scalar_registry is None:
5741
+ raise CodegenError(
5742
+ "Scalar function registry is required for Attention codegen."
5743
+ )
5744
+ max_fn = self._scalar_function_name(
5745
+ ScalarFunction.MAXIMUM, op.dtype, scalar_registry
5746
+ )
5747
+ if max_fn is None:
5748
+ raise CodegenError(
5749
+ "Failed to resolve scalar maximum function for Attention."
5750
+ )
6366
5751
  params = self._shared_param_map(
6367
5752
  [
6368
5753
  ("input_q", op.input_q),
@@ -6543,6 +5928,7 @@ class CEmitter:
6543
5928
  scale_literal=CEmitter._format_floating(op.scale, op.dtype),
6544
5929
  softcap_literal=CEmitter._format_floating(op.softcap, op.dtype),
6545
5930
  one_literal=CEmitter._format_literal(op.dtype, 1),
5931
+ max_fn=max_fn,
6546
5932
  exp_fn=CEmitter._math_fn(op.dtype, "expf", "exp"),
6547
5933
  tanh_fn=CEmitter._math_fn(op.dtype, "tanhf", "tanh"),
6548
5934
  is_causal=int(op.is_causal),
@@ -6580,9 +5966,74 @@ class CEmitter:
6580
5966
  input_past_value_suffix=input_past_value_suffix,
6581
5967
  input_nonpad_suffix=input_nonpad_suffix,
6582
5968
  output_suffix=output_suffix,
6583
- output_present_key_suffix=output_present_key_suffix,
6584
- output_present_value_suffix=output_present_value_suffix,
6585
- output_qk_matmul_suffix=output_qk_matmul_suffix,
5969
+ output_present_key_suffix=output_present_key_suffix,
5970
+ output_present_value_suffix=output_present_value_suffix,
5971
+ output_qk_matmul_suffix=output_qk_matmul_suffix,
5972
+ ).rstrip()
5973
+ return with_node_comment(rendered)
5974
+ if isinstance(op, RotaryEmbeddingOp):
5975
+ params = self._shared_param_map(
5976
+ [
5977
+ ("input0", op.input0),
5978
+ ("cos_cache", op.cos_cache),
5979
+ ("sin_cache", op.sin_cache),
5980
+ ("position_ids", op.position_ids),
5981
+ ("output", op.output),
5982
+ ]
5983
+ )
5984
+ input_suffix = self._param_array_suffix(
5985
+ op.input_shape, _dim_names_for(op.input0)
5986
+ )
5987
+ cos_suffix = self._param_array_suffix(op.cos_shape)
5988
+ sin_suffix = self._param_array_suffix(op.sin_shape)
5989
+ position_suffix = (
5990
+ self._param_array_suffix(op.position_ids_shape)
5991
+ if op.position_ids_shape is not None
5992
+ else ""
5993
+ )
5994
+ output_suffix = self._param_array_suffix(
5995
+ op.input_shape, _dim_names_for(op.output)
5996
+ )
5997
+ param_decls = self._build_param_decls(
5998
+ [
5999
+ (params["input0"], c_type, input_suffix, True),
6000
+ (params["cos_cache"], c_type, cos_suffix, True),
6001
+ (params["sin_cache"], c_type, sin_suffix, True),
6002
+ (
6003
+ params["position_ids"],
6004
+ op.position_ids_dtype.c_type,
6005
+ position_suffix,
6006
+ True,
6007
+ )
6008
+ if params["position_ids"]
6009
+ else (None, "", "", True),
6010
+ (params["output"], c_type, output_suffix, False),
6011
+ ]
6012
+ )
6013
+ rendered = rotary_embedding_template.render(
6014
+ model_name=model.name,
6015
+ op_name=op_name,
6016
+ input0=params["input0"],
6017
+ cos_cache=params["cos_cache"],
6018
+ sin_cache=params["sin_cache"],
6019
+ position_ids=params["position_ids"],
6020
+ output=params["output"],
6021
+ params=param_decls,
6022
+ c_type=c_type,
6023
+ input_suffix=input_suffix,
6024
+ cos_suffix=cos_suffix,
6025
+ sin_suffix=sin_suffix,
6026
+ position_suffix=position_suffix,
6027
+ output_suffix=output_suffix,
6028
+ batch=op.batch,
6029
+ seq_len=op.seq_len,
6030
+ num_heads=op.num_heads,
6031
+ head_size=op.head_size,
6032
+ rotary_dim=op.rotary_dim,
6033
+ rotary_dim_half=op.rotary_dim_half,
6034
+ input_rank=op.input_rank,
6035
+ interleaved=int(op.interleaved),
6036
+ has_position_ids=int(op.position_ids is not None),
6586
6037
  ).rstrip()
6587
6038
  return with_node_comment(rendered)
6588
6039
  if isinstance(op, ConvOp):
@@ -7432,15 +6883,142 @@ class CEmitter:
7432
6883
  activation_functions=activation_functions,
7433
6884
  ).rstrip()
7434
6885
  return with_node_comment(rendered)
6886
+ if isinstance(op, AdagradOp):
6887
+ params = self._shared_param_map(
6888
+ [
6889
+ ("rate", op.rate),
6890
+ ("timestep", op.timestep),
6891
+ *(
6892
+ (f"input{idx}", name)
6893
+ for idx, name in enumerate(op.inputs)
6894
+ ),
6895
+ *(
6896
+ (f"grad{idx}", name)
6897
+ for idx, name in enumerate(op.gradients)
6898
+ ),
6899
+ *(
6900
+ (f"acc{idx}", name)
6901
+ for idx, name in enumerate(op.accumulators)
6902
+ ),
6903
+ *(
6904
+ (f"output{idx}", name)
6905
+ for idx, name in enumerate(op.outputs)
6906
+ ),
6907
+ *(
6908
+ (f"acc_output{idx}", name)
6909
+ for idx, name in enumerate(op.accumulator_outputs)
6910
+ ),
6911
+ ]
6912
+ )
6913
+ rate_suffix = self._param_array_suffix(
6914
+ op.rate_shape, _dim_names_for(op.rate)
6915
+ )
6916
+ timestep_suffix = self._param_array_suffix(
6917
+ op.timestep_shape, _dim_names_for(op.timestep)
6918
+ )
6919
+ param_specs = [
6920
+ (params["rate"], op.rate_dtype.c_type, rate_suffix, True),
6921
+ (
6922
+ params["timestep"],
6923
+ op.timestep_dtype.c_type,
6924
+ timestep_suffix,
6925
+ True,
6926
+ ),
6927
+ ]
6928
+ tensor_specs = []
6929
+ for idx, shape in enumerate(op.output_shapes):
6930
+ input_suffix = self._param_array_suffix(
6931
+ op.tensor_shapes[idx], _dim_names_for(op.inputs[idx])
6932
+ )
6933
+ grad_suffix = self._param_array_suffix(
6934
+ op.tensor_shapes[idx], _dim_names_for(op.gradients[idx])
6935
+ )
6936
+ acc_suffix = self._param_array_suffix(
6937
+ op.tensor_shapes[idx], _dim_names_for(op.accumulators[idx])
6938
+ )
6939
+ output_suffix = self._param_array_suffix(
6940
+ op.output_shapes[idx], _dim_names_for(op.outputs[idx])
6941
+ )
6942
+ acc_output_suffix = self._param_array_suffix(
6943
+ op.output_shapes[idx],
6944
+ _dim_names_for(op.accumulator_outputs[idx]),
6945
+ )
6946
+ param_specs.extend(
6947
+ [
6948
+ (params[f"input{idx}"], c_type, input_suffix, True),
6949
+ (params[f"grad{idx}"], c_type, grad_suffix, True),
6950
+ (params[f"acc{idx}"], c_type, acc_suffix, True),
6951
+ (params[f"output{idx}"], c_type, output_suffix, False),
6952
+ (
6953
+ params[f"acc_output{idx}"],
6954
+ c_type,
6955
+ acc_output_suffix,
6956
+ False,
6957
+ ),
6958
+ ]
6959
+ )
6960
+ output_dim_names = _dim_names_for(op.outputs[idx])
6961
+ shape_exprs = CEmitter._shape_dim_exprs(
6962
+ shape, output_dim_names
6963
+ )
6964
+ loop_vars = CEmitter._loop_vars(shape)
6965
+ index_suffix = "".join(f"[{var}]" for var in loop_vars)
6966
+ tensor_specs.append(
6967
+ {
6968
+ "shape": shape_exprs,
6969
+ "loop_vars": loop_vars,
6970
+ "input_expr": f"{params[f'input{idx}']}{index_suffix}",
6971
+ "grad_expr": f"{params[f'grad{idx}']}{index_suffix}",
6972
+ "acc_expr": f"{params[f'acc{idx}']}{index_suffix}",
6973
+ "output_expr": f"{params[f'output{idx}']}{index_suffix}",
6974
+ "acc_output_expr": f"{params[f'acc_output{idx}']}{index_suffix}",
6975
+ }
6976
+ )
6977
+ param_decls = self._build_param_decls(param_specs)
6978
+ rendered = adagrad_template.render(
6979
+ model_name=model.name,
6980
+ op_name=op_name,
6981
+ rate=params["rate"],
6982
+ timestep=params["timestep"],
6983
+ params=param_decls,
6984
+ c_type=c_type,
6985
+ one_literal=CEmitter._format_literal(op.dtype, 1),
6986
+ decay_factor_literal=CEmitter._format_floating(
6987
+ op.decay_factor, op.dtype
6988
+ ),
6989
+ norm_coefficient_literal=CEmitter._format_floating(
6990
+ op.norm_coefficient, op.dtype
6991
+ ),
6992
+ epsilon_literal=CEmitter._format_floating(op.epsilon, op.dtype),
6993
+ sqrt_fn=CEmitter._math_fn(op.dtype, "sqrtf", "sqrt"),
6994
+ tensors=tensor_specs,
6995
+ ).rstrip()
6996
+ return with_node_comment(rendered)
7435
6997
  if isinstance(op, SoftmaxOp):
6998
+ if scalar_registry is None:
6999
+ raise CodegenError(
7000
+ "Scalar function registry is required for Softmax rendering."
7001
+ )
7002
+ output_shape = self._ctx_shape(op.output)
7003
+ output_dtype = self._ctx_dtype(op.output)
7004
+ outer = self._derived(op, "outer")
7005
+ axis_size = self._derived(op, "axis_size")
7006
+ inner = self._derived(op, "inner")
7007
+ max_fn = self._scalar_function_name(
7008
+ ScalarFunction.MAXIMUM, output_dtype, scalar_registry
7009
+ )
7010
+ if max_fn is None:
7011
+ raise CodegenError(
7012
+ "Failed to resolve scalar maximum function for Softmax."
7013
+ )
7436
7014
  params = self._shared_param_map(
7437
7015
  [("input0", op.input0), ("output", op.output)]
7438
7016
  )
7439
- array_suffix = self._param_array_suffix(op.shape)
7017
+ array_suffix = self._param_array_suffix(output_shape)
7440
7018
  param_decls = self._build_param_decls(
7441
7019
  [
7442
- (params["input0"], c_type, array_suffix, True),
7443
- (params["output"], c_type, array_suffix, False),
7020
+ (params["input0"], output_dtype.c_type, array_suffix, True),
7021
+ (params["output"], output_dtype.c_type, array_suffix, False),
7444
7022
  ]
7445
7023
  )
7446
7024
  rendered = softmax_template.render(
@@ -7449,23 +7027,40 @@ class CEmitter:
7449
7027
  input0=params["input0"],
7450
7028
  output=params["output"],
7451
7029
  params=param_decls,
7452
- c_type=c_type,
7030
+ c_type=output_dtype.c_type,
7453
7031
  array_suffix=array_suffix,
7454
- outer=op.outer,
7455
- axis_size=op.axis_size,
7456
- inner=op.inner,
7457
- exp_fn=CEmitter._math_fn(op.dtype, "expf", "exp"),
7032
+ outer=outer,
7033
+ axis_size=axis_size,
7034
+ inner=inner,
7035
+ max_fn=max_fn,
7036
+ exp_fn=CEmitter._math_fn(output_dtype, "expf", "exp"),
7458
7037
  ).rstrip()
7459
7038
  return with_node_comment(rendered)
7460
7039
  if isinstance(op, LogSoftmaxOp):
7040
+ if scalar_registry is None:
7041
+ raise CodegenError(
7042
+ "Scalar function registry is required for LogSoftmax rendering."
7043
+ )
7044
+ output_shape = self._ctx_shape(op.output)
7045
+ output_dtype = self._ctx_dtype(op.output)
7046
+ outer = self._derived(op, "outer")
7047
+ axis_size = self._derived(op, "axis_size")
7048
+ inner = self._derived(op, "inner")
7049
+ max_fn = self._scalar_function_name(
7050
+ ScalarFunction.MAXIMUM, output_dtype, scalar_registry
7051
+ )
7052
+ if max_fn is None:
7053
+ raise CodegenError(
7054
+ "Failed to resolve scalar maximum function for LogSoftmax."
7055
+ )
7461
7056
  params = self._shared_param_map(
7462
7057
  [("input0", op.input0), ("output", op.output)]
7463
7058
  )
7464
- array_suffix = self._param_array_suffix(op.shape)
7059
+ array_suffix = self._param_array_suffix(output_shape)
7465
7060
  param_decls = self._build_param_decls(
7466
7061
  [
7467
- (params["input0"], c_type, array_suffix, True),
7468
- (params["output"], c_type, array_suffix, False),
7062
+ (params["input0"], output_dtype.c_type, array_suffix, True),
7063
+ (params["output"], output_dtype.c_type, array_suffix, False),
7469
7064
  ]
7470
7065
  )
7471
7066
  rendered = logsoftmax_template.render(
@@ -7474,24 +7069,41 @@ class CEmitter:
7474
7069
  input0=params["input0"],
7475
7070
  output=params["output"],
7476
7071
  params=param_decls,
7477
- c_type=c_type,
7072
+ c_type=output_dtype.c_type,
7478
7073
  array_suffix=array_suffix,
7479
- outer=op.outer,
7480
- axis_size=op.axis_size,
7481
- inner=op.inner,
7482
- exp_fn=CEmitter._math_fn(op.dtype, "expf", "exp"),
7483
- log_fn=CEmitter._math_fn(op.dtype, "logf", "log"),
7074
+ outer=outer,
7075
+ axis_size=axis_size,
7076
+ inner=inner,
7077
+ max_fn=max_fn,
7078
+ exp_fn=CEmitter._math_fn(output_dtype, "expf", "exp"),
7079
+ log_fn=CEmitter._math_fn(output_dtype, "logf", "log"),
7484
7080
  ).rstrip()
7485
7081
  return with_node_comment(rendered)
7486
7082
  if isinstance(op, HardmaxOp):
7083
+ if scalar_registry is None:
7084
+ raise CodegenError(
7085
+ "Scalar function registry is required for Hardmax rendering."
7086
+ )
7087
+ output_shape = self._ctx_shape(op.output)
7088
+ output_dtype = self._ctx_dtype(op.output)
7089
+ outer = self._derived(op, "outer")
7090
+ axis_size = self._derived(op, "axis_size")
7091
+ inner = self._derived(op, "inner")
7092
+ max_fn = self._scalar_function_name(
7093
+ ScalarFunction.MAXIMUM, output_dtype, scalar_registry
7094
+ )
7095
+ if max_fn is None:
7096
+ raise CodegenError(
7097
+ "Failed to resolve scalar maximum function for Hardmax."
7098
+ )
7487
7099
  params = self._shared_param_map(
7488
7100
  [("input0", op.input0), ("output", op.output)]
7489
7101
  )
7490
- array_suffix = self._param_array_suffix(op.shape)
7102
+ array_suffix = self._param_array_suffix(output_shape)
7491
7103
  param_decls = self._build_param_decls(
7492
7104
  [
7493
- (params["input0"], c_type, array_suffix, True),
7494
- (params["output"], c_type, array_suffix, False),
7105
+ (params["input0"], output_dtype.c_type, array_suffix, True),
7106
+ (params["output"], output_dtype.c_type, array_suffix, False),
7495
7107
  ]
7496
7108
  )
7497
7109
  rendered = hardmax_template.render(
@@ -7500,13 +7112,14 @@ class CEmitter:
7500
7112
  input0=params["input0"],
7501
7113
  output=params["output"],
7502
7114
  params=param_decls,
7503
- c_type=c_type,
7115
+ c_type=output_dtype.c_type,
7504
7116
  array_suffix=array_suffix,
7505
- outer=op.outer,
7506
- axis_size=op.axis_size,
7507
- inner=op.inner,
7117
+ outer=outer,
7118
+ axis_size=axis_size,
7119
+ inner=inner,
7508
7120
  zero_literal=zero_literal,
7509
- one_literal=CEmitter._format_literal(op.dtype, 1),
7121
+ one_literal=CEmitter._format_literal(output_dtype, 1),
7122
+ max_fn=max_fn,
7510
7123
  ).rstrip()
7511
7124
  return with_node_comment(rendered)
7512
7125
  if isinstance(op, NegativeLogLikelihoodLossOp):
@@ -7576,6 +7189,17 @@ class CEmitter:
7576
7189
  if op.dtype in {ScalarType.F16, ScalarType.F32}
7577
7190
  else op.dtype
7578
7191
  )
7192
+ if scalar_registry is None:
7193
+ raise CodegenError(
7194
+ "Scalar function registry is required for SoftmaxCrossEntropyLoss."
7195
+ )
7196
+ max_fn = self._scalar_function_name(
7197
+ ScalarFunction.MAXIMUM, acc_dtype, scalar_registry
7198
+ )
7199
+ if max_fn is None:
7200
+ raise CodegenError(
7201
+ "Failed to resolve scalar maximum function for SoftmaxCrossEntropyLoss."
7202
+ )
7579
7203
  acc_type = acc_dtype.c_type
7580
7204
  acc_zero_literal = CEmitter._format_literal(acc_dtype, 0)
7581
7205
  acc_one_literal = CEmitter._format_literal(acc_dtype, 1)
@@ -7652,9 +7276,21 @@ class CEmitter:
7652
7276
  acc_one_literal=acc_one_literal,
7653
7277
  acc_exp_fn=acc_exp_fn,
7654
7278
  acc_log_fn=acc_log_fn,
7279
+ max_fn=max_fn,
7655
7280
  ).rstrip()
7656
7281
  return with_node_comment(rendered)
7657
7282
  if isinstance(op, MaxPoolOp):
7283
+ if scalar_registry is None:
7284
+ raise CodegenError(
7285
+ "Scalar function registry is required for MaxPool rendering."
7286
+ )
7287
+ max_fn = self._scalar_function_name(
7288
+ ScalarFunction.MAXIMUM, op.dtype, scalar_registry
7289
+ )
7290
+ if max_fn is None:
7291
+ raise CodegenError(
7292
+ "Failed to resolve scalar maximum function for MaxPool."
7293
+ )
7658
7294
  params = self._shared_param_map(
7659
7295
  [
7660
7296
  ("input0", op.input0),
@@ -7699,6 +7335,7 @@ class CEmitter:
7699
7335
  output_suffix=output_suffix,
7700
7336
  indices_suffix=indices_suffix,
7701
7337
  indices_c_type=indices_c_type,
7338
+ max_fn=max_fn,
7702
7339
  batch=op.batch,
7703
7340
  channels=op.channels,
7704
7341
  spatial_rank=op.spatial_rank,
@@ -8032,21 +7669,133 @@ class CEmitter:
8032
7669
  reduction=op.reduction,
8033
7670
  ).rstrip()
8034
7671
  return with_node_comment(rendered)
7672
+ if isinstance(op, TensorScatterOp):
7673
+ param_pairs = [
7674
+ ("past_cache", op.past_cache),
7675
+ ("update", op.update),
7676
+ ("output", op.output),
7677
+ ]
7678
+ if op.write_indices is not None:
7679
+ param_pairs.insert(2, ("write_indices", op.write_indices))
7680
+ params = self._shared_param_map(param_pairs)
7681
+ output_dim_names = _dim_names_for(op.output)
7682
+ update_dim_names = _dim_names_for(op.update)
7683
+ past_dim_names = _dim_names_for(op.past_cache)
7684
+ write_indices_dim_names = (
7685
+ _dim_names_for(op.write_indices) if op.write_indices else None
7686
+ )
7687
+ output_shape = CEmitter._shape_dim_exprs(
7688
+ op.output_shape, output_dim_names
7689
+ )
7690
+ update_shape = CEmitter._shape_dim_exprs(
7691
+ op.update_shape, update_dim_names
7692
+ )
7693
+ prefix_shape = output_shape[: op.axis]
7694
+ prefix_loop_vars = (
7695
+ CEmitter._loop_vars(op.output_shape[: op.axis])
7696
+ if op.output_shape[: op.axis]
7697
+ else ()
7698
+ )
7699
+ tail_shape = output_shape[op.axis + 1 :]
7700
+ tail_loop_vars = (
7701
+ tuple(
7702
+ f"t{index}"
7703
+ for index in range(len(op.output_shape[op.axis + 1 :]))
7704
+ )
7705
+ if op.output_shape[op.axis + 1 :]
7706
+ else ()
7707
+ )
7708
+ output_loop_vars = CEmitter._loop_vars(op.output_shape)
7709
+ sequence_loop_var = "seq"
7710
+ cache_index_var = "cache_index"
7711
+ write_index_var = "write_index"
7712
+ index_vars = (*prefix_loop_vars, cache_index_var, *tail_loop_vars)
7713
+ output_index_expr = f"{params['output']}" + "".join(
7714
+ f"[{var}]" for var in index_vars
7715
+ )
7716
+ update_index_vars = (
7717
+ *prefix_loop_vars,
7718
+ sequence_loop_var,
7719
+ *tail_loop_vars,
7720
+ )
7721
+ update_index_expr = f"{params['update']}" + "".join(
7722
+ f"[{var}]" for var in update_index_vars
7723
+ )
7724
+ past_suffix = self._param_array_suffix(
7725
+ op.past_cache_shape, past_dim_names
7726
+ )
7727
+ update_suffix = self._param_array_suffix(
7728
+ op.update_shape, update_dim_names
7729
+ )
7730
+ output_suffix = self._param_array_suffix(
7731
+ op.output_shape, output_dim_names
7732
+ )
7733
+ param_decls = [
7734
+ (params["past_cache"], c_type, past_suffix, True),
7735
+ (params["update"], c_type, update_suffix, True),
7736
+ ]
7737
+ if op.write_indices is not None and op.write_indices_dtype is not None:
7738
+ write_indices_suffix = self._param_array_suffix(
7739
+ op.write_indices_shape or (), write_indices_dim_names
7740
+ )
7741
+ param_decls.append(
7742
+ (
7743
+ params["write_indices"],
7744
+ op.write_indices_dtype.c_type,
7745
+ write_indices_suffix,
7746
+ True,
7747
+ )
7748
+ )
7749
+ param_decls.append((params["output"], c_type, output_suffix, False))
7750
+ param_decls_rendered = self._build_param_decls(param_decls)
7751
+ rendered = tensor_scatter_template.render(
7752
+ model_name=model.name,
7753
+ op_name=op_name,
7754
+ past_cache=params["past_cache"],
7755
+ update=params["update"],
7756
+ write_indices=(
7757
+ params.get("write_indices") if op.write_indices else None
7758
+ ),
7759
+ output=params["output"],
7760
+ params=param_decls_rendered,
7761
+ c_type=c_type,
7762
+ output_shape=output_shape,
7763
+ output_loop_vars=output_loop_vars,
7764
+ prefix_shape=prefix_shape,
7765
+ prefix_loop_vars=prefix_loop_vars,
7766
+ sequence_dim=update_shape[op.axis],
7767
+ sequence_loop_var=sequence_loop_var,
7768
+ tail_shape=tail_shape,
7769
+ tail_loop_vars=tail_loop_vars,
7770
+ output_index_expr=output_index_expr,
7771
+ update_index_expr=update_index_expr,
7772
+ max_sequence_length=output_shape[op.axis],
7773
+ write_indices_present=op.write_indices is not None,
7774
+ batch_index_var=prefix_loop_vars[0]
7775
+ if prefix_loop_vars
7776
+ else "0",
7777
+ write_index_var=write_index_var,
7778
+ cache_index_var=cache_index_var,
7779
+ circular=op.mode == "circular",
7780
+ ).rstrip()
7781
+ return with_node_comment(rendered)
8035
7782
  if isinstance(op, TransposeOp):
7783
+ input_shape = self._ctx_shape(op.input0)
7784
+ output_shape_raw = self._ctx_shape(op.output)
8036
7785
  params = self._shared_param_map(
8037
7786
  [("input0", op.input0), ("output", op.output)]
8038
7787
  )
8039
- output_shape = CEmitter._codegen_shape(op.output_shape)
7788
+ output_shape = CEmitter._codegen_shape(output_shape_raw)
8040
7789
  loop_vars = CEmitter._loop_vars(output_shape)
8041
7790
  output_suffix = self._param_array_suffix(output_shape)
8042
- input_suffix = self._param_array_suffix(op.input_shape)
7791
+ input_suffix = self._param_array_suffix(input_shape)
8043
7792
  param_decls = self._build_param_decls(
8044
7793
  [
8045
7794
  (params["input0"], c_type, input_suffix, True),
8046
7795
  (params["output"], c_type, output_suffix, False),
8047
7796
  ]
8048
7797
  )
8049
- if not op.input_shape:
7798
+ if not input_shape:
8050
7799
  input_indices = [loop_vars[0]]
8051
7800
  else:
8052
7801
  input_indices = [None] * len(op.perm)
@@ -8067,19 +7816,21 @@ class CEmitter:
8067
7816
  ).rstrip()
8068
7817
  return with_node_comment(rendered)
8069
7818
  if isinstance(op, ReshapeOp):
7819
+ input_shape = self._ctx_shape(op.input0)
7820
+ output_shape_raw = self._ctx_shape(op.output)
8070
7821
  params = self._shared_param_map(
8071
7822
  [("input0", op.input0), ("output", op.output)]
8072
7823
  )
8073
- input_suffix = self._param_array_suffix(op.input_shape)
8074
- output_shape = CEmitter._codegen_shape(op.output_shape)
8075
- output_suffix = self._param_array_suffix(op.output_shape)
7824
+ input_suffix = self._param_array_suffix(input_shape)
7825
+ output_shape = CEmitter._codegen_shape(output_shape_raw)
7826
+ output_suffix = self._param_array_suffix(output_shape_raw)
8076
7827
  param_decls = self._build_param_decls(
8077
7828
  [
8078
7829
  (params["input0"], c_type, input_suffix, True),
8079
7830
  (params["output"], c_type, output_suffix, False),
8080
7831
  ]
8081
7832
  )
8082
- loop_vars = CEmitter._loop_vars(op.output_shape)
7833
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
8083
7834
  rendered = reshape_template.render(
8084
7835
  model_name=model.name,
8085
7836
  op_name=op_name,
@@ -8089,20 +7840,27 @@ class CEmitter:
8089
7840
  c_type=c_type,
8090
7841
  input_suffix=input_suffix,
8091
7842
  output_suffix=output_suffix,
8092
- element_count=CEmitter._element_count(op.output_shape),
7843
+ element_count=CEmitter._element_count(output_shape_raw),
8093
7844
  output_shape=output_shape,
8094
7845
  loop_vars=loop_vars,
8095
7846
  ).rstrip()
8096
7847
  return with_node_comment(rendered)
8097
7848
  if isinstance(op, IdentityOp):
7849
+ output_shape_raw = self._ctx_shape(op.output)
8098
7850
  params = self._shared_param_map(
8099
7851
  [("input0", op.input0), ("output", op.output)]
8100
7852
  )
8101
7853
  output_dim_names = _dim_names_for(op.output)
8102
- shape = CEmitter._shape_dim_exprs(op.shape, output_dim_names)
8103
- loop_vars = CEmitter._loop_vars(op.shape)
8104
- output_suffix = self._param_array_suffix(shape, output_dim_names)
8105
- input_suffix = self._param_array_suffix(shape, _dim_names_for(op.input0))
7854
+ shape = CEmitter._shape_dim_exprs(
7855
+ output_shape_raw, output_dim_names
7856
+ )
7857
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
7858
+ output_suffix = self._param_array_suffix(
7859
+ output_shape_raw, output_dim_names
7860
+ )
7861
+ input_suffix = self._param_array_suffix(
7862
+ output_shape_raw, _dim_names_for(op.input0)
7863
+ )
8106
7864
  param_decls = self._build_param_decls(
8107
7865
  [
8108
7866
  (params["input0"], c_type, input_suffix, True),
@@ -8704,39 +8462,41 @@ class CEmitter:
8704
8462
  ).rstrip()
8705
8463
  return with_node_comment(rendered)
8706
8464
  if isinstance(op, ReduceOp) and op.axes_input is None:
8465
+ input_shape = self._ctx_shape(op.input0)
8466
+ output_shape_raw = self._ctx_shape(op.output)
8467
+ axes = self._derived(op, "axes")
8468
+ output_dtype = self._ctx_dtype(op.output)
8707
8469
  params = self._shared_param_map(
8708
8470
  [("input0", op.input0), ("output", op.output)]
8709
8471
  )
8710
- output_shape = CEmitter._codegen_shape(op.output_shape)
8472
+ output_shape = CEmitter._codegen_shape(output_shape_raw)
8711
8473
  output_loop_vars = CEmitter._loop_vars(output_shape)
8712
- if not op.input_shape:
8474
+ if not input_shape:
8713
8475
  reduce_loop_vars = ("r0",)
8714
8476
  reduce_dims = (1,)
8715
8477
  else:
8716
- reduce_loop_vars = tuple(
8717
- f"r{idx}" for idx in range(len(op.axes))
8718
- )
8719
- reduce_dims = tuple(op.input_shape[axis] for axis in op.axes)
8720
- if not op.input_shape:
8478
+ reduce_loop_vars = tuple(f"r{idx}" for idx in range(len(axes)))
8479
+ reduce_dims = tuple(input_shape[axis] for axis in axes)
8480
+ if not input_shape:
8721
8481
  input_indices = [reduce_loop_vars[0]]
8722
8482
  elif op.keepdims:
8723
8483
  input_indices = [
8724
- reduce_loop_vars[op.axes.index(axis)]
8725
- if axis in op.axes
8484
+ reduce_loop_vars[axes.index(axis)]
8485
+ if axis in axes
8726
8486
  else output_loop_vars[axis]
8727
- for axis in range(len(op.input_shape))
8487
+ for axis in range(len(input_shape))
8728
8488
  ]
8729
8489
  else:
8730
8490
  kept_axes = [
8731
8491
  axis
8732
- for axis in range(len(op.input_shape))
8733
- if axis not in op.axes
8492
+ for axis in range(len(input_shape))
8493
+ if axis not in axes
8734
8494
  ]
8735
8495
  input_indices = [
8736
- reduce_loop_vars[op.axes.index(axis)]
8737
- if axis in op.axes
8496
+ reduce_loop_vars[axes.index(axis)]
8497
+ if axis in axes
8738
8498
  else output_loop_vars[kept_axes.index(axis)]
8739
- for axis in range(len(op.input_shape))
8499
+ for axis in range(len(input_shape))
8740
8500
  ]
8741
8501
  input_index_expr = "".join(f"[{var}]" for var in input_indices)
8742
8502
  output_index_expr = "".join(
@@ -8748,16 +8508,16 @@ class CEmitter:
8748
8508
  final_expr = "acc"
8749
8509
  use_kahan = False
8750
8510
  kahan_value_expr = None
8751
- fabs_fn = CEmitter._math_fn(op.dtype, "fabsf", "fabs")
8752
- exp_fn = CEmitter._math_fn(op.dtype, "expf", "exp")
8753
- log_fn = CEmitter._math_fn(op.dtype, "logf", "log")
8754
- sqrt_fn = CEmitter._math_fn(op.dtype, "sqrtf", "sqrt")
8511
+ fabs_fn = CEmitter._math_fn(output_dtype, "fabsf", "fabs")
8512
+ exp_fn = CEmitter._math_fn(output_dtype, "expf", "exp")
8513
+ log_fn = CEmitter._math_fn(output_dtype, "logf", "log")
8514
+ sqrt_fn = CEmitter._math_fn(output_dtype, "sqrtf", "sqrt")
8755
8515
  if op.reduce_kind == "sum":
8756
8516
  init_literal = zero_literal
8757
8517
  update_expr = f"acc += {value_expr};"
8758
8518
  elif op.reduce_kind == "mean":
8759
8519
  count_literal = CEmitter._format_literal(
8760
- op.dtype, op.reduce_count
8520
+ output_dtype, op.reduce_count
8761
8521
  )
8762
8522
  init_literal = zero_literal
8763
8523
  update_expr = f"acc += {value_expr};"
@@ -8769,7 +8529,7 @@ class CEmitter:
8769
8529
  init_literal = max_literal
8770
8530
  update_expr = f"if ({value_expr} < acc) acc = {value_expr};"
8771
8531
  elif op.reduce_kind == "prod":
8772
- init_literal = CEmitter._format_literal(op.dtype, 1)
8532
+ init_literal = CEmitter._format_literal(output_dtype, 1)
8773
8533
  update_expr = f"acc *= {value_expr};"
8774
8534
  elif op.reduce_kind == "l1":
8775
8535
  init_literal = zero_literal
@@ -8793,7 +8553,7 @@ class CEmitter:
8793
8553
  raise CodegenError(
8794
8554
  f"Unsupported reduce kind {op.reduce_kind}"
8795
8555
  )
8796
- if op.dtype in {ScalarType.F16, ScalarType.F32} and op.reduce_kind in {
8556
+ if output_dtype in {ScalarType.F16, ScalarType.F32} and op.reduce_kind in {
8797
8557
  "sum",
8798
8558
  "mean",
8799
8559
  "logsum",
@@ -8811,8 +8571,8 @@ class CEmitter:
8811
8571
  kahan_value_expr = f"{value_expr} * {value_expr}"
8812
8572
  else:
8813
8573
  kahan_value_expr = value_expr
8814
- input_suffix = self._param_array_suffix(op.input_shape)
8815
- output_suffix = self._param_array_suffix(op.output_shape)
8574
+ input_suffix = self._param_array_suffix(input_shape)
8575
+ output_suffix = self._param_array_suffix(output_shape_raw)
8816
8576
  param_decls = self._build_param_decls(
8817
8577
  [
8818
8578
  (params["input0"], c_type, input_suffix, True),
@@ -8842,33 +8602,40 @@ class CEmitter:
8842
8602
  ).rstrip()
8843
8603
  return with_node_comment(rendered)
8844
8604
  if isinstance(op, ArgReduceOp):
8605
+ input_shape = self._ctx_shape(op.input0)
8606
+ output_shape_raw = self._ctx_shape(op.output)
8607
+ axis = self._derived(op, "axis")
8608
+ input_dtype = self._ctx_dtype(op.input0)
8609
+ output_dtype = self._ctx_dtype(op.output)
8845
8610
  params = self._shared_param_map(
8846
8611
  [("input0", op.input0), ("output", op.output)]
8847
8612
  )
8848
- output_shape = CEmitter._codegen_shape(op.output_shape)
8613
+ output_shape = CEmitter._codegen_shape(output_shape_raw)
8849
8614
  output_loop_vars = CEmitter._loop_vars(output_shape)
8850
8615
  reduce_var = "r0"
8851
- reduce_dim = op.input_shape[op.axis]
8616
+ reduce_dim = input_shape[axis]
8852
8617
  if op.keepdims:
8853
8618
  input_indices = [
8854
- reduce_var if axis == op.axis else output_loop_vars[axis]
8855
- for axis in range(len(op.input_shape))
8619
+ reduce_var
8620
+ if axis_index == axis
8621
+ else output_loop_vars[axis_index]
8622
+ for axis_index in range(len(input_shape))
8856
8623
  ]
8857
8624
  else:
8858
8625
  kept_axes = [
8859
- axis
8860
- for axis in range(len(op.input_shape))
8861
- if axis != op.axis
8626
+ axis_index
8627
+ for axis_index in range(len(input_shape))
8628
+ if axis_index != axis
8862
8629
  ]
8863
8630
  input_indices = [
8864
8631
  reduce_var
8865
- if axis == op.axis
8866
- else output_loop_vars[kept_axes.index(axis)]
8867
- for axis in range(len(op.input_shape))
8632
+ if axis_index == axis
8633
+ else output_loop_vars[kept_axes.index(axis_index)]
8634
+ for axis_index in range(len(input_shape))
8868
8635
  ]
8869
8636
  init_indices = [
8870
- "0" if axis == op.axis else input_indices[axis]
8871
- for axis in range(len(op.input_shape))
8637
+ "0" if axis_index == axis else input_indices[axis_index]
8638
+ for axis_index in range(len(input_shape))
8872
8639
  ]
8873
8640
  input_index_expr = "".join(f"[{var}]" for var in input_indices)
8874
8641
  init_index_expr = "".join(f"[{var}]" for var in init_indices)
@@ -8883,12 +8650,12 @@ class CEmitter:
8883
8650
  raise CodegenError(
8884
8651
  f"Unsupported arg reduce kind {op.reduce_kind}"
8885
8652
  )
8886
- input_suffix = self._param_array_suffix(op.input_shape)
8887
- output_suffix = self._param_array_suffix(op.output_shape)
8653
+ input_suffix = self._param_array_suffix(input_shape)
8654
+ output_suffix = self._param_array_suffix(output_shape_raw)
8888
8655
  param_decls = self._build_param_decls(
8889
8656
  [
8890
- (params["input0"], op.input_dtype.c_type, input_suffix, True),
8891
- (params["output"], op.output_dtype.c_type, output_suffix, False),
8657
+ (params["input0"], input_dtype.c_type, input_suffix, True),
8658
+ (params["output"], output_dtype.c_type, output_suffix, False),
8892
8659
  ]
8893
8660
  )
8894
8661
  rendered = arg_reduce_template.render(
@@ -8897,8 +8664,8 @@ class CEmitter:
8897
8664
  input0=params["input0"],
8898
8665
  output=params["output"],
8899
8666
  params=param_decls,
8900
- input_c_type=op.input_dtype.c_type,
8901
- output_c_type=op.output_dtype.c_type,
8667
+ input_c_type=input_dtype.c_type,
8668
+ output_c_type=output_dtype.c_type,
8902
8669
  input_suffix=input_suffix,
8903
8670
  output_suffix=output_suffix,
8904
8671
  output_shape=output_shape,
@@ -8913,6 +8680,11 @@ class CEmitter:
8913
8680
  ).rstrip()
8914
8681
  return with_node_comment(rendered)
8915
8682
  if isinstance(op, TopKOp):
8683
+ input_shape = self._ctx_shape(op.input0)
8684
+ output_shape_raw = self._ctx_shape(op.output_values)
8685
+ input_dtype = self._ctx_dtype(op.input0)
8686
+ output_values_dtype = self._ctx_dtype(op.output_values)
8687
+ output_indices_dtype = self._ctx_dtype(op.output_indices)
8916
8688
  params = self._shared_param_map(
8917
8689
  [
8918
8690
  ("input0", op.input0),
@@ -8920,7 +8692,7 @@ class CEmitter:
8920
8692
  ("output_indices", op.output_indices),
8921
8693
  ]
8922
8694
  )
8923
- output_shape = CEmitter._codegen_shape(op.output_shape)
8695
+ output_shape = CEmitter._codegen_shape(output_shape_raw)
8924
8696
  outer_shape = tuple(
8925
8697
  dim for axis, dim in enumerate(output_shape) if axis != op.axis
8926
8698
  )
@@ -8930,7 +8702,7 @@ class CEmitter:
8930
8702
  input_indices: list[str] = []
8931
8703
  output_indices: list[str] = []
8932
8704
  outer_index = 0
8933
- for axis in range(len(op.input_shape)):
8705
+ for axis in range(len(input_shape)):
8934
8706
  if axis == op.axis:
8935
8707
  input_indices.append(reduce_var)
8936
8708
  output_indices.append(k_var)
@@ -8945,20 +8717,20 @@ class CEmitter:
8945
8717
  if op.largest
8946
8718
  else "(a < b) || ((a == b) && (ai < bi))"
8947
8719
  )
8948
- input_suffix = self._param_array_suffix(op.input_shape)
8949
- output_suffix = self._param_array_suffix(op.output_shape)
8720
+ input_suffix = self._param_array_suffix(input_shape)
8721
+ output_suffix = self._param_array_suffix(output_shape_raw)
8950
8722
  param_decls = self._build_param_decls(
8951
8723
  [
8952
- (params["input0"], op.input_dtype.c_type, input_suffix, True),
8724
+ (params["input0"], input_dtype.c_type, input_suffix, True),
8953
8725
  (
8954
8726
  params["output_values"],
8955
- op.output_values_dtype.c_type,
8727
+ output_values_dtype.c_type,
8956
8728
  output_suffix,
8957
8729
  False,
8958
8730
  ),
8959
8731
  (
8960
8732
  params["output_indices"],
8961
- op.output_indices_dtype.c_type,
8733
+ output_indices_dtype.c_type,
8962
8734
  output_suffix,
8963
8735
  False,
8964
8736
  ),
@@ -9216,27 +8988,150 @@ class CEmitter:
9216
8988
  )
9217
8989
  param_decls = self._build_param_decls(
9218
8990
  [
9219
- (params["input0"], op.input_dtype.c_type, input_suffix, True),
9220
- (params["output"], c_type, output_suffix, False),
8991
+ (params["input0"], op.input_dtype.c_type, input_suffix, True),
8992
+ (params["output"], c_type, output_suffix, False),
8993
+ ]
8994
+ )
8995
+ input_expr = f"{params['input0']}" + "".join(
8996
+ f"[{var}]" for var in loop_vars
8997
+ )
8998
+ rendered = nonzero_template.render(
8999
+ model_name=model.name,
9000
+ op_name=op_name,
9001
+ input0=params["input0"],
9002
+ output=params["output"],
9003
+ params=param_decls,
9004
+ input_c_type=op.input_dtype.c_type,
9005
+ output_c_type=c_type,
9006
+ input_suffix=input_suffix,
9007
+ output_suffix=output_suffix,
9008
+ input_shape=input_shape,
9009
+ loop_vars=loop_vars,
9010
+ input_expr=input_expr,
9011
+ zero_literal=op.input_dtype.zero_literal,
9012
+ ).rstrip()
9013
+ return with_node_comment(rendered)
9014
+ if isinstance(op, NonMaxSuppressionOp):
9015
+ if scalar_registry is None:
9016
+ raise CodegenError(
9017
+ "Scalar function registry is required for NonMaxSuppression."
9018
+ )
9019
+ min_fn = self._scalar_function_name(
9020
+ ScalarFunction.MINIMUM, op.boxes_dtype, scalar_registry
9021
+ )
9022
+ max_fn = self._scalar_function_name(
9023
+ ScalarFunction.MAXIMUM, op.boxes_dtype, scalar_registry
9024
+ )
9025
+ if min_fn is None or max_fn is None:
9026
+ raise CodegenError(
9027
+ "Failed to resolve scalar min/max functions for NonMaxSuppression."
9028
+ )
9029
+ params = self._shared_param_map(
9030
+ [
9031
+ ("boxes", op.boxes),
9032
+ ("scores", op.scores),
9033
+ ("max_output_boxes_per_class", op.max_output_boxes_per_class),
9034
+ ("iou_threshold", op.iou_threshold),
9035
+ ("score_threshold", op.score_threshold),
9036
+ ("output", op.output),
9037
+ ]
9038
+ )
9039
+ boxes_suffix = self._param_array_suffix(
9040
+ op.boxes_shape, _dim_names_for(op.boxes)
9041
+ )
9042
+ scores_suffix = self._param_array_suffix(
9043
+ op.scores_shape, _dim_names_for(op.scores)
9044
+ )
9045
+ output_suffix = self._param_array_suffix(
9046
+ op.output_shape, _dim_names_for(op.output)
9047
+ )
9048
+ max_output_suffix = (
9049
+ self._param_array_suffix(
9050
+ op.max_output_shape,
9051
+ _dim_names_for(op.max_output_boxes_per_class or ""),
9052
+ )
9053
+ if op.max_output_shape is not None
9054
+ else ""
9055
+ )
9056
+ iou_threshold_suffix = (
9057
+ self._param_array_suffix(
9058
+ op.iou_threshold_shape,
9059
+ _dim_names_for(op.iou_threshold or ""),
9060
+ )
9061
+ if op.iou_threshold_shape is not None
9062
+ else ""
9063
+ )
9064
+ score_threshold_suffix = (
9065
+ self._param_array_suffix(
9066
+ op.score_threshold_shape,
9067
+ _dim_names_for(op.score_threshold or ""),
9068
+ )
9069
+ if op.score_threshold_shape is not None
9070
+ else ""
9071
+ )
9072
+ param_decls = self._build_param_decls(
9073
+ [
9074
+ (params["boxes"], op.boxes_dtype.c_type, boxes_suffix, True),
9075
+ (params["scores"], op.boxes_dtype.c_type, scores_suffix, True),
9076
+ (
9077
+ params["max_output_boxes_per_class"],
9078
+ op.max_output_dtype.c_type if op.max_output_dtype else "",
9079
+ max_output_suffix,
9080
+ True,
9081
+ )
9082
+ if params["max_output_boxes_per_class"]
9083
+ else (None, "", "", True),
9084
+ (
9085
+ params["iou_threshold"],
9086
+ (
9087
+ op.iou_threshold_dtype.c_type
9088
+ if op.iou_threshold_dtype
9089
+ else ""
9090
+ ),
9091
+ iou_threshold_suffix,
9092
+ True,
9093
+ )
9094
+ if params["iou_threshold"]
9095
+ else (None, "", "", True),
9096
+ (
9097
+ params["score_threshold"],
9098
+ (
9099
+ op.score_threshold_dtype.c_type
9100
+ if op.score_threshold_dtype
9101
+ else ""
9102
+ ),
9103
+ score_threshold_suffix,
9104
+ True,
9105
+ )
9106
+ if params["score_threshold"]
9107
+ else (None, "", "", True),
9108
+ (params["output"], op.output_dtype.c_type, output_suffix, False),
9221
9109
  ]
9222
9110
  )
9223
- input_expr = f"{params['input0']}" + "".join(
9224
- f"[{var}]" for var in loop_vars
9225
- )
9226
- rendered = nonzero_template.render(
9111
+ rendered = nonmax_suppression_template.render(
9227
9112
  model_name=model.name,
9228
9113
  op_name=op_name,
9229
- input0=params["input0"],
9114
+ boxes=params["boxes"],
9115
+ scores=params["scores"],
9116
+ max_output_boxes_per_class=params["max_output_boxes_per_class"],
9117
+ iou_threshold=params["iou_threshold"],
9118
+ score_threshold=params["score_threshold"],
9230
9119
  output=params["output"],
9231
9120
  params=param_decls,
9232
- input_c_type=op.input_dtype.c_type,
9233
- output_c_type=c_type,
9234
- input_suffix=input_suffix,
9235
- output_suffix=output_suffix,
9236
- input_shape=input_shape,
9237
- loop_vars=loop_vars,
9238
- input_expr=input_expr,
9239
- zero_literal=op.input_dtype.zero_literal,
9121
+ input_c_type=op.boxes_dtype.c_type,
9122
+ output_c_type=op.output_dtype.c_type,
9123
+ compute_type=op.boxes_dtype.c_type,
9124
+ output_capacity=op.output_shape[0],
9125
+ num_batches=op.boxes_shape[0],
9126
+ num_boxes=op.boxes_shape[1],
9127
+ num_classes=op.scores_shape[1],
9128
+ center_point_box=op.center_point_box,
9129
+ min_fn=min_fn,
9130
+ max_fn=max_fn,
9131
+ iou_threshold_default=op.boxes_dtype.zero_literal,
9132
+ score_threshold_default=op.boxes_dtype.zero_literal,
9133
+ score_threshold_enabled=op.score_threshold is not None,
9134
+ dim_args=dim_args,
9240
9135
  ).rstrip()
9241
9136
  return with_node_comment(rendered)
9242
9137
  if isinstance(op, ExpandOp):
@@ -9476,17 +9371,24 @@ class CEmitter:
9476
9371
  ).rstrip()
9477
9372
  return with_node_comment(rendered)
9478
9373
  if isinstance(op, CastOp):
9374
+ input_dtype = self._ctx_dtype(op.input0)
9375
+ output_dtype = self._ctx_dtype(op.output)
9376
+ output_shape_raw = self._ctx_shape(op.output)
9479
9377
  params = self._shared_param_map(
9480
9378
  [("input0", op.input0), ("output", op.output)]
9481
9379
  )
9482
9380
  output_dim_names = _dim_names_for(op.output)
9483
- shape = CEmitter._shape_dim_exprs(op.shape, output_dim_names)
9484
- loop_vars = CEmitter._loop_vars(op.shape)
9485
- array_suffix = self._param_array_suffix(op.shape, output_dim_names)
9381
+ shape = CEmitter._shape_dim_exprs(
9382
+ output_shape_raw, output_dim_names
9383
+ )
9384
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
9385
+ array_suffix = self._param_array_suffix(
9386
+ output_shape_raw, output_dim_names
9387
+ )
9486
9388
  param_decls = self._build_param_decls(
9487
9389
  [
9488
- (params["input0"], op.input_dtype.c_type, array_suffix, True),
9489
- (params["output"], op.dtype.c_type, array_suffix, False),
9390
+ (params["input0"], input_dtype.c_type, array_suffix, True),
9391
+ (params["output"], output_dtype.c_type, array_suffix, False),
9490
9392
  ]
9491
9393
  )
9492
9394
  rendered = cast_template.render(
@@ -9495,8 +9397,8 @@ class CEmitter:
9495
9397
  input0=params["input0"],
9496
9398
  output=params["output"],
9497
9399
  params=param_decls,
9498
- input_c_type=op.input_dtype.c_type,
9499
- output_c_type=op.dtype.c_type,
9400
+ input_c_type=input_dtype.c_type,
9401
+ output_c_type=output_dtype.c_type,
9500
9402
  array_suffix=array_suffix,
9501
9403
  shape=shape,
9502
9404
  loop_vars=loop_vars,
@@ -9504,6 +9406,10 @@ class CEmitter:
9504
9406
  ).rstrip()
9505
9407
  return with_node_comment(rendered)
9506
9408
  if isinstance(op, QuantizeLinearOp):
9409
+ if scalar_registry is None:
9410
+ raise CodegenError(
9411
+ "Scalar function registry is required for QuantizeLinear."
9412
+ )
9507
9413
  params = self._shared_param_map(
9508
9414
  [
9509
9415
  ("input0", op.input0),
@@ -9545,6 +9451,21 @@ class CEmitter:
9545
9451
  ]
9546
9452
  )
9547
9453
  compute_type = "double" if op.input_dtype == ScalarType.F64 else "float"
9454
+ compute_dtype = (
9455
+ ScalarType.F64
9456
+ if compute_type == "double"
9457
+ else ScalarType.F32
9458
+ )
9459
+ max_fn = self._scalar_function_name(
9460
+ ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
9461
+ )
9462
+ min_fn = self._scalar_function_name(
9463
+ ScalarFunction.MINIMUM, compute_dtype, scalar_registry
9464
+ )
9465
+ if max_fn is None or min_fn is None:
9466
+ raise CodegenError(
9467
+ "Failed to resolve scalar min/max functions for QuantizeLinear."
9468
+ )
9548
9469
  round_fn = CEmitter._math_fn(
9549
9470
  op.input_dtype, "nearbyintf", "nearbyint"
9550
9471
  )
@@ -9580,10 +9501,221 @@ class CEmitter:
9580
9501
  round_fn=round_fn,
9581
9502
  min_literal=op.dtype.min_literal,
9582
9503
  max_literal=op.dtype.max_literal,
9504
+ min_fn=min_fn,
9505
+ max_fn=max_fn,
9506
+ dim_args=dim_args,
9507
+ ).rstrip()
9508
+ return with_node_comment(rendered)
9509
+ if isinstance(op, QLinearMatMulOp):
9510
+ if scalar_registry is None:
9511
+ raise CodegenError(
9512
+ "Scalar function registry is required for QLinearMatMul."
9513
+ )
9514
+ params = self._shared_param_map(
9515
+ [
9516
+ ("input0", op.input0),
9517
+ ("input0_scale", op.input0_scale),
9518
+ ("input0_zero_point", op.input0_zero_point),
9519
+ ("input1", op.input1),
9520
+ ("input1_scale", op.input1_scale),
9521
+ ("input1_zero_point", op.input1_zero_point),
9522
+ ("output_scale", op.output_scale),
9523
+ ("output_zero_point", op.output_zero_point),
9524
+ ("output", op.output),
9525
+ ]
9526
+ )
9527
+ output_shape = CEmitter._codegen_shape(op.output_shape)
9528
+ output_loop_vars = CEmitter._loop_vars(output_shape)
9529
+ output_index_expr = f"{params['output']}" + "".join(
9530
+ f"[{var}]" for var in output_loop_vars
9531
+ )
9532
+ batch_rank = len(op.batch_shape)
9533
+ batch_vars = output_loop_vars[:batch_rank]
9534
+ if op.left_vector and op.right_vector:
9535
+ row_var = None
9536
+ col_var = None
9537
+ elif op.left_vector:
9538
+ row_var = None
9539
+ col_var = output_loop_vars[-1]
9540
+ elif op.right_vector:
9541
+ row_var = output_loop_vars[-1]
9542
+ col_var = None
9543
+ else:
9544
+ row_var = output_loop_vars[-2]
9545
+ col_var = output_loop_vars[-1]
9546
+ input0_index_expr, input1_index_expr = CEmitter._matmul_index_exprs(
9547
+ op,
9548
+ batch_vars,
9549
+ row_var,
9550
+ col_var,
9551
+ batch_rank,
9552
+ input0=params["input0"],
9553
+ input1=params["input1"],
9554
+ )
9555
+ input0_suffix = self._param_array_suffix(op.input0_shape)
9556
+ input1_suffix = self._param_array_suffix(op.input1_shape)
9557
+ input0_scale_suffix = self._param_array_suffix(
9558
+ op.input0_scale_shape
9559
+ )
9560
+ input1_scale_suffix = self._param_array_suffix(
9561
+ op.input1_scale_shape
9562
+ )
9563
+ output_scale_suffix = self._param_array_suffix(
9564
+ op.output_scale_shape
9565
+ )
9566
+ input0_zero_suffix = self._param_array_suffix(op.input0_zero_shape)
9567
+ input1_zero_suffix = self._param_array_suffix(op.input1_zero_shape)
9568
+ output_zero_suffix = self._param_array_suffix(op.output_zero_shape)
9569
+ output_suffix = self._param_array_suffix(op.output_shape)
9570
+ param_decls = self._build_param_decls(
9571
+ [
9572
+ (
9573
+ params["input0"],
9574
+ op.input0_dtype.c_type,
9575
+ input0_suffix,
9576
+ True,
9577
+ ),
9578
+ (
9579
+ params["input0_scale"],
9580
+ op.input0_scale_dtype.c_type,
9581
+ input0_scale_suffix,
9582
+ True,
9583
+ ),
9584
+ (
9585
+ params["input0_zero_point"],
9586
+ op.input0_dtype.c_type,
9587
+ input0_zero_suffix,
9588
+ True,
9589
+ ),
9590
+ (
9591
+ params["input1"],
9592
+ op.input1_dtype.c_type,
9593
+ input1_suffix,
9594
+ True,
9595
+ ),
9596
+ (
9597
+ params["input1_scale"],
9598
+ op.input1_scale_dtype.c_type,
9599
+ input1_scale_suffix,
9600
+ True,
9601
+ ),
9602
+ (
9603
+ params["input1_zero_point"],
9604
+ op.input1_dtype.c_type,
9605
+ input1_zero_suffix,
9606
+ True,
9607
+ ),
9608
+ (
9609
+ params["output_scale"],
9610
+ op.output_scale_dtype.c_type,
9611
+ output_scale_suffix,
9612
+ True,
9613
+ ),
9614
+ (
9615
+ params["output_zero_point"],
9616
+ op.dtype.c_type,
9617
+ output_zero_suffix,
9618
+ True,
9619
+ ),
9620
+ (
9621
+ params["output"],
9622
+ op.dtype.c_type,
9623
+ output_suffix,
9624
+ False,
9625
+ ),
9626
+ ]
9627
+ )
9628
+ compute_dtype = (
9629
+ ScalarType.F64
9630
+ if ScalarType.F64
9631
+ in {
9632
+ op.input0_scale_dtype,
9633
+ op.input1_scale_dtype,
9634
+ op.output_scale_dtype,
9635
+ }
9636
+ else ScalarType.F32
9637
+ )
9638
+ compute_type = (
9639
+ "double" if compute_dtype == ScalarType.F64 else "float"
9640
+ )
9641
+ max_fn = self._scalar_function_name(
9642
+ ScalarFunction.MAXIMUM, compute_dtype, scalar_registry
9643
+ )
9644
+ min_fn = self._scalar_function_name(
9645
+ ScalarFunction.MINIMUM, compute_dtype, scalar_registry
9646
+ )
9647
+ if max_fn is None or min_fn is None:
9648
+ raise CodegenError(
9649
+ "Failed to resolve scalar min/max functions for QLinearMatMul."
9650
+ )
9651
+ round_fn = CEmitter._math_fn(
9652
+ compute_dtype, "nearbyintf", "nearbyint"
9653
+ )
9654
+ scale_index = "0"
9655
+ rendered = qlinear_matmul_template.render(
9656
+ model_name=model.name,
9657
+ op_name=op_name,
9658
+ input0=params["input0"],
9659
+ input1=params["input1"],
9660
+ input0_scale=params["input0_scale"],
9661
+ input0_zero_point=params["input0_zero_point"],
9662
+ input1_scale=params["input1_scale"],
9663
+ input1_zero_point=params["input1_zero_point"],
9664
+ output_scale=params["output_scale"],
9665
+ output_zero_point=params["output_zero_point"],
9666
+ output=params["output"],
9667
+ params=param_decls,
9668
+ compute_type=compute_type,
9669
+ output_c_type=op.dtype.c_type,
9670
+ input0_index_expr=input0_index_expr,
9671
+ input1_index_expr=input1_index_expr,
9672
+ input0_scale_expr=f"{params['input0_scale']}[{scale_index}]",
9673
+ input1_scale_expr=f"{params['input1_scale']}[{scale_index}]",
9674
+ output_scale_expr=f"{params['output_scale']}[{scale_index}]",
9675
+ input0_zero_expr=f"{params['input0_zero_point']}[{scale_index}]",
9676
+ input1_zero_expr=f"{params['input1_zero_point']}[{scale_index}]",
9677
+ output_zero_expr=f"{params['output_zero_point']}[{scale_index}]",
9678
+ output_loop_vars=output_loop_vars,
9679
+ output_loop_bounds=output_shape,
9680
+ output_index_expr=output_index_expr,
9681
+ k=op.k,
9682
+ round_fn=round_fn,
9683
+ min_literal=op.dtype.min_literal,
9684
+ max_literal=op.dtype.max_literal,
9685
+ min_fn=min_fn,
9686
+ max_fn=max_fn,
9583
9687
  dim_args=dim_args,
9584
9688
  ).rstrip()
9585
9689
  return with_node_comment(rendered)
9586
9690
  if isinstance(op, ClipOp):
9691
+ if scalar_registry is None:
9692
+ raise CodegenError(
9693
+ "Scalar function registry is required for Clip rendering."
9694
+ )
9695
+ input_shape = self._ctx_shape(op.input0)
9696
+ output_shape_raw = self._ctx_shape(op.output)
9697
+ input_dtype = self._ctx_dtype(op.input0)
9698
+ output_dtype = self._ctx_dtype(op.output)
9699
+ min_shape = (
9700
+ self._ctx_shape(op.input_min)
9701
+ if op.input_min is not None
9702
+ else None
9703
+ )
9704
+ max_shape = (
9705
+ self._ctx_shape(op.input_max)
9706
+ if op.input_max is not None
9707
+ else None
9708
+ )
9709
+ min_fn = self._scalar_function_name(
9710
+ ScalarFunction.MINIMUM, input_dtype, scalar_registry
9711
+ )
9712
+ max_fn = self._scalar_function_name(
9713
+ ScalarFunction.MAXIMUM, input_dtype, scalar_registry
9714
+ )
9715
+ if min_fn is None or max_fn is None:
9716
+ raise CodegenError(
9717
+ "Failed to resolve scalar min/max functions for Clip."
9718
+ )
9587
9719
  params = self._shared_param_map(
9588
9720
  [
9589
9721
  ("input0", op.input0),
@@ -9594,61 +9726,61 @@ class CEmitter:
9594
9726
  )
9595
9727
  output_dim_names = _dim_names_for(op.output)
9596
9728
  output_shape = CEmitter._shape_dim_exprs(
9597
- op.output_shape, output_dim_names
9729
+ output_shape_raw, output_dim_names
9598
9730
  )
9599
- loop_vars = CEmitter._loop_vars(op.output_shape)
9731
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
9600
9732
  input_expr = CEmitter._broadcast_index_expr(
9601
9733
  params["input0"],
9602
- op.input_shape,
9603
- op.output_shape,
9734
+ input_shape,
9735
+ output_shape_raw,
9604
9736
  loop_vars,
9605
9737
  )
9606
9738
  min_expr = (
9607
9739
  CEmitter._broadcast_index_expr(
9608
9740
  params["input_min"],
9609
- op.min_shape,
9610
- op.output_shape,
9741
+ min_shape,
9742
+ output_shape_raw,
9611
9743
  loop_vars,
9612
9744
  )
9613
9745
  if op.input_min is not None
9614
- else op.dtype.min_literal
9746
+ else output_dtype.min_literal
9615
9747
  )
9616
9748
  max_expr = (
9617
9749
  CEmitter._broadcast_index_expr(
9618
9750
  params["input_max"],
9619
- op.max_shape,
9620
- op.output_shape,
9751
+ max_shape,
9752
+ output_shape_raw,
9621
9753
  loop_vars,
9622
9754
  )
9623
9755
  if op.input_max is not None
9624
- else op.dtype.max_literal
9756
+ else output_dtype.max_literal
9625
9757
  )
9626
9758
  input_suffix = self._param_array_suffix(
9627
- op.input_shape, _dim_names_for(op.input0)
9759
+ input_shape, _dim_names_for(op.input0)
9628
9760
  )
9629
9761
  min_suffix = (
9630
9762
  self._param_array_suffix(
9631
- op.min_shape, _dim_names_for(op.input_min)
9763
+ min_shape, _dim_names_for(op.input_min)
9632
9764
  )
9633
- if op.min_shape is not None
9765
+ if min_shape is not None
9634
9766
  else ""
9635
9767
  )
9636
9768
  max_suffix = (
9637
9769
  self._param_array_suffix(
9638
- op.max_shape, _dim_names_for(op.input_max)
9770
+ max_shape, _dim_names_for(op.input_max)
9639
9771
  )
9640
- if op.max_shape is not None
9772
+ if max_shape is not None
9641
9773
  else ""
9642
9774
  )
9643
9775
  output_suffix = self._param_array_suffix(
9644
- op.output_shape, output_dim_names
9776
+ output_shape_raw, output_dim_names
9645
9777
  )
9646
9778
  param_decls = self._build_param_decls(
9647
9779
  [
9648
- (params["input0"], op.dtype.c_type, input_suffix, True),
9780
+ (params["input0"], input_dtype.c_type, input_suffix, True),
9649
9781
  (
9650
9782
  params["input_min"],
9651
- op.dtype.c_type,
9783
+ input_dtype.c_type,
9652
9784
  min_suffix,
9653
9785
  True,
9654
9786
  )
@@ -9656,13 +9788,13 @@ class CEmitter:
9656
9788
  else (None, "", "", True),
9657
9789
  (
9658
9790
  params["input_max"],
9659
- op.dtype.c_type,
9791
+ input_dtype.c_type,
9660
9792
  max_suffix,
9661
9793
  True,
9662
9794
  )
9663
9795
  if params["input_max"]
9664
9796
  else (None, "", "", True),
9665
- (params["output"], op.dtype.c_type, output_suffix, False),
9797
+ (params["output"], output_dtype.c_type, output_suffix, False),
9666
9798
  ]
9667
9799
  )
9668
9800
  rendered = clip_template.render(
@@ -9673,8 +9805,8 @@ class CEmitter:
9673
9805
  input_max=params["input_max"],
9674
9806
  output=params["output"],
9675
9807
  params=param_decls,
9676
- input_c_type=op.dtype.c_type,
9677
- output_c_type=op.dtype.c_type,
9808
+ input_c_type=input_dtype.c_type,
9809
+ output_c_type=output_dtype.c_type,
9678
9810
  input_suffix=input_suffix,
9679
9811
  min_suffix=min_suffix,
9680
9812
  max_suffix=max_suffix,
@@ -9684,30 +9816,51 @@ class CEmitter:
9684
9816
  input_expr=input_expr,
9685
9817
  min_expr=min_expr,
9686
9818
  max_expr=max_expr,
9819
+ min_fn=min_fn,
9820
+ max_fn=max_fn,
9687
9821
  dim_args=dim_args,
9688
9822
  ).rstrip()
9689
9823
  return with_node_comment(rendered)
9690
9824
  if isinstance(op, UnaryOp):
9825
+ input_dtype = self._ctx_dtype(op.input0)
9826
+ output_dtype = self._ctx_dtype(op.output)
9827
+ output_shape_raw = self._ctx_shape(op.output)
9691
9828
  params = self._shared_param_map(
9692
9829
  [("input0", op.input0), ("output", op.output)]
9693
9830
  )
9694
9831
  scalar_operator = None
9695
9832
  if scalar_registry is not None:
9696
9833
  scalar_operator = self._scalar_function_name(
9697
- op.function, op.dtype, scalar_registry, params=op.params
9834
+ op.function, input_dtype, scalar_registry, params=op.params
9698
9835
  )
9699
9836
  output_dim_names = _dim_names_for(op.output)
9700
- shape = CEmitter._shape_dim_exprs(op.shape, output_dim_names)
9701
- loop_vars = CEmitter._loop_vars(op.shape)
9702
- array_suffix = self._param_array_suffix(op.shape, output_dim_names)
9837
+ shape = CEmitter._shape_dim_exprs(
9838
+ output_shape_raw, output_dim_names
9839
+ )
9840
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
9841
+ array_suffix = self._param_array_suffix(
9842
+ output_shape_raw, output_dim_names
9843
+ )
9703
9844
  param_decls = self._build_param_decls(
9704
9845
  [
9705
- (params["input0"], op.input_dtype.c_type, array_suffix, True),
9706
- (params["output"], op.dtype.c_type, array_suffix, False),
9846
+ (params["input0"], input_dtype.c_type, array_suffix, True),
9847
+ (params["output"], output_dtype.c_type, array_suffix, False),
9707
9848
  ]
9708
9849
  )
9709
- operator_symbol = unary_op_symbol(op.function, dtype=op.dtype)
9710
- if op.function in {ScalarFunction.ISINF, ScalarFunction.ISNAN}:
9850
+ operator_symbol = unary_op_symbol(op.function, dtype=output_dtype)
9851
+ if op.function == ScalarFunction.ISINF and len(op.params) == 2:
9852
+ detect_negative, detect_positive = op.params
9853
+ detect_negative = int(detect_negative)
9854
+ detect_positive = int(detect_positive)
9855
+ if detect_negative and detect_positive:
9856
+ operator_symbol = "isinf"
9857
+ elif detect_negative:
9858
+ operator_symbol = "isneginf"
9859
+ elif detect_positive:
9860
+ operator_symbol = "isposinf"
9861
+ else:
9862
+ operator_symbol = "zero"
9863
+ elif op.function in {ScalarFunction.ISINF, ScalarFunction.ISNAN}:
9711
9864
  operator_symbol = (
9712
9865
  "isinf" if op.function == ScalarFunction.ISINF else "isnan"
9713
9866
  )
@@ -9722,8 +9875,8 @@ class CEmitter:
9722
9875
  "array_suffix": array_suffix,
9723
9876
  "shape": shape,
9724
9877
  "loop_vars": loop_vars,
9725
- "input_c_type": op.input_dtype.c_type,
9726
- "output_c_type": op.dtype.c_type,
9878
+ "input_c_type": input_dtype.c_type,
9879
+ "output_c_type": output_dtype.c_type,
9727
9880
  "zero_literal": zero_literal,
9728
9881
  "dim_args": dim_args,
9729
9882
  "params": param_decls,
@@ -9774,6 +9927,7 @@ class CEmitter:
9774
9927
  | GatherOp
9775
9928
  | GatherNDOp
9776
9929
  | ScatterNDOp
9930
+ | TensorScatterOp
9777
9931
  | TransposeOp
9778
9932
  | ReshapeOp
9779
9933
  | IdentityOp
@@ -9803,8 +9957,8 @@ class CEmitter:
9803
9957
  return op.output_values
9804
9958
  return op.output
9805
9959
 
9806
- @staticmethod
9807
9960
  def _op_inputs(
9961
+ self,
9808
9962
  op: BinaryOp
9809
9963
  | MultiInputBinaryOp
9810
9964
  | WhereOp
@@ -9840,6 +9994,7 @@ class CEmitter:
9840
9994
  | GatherOp
9841
9995
  | GatherNDOp
9842
9996
  | ScatterNDOp
9997
+ | TensorScatterOp
9843
9998
  | TransposeOp
9844
9999
  | ReshapeOp
9845
10000
  | IdentityOp
@@ -9865,18 +10020,24 @@ class CEmitter:
9865
10020
  ) -> tuple[tuple[str, tuple[int, ...]], ...]:
9866
10021
  if isinstance(op, BinaryOp):
9867
10022
  return (
9868
- (op.input0, op.input0_shape),
9869
- (op.input1, op.input1_shape),
10023
+ (op.input0, self._ctx_shape(op.input0)),
10024
+ (op.input1, self._ctx_shape(op.input1)),
9870
10025
  )
9871
10026
  if isinstance(op, MultiInputBinaryOp):
9872
- return tuple((name, op.shape) for name in op.inputs)
10027
+ return tuple((name, self._ctx_shape(name)) for name in op.inputs)
10028
+ if isinstance(op, WhereOp):
10029
+ return (
10030
+ (op.condition, self._ctx_shape(op.condition)),
10031
+ (op.input_x, self._ctx_shape(op.input_x)),
10032
+ (op.input_y, self._ctx_shape(op.input_y)),
10033
+ )
9873
10034
  if isinstance(op, EinsumOp):
9874
10035
  return tuple(
9875
10036
  (name, shape)
9876
10037
  for name, shape in zip(op.inputs, op.input_shapes)
9877
10038
  )
9878
10039
  if isinstance(op, UnaryOp):
9879
- return ((op.input0, op.shape),)
10040
+ return ((op.input0, self._ctx_shape(op.input0)),)
9880
10041
  if isinstance(op, LpNormalizationOp):
9881
10042
  return ((op.input0, op.shape),)
9882
10043
  if isinstance(op, InstanceNormalizationOp):
@@ -9901,32 +10062,57 @@ class CEmitter:
9901
10062
  if isinstance(op, RMSNormalizationOp):
9902
10063
  return ((op.input0, op.shape), (op.scale, op.scale_shape))
9903
10064
  if isinstance(op, ClipOp):
9904
- inputs = [(op.input0, op.input_shape)]
9905
- if op.input_min is not None and op.min_shape is not None:
9906
- inputs.append((op.input_min, op.min_shape))
9907
- if op.input_max is not None and op.max_shape is not None:
9908
- inputs.append((op.input_max, op.max_shape))
10065
+ inputs = [(op.input0, self._ctx_shape(op.input0))]
10066
+ if op.input_min is not None:
10067
+ inputs.append((op.input_min, self._ctx_shape(op.input_min)))
10068
+ if op.input_max is not None:
10069
+ inputs.append((op.input_max, self._ctx_shape(op.input_max)))
9909
10070
  return tuple(inputs)
9910
10071
  if isinstance(op, CastOp):
9911
- return ((op.input0, op.shape),)
10072
+ return ((op.input0, self._ctx_shape(op.input0)),)
9912
10073
  if isinstance(op, NonZeroOp):
9913
10074
  return ((op.input0, op.input_shape),)
10075
+ if isinstance(op, NonMaxSuppressionOp):
10076
+ inputs = [
10077
+ (op.boxes, op.boxes_shape),
10078
+ (op.scores, op.scores_shape),
10079
+ ]
10080
+ if (
10081
+ op.max_output_boxes_per_class is not None
10082
+ and op.max_output_shape is not None
10083
+ ):
10084
+ inputs.append(
10085
+ (op.max_output_boxes_per_class, op.max_output_shape)
10086
+ )
10087
+ if (
10088
+ op.iou_threshold is not None
10089
+ and op.iou_threshold_shape is not None
10090
+ ):
10091
+ inputs.append((op.iou_threshold, op.iou_threshold_shape))
10092
+ if (
10093
+ op.score_threshold is not None
10094
+ and op.score_threshold_shape is not None
10095
+ ):
10096
+ inputs.append(
10097
+ (op.score_threshold, op.score_threshold_shape)
10098
+ )
10099
+ return tuple(inputs)
9914
10100
  if isinstance(op, QuantizeLinearOp):
9915
10101
  scale_shape = (
9916
10102
  ()
9917
10103
  if op.axis is None
9918
- else (op.input_shape[op.axis],)
10104
+ else (self._ctx_shape(op.input0)[op.axis],)
9919
10105
  )
9920
- inputs = [(op.input0, op.input_shape), (op.scale, scale_shape)]
10106
+ inputs = [(op.input0, self._ctx_shape(op.input0)), (op.scale, scale_shape)]
9921
10107
  if op.zero_point is not None:
9922
10108
  inputs.append((op.zero_point, scale_shape))
9923
10109
  return tuple(inputs)
9924
10110
  if isinstance(op, IdentityOp):
9925
- return ((op.input0, op.shape),)
10111
+ return ((op.input0, self._ctx_shape(op.input0)),)
9926
10112
  if isinstance(op, EyeLikeOp):
9927
10113
  return ((op.input0, op.output_shape),)
9928
10114
  if isinstance(op, TriluOp):
9929
- inputs = [(op.input0, op.input_shape)]
10115
+ inputs = [(op.input0, self._ctx_shape(op.input0))]
9930
10116
  if op.k_input is not None and op.k_input_shape is not None:
9931
10117
  inputs.append((op.k_input, op.k_input_shape))
9932
10118
  return tuple(inputs)
@@ -9943,6 +10129,14 @@ class CEmitter:
9943
10129
  return tuple(inputs)
9944
10130
  if isinstance(op, ScatterNDOp):
9945
10131
  return ((op.data, op.data_shape),)
10132
+ if isinstance(op, TensorScatterOp):
10133
+ inputs = [
10134
+ (op.past_cache, op.past_cache_shape),
10135
+ (op.update, op.update_shape),
10136
+ ]
10137
+ if op.write_indices is not None and op.write_indices_shape is not None:
10138
+ inputs.append((op.write_indices, op.write_indices_shape))
10139
+ return tuple(inputs)
9946
10140
  if isinstance(op, CumSumOp):
9947
10141
  return ((op.input0, op.input_shape),)
9948
10142
  if isinstance(op, RangeOp):
@@ -9956,7 +10150,9 @@ class CEmitter:
9956
10150
  if isinstance(op, SplitOp):
9957
10151
  return ((op.input0, op.input_shape),)
9958
10152
  if isinstance(op, TopKOp):
9959
- return ((op.input0, op.input_shape),)
10153
+ return ((op.input0, self._ctx_shape(op.input0)),)
10154
+ if isinstance(op, (TransposeOp, ReshapeOp, ReduceOp, ArgReduceOp)):
10155
+ return ((op.input0, self._ctx_shape(op.input0)),)
9960
10156
  return ()
9961
10157
 
9962
10158
  def _propagate_tensor_dim_names(
@@ -10014,6 +10210,7 @@ class CEmitter:
10014
10210
  | ShapeOp
10015
10211
  | SizeOp
10016
10212
  | NonZeroOp
10213
+ | NonMaxSuppressionOp
10017
10214
  | ExpandOp
10018
10215
  | RangeOp
10019
10216
  | OneHotOp
@@ -10031,8 +10228,8 @@ class CEmitter:
10031
10228
  tensor_dim_names[output_name] = dict(dim_names)
10032
10229
  break
10033
10230
 
10034
- @staticmethod
10035
10231
  def _op_outputs(
10232
+ self,
10036
10233
  op: BinaryOp
10037
10234
  | MultiInputBinaryOp
10038
10235
  | WhereOp
@@ -10068,6 +10265,7 @@ class CEmitter:
10068
10265
  | GatherOp
10069
10266
  | GatherNDOp
10070
10267
  | ScatterNDOp
10268
+ | TensorScatterOp
10071
10269
  | TransposeOp
10072
10270
  | ReshapeOp
10073
10271
  | IdentityOp
@@ -10086,14 +10284,40 @@ class CEmitter:
10086
10284
  | ShapeOp
10087
10285
  | SizeOp
10088
10286
  | NonZeroOp
10287
+ | NonMaxSuppressionOp
10089
10288
  | ExpandOp
10090
10289
  | RangeOp
10091
10290
  | OneHotOp
10092
10291
  | SplitOp,
10093
- ) -> tuple[tuple[str, tuple[int, ...], str], ...]:
10292
+ ) -> tuple[tuple[str, tuple[int, ...], ScalarType], ...]:
10293
+ if isinstance(
10294
+ op,
10295
+ (
10296
+ BinaryOp,
10297
+ MultiInputBinaryOp,
10298
+ WhereOp,
10299
+ UnaryOp,
10300
+ ClipOp,
10301
+ CastOp,
10302
+ TransposeOp,
10303
+ ReshapeOp,
10304
+ IdentityOp,
10305
+ SoftmaxOp,
10306
+ LogSoftmaxOp,
10307
+ HardmaxOp,
10308
+ ReduceOp,
10309
+ ),
10310
+ ):
10311
+ return (
10312
+ (
10313
+ op.output,
10314
+ self._op_output_shape(op),
10315
+ self._ctx_dtype(op.output),
10316
+ ),
10317
+ )
10094
10318
  if isinstance(op, AttentionOp):
10095
- outputs: list[tuple[str, tuple[int, ...], str]] = [
10096
- (op.output, CEmitter._op_output_shape(op), op.dtype)
10319
+ outputs: list[tuple[str, tuple[int, ...], ScalarType]] = [
10320
+ (op.output, self._op_output_shape(op), op.dtype)
10097
10321
  ]
10098
10322
  if op.output_present_key is not None:
10099
10323
  outputs.append(
@@ -10121,7 +10345,7 @@ class CEmitter:
10121
10345
  )
10122
10346
  return tuple(outputs)
10123
10347
  if isinstance(op, LstmOp):
10124
- outputs: list[tuple[str, tuple[int, ...], str]] = []
10348
+ outputs: list[tuple[str, tuple[int, ...], ScalarType]] = []
10125
10349
  if op.output_y is not None:
10126
10350
  if op.layout == 0:
10127
10351
  y_shape = (
@@ -10155,13 +10379,25 @@ class CEmitter:
10155
10379
  )
10156
10380
  )
10157
10381
  return tuple(outputs)
10382
+ if isinstance(op, AdagradOp):
10383
+ outputs = [
10384
+ (name, shape, op.dtype)
10385
+ for name, shape in zip(op.outputs, op.output_shapes)
10386
+ ]
10387
+ outputs.extend(
10388
+ (name, shape, op.dtype)
10389
+ for name, shape in zip(
10390
+ op.accumulator_outputs, op.output_shapes
10391
+ )
10392
+ )
10393
+ return tuple(outputs)
10158
10394
  if isinstance(op, SoftmaxCrossEntropyLossOp):
10159
10395
  outputs = [(op.output, op.output_shape, op.dtype)]
10160
10396
  if op.log_prob is not None and op.log_prob_shape is not None:
10161
10397
  outputs.append((op.log_prob, op.log_prob_shape, op.dtype))
10162
10398
  return tuple(outputs)
10163
10399
  if isinstance(op, LayerNormalizationOp):
10164
- outputs: list[tuple[str, tuple[int, ...], str]] = [
10400
+ outputs: list[tuple[str, tuple[int, ...], ScalarType]] = [
10165
10401
  (op.output, op.shape, op.dtype)
10166
10402
  ]
10167
10403
  if op.mean_output is not None:
@@ -10172,10 +10408,10 @@ class CEmitter:
10172
10408
  outputs.append((op.invstd_output, invstd_shape, op.dtype))
10173
10409
  return tuple(outputs)
10174
10410
  if isinstance(op, MaxPoolOp):
10175
- outputs = [(op.output, CEmitter._op_output_shape(op), op.dtype)]
10411
+ outputs = [(op.output, self._op_output_shape(op), op.dtype)]
10176
10412
  if op.indices is not None and op.indices_dtype is not None:
10177
10413
  outputs.append(
10178
- (op.indices, CEmitter._op_output_shape(op), op.indices_dtype)
10414
+ (op.indices, self._op_output_shape(op), op.indices_dtype)
10179
10415
  )
10180
10416
  return tuple(outputs)
10181
10417
  if isinstance(op, SplitOp):
@@ -10184,30 +10420,40 @@ class CEmitter:
10184
10420
  for name, shape in zip(op.outputs, op.output_shapes)
10185
10421
  )
10186
10422
  if isinstance(op, ArgReduceOp):
10187
- return ((op.output, CEmitter._op_output_shape(op), op.output_dtype),)
10423
+ return (
10424
+ (
10425
+ op.output,
10426
+ self._op_output_shape(op),
10427
+ self._ctx_dtype(op.output),
10428
+ ),
10429
+ )
10188
10430
  if isinstance(op, TopKOp):
10189
10431
  return (
10190
10432
  (
10191
10433
  op.output_values,
10192
- CEmitter._op_output_shape(op),
10193
- op.output_values_dtype,
10434
+ self._op_output_shape(op),
10435
+ self._ctx_dtype(op.output_values),
10194
10436
  ),
10195
10437
  (
10196
10438
  op.output_indices,
10197
- CEmitter._op_output_shape(op),
10198
- op.output_indices_dtype,
10439
+ self._op_output_shape(op),
10440
+ self._ctx_dtype(op.output_indices),
10199
10441
  ),
10200
10442
  )
10201
- return ((op.output, CEmitter._op_output_shape(op), op.dtype),)
10443
+ if isinstance(op, NonMaxSuppressionOp):
10444
+ return ((op.output, op.output_shape, op.output_dtype),)
10445
+ return ((op.output, self._op_output_shape(op), op.dtype),)
10202
10446
 
10203
- @staticmethod
10204
10447
  def _op_output_shape(
10448
+ self,
10205
10449
  op: BinaryOp
10206
10450
  | MultiInputBinaryOp
10207
10451
  | WhereOp
10208
10452
  | UnaryOp
10209
10453
  | ClipOp
10210
10454
  | CastOp
10455
+ | QuantizeLinearOp
10456
+ | QLinearMatMulOp
10211
10457
  | MatMulOp
10212
10458
  | EinsumOp
10213
10459
  | GemmOp
@@ -10249,6 +10495,7 @@ class CEmitter:
10249
10495
  | ShapeOp
10250
10496
  | SizeOp
10251
10497
  | NonZeroOp
10498
+ | NonMaxSuppressionOp
10252
10499
  | ExpandOp
10253
10500
  | CumSumOp
10254
10501
  | RangeOp
@@ -10257,19 +10504,21 @@ class CEmitter:
10257
10504
  | PadOp,
10258
10505
  ) -> tuple[int, ...]:
10259
10506
  if isinstance(op, BinaryOp):
10260
- return op.shape
10507
+ return self._ctx_shape(op.output)
10261
10508
  if isinstance(op, MultiInputBinaryOp):
10262
- return op.shape
10509
+ return self._ctx_shape(op.output)
10263
10510
  if isinstance(op, WhereOp):
10264
- return op.output_shape
10511
+ return self._ctx_shape(op.output)
10265
10512
  if isinstance(op, UnaryOp):
10266
- return op.shape
10513
+ return self._ctx_shape(op.output)
10267
10514
  if isinstance(op, ClipOp):
10268
- return op.output_shape
10515
+ return self._ctx_shape(op.output)
10269
10516
  if isinstance(op, QuantizeLinearOp):
10270
10517
  return op.input_shape
10271
10518
  if isinstance(op, CastOp):
10272
- return op.shape
10519
+ return self._ctx_shape(op.output)
10520
+ if isinstance(op, QLinearMatMulOp):
10521
+ return op.output_shape
10273
10522
  if isinstance(op, MatMulOp):
10274
10523
  return op.output_shape
10275
10524
  if isinstance(op, EinsumOp):
@@ -10301,11 +10550,11 @@ class CEmitter:
10301
10550
  if isinstance(op, LrnOp):
10302
10551
  return op.shape
10303
10552
  if isinstance(op, SoftmaxOp):
10304
- return op.shape
10553
+ return self._ctx_shape(op.output)
10305
10554
  if isinstance(op, LogSoftmaxOp):
10306
- return op.shape
10555
+ return self._ctx_shape(op.output)
10307
10556
  if isinstance(op, HardmaxOp):
10308
- return op.shape
10557
+ return self._ctx_shape(op.output)
10309
10558
  if isinstance(op, NegativeLogLikelihoodLossOp):
10310
10559
  return op.output_shape
10311
10560
  if isinstance(op, SoftmaxCrossEntropyLossOp):
@@ -10322,12 +10571,14 @@ class CEmitter:
10322
10571
  return op.output_shape
10323
10572
  if isinstance(op, ScatterNDOp):
10324
10573
  return op.output_shape
10325
- if isinstance(op, TransposeOp):
10574
+ if isinstance(op, TensorScatterOp):
10326
10575
  return op.output_shape
10576
+ if isinstance(op, TransposeOp):
10577
+ return self._ctx_shape(op.output)
10327
10578
  if isinstance(op, ReshapeOp):
10328
- return op.output_shape
10579
+ return self._ctx_shape(op.output)
10329
10580
  if isinstance(op, IdentityOp):
10330
- return op.shape
10581
+ return self._ctx_shape(op.output)
10331
10582
  if isinstance(op, EyeLikeOp):
10332
10583
  return op.output_shape
10333
10584
  if isinstance(op, TriluOp):
@@ -10347,11 +10598,11 @@ class CEmitter:
10347
10598
  if isinstance(op, GridSampleOp):
10348
10599
  return op.output_shape
10349
10600
  if isinstance(op, ReduceOp):
10350
- return op.output_shape
10601
+ return self._ctx_shape(op.output)
10351
10602
  if isinstance(op, ArgReduceOp):
10352
- return op.output_shape
10603
+ return self._ctx_shape(op.output)
10353
10604
  if isinstance(op, TopKOp):
10354
- return op.output_shape
10605
+ return self._ctx_shape(op.output_values)
10355
10606
  if isinstance(op, ConstantOfShapeOp):
10356
10607
  return op.shape
10357
10608
  if isinstance(op, ShapeOp):
@@ -10360,6 +10611,8 @@ class CEmitter:
10360
10611
  return op.output_shape
10361
10612
  if isinstance(op, NonZeroOp):
10362
10613
  return op.output_shape
10614
+ if isinstance(op, NonMaxSuppressionOp):
10615
+ return op.output_shape
10363
10616
  if isinstance(op, ExpandOp):
10364
10617
  return op.output_shape
10365
10618
  if isinstance(op, CumSumOp):
@@ -10372,8 +10625,8 @@ class CEmitter:
10372
10625
  return (op.batch, op.q_seq, op.q_heads * op.v_head_size)
10373
10626
  return (op.batch, op.q_heads, op.q_seq, op.v_head_size)
10374
10627
 
10375
- @staticmethod
10376
10628
  def _op_output_dtype(
10629
+ self,
10377
10630
  op: BinaryOp
10378
10631
  | MultiInputBinaryOp
10379
10632
  | WhereOp
@@ -10399,6 +10652,7 @@ class CEmitter:
10399
10652
  | SoftmaxOp
10400
10653
  | LogSoftmaxOp
10401
10654
  | HardmaxOp
10655
+ | AdagradOp
10402
10656
  | NegativeLogLikelihoodLossOp
10403
10657
  | SoftmaxCrossEntropyLossOp
10404
10658
  | MaxPoolOp
@@ -10420,6 +10674,7 @@ class CEmitter:
10420
10674
  | ShapeOp
10421
10675
  | SizeOp
10422
10676
  | NonZeroOp
10677
+ | NonMaxSuppressionOp
10423
10678
  | ExpandOp
10424
10679
  | CumSumOp
10425
10680
  | RangeOp
@@ -10428,9 +10683,30 @@ class CEmitter:
10428
10683
  | PadOp,
10429
10684
  ) -> ScalarType:
10430
10685
  if isinstance(op, ArgReduceOp):
10431
- return op.output_dtype
10686
+ return self._ctx_dtype(op.output)
10432
10687
  if isinstance(op, TopKOp):
10433
- return op.output_values_dtype
10688
+ return self._ctx_dtype(op.output_values)
10689
+ if isinstance(op, NonMaxSuppressionOp):
10690
+ return op.output_dtype
10691
+ if isinstance(
10692
+ op,
10693
+ (
10694
+ BinaryOp,
10695
+ MultiInputBinaryOp,
10696
+ WhereOp,
10697
+ UnaryOp,
10698
+ ClipOp,
10699
+ CastOp,
10700
+ SoftmaxOp,
10701
+ LogSoftmaxOp,
10702
+ HardmaxOp,
10703
+ TransposeOp,
10704
+ ReshapeOp,
10705
+ IdentityOp,
10706
+ ReduceOp,
10707
+ ),
10708
+ ):
10709
+ return self._ctx_dtype(op.output)
10434
10710
  return op.dtype
10435
10711
 
10436
10712
  @staticmethod
@@ -10815,7 +11091,7 @@ class CEmitter:
10815
11091
  self, constants: tuple[ConstTensor, ...]
10816
11092
  ) -> tuple[tuple[ConstTensor, ...], tuple[ConstTensor, ...]]:
10817
11093
  if self._large_weight_threshold <= 0:
10818
- return (), constants
11094
+ return constants, ()
10819
11095
  inline: list[ConstTensor] = []
10820
11096
  large: list[ConstTensor] = []
10821
11097
  for const in constants: