emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,16 @@ from ..lowering.batch_normalization import lower_batch_normalization
14
14
  from ..lowering.concat import lower_concat
15
15
  from ..lowering.constant_of_shape import lower_constant_of_shape
16
16
  from ..lowering.conv import resolve_conv_spec
17
+ from ..lowering.conv_transpose import resolve_conv_transpose_spec
17
18
  from ..lowering.dropout import lower_dropout
18
19
  from ..lowering.cumsum import lower_cumsum
20
+ from ..lowering.einsum import lower_einsum
19
21
  from ..lowering.flatten import lower_flatten
20
22
  from ..lowering.gemm import resolve_gemm_spec
21
23
  from ..lowering.logsoftmax import lower_logsoftmax
24
+ from ..lowering.hardmax import lower_hardmax
22
25
  from ..lowering.lp_normalization import lower_lp_normalization
26
+ from ..lowering.lp_pool import lower_lp_pool
23
27
  from ..lowering.grid_sample import lower_grid_sample
24
28
  from ..lowering.instance_normalization import lower_instance_normalization
25
29
  from ..lowering.group_normalization import lower_group_normalization
@@ -27,17 +31,21 @@ from ..lowering.layer_normalization import lower_layer_normalization
27
31
  from ..lowering.mean_variance_normalization import (
28
32
  lower_mean_variance_normalization,
29
33
  )
34
+ from ..lowering.global_max_pool import lower_global_max_pool
30
35
  from ..lowering.negative_log_likelihood_loss import (
31
36
  lower_negative_log_likelihood_loss,
32
37
  )
38
+ from ..lowering.nonzero import lower_nonzero
33
39
  from ..lowering.pad import lower_pad
34
40
  from ..lowering.expand import lower_expand
35
41
  from ..lowering.range import lower_range
42
+ from ..lowering.one_hot import lower_onehot
36
43
  from ..lowering.split import lower_split
37
44
  from ..lowering.softmax_cross_entropy_loss import (
38
45
  lower_softmax_cross_entropy_loss,
39
46
  )
40
47
  from ..lowering.arg_reduce import lower_arg_reduce
48
+ from ..lowering.topk import lower_topk
41
49
  from ..lowering.lstm import ACTIVATION_KIND_BY_NAME, resolve_lstm_spec
42
50
  from ..lowering.lrn import resolve_lrn_spec
43
51
  from ..lowering.matmul import lower_matmul
@@ -49,6 +57,7 @@ from ..lowering.reduce import (
49
57
  resolve_reduce_axes,
50
58
  )
51
59
  from ..lowering.reshape import lower_reshape
60
+ from ..lowering.scatter_nd import lower_scatternd
52
61
  from ..lowering.slice import _normalize_slices
53
62
  from ..lowering.shape import lower_shape
54
63
  from ..lowering.size import lower_size
@@ -58,6 +67,7 @@ from ..lowering.squeeze import lower_squeeze
58
67
  from ..lowering.transpose import lower_transpose
59
68
  from ..lowering.unsqueeze import lower_unsqueeze
60
69
  from ..lowering.where import lower_where
70
+ from ..lowering.quantize_linear import resolve_quantize_spec
61
71
  from ..lowering.variadic import BINARY_ONLY_OPS, VARIADIC_OP_FUNCTIONS
62
72
  from ..lowering.registry import resolve_dispatch
63
73
  from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
@@ -133,6 +143,21 @@ def _eval_matmul(evaluator: Evaluator, node: Node) -> None:
133
143
  evaluator.values[node.outputs[0]] = _apply_matmul(left, right)
134
144
 
135
145
 
146
+ @register_evaluator("Einsum")
147
+ def _eval_einsum(evaluator: Evaluator, node: Node) -> None:
148
+ lower_einsum(evaluator.graph, node)
149
+ equation_value = node.attrs.get("equation")
150
+ if equation_value is None:
151
+ raise UnsupportedOpError("Einsum equation attribute is required")
152
+ equation = (
153
+ equation_value.decode()
154
+ if isinstance(equation_value, (bytes, bytearray))
155
+ else str(equation_value)
156
+ )
157
+ inputs = [evaluator.values[name] for name in node.inputs]
158
+ evaluator.values[node.outputs[0]] = np.einsum(equation, *inputs)
159
+
160
+
136
161
  @register_evaluator("Clip")
137
162
  def _eval_clip(evaluator: Evaluator, node: Node) -> None:
138
163
  if not node.inputs or len(node.outputs) != 1:
@@ -242,10 +267,11 @@ def _eval_pad(evaluator: Evaluator, node: Node) -> None:
242
267
  pads_begin = np.zeros(rank, dtype=np.int64)
243
268
  pads_end = np.zeros(rank, dtype=np.int64)
244
269
  for axis, pad_index in enumerate(op.pads_axis_map):
245
- if pad_index is None:
246
- continue
247
- pads_begin[axis] = int(pads_values[pad_index])
248
- pads_end[axis] = int(pads_values[pad_index + axis_count])
270
+ if pad_index is not None:
271
+ pads_begin[axis] = int(pads_values[pad_index])
272
+ pads_end[axis] = int(
273
+ pads_values[pad_index + axis_count]
274
+ )
249
275
  pad_width = tuple(
250
276
  (int(pads_begin[index]), int(pads_end[index]))
251
277
  for index in range(rank)
@@ -270,6 +296,53 @@ def _eval_pad(evaluator: Evaluator, node: Node) -> None:
270
296
  )
271
297
 
272
298
 
299
+ @register_evaluator("ScatterND")
300
+ def _eval_scatternd(evaluator: Evaluator, node: Node) -> None:
301
+ op = lower_scatternd(evaluator.graph, node)
302
+ data = evaluator.values[op.data]
303
+ indices = evaluator.values[op.indices]
304
+ updates = evaluator.values[op.updates]
305
+ output = np.array(data, copy=True)
306
+ index_depth = op.indices_shape[-1]
307
+ update_indices_shape = op.indices_shape[:-1]
308
+ update_count = int(np.prod(update_indices_shape)) if update_indices_shape else 1
309
+ flat_indices = indices.astype(np.int64, copy=False).reshape(
310
+ update_count, index_depth
311
+ )
312
+ tail_shape = op.data_shape[index_depth:]
313
+ updates_reshaped = updates.reshape((update_count,) + tail_shape)
314
+ for index, index_values in enumerate(flat_indices):
315
+ output_index: list[int | slice] = []
316
+ for axis, value in enumerate(index_values):
317
+ axis_size = op.data_shape[axis]
318
+ idx = int(value)
319
+ if idx < 0:
320
+ idx += axis_size
321
+ if idx < 0 or idx >= axis_size:
322
+ raise UnsupportedOpError(
323
+ "ScatterND indices must be within data bounds"
324
+ )
325
+ output_index.append(idx)
326
+ output_index.extend([slice(None)] * len(tail_shape))
327
+ target = tuple(output_index)
328
+ update_value = updates_reshaped[index]
329
+ if op.reduction == "none":
330
+ output[target] = update_value
331
+ elif op.reduction == "add":
332
+ output[target] = output[target] + update_value
333
+ elif op.reduction == "mul":
334
+ output[target] = output[target] * update_value
335
+ elif op.reduction == "min":
336
+ output[target] = np.minimum(output[target], update_value)
337
+ elif op.reduction == "max":
338
+ output[target] = np.maximum(output[target], update_value)
339
+ else:
340
+ raise UnsupportedOpError(
341
+ f"Unsupported ScatterND reduction {op.reduction}"
342
+ )
343
+ evaluator.values[op.output] = output
344
+
345
+
273
346
  @register_evaluator("Celu")
274
347
  def _eval_celu(evaluator: Evaluator, node: Node) -> None:
275
348
  if len(node.inputs) != 1 or len(node.outputs) != 1:
@@ -786,6 +859,40 @@ def _eval_eye_like(evaluator: Evaluator, node: Node) -> None:
786
859
  evaluator.values[node.outputs[0]] = output
787
860
 
788
861
 
862
+ @register_evaluator("Trilu")
863
+ def _eval_trilu(evaluator: Evaluator, node: Node) -> None:
864
+ if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
865
+ raise UnsupportedOpError("Trilu must have 1 or 2 inputs and 1 output")
866
+ value = evaluator.values[node.inputs[0]]
867
+ if value.ndim < 2:
868
+ raise UnsupportedOpError("Trilu expects input rank >= 2")
869
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
870
+ input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
871
+ if output_dtype != input_dtype:
872
+ raise UnsupportedOpError(
873
+ "Trilu expects matching input/output dtypes, "
874
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
875
+ )
876
+ k = 0
877
+ if len(node.inputs) == 2 and node.inputs[1]:
878
+ k_value = np.array(evaluator.values[node.inputs[1]], dtype=np.int64)
879
+ if k_value.size != 1:
880
+ raise UnsupportedOpError("Trilu k input must be scalar")
881
+ k = int(k_value.reshape(-1)[0])
882
+ upper_attr = node.attrs.get("upper", 1)
883
+ upper = bool(int(upper_attr))
884
+ rows, cols = value.shape[-2], value.shape[-1]
885
+ batch_shape = value.shape[:-2]
886
+ batch_size = int(np.prod(batch_shape)) if batch_shape else 1
887
+ view = value.reshape(batch_size, rows, cols)
888
+ if upper:
889
+ mask = np.triu(np.ones((rows, cols), dtype=bool), k=k)
890
+ else:
891
+ mask = np.tril(np.ones((rows, cols), dtype=bool), k=k)
892
+ output = np.where(mask, view, np.zeros_like(view))
893
+ evaluator.values[node.outputs[0]] = output.reshape(value.shape)
894
+
895
+
789
896
  @register_evaluator("Tile")
790
897
  def _eval_tile(evaluator: Evaluator, node: Node) -> None:
791
898
  if len(node.inputs) != 2 or len(node.outputs) != 1:
@@ -922,6 +1029,73 @@ def _eval_gather(evaluator: Evaluator, node: Node) -> None:
922
1029
  evaluator.values[node.outputs[0]] = np.take(data, indices, axis=axis)
923
1030
 
924
1031
 
1032
+ @register_evaluator("GatherND")
1033
+ def _eval_gather_nd(evaluator: Evaluator, node: Node) -> None:
1034
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
1035
+ raise UnsupportedOpError("GatherND must have 2 inputs and 1 output")
1036
+ data = evaluator.values[node.inputs[0]]
1037
+ indices = evaluator.values[node.inputs[1]]
1038
+ if indices.dtype.type not in {np.int32, np.int64}:
1039
+ raise UnsupportedOpError(
1040
+ f"GatherND indices must be int32 or int64, got {indices.dtype}"
1041
+ )
1042
+ if indices.ndim < 1:
1043
+ raise UnsupportedOpError("GatherND indices must have rank >= 1")
1044
+ batch_dims = int(node.attrs.get("batch_dims", 0))
1045
+ if batch_dims < 0:
1046
+ raise UnsupportedOpError(
1047
+ f"GatherND batch_dims must be >= 0, got {batch_dims}"
1048
+ )
1049
+ if batch_dims > indices.ndim - 1:
1050
+ raise UnsupportedOpError(
1051
+ "GatherND batch_dims must be <= indices rank - 1, "
1052
+ f"got {batch_dims} vs {indices.ndim - 1}"
1053
+ )
1054
+ if batch_dims > data.ndim:
1055
+ raise UnsupportedOpError(
1056
+ "GatherND batch_dims must be <= data rank, "
1057
+ f"got {batch_dims} vs {data.ndim}"
1058
+ )
1059
+ if tuple(data.shape[:batch_dims]) != tuple(indices.shape[:batch_dims]):
1060
+ raise UnsupportedOpError(
1061
+ "GatherND batch_dims must match on data/indices, "
1062
+ f"got {data.shape} vs {indices.shape}"
1063
+ )
1064
+ index_depth = indices.shape[-1]
1065
+ if index_depth <= 0:
1066
+ raise UnsupportedOpError(
1067
+ "GatherND indices final dimension must be >= 1"
1068
+ )
1069
+ if index_depth > data.ndim - batch_dims:
1070
+ raise UnsupportedOpError(
1071
+ "GatherND indices final dimension must be <= data rank - "
1072
+ f"batch_dims, got {index_depth} vs {data.ndim - batch_dims}"
1073
+ )
1074
+ tail_shape = data.shape[batch_dims + index_depth :]
1075
+ output_shape = indices.shape[:-1] + tail_shape
1076
+ output = np.empty(output_shape, dtype=data.dtype)
1077
+ indices_prefix_shape = indices.shape[:-1]
1078
+ prefix_iter = (
1079
+ np.ndindex(*indices_prefix_shape) if indices_prefix_shape else [()]
1080
+ )
1081
+ for prefix in prefix_iter:
1082
+ raw_index = indices[prefix]
1083
+ if index_depth == 1:
1084
+ index_values = [int(np.asarray(raw_index).item())]
1085
+ else:
1086
+ index_values = [int(value) for value in raw_index]
1087
+ for dim_index, value in enumerate(index_values):
1088
+ if value < 0:
1089
+ index_values[dim_index] = value + data.shape[
1090
+ batch_dims + dim_index
1091
+ ]
1092
+ data_index = list(prefix[:batch_dims]) + index_values
1093
+ data_index.extend([slice(None)] * len(tail_shape))
1094
+ output_index = prefix + (slice(None),) * len(tail_shape)
1095
+ output[output_index] = data[tuple(data_index)]
1096
+ evaluator.values[node.outputs[0]] = output
1097
+
1098
+
925
1099
  @register_evaluator("Slice")
926
1100
  def _eval_slice(evaluator: Evaluator, node: Node) -> None:
927
1101
  input_value = evaluator.values[node.inputs[0]]
@@ -1101,6 +1275,28 @@ def _eval_conv(evaluator: Evaluator, node: Node) -> None:
1101
1275
  evaluator.values[node.outputs[0]] = _apply_conv(spec, data, weights, bias)
1102
1276
 
1103
1277
 
1278
+ @register_evaluator("ConvTranspose")
1279
+ def _eval_conv_transpose(evaluator: Evaluator, node: Node) -> None:
1280
+ op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1281
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1282
+ if op_dtype != output_dtype:
1283
+ raise UnsupportedOpError(
1284
+ f"{node.op_type} expects matching input/output dtypes, "
1285
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1286
+ )
1287
+ if not op_dtype.is_float:
1288
+ raise UnsupportedOpError(
1289
+ "ConvTranspose supports float16, float, and double inputs only"
1290
+ )
1291
+ spec = resolve_conv_transpose_spec(evaluator.graph, node)
1292
+ data = evaluator.values[node.inputs[0]]
1293
+ weights = evaluator.values[node.inputs[1]]
1294
+ bias = evaluator.values[node.inputs[2]] if len(node.inputs) > 2 else None
1295
+ evaluator.values[node.outputs[0]] = _apply_conv_transpose(
1296
+ spec, data, weights, bias
1297
+ )
1298
+
1299
+
1104
1300
  @register_evaluator("BatchNormalization")
1105
1301
  def _eval_batch_norm(evaluator: Evaluator, node: Node) -> None:
1106
1302
  op = lower_batch_normalization(evaluator.graph, node)
@@ -1133,6 +1329,59 @@ def _eval_lp_normalization(evaluator: Evaluator, node: Node) -> None:
1133
1329
  evaluator.values[op.output] = data / denom
1134
1330
 
1135
1331
 
1332
+ @register_evaluator("LpPool")
1333
+ def _eval_lp_pool(evaluator: Evaluator, node: Node) -> None:
1334
+ op = lower_lp_pool(evaluator.graph, node)
1335
+ data = evaluator.values[op.input0]
1336
+ output = np.zeros(
1337
+ (op.batch, op.channels, op.out_h, op.out_w), dtype=data.dtype
1338
+ )
1339
+ for n in range(op.batch):
1340
+ for c in range(op.channels):
1341
+ for out_h in range(op.out_h):
1342
+ for out_w in range(op.out_w):
1343
+ h_start = out_h * op.stride_h - op.pad_top
1344
+ w_start = out_w * op.stride_w - op.pad_left
1345
+ acc = 0.0
1346
+ for kh in range(op.kernel_h):
1347
+ for kw in range(op.kernel_w):
1348
+ in_h = h_start + kh
1349
+ in_w = w_start + kw
1350
+ if (
1351
+ 0 <= in_h < op.in_h
1352
+ and 0 <= in_w < op.in_w
1353
+ ):
1354
+ value = data[(n, c, in_h, in_w)]
1355
+ acc += abs(value) ** op.p
1356
+ output[(n, c, out_h, out_w)] = acc ** (1.0 / op.p)
1357
+ evaluator.values[op.output] = output
1358
+
1359
+
1360
+ @register_evaluator("QuantizeLinear")
1361
+ def _eval_quantize_linear(evaluator: Evaluator, node: Node) -> None:
1362
+ spec = resolve_quantize_spec(evaluator.graph, node)
1363
+ data = evaluator.values[node.inputs[0]]
1364
+ scale = evaluator.values[node.inputs[1]]
1365
+ zero_point_name = optional_name(node.inputs, 2)
1366
+ if zero_point_name is None:
1367
+ zero_point = 0
1368
+ else:
1369
+ zero_point = evaluator.values[zero_point_name]
1370
+ if spec.axis is None:
1371
+ scaled = data / scale + zero_point
1372
+ else:
1373
+ shape = [1] * data.ndim
1374
+ shape[spec.axis] = scale.shape[0]
1375
+ scaled = data / scale.reshape(shape) + np.asarray(zero_point).reshape(
1376
+ shape
1377
+ )
1378
+ rounded = np.rint(scaled)
1379
+ info = np.iinfo(spec.output_dtype.np_dtype)
1380
+ clipped = np.clip(rounded, info.min, info.max)
1381
+ evaluator.values[node.outputs[0]] = clipped.astype(
1382
+ spec.output_dtype.np_dtype, copy=False
1383
+ )
1384
+
1136
1385
  @register_evaluator("InstanceNormalization")
1137
1386
  def _eval_instance_normalization(evaluator: Evaluator, node: Node) -> None:
1138
1387
  op = lower_instance_normalization(evaluator.graph, node)
@@ -1284,6 +1533,18 @@ def _eval_maxpool(evaluator: Evaluator, node: Node) -> None:
1284
1533
  evaluator.values[indices_output] = indices
1285
1534
 
1286
1535
 
1536
+ @register_evaluator("GlobalMaxPool")
1537
+ def _eval_global_max_pool(evaluator: Evaluator, node: Node) -> None:
1538
+ op = lower_global_max_pool(evaluator.graph, node)
1539
+ value = evaluator.values[node.inputs[0]]
1540
+ if not op.axes:
1541
+ evaluator.values[node.outputs[0]] = value.copy()
1542
+ return
1543
+ evaluator.values[node.outputs[0]] = np.max(
1544
+ value, axis=op.axes, keepdims=op.keepdims
1545
+ )
1546
+
1547
+
1287
1548
  @register_evaluator("Softmax")
1288
1549
  def _eval_softmax(evaluator: Evaluator, node: Node) -> None:
1289
1550
  op = lower_softmax(evaluator.graph, node)
@@ -1298,6 +1559,19 @@ def _eval_logsoftmax(evaluator: Evaluator, node: Node) -> None:
1298
1559
  evaluator.values[node.outputs[0]] = _apply_logsoftmax(value, op.axis)
1299
1560
 
1300
1561
 
1562
+ @register_evaluator("Hardmax")
1563
+ def _eval_hardmax(evaluator: Evaluator, node: Node) -> None:
1564
+ op = lower_hardmax(evaluator.graph, node)
1565
+ value = evaluator.values[node.inputs[0]]
1566
+ max_values = np.max(value, axis=op.axis, keepdims=True)
1567
+ is_max = value == max_values
1568
+ max_index = np.argmax(is_max, axis=op.axis)
1569
+ output = np.zeros_like(value)
1570
+ ones = np.array(1.0, dtype=value.dtype)
1571
+ np.put_along_axis(output, np.expand_dims(max_index, axis=op.axis), ones, axis=op.axis)
1572
+ evaluator.values[node.outputs[0]] = output
1573
+
1574
+
1301
1575
  @register_evaluator("NegativeLogLikelihoodLoss")
1302
1576
  def _eval_negative_log_likelihood_loss(
1303
1577
  evaluator: Evaluator, node: Node
@@ -1409,6 +1683,16 @@ def _eval_size(evaluator: Evaluator, node: Node) -> None:
1409
1683
  evaluator.values[op.output] = np.array(op.value, dtype=np.int64)
1410
1684
 
1411
1685
 
1686
+ @register_evaluator("NonZero")
1687
+ def _eval_nonzero(evaluator: Evaluator, node: Node) -> None:
1688
+ op = lower_nonzero(evaluator.graph, node)
1689
+ values = evaluator.values[op.input0]
1690
+ indices = np.nonzero(values)
1691
+ evaluator.values[op.output] = np.stack(indices, axis=0).astype(
1692
+ np.int64, copy=False
1693
+ )
1694
+
1695
+
1412
1696
  @register_evaluator("Expand")
1413
1697
  def _eval_expand(evaluator: Evaluator, node: Node) -> None:
1414
1698
  op = lower_expand(evaluator.graph, node)
@@ -1428,6 +1712,45 @@ def _eval_range(evaluator: Evaluator, node: Node) -> None:
1428
1712
  evaluator.values[op.output] = output
1429
1713
 
1430
1714
 
1715
+ @register_evaluator("OneHot")
1716
+ def _eval_onehot(evaluator: Evaluator, node: Node) -> None:
1717
+ op = lower_onehot(evaluator.graph, node)
1718
+ indices = evaluator.values[op.indices].astype(np.int64, copy=False)
1719
+ depth_values = evaluator.values[op.depth].reshape(-1)
1720
+ if depth_values.size != 1:
1721
+ raise UnsupportedOpError("OneHot depth input must be a scalar")
1722
+ depth_value = int(depth_values[0])
1723
+ if depth_value < 0:
1724
+ raise UnsupportedOpError("OneHot depth must be non-negative")
1725
+ values = evaluator.values[op.values].reshape(-1)
1726
+ if values.size != 2:
1727
+ raise UnsupportedOpError("OneHot values input must have 2 elements")
1728
+ off_value, on_value = values[0], values[1]
1729
+ if depth_value == 0:
1730
+ evaluator.values[op.output] = np.full(
1731
+ op.output_shape, off_value, dtype=values.dtype
1732
+ )
1733
+ return
1734
+ axis = op.axis
1735
+ rank = indices.ndim
1736
+ if axis < 0:
1737
+ axis += rank + 1
1738
+ depth_range = np.arange(depth_value, dtype=np.int64)
1739
+ new_shape = (1,) * axis + (depth_value,) + (1,) * (rank - axis)
1740
+ targets = depth_range.reshape(new_shape)
1741
+ adjusted = np.mod(indices, depth_value) if depth_value > 0 else indices
1742
+ values_reshaped = np.reshape(
1743
+ adjusted, indices.shape[:axis] + (1,) + indices.shape[axis:]
1744
+ )
1745
+ valid_mask = (indices >= -depth_value) & (indices < depth_value)
1746
+ valid_mask = np.reshape(
1747
+ valid_mask, indices.shape[:axis] + (1,) + indices.shape[axis:]
1748
+ )
1749
+ one_hot = (targets == values_reshaped) & valid_mask
1750
+ output = np.where(one_hot, on_value, off_value).astype(values.dtype)
1751
+ evaluator.values[op.output] = output
1752
+
1753
+
1431
1754
  @register_evaluator("Split")
1432
1755
  def _eval_split(evaluator: Evaluator, node: Node) -> None:
1433
1756
  op = lower_split(evaluator.graph, node)
@@ -1550,6 +1873,39 @@ def _eval_arg_reduce(evaluator: Evaluator, node: Node) -> None:
1550
1873
  evaluator.values[op.output] = indices.astype(op.output_dtype.np_dtype)
1551
1874
 
1552
1875
 
1876
+ @register_evaluator("TopK")
1877
+ def _eval_topk(evaluator: Evaluator, node: Node) -> None:
1878
+ op = lower_topk(evaluator.graph, node)
1879
+ value = evaluator.values[op.input0]
1880
+ moved = np.moveaxis(value, op.axis, -1)
1881
+ axis_dim = moved.shape[-1]
1882
+ flat = moved.reshape(-1, axis_dim)
1883
+ values_out = np.empty((flat.shape[0], op.k), dtype=value.dtype)
1884
+ indices_out = np.empty((flat.shape[0], op.k), dtype=np.int64)
1885
+ for row_index in range(flat.shape[0]):
1886
+ row = flat[row_index]
1887
+ order = sorted(
1888
+ range(axis_dim),
1889
+ key=lambda idx: (
1890
+ -row[idx].item() if op.largest else row[idx].item(),
1891
+ idx,
1892
+ ),
1893
+ )
1894
+ topk = order[: op.k]
1895
+ indices_out[row_index] = topk
1896
+ values_out[row_index] = row[topk]
1897
+ values_out = values_out.reshape(moved.shape[:-1] + (op.k,))
1898
+ indices_out = indices_out.reshape(moved.shape[:-1] + (op.k,))
1899
+ values_out = np.moveaxis(values_out, -1, op.axis)
1900
+ indices_out = np.moveaxis(indices_out, -1, op.axis)
1901
+ evaluator.values[op.output_values] = values_out.astype(
1902
+ op.output_values_dtype.np_dtype
1903
+ )
1904
+ evaluator.values[op.output_indices] = indices_out.astype(
1905
+ op.output_indices_dtype.np_dtype
1906
+ )
1907
+
1908
+
1553
1909
  def _eval_binary_unary(evaluator: Evaluator, node: Node) -> None:
1554
1910
  if node.op_type == "BitShift":
1555
1911
  if len(node.inputs) != 2 or len(node.outputs) != 1:
@@ -1671,9 +2027,8 @@ def _matmul_batch_broadcastable(
1671
2027
  left_padded = (1,) * (max_rank - len(left)) + left
1672
2028
  right_padded = (1,) * (max_rank - len(right)) + right
1673
2029
  for left_dim, right_dim in zip(left_padded, right_padded):
1674
- if left_dim == right_dim or left_dim == 1 or right_dim == 1:
1675
- continue
1676
- return False
2030
+ if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
2031
+ return False
1677
2032
  return True
1678
2033
 
1679
2034
 
@@ -1916,7 +2271,9 @@ def _apply_attention(
1916
2271
  return output, key_total, value_total, qk_output
1917
2272
 
1918
2273
 
1919
- def _apply_conv(spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None) -> np.ndarray:
2274
+ def _apply_conv(
2275
+ spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None
2276
+ ) -> np.ndarray:
1920
2277
  output = np.zeros(
1921
2278
  (spec.batch, spec.out_channels, *spec.out_spatial),
1922
2279
  dtype=data.dtype,
@@ -1958,15 +2315,67 @@ def _apply_conv(spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray |
1958
2315
  valid = False
1959
2316
  break
1960
2317
  in_index.append(in_dim)
1961
- if not valid:
1962
- continue
1963
- acc += data[(n, ic_global, *in_index)] * weights[
1964
- (oc_global, ic, *kernel_index)
1965
- ]
2318
+ if valid:
2319
+ acc += data[(n, ic_global, *in_index)] * weights[
2320
+ (oc_global, ic, *kernel_index)
2321
+ ]
1966
2322
  output[(n, oc_global, *out_index)] = acc
1967
2323
  return output
1968
2324
 
1969
2325
 
2326
+ def _apply_conv_transpose(
2327
+ spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None
2328
+ ) -> np.ndarray:
2329
+ output = np.zeros(
2330
+ (spec.batch, spec.out_channels, *spec.out_spatial), dtype=data.dtype
2331
+ )
2332
+ if bias is not None:
2333
+ output += bias.reshape((1, spec.out_channels) + (1,) * spec.spatial_rank)
2334
+ pad_begin = spec.pads[: spec.spatial_rank]
2335
+ group_in_channels = spec.in_channels // spec.group
2336
+ group_out_channels = spec.out_channels // spec.group
2337
+ for n in range(spec.batch):
2338
+ for g in range(spec.group):
2339
+ oc_base = g * group_out_channels
2340
+ ic_base = g * group_in_channels
2341
+ for ic in range(group_in_channels):
2342
+ ic_global = ic_base + ic
2343
+ for in_index in np.ndindex(*spec.in_spatial):
2344
+ value = data[(n, ic_global, *in_index)]
2345
+ for oc in range(group_out_channels):
2346
+ oc_global = oc_base + oc
2347
+ for kernel_index in np.ndindex(*spec.kernel_shape):
2348
+ out_index = []
2349
+ valid = True
2350
+ for (
2351
+ in_dim,
2352
+ kernel_dim,
2353
+ stride,
2354
+ dilation,
2355
+ pad,
2356
+ out_size,
2357
+ ) in zip(
2358
+ in_index,
2359
+ kernel_index,
2360
+ spec.strides,
2361
+ spec.dilations,
2362
+ pad_begin,
2363
+ spec.out_spatial,
2364
+ ):
2365
+ out_dim = (
2366
+ in_dim * stride + kernel_dim * dilation - pad
2367
+ )
2368
+ if out_dim < 0 or out_dim >= out_size:
2369
+ valid = False
2370
+ break
2371
+ out_index.append(out_dim)
2372
+ if valid:
2373
+ output[(n, oc_global, *out_index)] += (
2374
+ value * weights[(ic_global, oc, *kernel_index)]
2375
+ )
2376
+ return output
2377
+
2378
+
1970
2379
  def _apply_lrn(spec, data: np.ndarray) -> np.ndarray:
1971
2380
  output = np.empty_like(data)
1972
2381
  spatial_shape = spec.shape[2:]
@@ -2002,15 +2411,15 @@ def _apply_average_pool(op, data: np.ndarray) -> np.ndarray:
2002
2411
  if ih < 0 or ih >= op.in_h:
2003
2412
  if op.count_include_pad:
2004
2413
  count += op.kernel_w
2005
- continue
2006
- for kw in range(op.kernel_w):
2007
- iw = ow * op.stride_w + kw - op.pad_left
2008
- if iw < 0 or iw >= op.in_w:
2009
- if op.count_include_pad:
2414
+ else:
2415
+ for kw in range(op.kernel_w):
2416
+ iw = ow * op.stride_w + kw - op.pad_left
2417
+ if iw < 0 or iw >= op.in_w:
2418
+ if op.count_include_pad:
2419
+ count += 1
2420
+ else:
2421
+ acc += data[n, c, ih, iw]
2010
2422
  count += 1
2011
- continue
2012
- acc += data[n, c, ih, iw]
2013
- count += 1
2014
2423
  output[n, c, oh, ow] = 0.0 if count == 0 else acc / float(count)
2015
2424
  return output
2016
2425
 
@@ -2059,25 +2468,30 @@ def _apply_maxpool(
2059
2468
  valid = False
2060
2469
  break
2061
2470
  in_index.append(idx)
2062
- if not valid:
2063
- continue
2064
- value = data[(n, c, *in_index)]
2065
- if value > max_value or not has_value:
2066
- max_value = value
2067
- has_value = True
2068
- if return_indices:
2069
- linear_index = n * spec.channels + c
2070
- if spec.storage_order == 0:
2071
- for idx, size in zip(in_index, spec.in_spatial):
2072
- linear_index = linear_index * size + idx
2073
- else:
2074
- spatial_index = 0
2075
- spatial_stride = 1
2076
- for idx, size in zip(in_index, spec.in_spatial):
2077
- spatial_index += idx * spatial_stride
2078
- spatial_stride *= size
2079
- linear_index = linear_index * spatial_stride + spatial_index
2080
- max_index = linear_index
2471
+ if valid:
2472
+ value = data[(n, c, *in_index)]
2473
+ if value > max_value or not has_value:
2474
+ max_value = value
2475
+ has_value = True
2476
+ if return_indices:
2477
+ linear_index = n * spec.channels + c
2478
+ if spec.storage_order == 0:
2479
+ for idx, size in zip(
2480
+ in_index, spec.in_spatial
2481
+ ):
2482
+ linear_index = linear_index * size + idx
2483
+ else:
2484
+ spatial_index = 0
2485
+ spatial_stride = 1
2486
+ for idx, size in zip(
2487
+ in_index, spec.in_spatial
2488
+ ):
2489
+ spatial_index += idx * spatial_stride
2490
+ spatial_stride *= size
2491
+ linear_index = (
2492
+ linear_index * spatial_stride + spatial_index
2493
+ )
2494
+ max_index = linear_index
2081
2495
  output[(n, c, *out_index)] = max_value
2082
2496
  if return_indices and indices is not None:
2083
2497
  indices[(n, c, *out_index)] = max_index
@@ -2162,8 +2576,12 @@ def _apply_lstm(
2162
2576
  beta_g = spec.activation_betas[act_offset + 1]
2163
2577
  beta_h = spec.activation_betas[act_offset + 2]
2164
2578
  for step in range(seq_length):
2165
- t_index = step if dir_kind == "forward" else seq_length - 1 - step
2166
- x_t = x[t_index]
2579
+ if dir_kind == "forward":
2580
+ x_t = x[step]
2581
+ else:
2582
+ t_indices = sequence_lens - 1 - step
2583
+ t_indices = np.clip(t_indices, 0, seq_length - 1)
2584
+ x_t = x[t_indices, np.arange(batch_size)]
2167
2585
  gates = x_t @ w_dir.T + h_prev @ r_dir.T + bias
2168
2586
  if spec.clip is not None and spec.clip > 0:
2169
2587
  gates = np.clip(gates, -spec.clip, spec.clip)