emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -10,37 +10,48 @@ from ..errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..ir.model import Graph, Node
11
11
  from ..lowering.attention import resolve_attention_spec
12
12
  from ..lowering.average_pool import lower_average_pool, lower_global_average_pool
13
+ from ..lowering.adagrad import lower_adagrad
13
14
  from ..lowering.batch_normalization import lower_batch_normalization
14
15
  from ..lowering.concat import lower_concat
15
16
  from ..lowering.constant_of_shape import lower_constant_of_shape
16
17
  from ..lowering.conv import resolve_conv_spec
18
+ from ..lowering.conv_transpose import resolve_conv_transpose_spec
17
19
  from ..lowering.dropout import lower_dropout
18
20
  from ..lowering.cumsum import lower_cumsum
21
+ from ..lowering.einsum import lower_einsum
19
22
  from ..lowering.flatten import lower_flatten
20
23
  from ..lowering.gemm import resolve_gemm_spec
21
24
  from ..lowering.logsoftmax import lower_logsoftmax
25
+ from ..lowering.hardmax import lower_hardmax
22
26
  from ..lowering.lp_normalization import lower_lp_normalization
27
+ from ..lowering.lp_pool import lower_lp_pool
23
28
  from ..lowering.grid_sample import lower_grid_sample
24
29
  from ..lowering.instance_normalization import lower_instance_normalization
25
30
  from ..lowering.group_normalization import lower_group_normalization
26
31
  from ..lowering.layer_normalization import lower_layer_normalization
32
+ from ..lowering.non_max_suppression import lower_non_max_suppression
27
33
  from ..lowering.mean_variance_normalization import (
28
34
  lower_mean_variance_normalization,
29
35
  )
36
+ from ..lowering.global_max_pool import lower_global_max_pool
30
37
  from ..lowering.negative_log_likelihood_loss import (
31
38
  lower_negative_log_likelihood_loss,
32
39
  )
40
+ from ..lowering.nonzero import lower_nonzero
33
41
  from ..lowering.pad import lower_pad
34
42
  from ..lowering.expand import lower_expand
35
43
  from ..lowering.range import lower_range
44
+ from ..lowering.one_hot import lower_onehot
36
45
  from ..lowering.split import lower_split
37
46
  from ..lowering.softmax_cross_entropy_loss import (
38
47
  lower_softmax_cross_entropy_loss,
39
48
  )
40
49
  from ..lowering.arg_reduce import lower_arg_reduce
50
+ from ..lowering.topk import lower_topk
41
51
  from ..lowering.lstm import ACTIVATION_KIND_BY_NAME, resolve_lstm_spec
42
52
  from ..lowering.lrn import resolve_lrn_spec
43
53
  from ..lowering.matmul import lower_matmul
54
+ from ..lowering.qlinear_matmul import lower_qlinear_matmul
44
55
  from ..lowering.maxpool import resolve_maxpool_spec
45
56
  from ..lowering.reduce import (
46
57
  REDUCE_KIND_BY_OP,
@@ -49,15 +60,19 @@ from ..lowering.reduce import (
49
60
  resolve_reduce_axes,
50
61
  )
51
62
  from ..lowering.reshape import lower_reshape
63
+ from ..lowering.scatter_nd import lower_scatternd
64
+ from ..lowering.tensor_scatter import lower_tensor_scatter
52
65
  from ..lowering.slice import _normalize_slices
53
66
  from ..lowering.shape import lower_shape
54
67
  from ..lowering.size import lower_size
55
68
  from ..lowering.softmax import lower_softmax
56
69
  from ..lowering.rms_normalization import lower_rms_normalization
70
+ from ..lowering.rotary_embedding import lower_rotary_embedding
57
71
  from ..lowering.squeeze import lower_squeeze
58
72
  from ..lowering.transpose import lower_transpose
59
73
  from ..lowering.unsqueeze import lower_unsqueeze
60
74
  from ..lowering.where import lower_where
75
+ from ..lowering.quantize_linear import resolve_quantize_spec
61
76
  from ..lowering.variadic import BINARY_ONLY_OPS, VARIADIC_OP_FUNCTIONS
62
77
  from ..lowering.registry import resolve_dispatch
63
78
  from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
@@ -133,6 +148,52 @@ def _eval_matmul(evaluator: Evaluator, node: Node) -> None:
133
148
  evaluator.values[node.outputs[0]] = _apply_matmul(left, right)
134
149
 
135
150
 
151
+ @register_evaluator("Einsum")
152
+ def _eval_einsum(evaluator: Evaluator, node: Node) -> None:
153
+ lower_einsum(evaluator.graph, node)
154
+ equation_value = node.attrs.get("equation")
155
+ if equation_value is None:
156
+ raise UnsupportedOpError("Einsum equation attribute is required")
157
+ equation = (
158
+ equation_value.decode()
159
+ if isinstance(equation_value, (bytes, bytearray))
160
+ else str(equation_value)
161
+ )
162
+ inputs = [evaluator.values[name] for name in node.inputs]
163
+ evaluator.values[node.outputs[0]] = np.einsum(equation, *inputs)
164
+
165
+
166
+ @register_evaluator("Adagrad")
167
+ def _eval_adagrad(evaluator: Evaluator, node: Node) -> None:
168
+ op = lower_adagrad(evaluator.graph, node)
169
+ rate = evaluator.values[op.rate]
170
+ timestep = evaluator.values[op.timestep]
171
+ rate_value = (
172
+ np.array(rate, dtype=op.dtype.np_dtype).reshape(-1)[0].item()
173
+ )
174
+ timestep_value = (
175
+ np.array(timestep, dtype=np.int64).reshape(-1)[0].item()
176
+ )
177
+ r = op.dtype.np_dtype.type(
178
+ rate_value / (1.0 + float(timestep_value) * op.decay_factor)
179
+ )
180
+ for x_name, g_name, h_name, out_name, h_out_name in zip(
181
+ op.inputs,
182
+ op.gradients,
183
+ op.accumulators,
184
+ op.outputs,
185
+ op.accumulator_outputs,
186
+ ):
187
+ x = evaluator.values[x_name]
188
+ g = evaluator.values[g_name]
189
+ h = evaluator.values[h_name]
190
+ g_regularized = op.norm_coefficient * x + g
191
+ h_new = h + g_regularized * g_regularized
192
+ h_adaptive = np.sqrt(h_new) + op.epsilon
193
+ evaluator.values[out_name] = x - r * g_regularized / h_adaptive
194
+ evaluator.values[h_out_name] = h_new
195
+
196
+
136
197
  @register_evaluator("Clip")
137
198
  def _eval_clip(evaluator: Evaluator, node: Node) -> None:
138
199
  if not node.inputs or len(node.outputs) != 1:
@@ -163,6 +224,79 @@ def _eval_clip(evaluator: Evaluator, node: Node) -> None:
163
224
  evaluator.values[node.outputs[0]] = np.clip(x, min_val, max_val)
164
225
 
165
226
 
227
+ def _max_min(lhs: float, rhs: float) -> tuple[float, float]:
228
+ if lhs >= rhs:
229
+ return rhs, lhs
230
+ return lhs, rhs
231
+
232
+
233
+ def _suppress_by_iou(
234
+ boxes: np.ndarray,
235
+ box_index1: int,
236
+ box_index2: int,
237
+ *,
238
+ center_point_box: int,
239
+ iou_threshold: float,
240
+ ) -> bool:
241
+ box1 = boxes[box_index1]
242
+ box2 = boxes[box_index2]
243
+ if center_point_box == 0:
244
+ x1_min, x1_max = _max_min(float(box1[1]), float(box1[3]))
245
+ x2_min, x2_max = _max_min(float(box2[1]), float(box2[3]))
246
+ intersection_x_min = max(x1_min, x2_min)
247
+ intersection_x_max = min(x1_max, x2_max)
248
+ if intersection_x_max <= intersection_x_min:
249
+ return False
250
+
251
+ y1_min, y1_max = _max_min(float(box1[0]), float(box1[2]))
252
+ y2_min, y2_max = _max_min(float(box2[0]), float(box2[2]))
253
+ intersection_y_min = max(y1_min, y2_min)
254
+ intersection_y_max = min(y1_max, y2_max)
255
+ if intersection_y_max <= intersection_y_min:
256
+ return False
257
+ else:
258
+ box1_width_half = float(box1[2]) / 2.0
259
+ box1_height_half = float(box1[3]) / 2.0
260
+ box2_width_half = float(box2[2]) / 2.0
261
+ box2_height_half = float(box2[3]) / 2.0
262
+
263
+ x1_min = float(box1[0]) - box1_width_half
264
+ x1_max = float(box1[0]) + box1_width_half
265
+ x2_min = float(box2[0]) - box2_width_half
266
+ x2_max = float(box2[0]) + box2_width_half
267
+
268
+ y1_min = float(box1[1]) - box1_height_half
269
+ y1_max = float(box1[1]) + box1_height_half
270
+ y2_min = float(box2[1]) - box2_height_half
271
+ y2_max = float(box2[1]) + box2_height_half
272
+
273
+ intersection_x_min = max(x1_min, x2_min)
274
+ intersection_x_max = min(x1_max, x2_max)
275
+ if intersection_x_max <= intersection_x_min:
276
+ return False
277
+
278
+ intersection_y_min = max(y1_min, y2_min)
279
+ intersection_y_max = min(y1_max, y2_max)
280
+ if intersection_y_max <= intersection_y_min:
281
+ return False
282
+
283
+ intersection_area = (intersection_x_max - intersection_x_min) * (
284
+ intersection_y_max - intersection_y_min
285
+ )
286
+ if intersection_area <= 0:
287
+ return False
288
+
289
+ area1 = (x1_max - x1_min) * (y1_max - y1_min)
290
+ area2 = (x2_max - x2_min) * (y2_max - y2_min)
291
+ union_area = area1 + area2 - intersection_area
292
+
293
+ if area1 <= 0 or area2 <= 0 or union_area <= 0:
294
+ return False
295
+
296
+ intersection_over_union = intersection_area / union_area
297
+ return intersection_over_union > iou_threshold
298
+
299
+
166
300
  def _exclusive_cumsum(data: np.ndarray, axis: int) -> np.ndarray:
167
301
  result = np.zeros_like(data)
168
302
  if data.shape[axis] == 0:
@@ -197,6 +331,100 @@ def _eval_cumsum(evaluator: Evaluator, node: Node) -> None:
197
331
  evaluator.values[op.output] = result
198
332
 
199
333
 
334
+ @register_evaluator("NonMaxSuppression")
335
+ def _eval_nonmax_suppression(evaluator: Evaluator, node: Node) -> None:
336
+ op = lower_non_max_suppression(evaluator.graph, node)
337
+ boxes = evaluator.values[op.boxes]
338
+ scores = evaluator.values[op.scores]
339
+
340
+ max_output_boxes_per_class = 0
341
+ if op.max_output_boxes_per_class is not None:
342
+ max_output_values = evaluator.values[
343
+ op.max_output_boxes_per_class
344
+ ].astype(np.int64, copy=False)
345
+ max_output_values = max_output_values.reshape(-1)
346
+ if max_output_values.size != 1:
347
+ raise UnsupportedOpError(
348
+ "NonMaxSuppression max_output_boxes_per_class must be scalar"
349
+ )
350
+ max_output_boxes_per_class = max(int(max_output_values[0]), 0)
351
+
352
+ iou_threshold = 0.0
353
+ if op.iou_threshold is not None:
354
+ iou_values = evaluator.values[op.iou_threshold].reshape(-1)
355
+ if iou_values.size != 1:
356
+ raise UnsupportedOpError(
357
+ "NonMaxSuppression iou_threshold must be scalar"
358
+ )
359
+ iou_threshold = float(iou_values[0])
360
+
361
+ score_threshold = 0.0
362
+ score_threshold_enabled = op.score_threshold is not None
363
+ if op.score_threshold is not None:
364
+ score_values = evaluator.values[op.score_threshold].reshape(-1)
365
+ if score_values.size != 1:
366
+ raise UnsupportedOpError(
367
+ "NonMaxSuppression score_threshold must be scalar"
368
+ )
369
+ score_threshold = float(score_values[0])
370
+
371
+ if max_output_boxes_per_class == 0:
372
+ evaluator.values[op.output] = np.empty((0, 3), dtype=np.int64)
373
+ return
374
+
375
+ num_batches = boxes.shape[0]
376
+ num_boxes = boxes.shape[1]
377
+ num_classes = scores.shape[1]
378
+
379
+ selected_indices: list[tuple[int, int, int]] = []
380
+ for batch_index in range(num_batches):
381
+ batch_boxes = boxes[batch_index]
382
+ for class_index in range(num_classes):
383
+ class_scores = scores[batch_index, class_index]
384
+ candidates: list[tuple[float, int]] = []
385
+ if score_threshold_enabled:
386
+ for box_index in range(num_boxes):
387
+ score = float(class_scores[box_index])
388
+ if score > score_threshold:
389
+ candidates.append((score, box_index))
390
+ else:
391
+ for box_index in range(num_boxes):
392
+ candidates.append(
393
+ (float(class_scores[box_index]), box_index)
394
+ )
395
+ candidates.sort(key=lambda item: (item[0], -item[1]))
396
+ selected_boxes: list[int] = []
397
+ while (
398
+ candidates
399
+ and len(selected_boxes) < max_output_boxes_per_class
400
+ ):
401
+ _, box_index = candidates.pop()
402
+ if any(
403
+ _suppress_by_iou(
404
+ batch_boxes,
405
+ box_index,
406
+ selected_index,
407
+ center_point_box=op.center_point_box,
408
+ iou_threshold=iou_threshold,
409
+ )
410
+ for selected_index in selected_boxes
411
+ ):
412
+ continue
413
+ selected_boxes.append(box_index)
414
+ selected_indices.append(
415
+ (batch_index, class_index, box_index)
416
+ )
417
+
418
+ result = np.empty((len(selected_indices), 3), dtype=np.int64)
419
+ for idx, (batch_index, class_index, box_index) in enumerate(
420
+ selected_indices
421
+ ):
422
+ result[idx, 0] = batch_index
423
+ result[idx, 1] = class_index
424
+ result[idx, 2] = box_index
425
+ evaluator.values[op.output] = result
426
+
427
+
200
428
  @register_evaluator("Pad")
201
429
  def _eval_pad(evaluator: Evaluator, node: Node) -> None:
202
430
  op = lower_pad(evaluator.graph, node)
@@ -242,10 +470,11 @@ def _eval_pad(evaluator: Evaluator, node: Node) -> None:
242
470
  pads_begin = np.zeros(rank, dtype=np.int64)
243
471
  pads_end = np.zeros(rank, dtype=np.int64)
244
472
  for axis, pad_index in enumerate(op.pads_axis_map):
245
- if pad_index is None:
246
- continue
247
- pads_begin[axis] = int(pads_values[pad_index])
248
- pads_end[axis] = int(pads_values[pad_index + axis_count])
473
+ if pad_index is not None:
474
+ pads_begin[axis] = int(pads_values[pad_index])
475
+ pads_end[axis] = int(
476
+ pads_values[pad_index + axis_count]
477
+ )
249
478
  pad_width = tuple(
250
479
  (int(pads_begin[index]), int(pads_end[index]))
251
480
  for index in range(rank)
@@ -270,6 +499,82 @@ def _eval_pad(evaluator: Evaluator, node: Node) -> None:
270
499
  )
271
500
 
272
501
 
502
+ @register_evaluator("ScatterND")
503
+ def _eval_scatternd(evaluator: Evaluator, node: Node) -> None:
504
+ op = lower_scatternd(evaluator.graph, node)
505
+ data = evaluator.values[op.data]
506
+ indices = evaluator.values[op.indices]
507
+ updates = evaluator.values[op.updates]
508
+ output = np.array(data, copy=True)
509
+ index_depth = op.indices_shape[-1]
510
+ update_indices_shape = op.indices_shape[:-1]
511
+ update_count = int(np.prod(update_indices_shape)) if update_indices_shape else 1
512
+ flat_indices = indices.astype(np.int64, copy=False).reshape(
513
+ update_count, index_depth
514
+ )
515
+ tail_shape = op.data_shape[index_depth:]
516
+ updates_reshaped = updates.reshape((update_count,) + tail_shape)
517
+ for index, index_values in enumerate(flat_indices):
518
+ output_index: list[int | slice] = []
519
+ for axis, value in enumerate(index_values):
520
+ axis_size = op.data_shape[axis]
521
+ idx = int(value)
522
+ if idx < 0:
523
+ idx += axis_size
524
+ if idx < 0 or idx >= axis_size:
525
+ raise UnsupportedOpError(
526
+ "ScatterND indices must be within data bounds"
527
+ )
528
+ output_index.append(idx)
529
+ output_index.extend([slice(None)] * len(tail_shape))
530
+ target = tuple(output_index)
531
+ update_value = updates_reshaped[index]
532
+ if op.reduction == "none":
533
+ output[target] = update_value
534
+ elif op.reduction == "add":
535
+ output[target] = output[target] + update_value
536
+ elif op.reduction == "mul":
537
+ output[target] = output[target] * update_value
538
+ elif op.reduction == "min":
539
+ output[target] = np.minimum(output[target], update_value)
540
+ elif op.reduction == "max":
541
+ output[target] = np.maximum(output[target], update_value)
542
+ else:
543
+ raise UnsupportedOpError(
544
+ f"Unsupported ScatterND reduction {op.reduction}"
545
+ )
546
+ evaluator.values[op.output] = output
547
+
548
+
549
+ @register_evaluator("TensorScatter")
550
+ def _eval_tensor_scatter(evaluator: Evaluator, node: Node) -> None:
551
+ op = lower_tensor_scatter(evaluator.graph, node)
552
+ past_cache = evaluator.values[op.past_cache]
553
+ update = evaluator.values[op.update]
554
+ if op.write_indices is None:
555
+ write_indices = np.zeros((past_cache.shape[0],), dtype=np.int64)
556
+ else:
557
+ write_indices = evaluator.values[op.write_indices].astype(
558
+ np.int64, copy=False
559
+ )
560
+ axis = op.axis
561
+ max_sequence_length = past_cache.shape[axis]
562
+ sequence_length = update.shape[axis]
563
+ output = np.array(past_cache, copy=True)
564
+ for prefix_idx in np.ndindex(past_cache.shape[:axis]):
565
+ batch_idx = prefix_idx[0]
566
+ base_index = int(write_indices[batch_idx])
567
+ for sequence_idx in range(sequence_length):
568
+ cache_idx = (*prefix_idx, base_index + sequence_idx)
569
+ if op.mode == "circular":
570
+ cache_idx = tuple(
571
+ np.mod(np.asarray(cache_idx), max_sequence_length)
572
+ )
573
+ update_idx = (*prefix_idx, sequence_idx)
574
+ output[cache_idx] = update[update_idx]
575
+ evaluator.values[op.output] = output
576
+
577
+
273
578
  @register_evaluator("Celu")
274
579
  def _eval_celu(evaluator: Evaluator, node: Node) -> None:
275
580
  if len(node.inputs) != 1 or len(node.outputs) != 1:
@@ -676,8 +981,22 @@ def _eval_isinf(evaluator: Evaluator, node: Node) -> None:
676
981
  output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
677
982
  if output_dtype != ScalarType.BOOL:
678
983
  raise UnsupportedOpError("IsInf output must be bool")
984
+ detect_negative = int(node.attrs.get("detect_negative", 1))
985
+ detect_positive = int(node.attrs.get("detect_positive", 1))
986
+ if detect_negative not in {0, 1} or detect_positive not in {0, 1}:
987
+ raise UnsupportedOpError(
988
+ "IsInf detect_negative and detect_positive must be 0 or 1"
989
+ )
679
990
  x = evaluator.values[node.inputs[0]]
680
- evaluator.values[node.outputs[0]] = np.isinf(x)
991
+ if detect_negative and detect_positive:
992
+ result = np.isinf(x)
993
+ elif detect_negative:
994
+ result = np.isneginf(x)
995
+ elif detect_positive:
996
+ result = np.isposinf(x)
997
+ else:
998
+ result = np.zeros(x.shape, dtype=bool)
999
+ evaluator.values[node.outputs[0]] = result
681
1000
 
682
1001
 
683
1002
  @register_evaluator("IsNaN")
@@ -786,6 +1105,40 @@ def _eval_eye_like(evaluator: Evaluator, node: Node) -> None:
786
1105
  evaluator.values[node.outputs[0]] = output
787
1106
 
788
1107
 
1108
+ @register_evaluator("Trilu")
1109
+ def _eval_trilu(evaluator: Evaluator, node: Node) -> None:
1110
+ if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
1111
+ raise UnsupportedOpError("Trilu must have 1 or 2 inputs and 1 output")
1112
+ value = evaluator.values[node.inputs[0]]
1113
+ if value.ndim < 2:
1114
+ raise UnsupportedOpError("Trilu expects input rank >= 2")
1115
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1116
+ input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1117
+ if output_dtype != input_dtype:
1118
+ raise UnsupportedOpError(
1119
+ "Trilu expects matching input/output dtypes, "
1120
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
1121
+ )
1122
+ k = 0
1123
+ if len(node.inputs) == 2 and node.inputs[1]:
1124
+ k_value = np.array(evaluator.values[node.inputs[1]], dtype=np.int64)
1125
+ if k_value.size != 1:
1126
+ raise UnsupportedOpError("Trilu k input must be scalar")
1127
+ k = int(k_value.reshape(-1)[0])
1128
+ upper_attr = node.attrs.get("upper", 1)
1129
+ upper = bool(int(upper_attr))
1130
+ rows, cols = value.shape[-2], value.shape[-1]
1131
+ batch_shape = value.shape[:-2]
1132
+ batch_size = int(np.prod(batch_shape)) if batch_shape else 1
1133
+ view = value.reshape(batch_size, rows, cols)
1134
+ if upper:
1135
+ mask = np.triu(np.ones((rows, cols), dtype=bool), k=k)
1136
+ else:
1137
+ mask = np.tril(np.ones((rows, cols), dtype=bool), k=k)
1138
+ output = np.where(mask, view, np.zeros_like(view))
1139
+ evaluator.values[node.outputs[0]] = output.reshape(value.shape)
1140
+
1141
+
789
1142
  @register_evaluator("Tile")
790
1143
  def _eval_tile(evaluator: Evaluator, node: Node) -> None:
791
1144
  if len(node.inputs) != 2 or len(node.outputs) != 1:
@@ -922,6 +1275,73 @@ def _eval_gather(evaluator: Evaluator, node: Node) -> None:
922
1275
  evaluator.values[node.outputs[0]] = np.take(data, indices, axis=axis)
923
1276
 
924
1277
 
1278
+ @register_evaluator("GatherND")
1279
+ def _eval_gather_nd(evaluator: Evaluator, node: Node) -> None:
1280
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
1281
+ raise UnsupportedOpError("GatherND must have 2 inputs and 1 output")
1282
+ data = evaluator.values[node.inputs[0]]
1283
+ indices = evaluator.values[node.inputs[1]]
1284
+ if indices.dtype.type not in {np.int32, np.int64}:
1285
+ raise UnsupportedOpError(
1286
+ f"GatherND indices must be int32 or int64, got {indices.dtype}"
1287
+ )
1288
+ if indices.ndim < 1:
1289
+ raise UnsupportedOpError("GatherND indices must have rank >= 1")
1290
+ batch_dims = int(node.attrs.get("batch_dims", 0))
1291
+ if batch_dims < 0:
1292
+ raise UnsupportedOpError(
1293
+ f"GatherND batch_dims must be >= 0, got {batch_dims}"
1294
+ )
1295
+ if batch_dims > indices.ndim - 1:
1296
+ raise UnsupportedOpError(
1297
+ "GatherND batch_dims must be <= indices rank - 1, "
1298
+ f"got {batch_dims} vs {indices.ndim - 1}"
1299
+ )
1300
+ if batch_dims > data.ndim:
1301
+ raise UnsupportedOpError(
1302
+ "GatherND batch_dims must be <= data rank, "
1303
+ f"got {batch_dims} vs {data.ndim}"
1304
+ )
1305
+ if tuple(data.shape[:batch_dims]) != tuple(indices.shape[:batch_dims]):
1306
+ raise UnsupportedOpError(
1307
+ "GatherND batch_dims must match on data/indices, "
1308
+ f"got {data.shape} vs {indices.shape}"
1309
+ )
1310
+ index_depth = indices.shape[-1]
1311
+ if index_depth <= 0:
1312
+ raise UnsupportedOpError(
1313
+ "GatherND indices final dimension must be >= 1"
1314
+ )
1315
+ if index_depth > data.ndim - batch_dims:
1316
+ raise UnsupportedOpError(
1317
+ "GatherND indices final dimension must be <= data rank - "
1318
+ f"batch_dims, got {index_depth} vs {data.ndim - batch_dims}"
1319
+ )
1320
+ tail_shape = data.shape[batch_dims + index_depth :]
1321
+ output_shape = indices.shape[:-1] + tail_shape
1322
+ output = np.empty(output_shape, dtype=data.dtype)
1323
+ indices_prefix_shape = indices.shape[:-1]
1324
+ prefix_iter = (
1325
+ np.ndindex(*indices_prefix_shape) if indices_prefix_shape else [()]
1326
+ )
1327
+ for prefix in prefix_iter:
1328
+ raw_index = indices[prefix]
1329
+ if index_depth == 1:
1330
+ index_values = [int(np.asarray(raw_index).item())]
1331
+ else:
1332
+ index_values = [int(value) for value in raw_index]
1333
+ for dim_index, value in enumerate(index_values):
1334
+ if value < 0:
1335
+ index_values[dim_index] = value + data.shape[
1336
+ batch_dims + dim_index
1337
+ ]
1338
+ data_index = list(prefix[:batch_dims]) + index_values
1339
+ data_index.extend([slice(None)] * len(tail_shape))
1340
+ output_index = prefix + (slice(None),) * len(tail_shape)
1341
+ output[output_index] = data[tuple(data_index)]
1342
+ evaluator.values[node.outputs[0]] = output
1343
+
1344
+
925
1345
  @register_evaluator("Slice")
926
1346
  def _eval_slice(evaluator: Evaluator, node: Node) -> None:
927
1347
  input_value = evaluator.values[node.inputs[0]]
@@ -1010,6 +1430,49 @@ def _eval_attention(evaluator: Evaluator, node: Node) -> None:
1010
1430
  evaluator.values[qk_matmul_output_name] = qk_output
1011
1431
 
1012
1432
 
1433
+ @register_evaluator("RotaryEmbedding")
1434
+ def _eval_rotary_embedding(evaluator: Evaluator, node: Node) -> None:
1435
+ op = lower_rotary_embedding(evaluator.graph, node)
1436
+ x = evaluator.values[op.input0]
1437
+ cos_cache = evaluator.values[op.cos_cache]
1438
+ sin_cache = evaluator.values[op.sin_cache]
1439
+ position_ids = (
1440
+ evaluator.values[op.position_ids] if op.position_ids else None
1441
+ )
1442
+ original_shape = x.shape
1443
+ if op.input_rank == 4:
1444
+ x = np.transpose(x, (0, 2, 1, 3))
1445
+ else:
1446
+ x = x.reshape(op.batch, op.seq_len, op.num_heads, op.head_size)
1447
+ x_rotate = x[..., : op.rotary_dim]
1448
+ x_not_rotate = x[..., op.rotary_dim :]
1449
+ if position_ids is not None:
1450
+ cos_cache = cos_cache[position_ids]
1451
+ sin_cache = sin_cache[position_ids]
1452
+ cos_cache = np.expand_dims(cos_cache, axis=2)
1453
+ sin_cache = np.expand_dims(sin_cache, axis=2)
1454
+ if op.interleaved:
1455
+ x1 = x_rotate[..., 0::2]
1456
+ x2 = x_rotate[..., 1::2]
1457
+ else:
1458
+ x1, x2 = np.split(x_rotate, 2, axis=-1)
1459
+ real = (cos_cache * x1) - (sin_cache * x2)
1460
+ imag = (sin_cache * x1) + (cos_cache * x2)
1461
+ if op.interleaved:
1462
+ real = np.expand_dims(real, axis=-1)
1463
+ imag = np.expand_dims(imag, axis=-1)
1464
+ x_rotate_concat = np.concatenate((real, imag), axis=-1)
1465
+ x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
1466
+ else:
1467
+ x_rotate = np.concatenate((real, imag), axis=-1)
1468
+ output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
1469
+ if op.input_rank == 4:
1470
+ output = np.transpose(output, (0, 2, 1, 3))
1471
+ else:
1472
+ output = output.reshape(original_shape)
1473
+ evaluator.values[node.outputs[0]] = output
1474
+
1475
+
1013
1476
  def _apply_lstm_activation(
1014
1477
  kind: int, value: np.ndarray, alpha: float, beta: float
1015
1478
  ) -> np.ndarray:
@@ -1101,6 +1564,28 @@ def _eval_conv(evaluator: Evaluator, node: Node) -> None:
1101
1564
  evaluator.values[node.outputs[0]] = _apply_conv(spec, data, weights, bias)
1102
1565
 
1103
1566
 
1567
+ @register_evaluator("ConvTranspose")
1568
+ def _eval_conv_transpose(evaluator: Evaluator, node: Node) -> None:
1569
+ op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1570
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1571
+ if op_dtype != output_dtype:
1572
+ raise UnsupportedOpError(
1573
+ f"{node.op_type} expects matching input/output dtypes, "
1574
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1575
+ )
1576
+ if not op_dtype.is_float:
1577
+ raise UnsupportedOpError(
1578
+ "ConvTranspose supports float16, float, and double inputs only"
1579
+ )
1580
+ spec = resolve_conv_transpose_spec(evaluator.graph, node)
1581
+ data = evaluator.values[node.inputs[0]]
1582
+ weights = evaluator.values[node.inputs[1]]
1583
+ bias = evaluator.values[node.inputs[2]] if len(node.inputs) > 2 else None
1584
+ evaluator.values[node.outputs[0]] = _apply_conv_transpose(
1585
+ spec, data, weights, bias
1586
+ )
1587
+
1588
+
1104
1589
  @register_evaluator("BatchNormalization")
1105
1590
  def _eval_batch_norm(evaluator: Evaluator, node: Node) -> None:
1106
1591
  op = lower_batch_normalization(evaluator.graph, node)
@@ -1133,6 +1618,94 @@ def _eval_lp_normalization(evaluator: Evaluator, node: Node) -> None:
1133
1618
  evaluator.values[op.output] = data / denom
1134
1619
 
1135
1620
 
1621
+ @register_evaluator("LpPool")
1622
+ def _eval_lp_pool(evaluator: Evaluator, node: Node) -> None:
1623
+ op = lower_lp_pool(evaluator.graph, node)
1624
+ data = evaluator.values[op.input0]
1625
+ output = np.zeros(
1626
+ (op.batch, op.channels, op.out_h, op.out_w), dtype=data.dtype
1627
+ )
1628
+ for n in range(op.batch):
1629
+ for c in range(op.channels):
1630
+ for out_h in range(op.out_h):
1631
+ for out_w in range(op.out_w):
1632
+ h_start = out_h * op.stride_h - op.pad_top
1633
+ w_start = out_w * op.stride_w - op.pad_left
1634
+ acc = 0.0
1635
+ for kh in range(op.kernel_h):
1636
+ for kw in range(op.kernel_w):
1637
+ in_h = h_start + kh
1638
+ in_w = w_start + kw
1639
+ if (
1640
+ 0 <= in_h < op.in_h
1641
+ and 0 <= in_w < op.in_w
1642
+ ):
1643
+ value = data[(n, c, in_h, in_w)]
1644
+ acc += abs(value) ** op.p
1645
+ output[(n, c, out_h, out_w)] = acc ** (1.0 / op.p)
1646
+ evaluator.values[op.output] = output
1647
+
1648
+
1649
+ @register_evaluator("QuantizeLinear")
1650
+ def _eval_quantize_linear(evaluator: Evaluator, node: Node) -> None:
1651
+ spec = resolve_quantize_spec(evaluator.graph, node)
1652
+ data = evaluator.values[node.inputs[0]]
1653
+ scale = evaluator.values[node.inputs[1]]
1654
+ zero_point_name = optional_name(node.inputs, 2)
1655
+ if zero_point_name is None:
1656
+ zero_point = 0
1657
+ else:
1658
+ zero_point = evaluator.values[zero_point_name]
1659
+ if spec.axis is None:
1660
+ scaled = data / scale + zero_point
1661
+ else:
1662
+ shape = [1] * data.ndim
1663
+ shape[spec.axis] = scale.shape[0]
1664
+ scaled = data / scale.reshape(shape) + np.asarray(zero_point).reshape(
1665
+ shape
1666
+ )
1667
+ rounded = np.rint(scaled)
1668
+ info = np.iinfo(spec.output_dtype.np_dtype)
1669
+ clipped = np.clip(rounded, info.min, info.max)
1670
+ evaluator.values[node.outputs[0]] = clipped.astype(
1671
+ spec.output_dtype.np_dtype, copy=False
1672
+ )
1673
+
1674
+
1675
+ @register_evaluator("QLinearMatMul")
1676
+ def _eval_qlinear_matmul(evaluator: Evaluator, node: Node) -> None:
1677
+ op = lower_qlinear_matmul(evaluator.graph, node)
1678
+ input0 = evaluator.values[op.input0]
1679
+ input1 = evaluator.values[op.input1]
1680
+ input0_scale = evaluator.values[op.input0_scale]
1681
+ input1_scale = evaluator.values[op.input1_scale]
1682
+ output_scale = evaluator.values[op.output_scale]
1683
+ input0_zero_point = evaluator.values[op.input0_zero_point]
1684
+ input1_zero_point = evaluator.values[op.input1_zero_point]
1685
+ output_zero_point = evaluator.values[op.output_zero_point]
1686
+
1687
+ def _scalar_value(array: np.ndarray) -> float:
1688
+ return float(np.asarray(array).reshape(-1)[0])
1689
+
1690
+ def _scalar_int(array: np.ndarray) -> int:
1691
+ return int(np.asarray(array).reshape(-1)[0])
1692
+
1693
+ input0_zero = _scalar_int(input0_zero_point)
1694
+ input1_zero = _scalar_int(input1_zero_point)
1695
+ output_zero = _scalar_int(output_zero_point)
1696
+ scale = _scalar_value(input0_scale) * _scalar_value(
1697
+ input1_scale
1698
+ ) / _scalar_value(output_scale)
1699
+ acc = _apply_matmul(
1700
+ input0.astype(np.int32) - input0_zero,
1701
+ input1.astype(np.int32) - input1_zero,
1702
+ )
1703
+ scaled = acc.astype(np.float64) * scale + output_zero
1704
+ rounded = np.rint(scaled)
1705
+ info = np.iinfo(op.dtype.np_dtype)
1706
+ clipped = np.clip(rounded, info.min, info.max)
1707
+ evaluator.values[op.output] = clipped.astype(op.dtype.np_dtype)
1708
+
1136
1709
  @register_evaluator("InstanceNormalization")
1137
1710
  def _eval_instance_normalization(evaluator: Evaluator, node: Node) -> None:
1138
1711
  op = lower_instance_normalization(evaluator.graph, node)
@@ -1284,6 +1857,18 @@ def _eval_maxpool(evaluator: Evaluator, node: Node) -> None:
1284
1857
  evaluator.values[indices_output] = indices
1285
1858
 
1286
1859
 
1860
+ @register_evaluator("GlobalMaxPool")
1861
+ def _eval_global_max_pool(evaluator: Evaluator, node: Node) -> None:
1862
+ op = lower_global_max_pool(evaluator.graph, node)
1863
+ value = evaluator.values[node.inputs[0]]
1864
+ if not op.axes:
1865
+ evaluator.values[node.outputs[0]] = value.copy()
1866
+ return
1867
+ evaluator.values[node.outputs[0]] = np.max(
1868
+ value, axis=op.axes, keepdims=op.keepdims
1869
+ )
1870
+
1871
+
1287
1872
  @register_evaluator("Softmax")
1288
1873
  def _eval_softmax(evaluator: Evaluator, node: Node) -> None:
1289
1874
  op = lower_softmax(evaluator.graph, node)
@@ -1298,6 +1883,19 @@ def _eval_logsoftmax(evaluator: Evaluator, node: Node) -> None:
1298
1883
  evaluator.values[node.outputs[0]] = _apply_logsoftmax(value, op.axis)
1299
1884
 
1300
1885
 
1886
+ @register_evaluator("Hardmax")
1887
+ def _eval_hardmax(evaluator: Evaluator, node: Node) -> None:
1888
+ op = lower_hardmax(evaluator.graph, node)
1889
+ value = evaluator.values[node.inputs[0]]
1890
+ max_values = np.max(value, axis=op.axis, keepdims=True)
1891
+ is_max = value == max_values
1892
+ max_index = np.argmax(is_max, axis=op.axis)
1893
+ output = np.zeros_like(value)
1894
+ ones = np.array(1.0, dtype=value.dtype)
1895
+ np.put_along_axis(output, np.expand_dims(max_index, axis=op.axis), ones, axis=op.axis)
1896
+ evaluator.values[node.outputs[0]] = output
1897
+
1898
+
1301
1899
  @register_evaluator("NegativeLogLikelihoodLoss")
1302
1900
  def _eval_negative_log_likelihood_loss(
1303
1901
  evaluator: Evaluator, node: Node
@@ -1409,6 +2007,16 @@ def _eval_size(evaluator: Evaluator, node: Node) -> None:
1409
2007
  evaluator.values[op.output] = np.array(op.value, dtype=np.int64)
1410
2008
 
1411
2009
 
2010
+ @register_evaluator("NonZero")
2011
+ def _eval_nonzero(evaluator: Evaluator, node: Node) -> None:
2012
+ op = lower_nonzero(evaluator.graph, node)
2013
+ values = evaluator.values[op.input0]
2014
+ indices = np.nonzero(values)
2015
+ evaluator.values[op.output] = np.stack(indices, axis=0).astype(
2016
+ np.int64, copy=False
2017
+ )
2018
+
2019
+
1412
2020
  @register_evaluator("Expand")
1413
2021
  def _eval_expand(evaluator: Evaluator, node: Node) -> None:
1414
2022
  op = lower_expand(evaluator.graph, node)
@@ -1428,6 +2036,45 @@ def _eval_range(evaluator: Evaluator, node: Node) -> None:
1428
2036
  evaluator.values[op.output] = output
1429
2037
 
1430
2038
 
2039
+ @register_evaluator("OneHot")
2040
+ def _eval_onehot(evaluator: Evaluator, node: Node) -> None:
2041
+ op = lower_onehot(evaluator.graph, node)
2042
+ indices = evaluator.values[op.indices].astype(np.int64, copy=False)
2043
+ depth_values = evaluator.values[op.depth].reshape(-1)
2044
+ if depth_values.size != 1:
2045
+ raise UnsupportedOpError("OneHot depth input must be a scalar")
2046
+ depth_value = int(depth_values[0])
2047
+ if depth_value < 0:
2048
+ raise UnsupportedOpError("OneHot depth must be non-negative")
2049
+ values = evaluator.values[op.values].reshape(-1)
2050
+ if values.size != 2:
2051
+ raise UnsupportedOpError("OneHot values input must have 2 elements")
2052
+ off_value, on_value = values[0], values[1]
2053
+ if depth_value == 0:
2054
+ evaluator.values[op.output] = np.full(
2055
+ op.output_shape, off_value, dtype=values.dtype
2056
+ )
2057
+ return
2058
+ axis = op.axis
2059
+ rank = indices.ndim
2060
+ if axis < 0:
2061
+ axis += rank + 1
2062
+ depth_range = np.arange(depth_value, dtype=np.int64)
2063
+ new_shape = (1,) * axis + (depth_value,) + (1,) * (rank - axis)
2064
+ targets = depth_range.reshape(new_shape)
2065
+ adjusted = np.mod(indices, depth_value) if depth_value > 0 else indices
2066
+ values_reshaped = np.reshape(
2067
+ adjusted, indices.shape[:axis] + (1,) + indices.shape[axis:]
2068
+ )
2069
+ valid_mask = (indices >= -depth_value) & (indices < depth_value)
2070
+ valid_mask = np.reshape(
2071
+ valid_mask, indices.shape[:axis] + (1,) + indices.shape[axis:]
2072
+ )
2073
+ one_hot = (targets == values_reshaped) & valid_mask
2074
+ output = np.where(one_hot, on_value, off_value).astype(values.dtype)
2075
+ evaluator.values[op.output] = output
2076
+
2077
+
1431
2078
  @register_evaluator("Split")
1432
2079
  def _eval_split(evaluator: Evaluator, node: Node) -> None:
1433
2080
  op = lower_split(evaluator.graph, node)
@@ -1550,6 +2197,39 @@ def _eval_arg_reduce(evaluator: Evaluator, node: Node) -> None:
1550
2197
  evaluator.values[op.output] = indices.astype(op.output_dtype.np_dtype)
1551
2198
 
1552
2199
 
2200
+ @register_evaluator("TopK")
2201
+ def _eval_topk(evaluator: Evaluator, node: Node) -> None:
2202
+ op = lower_topk(evaluator.graph, node)
2203
+ value = evaluator.values[op.input0]
2204
+ moved = np.moveaxis(value, op.axis, -1)
2205
+ axis_dim = moved.shape[-1]
2206
+ flat = moved.reshape(-1, axis_dim)
2207
+ values_out = np.empty((flat.shape[0], op.k), dtype=value.dtype)
2208
+ indices_out = np.empty((flat.shape[0], op.k), dtype=np.int64)
2209
+ for row_index in range(flat.shape[0]):
2210
+ row = flat[row_index]
2211
+ order = sorted(
2212
+ range(axis_dim),
2213
+ key=lambda idx: (
2214
+ -row[idx].item() if op.largest else row[idx].item(),
2215
+ idx,
2216
+ ),
2217
+ )
2218
+ topk = order[: op.k]
2219
+ indices_out[row_index] = topk
2220
+ values_out[row_index] = row[topk]
2221
+ values_out = values_out.reshape(moved.shape[:-1] + (op.k,))
2222
+ indices_out = indices_out.reshape(moved.shape[:-1] + (op.k,))
2223
+ values_out = np.moveaxis(values_out, -1, op.axis)
2224
+ indices_out = np.moveaxis(indices_out, -1, op.axis)
2225
+ evaluator.values[op.output_values] = values_out.astype(
2226
+ op.output_values_dtype.np_dtype
2227
+ )
2228
+ evaluator.values[op.output_indices] = indices_out.astype(
2229
+ op.output_indices_dtype.np_dtype
2230
+ )
2231
+
2232
+
1553
2233
  def _eval_binary_unary(evaluator: Evaluator, node: Node) -> None:
1554
2234
  if node.op_type == "BitShift":
1555
2235
  if len(node.inputs) != 2 or len(node.outputs) != 1:
@@ -1671,9 +2351,8 @@ def _matmul_batch_broadcastable(
1671
2351
  left_padded = (1,) * (max_rank - len(left)) + left
1672
2352
  right_padded = (1,) * (max_rank - len(right)) + right
1673
2353
  for left_dim, right_dim in zip(left_padded, right_padded):
1674
- if left_dim == right_dim or left_dim == 1 or right_dim == 1:
1675
- continue
1676
- return False
2354
+ if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
2355
+ return False
1677
2356
  return True
1678
2357
 
1679
2358
 
@@ -1916,7 +2595,9 @@ def _apply_attention(
1916
2595
  return output, key_total, value_total, qk_output
1917
2596
 
1918
2597
 
1919
- def _apply_conv(spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None) -> np.ndarray:
2598
+ def _apply_conv(
2599
+ spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None
2600
+ ) -> np.ndarray:
1920
2601
  output = np.zeros(
1921
2602
  (spec.batch, spec.out_channels, *spec.out_spatial),
1922
2603
  dtype=data.dtype,
@@ -1958,15 +2639,67 @@ def _apply_conv(spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray |
1958
2639
  valid = False
1959
2640
  break
1960
2641
  in_index.append(in_dim)
1961
- if not valid:
1962
- continue
1963
- acc += data[(n, ic_global, *in_index)] * weights[
1964
- (oc_global, ic, *kernel_index)
1965
- ]
2642
+ if valid:
2643
+ acc += data[(n, ic_global, *in_index)] * weights[
2644
+ (oc_global, ic, *kernel_index)
2645
+ ]
1966
2646
  output[(n, oc_global, *out_index)] = acc
1967
2647
  return output
1968
2648
 
1969
2649
 
2650
+ def _apply_conv_transpose(
2651
+ spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None
2652
+ ) -> np.ndarray:
2653
+ output = np.zeros(
2654
+ (spec.batch, spec.out_channels, *spec.out_spatial), dtype=data.dtype
2655
+ )
2656
+ if bias is not None:
2657
+ output += bias.reshape((1, spec.out_channels) + (1,) * spec.spatial_rank)
2658
+ pad_begin = spec.pads[: spec.spatial_rank]
2659
+ group_in_channels = spec.in_channels // spec.group
2660
+ group_out_channels = spec.out_channels // spec.group
2661
+ for n in range(spec.batch):
2662
+ for g in range(spec.group):
2663
+ oc_base = g * group_out_channels
2664
+ ic_base = g * group_in_channels
2665
+ for ic in range(group_in_channels):
2666
+ ic_global = ic_base + ic
2667
+ for in_index in np.ndindex(*spec.in_spatial):
2668
+ value = data[(n, ic_global, *in_index)]
2669
+ for oc in range(group_out_channels):
2670
+ oc_global = oc_base + oc
2671
+ for kernel_index in np.ndindex(*spec.kernel_shape):
2672
+ out_index = []
2673
+ valid = True
2674
+ for (
2675
+ in_dim,
2676
+ kernel_dim,
2677
+ stride,
2678
+ dilation,
2679
+ pad,
2680
+ out_size,
2681
+ ) in zip(
2682
+ in_index,
2683
+ kernel_index,
2684
+ spec.strides,
2685
+ spec.dilations,
2686
+ pad_begin,
2687
+ spec.out_spatial,
2688
+ ):
2689
+ out_dim = (
2690
+ in_dim * stride + kernel_dim * dilation - pad
2691
+ )
2692
+ if out_dim < 0 or out_dim >= out_size:
2693
+ valid = False
2694
+ break
2695
+ out_index.append(out_dim)
2696
+ if valid:
2697
+ output[(n, oc_global, *out_index)] += (
2698
+ value * weights[(ic_global, oc, *kernel_index)]
2699
+ )
2700
+ return output
2701
+
2702
+
1970
2703
  def _apply_lrn(spec, data: np.ndarray) -> np.ndarray:
1971
2704
  output = np.empty_like(data)
1972
2705
  spatial_shape = spec.shape[2:]
@@ -2002,15 +2735,15 @@ def _apply_average_pool(op, data: np.ndarray) -> np.ndarray:
2002
2735
  if ih < 0 or ih >= op.in_h:
2003
2736
  if op.count_include_pad:
2004
2737
  count += op.kernel_w
2005
- continue
2006
- for kw in range(op.kernel_w):
2007
- iw = ow * op.stride_w + kw - op.pad_left
2008
- if iw < 0 or iw >= op.in_w:
2009
- if op.count_include_pad:
2738
+ else:
2739
+ for kw in range(op.kernel_w):
2740
+ iw = ow * op.stride_w + kw - op.pad_left
2741
+ if iw < 0 or iw >= op.in_w:
2742
+ if op.count_include_pad:
2743
+ count += 1
2744
+ else:
2745
+ acc += data[n, c, ih, iw]
2010
2746
  count += 1
2011
- continue
2012
- acc += data[n, c, ih, iw]
2013
- count += 1
2014
2747
  output[n, c, oh, ow] = 0.0 if count == 0 else acc / float(count)
2015
2748
  return output
2016
2749
 
@@ -2059,25 +2792,30 @@ def _apply_maxpool(
2059
2792
  valid = False
2060
2793
  break
2061
2794
  in_index.append(idx)
2062
- if not valid:
2063
- continue
2064
- value = data[(n, c, *in_index)]
2065
- if value > max_value or not has_value:
2066
- max_value = value
2067
- has_value = True
2068
- if return_indices:
2069
- linear_index = n * spec.channels + c
2070
- if spec.storage_order == 0:
2071
- for idx, size in zip(in_index, spec.in_spatial):
2072
- linear_index = linear_index * size + idx
2073
- else:
2074
- spatial_index = 0
2075
- spatial_stride = 1
2076
- for idx, size in zip(in_index, spec.in_spatial):
2077
- spatial_index += idx * spatial_stride
2078
- spatial_stride *= size
2079
- linear_index = linear_index * spatial_stride + spatial_index
2080
- max_index = linear_index
2795
+ if valid:
2796
+ value = data[(n, c, *in_index)]
2797
+ if value > max_value or not has_value:
2798
+ max_value = value
2799
+ has_value = True
2800
+ if return_indices:
2801
+ linear_index = n * spec.channels + c
2802
+ if spec.storage_order == 0:
2803
+ for idx, size in zip(
2804
+ in_index, spec.in_spatial
2805
+ ):
2806
+ linear_index = linear_index * size + idx
2807
+ else:
2808
+ spatial_index = 0
2809
+ spatial_stride = 1
2810
+ for idx, size in zip(
2811
+ in_index, spec.in_spatial
2812
+ ):
2813
+ spatial_index += idx * spatial_stride
2814
+ spatial_stride *= size
2815
+ linear_index = (
2816
+ linear_index * spatial_stride + spatial_index
2817
+ )
2818
+ max_index = linear_index
2081
2819
  output[(n, c, *out_index)] = max_value
2082
2820
  if return_indices and indices is not None:
2083
2821
  indices[(n, c, *out_index)] = max_index
@@ -2162,8 +2900,12 @@ def _apply_lstm(
2162
2900
  beta_g = spec.activation_betas[act_offset + 1]
2163
2901
  beta_h = spec.activation_betas[act_offset + 2]
2164
2902
  for step in range(seq_length):
2165
- t_index = step if dir_kind == "forward" else seq_length - 1 - step
2166
- x_t = x[t_index]
2903
+ if dir_kind == "forward":
2904
+ x_t = x[step]
2905
+ else:
2906
+ t_indices = sequence_lens - 1 - step
2907
+ t_indices = np.clip(t_indices, 0, seq_length - 1)
2908
+ x_t = x[t_indices, np.arange(batch_size)]
2167
2909
  gates = x_t @ w_dir.T + h_prev @ r_dir.T + bias
2168
2910
  if spec.clip is not None and spec.clip > 0:
2169
2911
  gates = np.clip(gates, -spec.clip, spec.clip)