emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,2206 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Mapping
4
+ import math
5
+
6
+ import numpy as np
7
+
8
+ from shared.scalar_types import ScalarType
9
+ from ..errors import ShapeInferenceError, UnsupportedOpError
10
+ from ..ir.model import Graph, Node
11
+ from ..lowering.attention import resolve_attention_spec
12
+ from ..lowering.average_pool import lower_average_pool, lower_global_average_pool
13
+ from ..lowering.batch_normalization import lower_batch_normalization
14
+ from ..lowering.concat import lower_concat
15
+ from ..lowering.constant_of_shape import lower_constant_of_shape
16
+ from ..lowering.conv import resolve_conv_spec
17
+ from ..lowering.dropout import lower_dropout
18
+ from ..lowering.cumsum import lower_cumsum
19
+ from ..lowering.flatten import lower_flatten
20
+ from ..lowering.gemm import resolve_gemm_spec
21
+ from ..lowering.logsoftmax import lower_logsoftmax
22
+ from ..lowering.lp_normalization import lower_lp_normalization
23
+ from ..lowering.grid_sample import lower_grid_sample
24
+ from ..lowering.instance_normalization import lower_instance_normalization
25
+ from ..lowering.group_normalization import lower_group_normalization
26
+ from ..lowering.layer_normalization import lower_layer_normalization
27
+ from ..lowering.mean_variance_normalization import (
28
+ lower_mean_variance_normalization,
29
+ )
30
+ from ..lowering.negative_log_likelihood_loss import (
31
+ lower_negative_log_likelihood_loss,
32
+ )
33
+ from ..lowering.pad import lower_pad
34
+ from ..lowering.expand import lower_expand
35
+ from ..lowering.range import lower_range
36
+ from ..lowering.split import lower_split
37
+ from ..lowering.softmax_cross_entropy_loss import (
38
+ lower_softmax_cross_entropy_loss,
39
+ )
40
+ from ..lowering.arg_reduce import lower_arg_reduce
41
+ from ..lowering.lstm import ACTIVATION_KIND_BY_NAME, resolve_lstm_spec
42
+ from ..lowering.lrn import resolve_lrn_spec
43
+ from ..lowering.matmul import lower_matmul
44
+ from ..lowering.maxpool import resolve_maxpool_spec
45
+ from ..lowering.reduce import (
46
+ REDUCE_KIND_BY_OP,
47
+ REDUCE_OUTPUTS_FLOAT_ONLY,
48
+ normalize_reduce_axes,
49
+ resolve_reduce_axes,
50
+ )
51
+ from ..lowering.reshape import lower_reshape
52
+ from ..lowering.slice import _normalize_slices
53
+ from ..lowering.shape import lower_shape
54
+ from ..lowering.size import lower_size
55
+ from ..lowering.softmax import lower_softmax
56
+ from ..lowering.rms_normalization import lower_rms_normalization
57
+ from ..lowering.squeeze import lower_squeeze
58
+ from ..lowering.transpose import lower_transpose
59
+ from ..lowering.unsqueeze import lower_unsqueeze
60
+ from ..lowering.where import lower_where
61
+ from ..lowering.variadic import BINARY_ONLY_OPS, VARIADIC_OP_FUNCTIONS
62
+ from ..lowering.registry import resolve_dispatch
63
+ from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
64
+ from ..ops import (
65
+ BINARY_OP_TYPES,
66
+ COMPARE_FUNCTIONS,
67
+ UNARY_OP_TYPES,
68
+ apply_binary_op,
69
+ apply_unary_op,
70
+ binary_op_symbol,
71
+ unary_op_symbol,
72
+ validate_unary_attrs,
73
+ )
74
+ from shared.scalar_functions import ScalarFunction, ScalarFunctionError
75
+ from ..validation import normalize_axis
76
+
77
+ Handler = Callable[["Evaluator", Node], None]
78
+ _EVAL_REGISTRY: dict[str, Handler] = {}
79
+
80
+
81
+ def register_evaluator(op_type: str) -> Callable[[Handler], Handler]:
82
+ def decorator(func: Handler) -> Handler:
83
+ _EVAL_REGISTRY[op_type] = func
84
+ return func
85
+
86
+ return decorator
87
+
88
+
89
+ class Evaluator:
90
+ def __init__(self, graph: Graph) -> None:
91
+ self._graph = graph
92
+ self._values: dict[str, np.ndarray] = {}
93
+
94
+ @property
95
+ def graph(self) -> Graph:
96
+ return self._graph
97
+
98
+ @property
99
+ def values(self) -> dict[str, np.ndarray]:
100
+ return self._values
101
+
102
+ def run(self, feeds: Mapping[str, np.ndarray]) -> dict[str, np.ndarray]:
103
+ values = {
104
+ initializer.name: initializer.data
105
+ for initializer in self._graph.initializers
106
+ }
107
+ values.update(feeds)
108
+ self._values = values
109
+ for node in self._graph.nodes:
110
+ self._dispatch(node)
111
+ return {
112
+ output.name: self._values[output.name]
113
+ for output in self._graph.outputs
114
+ }
115
+
116
+ def _dispatch(self, node: Node) -> None:
117
+ handler = resolve_dispatch(
118
+ node.op_type,
119
+ _EVAL_REGISTRY,
120
+ binary_types=BINARY_OP_TYPES,
121
+ unary_types=UNARY_OP_TYPES,
122
+ binary_fallback=lambda: _eval_binary_unary,
123
+ unary_fallback=lambda: _eval_binary_unary,
124
+ )
125
+ handler(self, node)
126
+
127
+
128
+ @register_evaluator("MatMul")
129
+ def _eval_matmul(evaluator: Evaluator, node: Node) -> None:
130
+ lower_matmul(evaluator.graph, node)
131
+ left = evaluator.values[node.inputs[0]]
132
+ right = evaluator.values[node.inputs[1]]
133
+ evaluator.values[node.outputs[0]] = _apply_matmul(left, right)
134
+
135
+
136
+ @register_evaluator("Clip")
137
+ def _eval_clip(evaluator: Evaluator, node: Node) -> None:
138
+ if not node.inputs or len(node.outputs) != 1:
139
+ raise UnsupportedOpError("Clip must have 1 output")
140
+ input_name = node.inputs[0]
141
+ if not input_name:
142
+ raise UnsupportedOpError("Clip input must be provided")
143
+ x = evaluator.values[input_name]
144
+ min_name = optional_name(node.inputs, 1)
145
+ max_name = optional_name(node.inputs, 2)
146
+ dtype = value_dtype(evaluator.graph, input_name, node)
147
+ if min_name is None:
148
+ min_val = (
149
+ -np.inf
150
+ if dtype.is_float
151
+ else np.iinfo(dtype.np_dtype).min
152
+ )
153
+ else:
154
+ min_val = evaluator.values[min_name]
155
+ if max_name is None:
156
+ max_val = (
157
+ np.inf
158
+ if dtype.is_float
159
+ else np.iinfo(dtype.np_dtype).max
160
+ )
161
+ else:
162
+ max_val = evaluator.values[max_name]
163
+ evaluator.values[node.outputs[0]] = np.clip(x, min_val, max_val)
164
+
165
+
166
+ def _exclusive_cumsum(data: np.ndarray, axis: int) -> np.ndarray:
167
+ result = np.zeros_like(data)
168
+ if data.shape[axis] == 0:
169
+ return result
170
+ cumsum = np.cumsum(data, axis=axis, dtype=data.dtype)
171
+ src_slice = [slice(None)] * data.ndim
172
+ dst_slice = [slice(None)] * data.ndim
173
+ src_slice[axis] = slice(None, -1)
174
+ dst_slice[axis] = slice(1, None)
175
+ result[tuple(dst_slice)] = cumsum[tuple(src_slice)]
176
+ return result
177
+
178
+
179
+ @register_evaluator("CumSum")
180
+ def _eval_cumsum(evaluator: Evaluator, node: Node) -> None:
181
+ op = lower_cumsum(evaluator.graph, node)
182
+ x = evaluator.values[op.input0]
183
+ axis = op.axis
184
+ if axis is None:
185
+ axis_values = evaluator.values[op.axis_input].astype(np.int64, copy=False)
186
+ axis_values = axis_values.reshape(-1)
187
+ if axis_values.size != 1:
188
+ raise UnsupportedOpError("CumSum axis input must be scalar")
189
+ axis = normalize_axis(int(axis_values[0]), op.input_shape, node)
190
+ data = np.flip(x, axis=axis) if op.reverse else x
191
+ if op.exclusive:
192
+ result = _exclusive_cumsum(data, axis)
193
+ else:
194
+ result = np.cumsum(data, axis=axis, dtype=data.dtype)
195
+ if op.reverse:
196
+ result = np.flip(result, axis=axis)
197
+ evaluator.values[op.output] = result
198
+
199
+
200
+ @register_evaluator("Pad")
201
+ def _eval_pad(evaluator: Evaluator, node: Node) -> None:
202
+ op = lower_pad(evaluator.graph, node)
203
+ x = evaluator.values[op.input0]
204
+ if op.value_input is not None:
205
+ value_array = evaluator.values[op.value_input]
206
+ pad_value = np.array(value_array, dtype=op.dtype.np_dtype).reshape(-1)[0].item()
207
+ else:
208
+ pad_value = np.array(op.value, dtype=op.dtype.np_dtype).item()
209
+ rank = len(op.input_shape)
210
+ if op.axes_input is not None:
211
+ axes_values = evaluator.values[op.axes_input].astype(
212
+ np.int64, copy=False
213
+ )
214
+ axes_values = axes_values.reshape(-1)
215
+ if op.pads_input is not None:
216
+ pads_values = evaluator.values[op.pads_input].astype(
217
+ np.int64, copy=False
218
+ )
219
+ pads_values = pads_values.reshape(-1)
220
+ else:
221
+ pads_values = np.array(op.pads_values, dtype=np.int64).reshape(-1)
222
+ axis_count = len(axes_values)
223
+ pads_begin = np.zeros(rank, dtype=np.int64)
224
+ pads_end = np.zeros(rank, dtype=np.int64)
225
+ for index, axis_value in enumerate(axes_values):
226
+ axis = int(axis_value)
227
+ if axis < 0:
228
+ axis += rank
229
+ pads_begin[axis] = int(pads_values[index])
230
+ pads_end[axis] = int(pads_values[index + axis_count])
231
+ pad_width = tuple(
232
+ (int(pads_begin[index]), int(pads_end[index]))
233
+ for index in range(rank)
234
+ )
235
+ elif op.pads_input is not None:
236
+ pads_values = evaluator.values[op.pads_input].astype(np.int64, copy=False)
237
+ pads_values = pads_values.reshape(-1)
238
+ if op.pads_axis_map is not None:
239
+ axis_count = sum(
240
+ 1 for axis_index in op.pads_axis_map if axis_index is not None
241
+ )
242
+ pads_begin = np.zeros(rank, dtype=np.int64)
243
+ pads_end = np.zeros(rank, dtype=np.int64)
244
+ for axis, pad_index in enumerate(op.pads_axis_map):
245
+ if pad_index is None:
246
+ continue
247
+ pads_begin[axis] = int(pads_values[pad_index])
248
+ pads_end[axis] = int(pads_values[pad_index + axis_count])
249
+ pad_width = tuple(
250
+ (int(pads_begin[index]), int(pads_end[index]))
251
+ for index in range(rank)
252
+ )
253
+ else:
254
+ pads_begin = pads_values[:rank]
255
+ pads_end = pads_values[rank: rank * 2]
256
+ pad_width = tuple(
257
+ (int(pads_begin[index]), int(pads_end[index]))
258
+ for index in range(rank)
259
+ )
260
+ else:
261
+ pad_width = tuple(zip(op.pads_begin or (), op.pads_end or ()))
262
+ pad_kwargs = {}
263
+ if op.mode == "constant":
264
+ pad_kwargs["constant_values"] = pad_value
265
+ evaluator.values[op.output] = np.pad(
266
+ x,
267
+ pad_width,
268
+ mode=op.mode,
269
+ **pad_kwargs,
270
+ )
271
+
272
+
273
+ @register_evaluator("Celu")
274
+ def _eval_celu(evaluator: Evaluator, node: Node) -> None:
275
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
276
+ raise UnsupportedOpError("Celu must have 1 input and 1 output")
277
+ dtype = value_dtype(evaluator.graph, node.inputs[0], node)
278
+ if not dtype.is_float:
279
+ raise UnsupportedOpError("Celu only supports floating-point inputs")
280
+ alpha = float(node.attrs.get("alpha", 1.0))
281
+ x = evaluator.values[node.inputs[0]]
282
+ evaluator.values[node.outputs[0]] = np.where(
283
+ x > 0,
284
+ x,
285
+ alpha * (np.exp(x / alpha) - 1.0),
286
+ )
287
+
288
+
289
+ @register_evaluator("Swish")
290
+ def _eval_swish(evaluator: Evaluator, node: Node) -> None:
291
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
292
+ raise UnsupportedOpError("Swish must have 1 input and 1 output")
293
+ dtype = value_dtype(evaluator.graph, node.inputs[0], node)
294
+ if not dtype.is_float:
295
+ raise UnsupportedOpError("Swish only supports floating-point inputs")
296
+ alpha = float(node.attrs.get("alpha", 1.0))
297
+ x = evaluator.values[node.inputs[0]]
298
+ evaluator.values[node.outputs[0]] = x / (1.0 + np.exp(-alpha * x))
299
+
300
+
301
+ def _grid_sample_denormalize(
302
+ value: float, length: int, *, align_corners: bool
303
+ ) -> float:
304
+ if align_corners:
305
+ return (value + 1.0) * (length - 1) / 2.0
306
+ return ((value + 1.0) * length - 1.0) / 2.0
307
+
308
+
309
+ def _grid_sample_reflect(value: float, x_min: float, x_max: float) -> float:
310
+ rng = x_max - x_min
311
+ if rng == 0:
312
+ return x_min
313
+ if value < x_min:
314
+ dx = x_min - value
315
+ n = int(dx / rng)
316
+ r = dx - n * rng
317
+ return x_min + r if n % 2 == 0 else x_max - r
318
+ if value > x_max:
319
+ dx = value - x_max
320
+ n = int(dx / rng)
321
+ r = dx - n * rng
322
+ return x_max - r if n % 2 == 0 else x_min + r
323
+ return value
324
+
325
+
326
+ def _grid_sample_border(
327
+ dims: tuple[int, ...], *, align_corners: bool
328
+ ) -> tuple[list[float], list[float]]:
329
+ min_vals: list[float] = []
330
+ max_vals: list[float] = []
331
+ for dim in dims:
332
+ if align_corners:
333
+ min_vals.append(0.0)
334
+ max_vals.append(dim - 1.0)
335
+ else:
336
+ min_vals.append(-0.5)
337
+ max_vals.append(dim - 0.5)
338
+ return min_vals, max_vals
339
+
340
+
341
+ def _grid_sample_pixel_at(
342
+ data: np.ndarray,
343
+ indices: list[int],
344
+ border_min: list[float],
345
+ border_max: list[float],
346
+ padding_mode: str,
347
+ ) -> float:
348
+ if padding_mode == "zeros":
349
+ for idx, dim in zip(indices, data.shape):
350
+ if idx < 0 or idx >= dim:
351
+ return data.dtype.type(0)
352
+ return data[tuple(indices)]
353
+ if padding_mode == "border":
354
+ clamped = [
355
+ 0 if idx < 0 else dim - 1 if idx >= dim else idx
356
+ for idx, dim in zip(indices, data.shape)
357
+ ]
358
+ return data[tuple(clamped)]
359
+ reflected = [
360
+ int(_grid_sample_reflect(idx, border_min[i], border_max[i]))
361
+ for i, idx in enumerate(indices)
362
+ ]
363
+ return data[tuple(reflected)]
364
+
365
+
366
+ def _grid_sample_linear_1d(
367
+ data: np.ndarray,
368
+ coord: float,
369
+ border_min: float,
370
+ border_max: float,
371
+ padding_mode: str,
372
+ ) -> float:
373
+ base = int(np.floor(coord))
374
+ weight = coord - base
375
+ lower = _grid_sample_pixel_at(
376
+ data, [base], [border_min], [border_max], padding_mode
377
+ )
378
+ upper = _grid_sample_pixel_at(
379
+ data, [base + 1], [border_min], [border_max], padding_mode
380
+ )
381
+ return (1.0 - weight) * lower + weight * upper
382
+
383
+
384
+ def _grid_sample_cubic_coeffs(x: float) -> np.ndarray:
385
+ alpha = -0.75
386
+ abs_x = abs(x)
387
+ coeffs = np.empty((4,), dtype=np.float64)
388
+ coeffs[0] = (
389
+ (alpha * (abs_x + 1.0) - 5.0 * alpha) * (abs_x + 1.0) + 8.0 * alpha
390
+ ) * (abs_x + 1.0) - 4.0 * alpha
391
+ coeffs[1] = ((alpha + 2.0) * abs_x - (alpha + 3.0)) * abs_x * abs_x + 1.0
392
+ inv_x = 1.0 - abs_x
393
+ coeffs[2] = ((alpha + 2.0) * inv_x - (alpha + 3.0)) * inv_x * inv_x + 1.0
394
+ span = 2.0 - abs_x
395
+ coeffs[3] = (
396
+ (alpha * span - 5.0 * alpha) * span + 8.0 * alpha
397
+ ) * span - 4.0 * alpha
398
+ return coeffs
399
+
400
+
401
+ def _grid_sample_cubic_1d(
402
+ data: np.ndarray,
403
+ coord: float,
404
+ border_min: float,
405
+ border_max: float,
406
+ padding_mode: str,
407
+ ) -> float:
408
+ base = int(np.floor(coord))
409
+ coeffs = _grid_sample_cubic_coeffs(coord - base)
410
+ values = np.empty((4,), dtype=np.float64)
411
+ for offset in range(4):
412
+ values[offset] = _grid_sample_pixel_at(
413
+ data,
414
+ [base - 1 + offset],
415
+ [border_min],
416
+ [border_max],
417
+ padding_mode,
418
+ )
419
+ return float(coeffs @ values)
420
+
421
+
422
+ def _grid_sample_linear_nd(
423
+ data: np.ndarray,
424
+ coords: np.ndarray,
425
+ border_min: list[float],
426
+ border_max: list[float],
427
+ padding_mode: str,
428
+ ) -> float:
429
+ if data.ndim == 1:
430
+ return _grid_sample_linear_1d(
431
+ data, float(coords[0]), border_min[0], border_max[0], padding_mode
432
+ )
433
+ reduced = np.array(
434
+ [
435
+ _grid_sample_linear_nd(
436
+ data[index],
437
+ coords[1:],
438
+ border_min[1:],
439
+ border_max[1:],
440
+ padding_mode,
441
+ )
442
+ for index in range(data.shape[0])
443
+ ],
444
+ dtype=np.float64,
445
+ )
446
+ return _grid_sample_linear_1d(
447
+ reduced, float(coords[0]), border_min[0], border_max[0], padding_mode
448
+ )
449
+
450
+
451
+ def _grid_sample_cubic_nd(
452
+ data: np.ndarray,
453
+ coords: np.ndarray,
454
+ border_min: list[float],
455
+ border_max: list[float],
456
+ padding_mode: str,
457
+ ) -> float:
458
+ if data.ndim == 1:
459
+ return _grid_sample_cubic_1d(
460
+ data, float(coords[0]), border_min[0], border_max[0], padding_mode
461
+ )
462
+ reduced = np.array(
463
+ [
464
+ _grid_sample_cubic_nd(
465
+ data[index],
466
+ coords[1:],
467
+ border_min[1:],
468
+ border_max[1:],
469
+ padding_mode,
470
+ )
471
+ for index in range(data.shape[0])
472
+ ],
473
+ dtype=np.float64,
474
+ )
475
+ return _grid_sample_cubic_1d(
476
+ reduced, float(coords[0]), border_min[0], border_max[0], padding_mode
477
+ )
478
+
479
+
480
+ @register_evaluator("GridSample")
481
+ def _eval_grid_sample(evaluator: Evaluator, node: Node) -> None:
482
+ op = lower_grid_sample(evaluator.graph, node)
483
+ input_data = evaluator.values[op.input0]
484
+ grid_data = evaluator.values[op.grid]
485
+ output = np.empty(op.output_shape, dtype=input_data.dtype)
486
+ if output.size == 0:
487
+ evaluator.values[op.output] = output
488
+ return
489
+ dims = op.input_spatial
490
+ border_min, border_max = _grid_sample_border(
491
+ dims, align_corners=op.align_corners
492
+ )
493
+ for n in range(op.output_shape[0]):
494
+ grid_batch = grid_data[n]
495
+ for c in range(op.output_shape[1]):
496
+ input_slice = input_data[n, c]
497
+ for out_idx in np.ndindex(*op.output_spatial):
498
+ coords = np.array(
499
+ grid_batch[out_idx][::-1], dtype=np.float64
500
+ )
501
+ for i, dim in enumerate(dims):
502
+ coords[i] = _grid_sample_denormalize(
503
+ float(coords[i]), dim, align_corners=op.align_corners
504
+ )
505
+ if op.mode == "nearest":
506
+ rounded = np.rint(coords).astype(int)
507
+ if op.padding_mode != "zeros":
508
+ for i, dim in enumerate(dims):
509
+ if (
510
+ rounded[i] < border_min[i]
511
+ or rounded[i] > border_max[i]
512
+ ):
513
+ if op.padding_mode == "border":
514
+ rounded[i] = min(
515
+ max(rounded[i], 0), dim - 1
516
+ )
517
+ else:
518
+ rounded[i] = int(
519
+ _grid_sample_reflect(
520
+ rounded[i],
521
+ border_min[i],
522
+ border_max[i],
523
+ )
524
+ )
525
+ value = _grid_sample_pixel_at(
526
+ input_slice,
527
+ rounded.tolist(),
528
+ border_min,
529
+ border_max,
530
+ op.padding_mode,
531
+ )
532
+ else:
533
+ if op.padding_mode != "zeros":
534
+ for i, dim in enumerate(dims):
535
+ if (
536
+ coords[i] < border_min[i]
537
+ or coords[i] > border_max[i]
538
+ ):
539
+ if op.padding_mode == "border":
540
+ coords[i] = min(
541
+ max(coords[i], 0.0), dim - 1.0
542
+ )
543
+ else:
544
+ coords[i] = _grid_sample_reflect(
545
+ coords[i],
546
+ border_min[i],
547
+ border_max[i],
548
+ )
549
+ if op.mode == "linear":
550
+ value = _grid_sample_linear_nd(
551
+ input_slice,
552
+ coords,
553
+ border_min,
554
+ border_max,
555
+ op.padding_mode,
556
+ )
557
+ else:
558
+ value = _grid_sample_cubic_nd(
559
+ input_slice,
560
+ coords,
561
+ border_min,
562
+ border_max,
563
+ op.padding_mode,
564
+ )
565
+ output[(n, c, *out_idx)] = value
566
+ evaluator.values[op.output] = output
567
+
568
+
569
+ _VARIADIC_COMBINE_FUNCS: dict[
570
+ ScalarFunction, Callable[[np.ndarray, np.ndarray], np.ndarray]
571
+ ] = {
572
+ ScalarFunction.ADD: np.add,
573
+ ScalarFunction.MAXIMUM: np.maximum,
574
+ ScalarFunction.MINIMUM: np.minimum,
575
+ ScalarFunction.LOGICAL_AND: np.logical_and,
576
+ ScalarFunction.LOGICAL_OR: np.logical_or,
577
+ ScalarFunction.LOGICAL_XOR: np.logical_xor,
578
+ ScalarFunction.BITWISE_AND: np.bitwise_and,
579
+ ScalarFunction.BITWISE_OR: np.bitwise_or,
580
+ ScalarFunction.BITWISE_XOR: np.bitwise_xor,
581
+ }
582
+
583
+
584
+ def _validate_variadic_inputs(
585
+ evaluator: Evaluator, node: Node, *, function: ScalarFunction
586
+ ) -> tuple[ScalarType, tuple[int, ...]]:
587
+ if len(node.outputs) != 1:
588
+ raise UnsupportedOpError(f"{node.op_type} must have 1 output")
589
+ if node.op_type in BINARY_ONLY_OPS:
590
+ if len(node.inputs) != 2:
591
+ raise UnsupportedOpError(
592
+ f"{node.op_type} must have exactly 2 inputs"
593
+ )
594
+ elif len(node.inputs) < 2:
595
+ raise UnsupportedOpError(
596
+ f"{node.op_type} must have at least 2 inputs"
597
+ )
598
+ for name in node.inputs:
599
+ if not name:
600
+ raise UnsupportedOpError(f"{node.op_type} input must be provided")
601
+ op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
602
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
603
+ if op_dtype != output_dtype:
604
+ raise UnsupportedOpError(
605
+ f"{node.op_type} expects matching input/output dtypes, "
606
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
607
+ )
608
+ output_shape = value_shape(evaluator.graph, node.outputs[0], node)
609
+ for name in node.inputs:
610
+ input_shape = value_shape(evaluator.graph, name, node)
611
+ if input_shape != output_shape:
612
+ raise UnsupportedOpError(
613
+ f"{node.op_type} expects identical input/output shapes"
614
+ )
615
+ if function in {
616
+ ScalarFunction.LOGICAL_AND,
617
+ ScalarFunction.LOGICAL_OR,
618
+ ScalarFunction.LOGICAL_XOR,
619
+ } and op_dtype != ScalarType.BOOL:
620
+ raise UnsupportedOpError(f"{node.op_type} expects bool inputs")
621
+ if function in {
622
+ ScalarFunction.BITWISE_AND,
623
+ ScalarFunction.BITWISE_OR,
624
+ ScalarFunction.BITWISE_XOR,
625
+ } and not op_dtype.is_integer:
626
+ raise UnsupportedOpError(f"{node.op_type} expects integer inputs")
627
+ if function == ScalarFunction.MEAN and not op_dtype.is_float:
628
+ raise UnsupportedOpError(f"{node.op_type} expects floating-point inputs")
629
+ return op_dtype, output_shape
630
+
631
+
632
+ def _eval_variadic(evaluator: Evaluator, node: Node) -> None:
633
+ function = VARIADIC_OP_FUNCTIONS[node.op_type]
634
+ _validate_variadic_inputs(evaluator, node, function=function)
635
+ values = [evaluator.values[name] for name in node.inputs]
636
+ if function == ScalarFunction.MEAN:
637
+ combine_func = _VARIADIC_COMBINE_FUNCS[ScalarFunction.ADD]
638
+ else:
639
+ combine_func = _VARIADIC_COMBINE_FUNCS[function]
640
+ result = values[0]
641
+ for value in values[1:]:
642
+ result = combine_func(result, value)
643
+ if function == ScalarFunction.MEAN:
644
+ result = result / len(values)
645
+ evaluator.values[node.outputs[0]] = result
646
+
647
+
648
+ for _op_type in VARIADIC_OP_FUNCTIONS:
649
+ register_evaluator(_op_type)(_eval_variadic)
650
+
651
+
652
+ @register_evaluator("Shrink")
653
+ def _eval_shrink(evaluator: Evaluator, node: Node) -> None:
654
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
655
+ raise UnsupportedOpError("Shrink must have 1 input and 1 output")
656
+ bias = float(node.attrs.get("bias", 0.0))
657
+ lambd = float(node.attrs.get("lambd", 0.5))
658
+ x = evaluator.values[node.inputs[0]]
659
+ result = np.where(
660
+ x < -lambd,
661
+ x + bias,
662
+ np.where(x > lambd, x - bias, 0.0),
663
+ )
664
+ if result.dtype != x.dtype:
665
+ result = result.astype(x.dtype)
666
+ evaluator.values[node.outputs[0]] = result
667
+
668
+
669
+ @register_evaluator("IsInf")
670
+ def _eval_isinf(evaluator: Evaluator, node: Node) -> None:
671
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
672
+ raise UnsupportedOpError("IsInf must have 1 input and 1 output")
673
+ input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
674
+ if not input_dtype.is_float:
675
+ raise UnsupportedOpError("IsInf only supports floating-point inputs")
676
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
677
+ if output_dtype != ScalarType.BOOL:
678
+ raise UnsupportedOpError("IsInf output must be bool")
679
+ x = evaluator.values[node.inputs[0]]
680
+ evaluator.values[node.outputs[0]] = np.isinf(x)
681
+
682
+
683
+ @register_evaluator("IsNaN")
684
+ def _eval_isnan(evaluator: Evaluator, node: Node) -> None:
685
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
686
+ raise UnsupportedOpError("IsNaN must have 1 input and 1 output")
687
+ input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
688
+ if not input_dtype.is_float:
689
+ raise UnsupportedOpError("IsNaN only supports floating-point inputs")
690
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
691
+ if output_dtype != ScalarType.BOOL:
692
+ raise UnsupportedOpError("IsNaN output must be bool")
693
+ x = evaluator.values[node.inputs[0]]
694
+ evaluator.values[node.outputs[0]] = np.isnan(x)
695
+
696
+
697
+ @register_evaluator("Gemm")
698
+ def _eval_gemm(evaluator: Evaluator, node: Node) -> None:
699
+ op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
700
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
701
+ if op_dtype != output_dtype:
702
+ raise UnsupportedOpError(
703
+ f"{node.op_type} expects matching input/output dtypes, "
704
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
705
+ )
706
+ spec = resolve_gemm_spec(evaluator.graph, node, op_dtype)
707
+ left = evaluator.values[node.inputs[0]]
708
+ right = evaluator.values[node.inputs[1]]
709
+ if spec.trans_a:
710
+ left = left.T
711
+ if spec.trans_b:
712
+ right = right.T
713
+ result = _apply_matmul(left, right)
714
+ if op_dtype.is_float:
715
+ alpha = float(spec.alpha)
716
+ beta = float(spec.beta)
717
+ else:
718
+ alpha = int(spec.alpha)
719
+ beta = int(spec.beta)
720
+ if alpha != 1:
721
+ result = result * alpha
722
+ if len(node.inputs) == 3:
723
+ bias = evaluator.values[node.inputs[2]]
724
+ if beta != 1:
725
+ bias = bias * beta
726
+ result = result + bias
727
+ evaluator.values[node.outputs[0]] = result
728
+
729
+
730
+ @register_evaluator("Cast")
731
+ def _eval_cast(evaluator: Evaluator, node: Node) -> None:
732
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
733
+ raise UnsupportedOpError("Cast must have 1 input and 1 output")
734
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
735
+ input_value = evaluator.values[node.inputs[0]]
736
+ evaluator.values[node.outputs[0]] = input_value.astype(
737
+ output_dtype.np_dtype, copy=False
738
+ )
739
+
740
+
741
+ @register_evaluator("CastLike")
742
+ def _eval_castlike(evaluator: Evaluator, node: Node) -> None:
743
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
744
+ raise UnsupportedOpError("CastLike must have 2 inputs and 1 output")
745
+ like_dtype = value_dtype(evaluator.graph, node.inputs[1], node)
746
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
747
+ if output_dtype != like_dtype:
748
+ raise UnsupportedOpError(
749
+ "CastLike output dtype must match like input dtype, "
750
+ f"got {output_dtype.onnx_name} and {like_dtype.onnx_name}"
751
+ )
752
+ input_value = evaluator.values[node.inputs[0]]
753
+ evaluator.values[node.outputs[0]] = input_value.astype(
754
+ output_dtype.np_dtype, copy=False
755
+ )
756
+
757
+
758
+ @register_evaluator("Identity")
759
+ def _eval_identity(evaluator: Evaluator, node: Node) -> None:
760
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
761
+ raise UnsupportedOpError("Identity must have 1 input and 1 output")
762
+ value = evaluator.values[node.inputs[0]]
763
+ evaluator.values[node.outputs[0]] = np.array(value, copy=True)
764
+
765
+
766
+ @register_evaluator("EyeLike")
767
+ def _eval_eye_like(evaluator: Evaluator, node: Node) -> None:
768
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
769
+ raise UnsupportedOpError("EyeLike must have 1 input and 1 output")
770
+ output_shape = value_shape(evaluator.graph, node.outputs[0], node)
771
+ if len(output_shape) < 2:
772
+ raise UnsupportedOpError("EyeLike expects input rank >= 2")
773
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
774
+ k = int(node.attrs.get("k", 0))
775
+ output = np.zeros(output_shape, dtype=output_dtype.np_dtype)
776
+ rows, cols = output_shape[-2], output_shape[-1]
777
+ row_start = 0 if k >= 0 else -k
778
+ col_start = k if k >= 0 else 0
779
+ if row_start < rows and col_start < cols:
780
+ diag_len = min(rows - row_start, cols - col_start)
781
+ batch_size = int(np.prod(output_shape[:-2])) if output_shape[:-2] else 1
782
+ view = output.reshape(batch_size, rows, cols)
783
+ diag_idx = np.arange(diag_len, dtype=np.int64)
784
+ one = output_dtype.np_dtype.type(1)
785
+ view[:, row_start + diag_idx, col_start + diag_idx] = one
786
+ evaluator.values[node.outputs[0]] = output
787
+
788
+
789
+ @register_evaluator("Tile")
790
+ def _eval_tile(evaluator: Evaluator, node: Node) -> None:
791
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
792
+ raise UnsupportedOpError("Tile must have 2 inputs and 1 output")
793
+ value = evaluator.values[node.inputs[0]]
794
+ repeats = evaluator.values[node.inputs[1]]
795
+ repeats = np.array(repeats, dtype=np.int64).reshape(-1)
796
+ if repeats.size != value.ndim:
797
+ raise UnsupportedOpError(
798
+ "Tile repeats must have the same rank as input shape"
799
+ )
800
+ evaluator.values[node.outputs[0]] = np.tile(value, repeats)
801
+
802
+
803
+ @register_evaluator("DepthToSpace")
804
+ def _eval_depth_to_space(evaluator: Evaluator, node: Node) -> None:
805
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
806
+ raise UnsupportedOpError("DepthToSpace must have 1 input and 1 output")
807
+ data = evaluator.values[node.inputs[0]]
808
+ if data.ndim != 4:
809
+ raise UnsupportedOpError("DepthToSpace only supports 4D inputs")
810
+ blocksize = int(node.attrs.get("blocksize", 0))
811
+ if blocksize <= 0:
812
+ raise UnsupportedOpError(
813
+ f"DepthToSpace blocksize must be > 0, got {blocksize}"
814
+ )
815
+ mode_attr = node.attrs.get("mode", "DCR")
816
+ if isinstance(mode_attr, bytes):
817
+ mode = mode_attr.decode()
818
+ else:
819
+ mode = str(mode_attr)
820
+ if mode not in {"DCR", "CRD"}:
821
+ raise UnsupportedOpError("DepthToSpace only supports mode DCR or CRD")
822
+ b, c, h, w = data.shape
823
+ if mode == "DCR":
824
+ tmpshape = (
825
+ b,
826
+ blocksize,
827
+ blocksize,
828
+ c // (blocksize * blocksize),
829
+ h,
830
+ w,
831
+ )
832
+ reshaped = data.reshape(tmpshape)
833
+ transposed = np.transpose(reshaped, [0, 3, 4, 1, 5, 2])
834
+ else:
835
+ tmpshape = (
836
+ b,
837
+ c // (blocksize * blocksize),
838
+ blocksize,
839
+ blocksize,
840
+ h,
841
+ w,
842
+ )
843
+ reshaped = data.reshape(tmpshape)
844
+ transposed = np.transpose(reshaped, [0, 1, 4, 2, 5, 3])
845
+ finalshape = (
846
+ b,
847
+ c // (blocksize * blocksize),
848
+ h * blocksize,
849
+ w * blocksize,
850
+ )
851
+ evaluator.values[node.outputs[0]] = np.reshape(transposed, finalshape)
852
+
853
+
854
+ @register_evaluator("SpaceToDepth")
855
+ def _eval_space_to_depth(evaluator: Evaluator, node: Node) -> None:
856
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
857
+ raise UnsupportedOpError("SpaceToDepth must have 1 input and 1 output")
858
+ data = evaluator.values[node.inputs[0]]
859
+ if data.ndim != 4:
860
+ raise UnsupportedOpError("SpaceToDepth only supports 4D inputs")
861
+ blocksize = int(node.attrs.get("blocksize", 0))
862
+ if blocksize <= 0:
863
+ raise UnsupportedOpError(
864
+ f"SpaceToDepth blocksize must be > 0, got {blocksize}"
865
+ )
866
+ b, c, h, w = data.shape
867
+ tmpshape = (
868
+ b,
869
+ c,
870
+ h // blocksize,
871
+ blocksize,
872
+ w // blocksize,
873
+ blocksize,
874
+ )
875
+ reshaped = np.reshape(data, tmpshape)
876
+ transposed = np.transpose(reshaped, [0, 3, 5, 1, 2, 4])
877
+ finalshape = (
878
+ b,
879
+ c * blocksize * blocksize,
880
+ h // blocksize,
881
+ w // blocksize,
882
+ )
883
+ evaluator.values[node.outputs[0]] = np.reshape(transposed, finalshape)
884
+
885
+
886
+ @register_evaluator("Where")
887
+ def _eval_where(evaluator: Evaluator, node: Node) -> None:
888
+ lower_where(evaluator.graph, node)
889
+ condition = evaluator.values[node.inputs[0]]
890
+ x_value = evaluator.values[node.inputs[1]]
891
+ y_value = evaluator.values[node.inputs[2]]
892
+ evaluator.values[node.outputs[0]] = np.where(condition, x_value, y_value)
893
+
894
+
895
+ @register_evaluator("GatherElements")
896
+ def _eval_gather_elements(evaluator: Evaluator, node: Node) -> None:
897
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
898
+ raise UnsupportedOpError("GatherElements must have 2 inputs and 1 output")
899
+ data = evaluator.values[node.inputs[0]]
900
+ indices = evaluator.values[node.inputs[1]]
901
+ if indices.dtype.type not in {np.int32, np.int64}:
902
+ raise UnsupportedOpError(
903
+ f"GatherElements indices must be int32 or int64, got {indices.dtype}"
904
+ )
905
+ axis = normalize_axis(int(node.attrs.get("axis", 0)), data.shape, node)
906
+ evaluator.values[node.outputs[0]] = np.take_along_axis(
907
+ data, indices, axis=axis
908
+ )
909
+
910
+
911
+ @register_evaluator("Gather")
912
+ def _eval_gather(evaluator: Evaluator, node: Node) -> None:
913
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
914
+ raise UnsupportedOpError("Gather must have 2 inputs and 1 output")
915
+ data = evaluator.values[node.inputs[0]]
916
+ indices = evaluator.values[node.inputs[1]]
917
+ if indices.dtype.type not in {np.int32, np.int64}:
918
+ raise UnsupportedOpError(
919
+ f"Gather indices must be int32 or int64, got {indices.dtype}"
920
+ )
921
+ axis = normalize_axis(int(node.attrs.get("axis", 0)), data.shape, node)
922
+ evaluator.values[node.outputs[0]] = np.take(data, indices, axis=axis)
923
+
924
+
925
+ @register_evaluator("Slice")
926
+ def _eval_slice(evaluator: Evaluator, node: Node) -> None:
927
+ input_value = evaluator.values[node.inputs[0]]
928
+ if "starts" in node.attrs or "ends" in node.attrs:
929
+ starts = [int(value) for value in node.attrs.get("starts", [])]
930
+ ends = [int(value) for value in node.attrs.get("ends", [])]
931
+ axes_attr = node.attrs.get("axes")
932
+ axes = [int(value) for value in axes_attr] if axes_attr else None
933
+ steps = None
934
+ else:
935
+ if len(node.inputs) < 3:
936
+ raise UnsupportedOpError(
937
+ f"{node.op_type} expects at least 3 inputs"
938
+ )
939
+ starts_value = evaluator.values[node.inputs[1]]
940
+ ends_value = evaluator.values[node.inputs[2]]
941
+ if starts_value.dtype.type not in {np.int32, np.int64}:
942
+ raise UnsupportedOpError(
943
+ f"{node.op_type} starts input must be int64 or int32"
944
+ )
945
+ if ends_value.dtype.type not in {np.int32, np.int64}:
946
+ raise UnsupportedOpError(
947
+ f"{node.op_type} ends input must be int64 or int32"
948
+ )
949
+ starts = [int(value) for value in starts_value.reshape(-1)]
950
+ ends = [int(value) for value in ends_value.reshape(-1)]
951
+ axes = None
952
+ steps = None
953
+ if len(node.inputs) >= 4 and node.inputs[3]:
954
+ axes_value = evaluator.values[node.inputs[3]]
955
+ if axes_value.dtype.type not in {np.int32, np.int64}:
956
+ raise UnsupportedOpError(
957
+ f"{node.op_type} axes input must be int64 or int32"
958
+ )
959
+ axes = [int(value) for value in axes_value.reshape(-1)]
960
+ if len(node.inputs) >= 5 and node.inputs[4]:
961
+ steps_value = evaluator.values[node.inputs[4]]
962
+ if steps_value.dtype.type not in {np.int32, np.int64}:
963
+ raise UnsupportedOpError(
964
+ f"{node.op_type} steps input must be int64 or int32"
965
+ )
966
+ steps = [int(value) for value in steps_value.reshape(-1)]
967
+ normalized_starts, normalized_steps, output_shape = _normalize_slices(
968
+ input_value.shape, starts, ends, axes, steps, node
969
+ )
970
+ slices = tuple(
971
+ slice(start, start + step * size, step)
972
+ for start, step, size in zip(
973
+ normalized_starts, normalized_steps, output_shape
974
+ )
975
+ )
976
+ evaluator.values[node.outputs[0]] = input_value[slices]
977
+
978
+
979
+ @register_evaluator("Attention")
980
+ def _eval_attention(evaluator: Evaluator, node: Node) -> None:
981
+ input_q = node.inputs[0]
982
+ input_k = node.inputs[1]
983
+ input_v = node.inputs[2]
984
+ output_y = node.outputs[0]
985
+ op_dtype = node_dtype(evaluator.graph, node, input_q, input_k, input_v, output_y)
986
+ spec = resolve_attention_spec(evaluator.graph, node, op_dtype)
987
+ attn_mask_name = optional_name(node.inputs, 3)
988
+ past_key_name = optional_name(node.inputs, 4)
989
+ past_value_name = optional_name(node.inputs, 5)
990
+ nonpad_name = optional_name(node.inputs, 6)
991
+ present_key_name = optional_name(node.outputs, 1)
992
+ present_value_name = optional_name(node.outputs, 2)
993
+ qk_matmul_output_name = optional_name(node.outputs, 3)
994
+ output, present_key, present_value, qk_output = _apply_attention(
995
+ spec,
996
+ evaluator.values[input_q],
997
+ evaluator.values[input_k],
998
+ evaluator.values[input_v],
999
+ evaluator.values[attn_mask_name] if attn_mask_name else None,
1000
+ evaluator.values[past_key_name] if past_key_name else None,
1001
+ evaluator.values[past_value_name] if past_value_name else None,
1002
+ evaluator.values[nonpad_name] if nonpad_name else None,
1003
+ )
1004
+ evaluator.values[output_y] = output
1005
+ if present_key_name is not None:
1006
+ evaluator.values[present_key_name] = present_key
1007
+ if present_value_name is not None:
1008
+ evaluator.values[present_value_name] = present_value
1009
+ if qk_matmul_output_name is not None:
1010
+ evaluator.values[qk_matmul_output_name] = qk_output
1011
+
1012
+
1013
+ def _apply_lstm_activation(
1014
+ kind: int, value: np.ndarray, alpha: float, beta: float
1015
+ ) -> np.ndarray:
1016
+ if kind == ACTIVATION_KIND_BY_NAME["Relu"]:
1017
+ return np.maximum(value, 0)
1018
+ if kind == ACTIVATION_KIND_BY_NAME["Tanh"]:
1019
+ return np.tanh(value)
1020
+ if kind == ACTIVATION_KIND_BY_NAME["Sigmoid"]:
1021
+ return 1 / (1 + np.exp(-value))
1022
+ if kind == ACTIVATION_KIND_BY_NAME["Affine"]:
1023
+ return alpha * value + beta
1024
+ if kind == ACTIVATION_KIND_BY_NAME["LeakyRelu"]:
1025
+ return np.where(value < 0, alpha * value, value)
1026
+ if kind == ACTIVATION_KIND_BY_NAME["ThresholdedRelu"]:
1027
+ return np.where(value > alpha, value, 0)
1028
+ if kind == ACTIVATION_KIND_BY_NAME["ScaledTanh"]:
1029
+ return alpha * np.tanh(beta * value)
1030
+ if kind == ACTIVATION_KIND_BY_NAME["HardSigmoid"]:
1031
+ return np.clip(alpha * value + beta, 0, 1)
1032
+ if kind == ACTIVATION_KIND_BY_NAME["Elu"]:
1033
+ return np.where(value >= 0, value, alpha * (np.exp(value) - 1))
1034
+ if kind == ACTIVATION_KIND_BY_NAME["Softsign"]:
1035
+ return value / (1 + np.abs(value))
1036
+ if kind == ACTIVATION_KIND_BY_NAME["Softplus"]:
1037
+ return np.log1p(np.exp(value))
1038
+ raise UnsupportedOpError(f"Unsupported LSTM activation kind {kind}")
1039
+
1040
+
1041
+ @register_evaluator("LSTM")
1042
+ def _eval_lstm(evaluator: Evaluator, node: Node) -> None:
1043
+ spec = resolve_lstm_spec(evaluator.graph, node)
1044
+ inputs = evaluator.values
1045
+ x = inputs[spec.input_x]
1046
+ w = inputs[spec.input_w]
1047
+ r = inputs[spec.input_r]
1048
+ b = inputs[spec.input_b] if spec.input_b is not None else None
1049
+ sequence_lens = (
1050
+ inputs[spec.input_sequence_lens]
1051
+ if spec.input_sequence_lens is not None
1052
+ else None
1053
+ )
1054
+ initial_h = (
1055
+ inputs[spec.input_initial_h]
1056
+ if spec.input_initial_h is not None
1057
+ else None
1058
+ )
1059
+ initial_c = (
1060
+ inputs[spec.input_initial_c]
1061
+ if spec.input_initial_c is not None
1062
+ else None
1063
+ )
1064
+ p = inputs[spec.input_p] if spec.input_p is not None else None
1065
+ output_y, output_y_h, output_y_c = _apply_lstm(
1066
+ spec,
1067
+ x,
1068
+ w,
1069
+ r,
1070
+ b,
1071
+ sequence_lens,
1072
+ initial_h,
1073
+ initial_c,
1074
+ p,
1075
+ )
1076
+ if spec.output_y is not None:
1077
+ evaluator.values[spec.output_y] = output_y
1078
+ if spec.output_y_h is not None:
1079
+ evaluator.values[spec.output_y_h] = output_y_h
1080
+ if spec.output_y_c is not None:
1081
+ evaluator.values[spec.output_y_c] = output_y_c
1082
+
1083
+
1084
+ @register_evaluator("Conv")
1085
+ def _eval_conv(evaluator: Evaluator, node: Node) -> None:
1086
+ op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1087
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1088
+ if op_dtype != output_dtype:
1089
+ raise UnsupportedOpError(
1090
+ f"{node.op_type} expects matching input/output dtypes, "
1091
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1092
+ )
1093
+ if not op_dtype.is_float:
1094
+ raise UnsupportedOpError(
1095
+ "Conv supports float16, float, and double inputs only"
1096
+ )
1097
+ spec = resolve_conv_spec(evaluator.graph, node)
1098
+ data = evaluator.values[node.inputs[0]]
1099
+ weights = evaluator.values[node.inputs[1]]
1100
+ bias = evaluator.values[node.inputs[2]] if len(node.inputs) > 2 else None
1101
+ evaluator.values[node.outputs[0]] = _apply_conv(spec, data, weights, bias)
1102
+
1103
+
1104
+ @register_evaluator("BatchNormalization")
1105
+ def _eval_batch_norm(evaluator: Evaluator, node: Node) -> None:
1106
+ op = lower_batch_normalization(evaluator.graph, node)
1107
+ data = evaluator.values[op.input0]
1108
+ scale = evaluator.values[op.scale].reshape(
1109
+ (1, op.channels) + (1,) * (data.ndim - 2)
1110
+ )
1111
+ bias = evaluator.values[op.bias].reshape(
1112
+ (1, op.channels) + (1,) * (data.ndim - 2)
1113
+ )
1114
+ mean = evaluator.values[op.mean].reshape(
1115
+ (1, op.channels) + (1,) * (data.ndim - 2)
1116
+ )
1117
+ variance = evaluator.values[op.variance].reshape(
1118
+ (1, op.channels) + (1,) * (data.ndim - 2)
1119
+ )
1120
+ evaluator.values[op.output] = (
1121
+ (data - mean) / np.sqrt(variance + op.epsilon) * scale + bias
1122
+ )
1123
+
1124
+
1125
+ @register_evaluator("LpNormalization")
1126
+ def _eval_lp_normalization(evaluator: Evaluator, node: Node) -> None:
1127
+ op = lower_lp_normalization(evaluator.graph, node)
1128
+ data = evaluator.values[op.input0]
1129
+ if op.p == 1:
1130
+ denom = np.sum(np.abs(data), axis=op.axis, keepdims=True)
1131
+ else:
1132
+ denom = np.sqrt(np.sum(data * data, axis=op.axis, keepdims=True))
1133
+ evaluator.values[op.output] = data / denom
1134
+
1135
+
1136
+ @register_evaluator("InstanceNormalization")
1137
+ def _eval_instance_normalization(evaluator: Evaluator, node: Node) -> None:
1138
+ op = lower_instance_normalization(evaluator.graph, node)
1139
+ data = evaluator.values[op.input0]
1140
+ axes = tuple(range(2, data.ndim))
1141
+ mean = np.mean(data, axis=axes, keepdims=True)
1142
+ var = np.mean((data - mean) ** 2, axis=axes, keepdims=True)
1143
+ scale = evaluator.values[op.scale].reshape(
1144
+ (1, op.channels) + (1,) * (data.ndim - 2)
1145
+ )
1146
+ bias = evaluator.values[op.bias].reshape(
1147
+ (1, op.channels) + (1,) * (data.ndim - 2)
1148
+ )
1149
+ evaluator.values[op.output] = (
1150
+ (data - mean) / np.sqrt(var + op.epsilon) * scale + bias
1151
+ )
1152
+
1153
+
1154
+ @register_evaluator("GroupNormalization")
1155
+ def _eval_group_normalization(evaluator: Evaluator, node: Node) -> None:
1156
+ op = lower_group_normalization(evaluator.graph, node)
1157
+ data = evaluator.values[op.input0]
1158
+ batch = data.shape[0]
1159
+ spatial_shape = data.shape[2:]
1160
+ grouped = data.reshape(
1161
+ (batch, op.num_groups, op.group_size) + spatial_shape
1162
+ )
1163
+ axes = tuple(range(2, grouped.ndim))
1164
+ mean = np.mean(grouped, axis=axes, keepdims=True)
1165
+ var = np.mean((grouped - mean) ** 2, axis=axes, keepdims=True)
1166
+ normalized = (grouped - mean) / np.sqrt(var + op.epsilon)
1167
+ normalized = normalized.reshape(data.shape)
1168
+ scale = evaluator.values[op.scale].reshape(
1169
+ (1, op.channels) + (1,) * (data.ndim - 2)
1170
+ )
1171
+ bias = evaluator.values[op.bias].reshape(
1172
+ (1, op.channels) + (1,) * (data.ndim - 2)
1173
+ )
1174
+ evaluator.values[op.output] = normalized * scale + bias
1175
+
1176
+
1177
+ @register_evaluator("LayerNormalization")
1178
+ def _eval_layer_normalization(evaluator: Evaluator, node: Node) -> None:
1179
+ op = lower_layer_normalization(evaluator.graph, node)
1180
+ data = evaluator.values[op.input0]
1181
+ axes = tuple(range(op.axis, data.ndim))
1182
+ mean = np.mean(data, axis=axes, keepdims=True)
1183
+ var = np.mean((data - mean) ** 2, axis=axes, keepdims=True)
1184
+ inv_std = 1.0 / np.sqrt(var + op.epsilon)
1185
+ normalized = (data - mean) * inv_std
1186
+ scale = evaluator.values[op.scale].reshape(
1187
+ (1,) * op.axis + evaluator.values[op.scale].shape
1188
+ )
1189
+ normalized = normalized * scale
1190
+ if op.bias is not None:
1191
+ bias = evaluator.values[op.bias].reshape(
1192
+ (1,) * op.axis + evaluator.values[op.bias].shape
1193
+ )
1194
+ normalized = normalized + bias
1195
+ evaluator.values[op.output] = normalized
1196
+ if op.mean_output is not None:
1197
+ evaluator.values[op.mean_output] = mean
1198
+ if op.invstd_output is not None:
1199
+ evaluator.values[op.invstd_output] = inv_std
1200
+
1201
+
1202
+ @register_evaluator("MeanVarianceNormalization")
1203
+ def _eval_mean_variance_normalization(
1204
+ evaluator: Evaluator, node: Node
1205
+ ) -> None:
1206
+ op = lower_mean_variance_normalization(evaluator.graph, node)
1207
+ data = evaluator.values[op.input0]
1208
+ mean = np.mean(data, axis=op.axes, keepdims=True)
1209
+ variance = np.mean((data - mean) ** 2, axis=op.axes, keepdims=True)
1210
+ evaluator.values[op.output] = (data - mean) / np.sqrt(
1211
+ variance + op.epsilon
1212
+ )
1213
+
1214
+
1215
+ @register_evaluator("RMSNormalization")
1216
+ def _eval_rms_normalization(evaluator: Evaluator, node: Node) -> None:
1217
+ op = lower_rms_normalization(evaluator.graph, node)
1218
+ data = evaluator.values[op.input0]
1219
+ axes = tuple(range(op.axis, data.ndim))
1220
+ mean_square = np.mean(data * data, axis=axes, keepdims=True)
1221
+ rms = np.sqrt(mean_square + op.epsilon)
1222
+ normalized = data / rms
1223
+ scale = evaluator.values[op.scale].reshape(
1224
+ (1,) * op.axis + evaluator.values[op.scale].shape
1225
+ )
1226
+ evaluator.values[op.output] = normalized * scale
1227
+
1228
+
1229
+ @register_evaluator("LRN")
1230
+ def _eval_lrn(evaluator: Evaluator, node: Node) -> None:
1231
+ op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1232
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1233
+ if op_dtype != output_dtype:
1234
+ raise UnsupportedOpError(
1235
+ f"{node.op_type} expects matching input/output dtypes, "
1236
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1237
+ )
1238
+ if not op_dtype.is_float:
1239
+ raise UnsupportedOpError(
1240
+ "LRN supports float16, float, and double inputs only"
1241
+ )
1242
+ spec = resolve_lrn_spec(evaluator.graph, node)
1243
+ data = evaluator.values[node.inputs[0]]
1244
+ evaluator.values[node.outputs[0]] = _apply_lrn(spec, data)
1245
+
1246
+
1247
+ @register_evaluator("AveragePool")
1248
+ def _eval_average_pool(evaluator: Evaluator, node: Node) -> None:
1249
+ op = lower_average_pool(evaluator.graph, node)
1250
+ data = evaluator.values[node.inputs[0]]
1251
+ evaluator.values[node.outputs[0]] = _apply_average_pool(op, data)
1252
+
1253
+
1254
+ @register_evaluator("GlobalAveragePool")
1255
+ def _eval_global_average_pool(evaluator: Evaluator, node: Node) -> None:
1256
+ op = lower_global_average_pool(evaluator.graph, node)
1257
+ data = evaluator.values[node.inputs[0]]
1258
+ evaluator.values[node.outputs[0]] = _apply_average_pool(op, data)
1259
+
1260
+
1261
+ @register_evaluator("MaxPool")
1262
+ def _eval_maxpool(evaluator: Evaluator, node: Node) -> None:
1263
+ op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1264
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1265
+ if op_dtype != output_dtype:
1266
+ raise UnsupportedOpError(
1267
+ f"{node.op_type} expects matching input/output dtypes, "
1268
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1269
+ )
1270
+ indices_output = node.outputs[1] if len(node.outputs) > 1 else None
1271
+ if indices_output is not None:
1272
+ indices_dtype = value_dtype(evaluator.graph, indices_output, node)
1273
+ if indices_dtype != ScalarType.I64:
1274
+ raise UnsupportedOpError("MaxPool indices output must be int64")
1275
+ if op_dtype == ScalarType.BOOL:
1276
+ raise UnsupportedOpError("MaxPool supports numeric inputs only")
1277
+ spec = resolve_maxpool_spec(evaluator.graph, node)
1278
+ data = evaluator.values[node.inputs[0]]
1279
+ if indices_output is None:
1280
+ evaluator.values[node.outputs[0]] = _apply_maxpool(spec, data)
1281
+ else:
1282
+ values, indices = _apply_maxpool(spec, data, return_indices=True)
1283
+ evaluator.values[node.outputs[0]] = values
1284
+ evaluator.values[indices_output] = indices
1285
+
1286
+
1287
+ @register_evaluator("Softmax")
1288
+ def _eval_softmax(evaluator: Evaluator, node: Node) -> None:
1289
+ op = lower_softmax(evaluator.graph, node)
1290
+ value = evaluator.values[node.inputs[0]]
1291
+ evaluator.values[node.outputs[0]] = _apply_softmax(value, op.axis)
1292
+
1293
+
1294
+ @register_evaluator("LogSoftmax")
1295
+ def _eval_logsoftmax(evaluator: Evaluator, node: Node) -> None:
1296
+ op = lower_logsoftmax(evaluator.graph, node)
1297
+ value = evaluator.values[node.inputs[0]]
1298
+ evaluator.values[node.outputs[0]] = _apply_logsoftmax(value, op.axis)
1299
+
1300
+
1301
+ @register_evaluator("NegativeLogLikelihoodLoss")
1302
+ def _eval_negative_log_likelihood_loss(
1303
+ evaluator: Evaluator, node: Node
1304
+ ) -> None:
1305
+ op = lower_negative_log_likelihood_loss(evaluator.graph, node)
1306
+ input_value = evaluator.values[op.input0]
1307
+ target_value = evaluator.values[op.target]
1308
+ weight_value = evaluator.values[op.weight] if op.weight is not None else None
1309
+ evaluator.values[op.output] = _apply_negative_log_likelihood_loss(
1310
+ input_value,
1311
+ target_value,
1312
+ weight_value,
1313
+ reduction=op.reduction,
1314
+ ignore_index=op.ignore_index,
1315
+ )
1316
+
1317
+
1318
+ @register_evaluator("SoftmaxCrossEntropyLoss")
1319
+ def _eval_softmax_cross_entropy_loss(
1320
+ evaluator: Evaluator, node: Node
1321
+ ) -> None:
1322
+ op = lower_softmax_cross_entropy_loss(evaluator.graph, node)
1323
+ input_value = evaluator.values[op.input0]
1324
+ target_value = evaluator.values[op.target]
1325
+ weight_value = evaluator.values[op.weight] if op.weight is not None else None
1326
+ loss, log_prob = _apply_softmax_cross_entropy_loss(
1327
+ input_value,
1328
+ target_value,
1329
+ weight_value,
1330
+ reduction=op.reduction,
1331
+ ignore_index=op.ignore_index,
1332
+ return_log_prob=op.log_prob is not None,
1333
+ )
1334
+ evaluator.values[op.output] = loss
1335
+ if op.log_prob is not None and log_prob is not None:
1336
+ evaluator.values[op.log_prob] = log_prob
1337
+
1338
+
1339
+ @register_evaluator("Dropout")
1340
+ def _eval_dropout(evaluator: Evaluator, node: Node) -> None:
1341
+ op = lower_dropout(evaluator.graph, node)
1342
+ evaluator.values[op.output] = evaluator.values[op.input0].copy()
1343
+
1344
+
1345
+ @register_evaluator("Concat")
1346
+ def _eval_concat(evaluator: Evaluator, node: Node) -> None:
1347
+ op = lower_concat(evaluator.graph, node)
1348
+ tensors = [evaluator.values[name] for name in node.inputs]
1349
+ evaluator.values[op.output] = np.concatenate(tensors, axis=op.axis)
1350
+
1351
+
1352
+ @register_evaluator("Transpose")
1353
+ def _eval_transpose(evaluator: Evaluator, node: Node) -> None:
1354
+ op = lower_transpose(evaluator.graph, node)
1355
+ evaluator.values[op.output] = np.transpose(
1356
+ evaluator.values[op.input0], axes=tuple(op.perm)
1357
+ )
1358
+
1359
+
1360
+ @register_evaluator("Unsqueeze")
1361
+ def _eval_unsqueeze(evaluator: Evaluator, node: Node) -> None:
1362
+ op = lower_unsqueeze(evaluator.graph, node)
1363
+ evaluator.values[op.output] = evaluator.values[op.input0].reshape(
1364
+ op.output_shape
1365
+ )
1366
+
1367
+
1368
+ @register_evaluator("Squeeze")
1369
+ def _eval_squeeze(evaluator: Evaluator, node: Node) -> None:
1370
+ op = lower_squeeze(evaluator.graph, node)
1371
+ evaluator.values[op.output] = evaluator.values[op.input0].reshape(
1372
+ op.output_shape
1373
+ )
1374
+
1375
+
1376
+ @register_evaluator("Reshape")
1377
+ def _eval_reshape(evaluator: Evaluator, node: Node) -> None:
1378
+ op = lower_reshape(evaluator.graph, node)
1379
+ evaluator.values[op.output] = evaluator.values[op.input0].reshape(
1380
+ op.output_shape
1381
+ )
1382
+
1383
+
1384
+ @register_evaluator("Flatten")
1385
+ def _eval_flatten(evaluator: Evaluator, node: Node) -> None:
1386
+ op = lower_flatten(evaluator.graph, node)
1387
+ evaluator.values[op.output] = evaluator.values[op.input0].reshape(
1388
+ op.output_shape
1389
+ )
1390
+
1391
+
1392
+ @register_evaluator("ConstantOfShape")
1393
+ def _eval_constant_of_shape(evaluator: Evaluator, node: Node) -> None:
1394
+ op = lower_constant_of_shape(evaluator.graph, node)
1395
+ evaluator.values[op.output] = np.full(
1396
+ op.shape, op.value, dtype=op.dtype.np_dtype
1397
+ )
1398
+
1399
+
1400
+ @register_evaluator("Shape")
1401
+ def _eval_shape(evaluator: Evaluator, node: Node) -> None:
1402
+ op = lower_shape(evaluator.graph, node)
1403
+ evaluator.values[op.output] = np.array(op.values, dtype=np.int64)
1404
+
1405
+
1406
+ @register_evaluator("Size")
1407
+ def _eval_size(evaluator: Evaluator, node: Node) -> None:
1408
+ op = lower_size(evaluator.graph, node)
1409
+ evaluator.values[op.output] = np.array(op.value, dtype=np.int64)
1410
+
1411
+
1412
+ @register_evaluator("Expand")
1413
+ def _eval_expand(evaluator: Evaluator, node: Node) -> None:
1414
+ op = lower_expand(evaluator.graph, node)
1415
+ value = evaluator.values[op.input0]
1416
+ evaluator.values[op.output] = np.broadcast_to(
1417
+ value, op.output_shape
1418
+ ).copy()
1419
+
1420
+
1421
+ @register_evaluator("Range")
1422
+ def _eval_range(evaluator: Evaluator, node: Node) -> None:
1423
+ op = lower_range(evaluator.graph, node)
1424
+ start_value = evaluator.values[op.start].reshape(-1)[0]
1425
+ delta_value = evaluator.values[op.delta].reshape(-1)[0]
1426
+ indices = np.arange(op.length, dtype=op.dtype.np_dtype)
1427
+ output = start_value + indices * delta_value
1428
+ evaluator.values[op.output] = output
1429
+
1430
+
1431
+ @register_evaluator("Split")
1432
+ def _eval_split(evaluator: Evaluator, node: Node) -> None:
1433
+ op = lower_split(evaluator.graph, node)
1434
+ data = evaluator.values[op.input0]
1435
+ split_points = np.cumsum(op.split_sizes)[:-1]
1436
+ outputs = np.split(data, split_points, axis=op.axis)
1437
+ for output_name, output_value in zip(op.outputs, outputs):
1438
+ evaluator.values[output_name] = output_value
1439
+
1440
+
1441
+ @register_evaluator("ReduceMean")
1442
+ @register_evaluator("ReduceSum")
1443
+ @register_evaluator("ReduceMax")
1444
+ @register_evaluator("ReduceMin")
1445
+ @register_evaluator("ReduceProd")
1446
+ @register_evaluator("ReduceL1")
1447
+ @register_evaluator("ReduceL2")
1448
+ @register_evaluator("ReduceLogSum")
1449
+ @register_evaluator("ReduceLogSumExp")
1450
+ @register_evaluator("ReduceSumSquare")
1451
+ def _eval_reduce(evaluator: Evaluator, node: Node) -> None:
1452
+ if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
1453
+ raise UnsupportedOpError(
1454
+ f"{node.op_type} must have 1 or 2 inputs and 1 output"
1455
+ )
1456
+ op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1457
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1458
+ if op_dtype != output_dtype:
1459
+ raise UnsupportedOpError(
1460
+ f"{node.op_type} expects matching input/output dtypes, "
1461
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1462
+ )
1463
+ if (
1464
+ node.op_type in REDUCE_OUTPUTS_FLOAT_ONLY
1465
+ and not op_dtype.is_float
1466
+ ):
1467
+ raise UnsupportedOpError(
1468
+ f"{node.op_type} supports float16, float, and double inputs only"
1469
+ )
1470
+ value = evaluator.values[node.inputs[0]]
1471
+ input_shape = value.shape
1472
+ if len(node.inputs) > 1 and node.inputs[1]:
1473
+ axes_value = evaluator.values[node.inputs[1]]
1474
+ if axes_value.dtype.type not in {np.int32, np.int64}:
1475
+ raise UnsupportedOpError(
1476
+ f"{node.op_type} axes input must be int64 or int32"
1477
+ )
1478
+ axes = tuple(int(axis) for axis in axes_value.ravel())
1479
+ noop_with_empty_axes = bool(int(node.attrs.get("noop_with_empty_axes", 0)))
1480
+ if not axes:
1481
+ if noop_with_empty_axes:
1482
+ evaluator.values[node.outputs[0]] = value.copy()
1483
+ return
1484
+ axes = tuple(range(len(input_shape)))
1485
+ axes = normalize_reduce_axes(axes, input_shape, node)
1486
+ else:
1487
+ axes_spec, noop = resolve_reduce_axes(evaluator.graph, node, input_shape)
1488
+ if noop:
1489
+ evaluator.values[node.outputs[0]] = value.copy()
1490
+ return
1491
+ if axes_spec is None or axes_spec.axes is None:
1492
+ raise UnsupportedOpError(
1493
+ f"{node.op_type} axes input must be constant for evaluator"
1494
+ )
1495
+ axes = axes_spec.axes
1496
+ keepdims = bool(int(node.attrs.get("keepdims", 1)))
1497
+ reduce_kind = REDUCE_KIND_BY_OP[node.op_type]
1498
+ if reduce_kind == "sum":
1499
+ result = np.sum(value, axis=axes, keepdims=keepdims)
1500
+ elif reduce_kind == "mean":
1501
+ result = np.mean(value, axis=axes, keepdims=keepdims)
1502
+ elif reduce_kind == "max":
1503
+ result = np.max(value, axis=axes, keepdims=keepdims)
1504
+ elif reduce_kind == "min":
1505
+ result = np.min(value, axis=axes, keepdims=keepdims)
1506
+ elif reduce_kind == "prod":
1507
+ result = np.prod(value, axis=axes, keepdims=keepdims)
1508
+ elif reduce_kind == "l1":
1509
+ result = np.sum(np.abs(value), axis=axes, keepdims=keepdims)
1510
+ elif reduce_kind == "l2":
1511
+ result = np.sqrt(np.sum(value * value, axis=axes, keepdims=keepdims))
1512
+ elif reduce_kind == "logsum":
1513
+ result = np.log(np.sum(value, axis=axes, keepdims=keepdims))
1514
+ elif reduce_kind == "logsumexp":
1515
+ result = np.log(np.sum(np.exp(value), axis=axes, keepdims=keepdims))
1516
+ elif reduce_kind == "sumsquare":
1517
+ result = np.sum(value * value, axis=axes, keepdims=keepdims)
1518
+ else:
1519
+ raise UnsupportedOpError(f"Unsupported reduce kind {reduce_kind}")
1520
+ evaluator.values[node.outputs[0]] = result
1521
+
1522
+
1523
+ @register_evaluator("ArgMax")
1524
+ @register_evaluator("ArgMin")
1525
+ def _eval_arg_reduce(evaluator: Evaluator, node: Node) -> None:
1526
+ op = lower_arg_reduce(evaluator.graph, node)
1527
+ value = evaluator.values[op.input0]
1528
+ if op.select_last_index:
1529
+ flipped = np.flip(value, axis=op.axis)
1530
+ if op.reduce_kind == "max":
1531
+ indices = np.argmax(flipped, axis=op.axis)
1532
+ elif op.reduce_kind == "min":
1533
+ indices = np.argmin(flipped, axis=op.axis)
1534
+ else:
1535
+ raise UnsupportedOpError(
1536
+ f"Unsupported arg reduce kind {op.reduce_kind}"
1537
+ )
1538
+ indices = value.shape[op.axis] - 1 - indices
1539
+ else:
1540
+ if op.reduce_kind == "max":
1541
+ indices = np.argmax(value, axis=op.axis)
1542
+ elif op.reduce_kind == "min":
1543
+ indices = np.argmin(value, axis=op.axis)
1544
+ else:
1545
+ raise UnsupportedOpError(
1546
+ f"Unsupported arg reduce kind {op.reduce_kind}"
1547
+ )
1548
+ if op.keepdims:
1549
+ indices = np.expand_dims(indices, axis=op.axis)
1550
+ evaluator.values[op.output] = indices.astype(op.output_dtype.np_dtype)
1551
+
1552
+
1553
+ def _eval_binary_unary(evaluator: Evaluator, node: Node) -> None:
1554
+ if node.op_type == "BitShift":
1555
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
1556
+ raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
1557
+ direction_attr = node.attrs.get("direction", "LEFT")
1558
+ if isinstance(direction_attr, bytes):
1559
+ direction = direction_attr.decode()
1560
+ else:
1561
+ direction = str(direction_attr)
1562
+ if direction not in {"LEFT", "RIGHT"}:
1563
+ raise UnsupportedOpError(
1564
+ "BitShift direction must be LEFT or RIGHT"
1565
+ )
1566
+ op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
1567
+ if not op_dtype.is_integer:
1568
+ raise UnsupportedOpError("BitShift expects integer inputs")
1569
+ function = (
1570
+ ScalarFunction.BITWISE_LEFT_SHIFT
1571
+ if direction == "LEFT"
1572
+ else ScalarFunction.BITWISE_RIGHT_SHIFT
1573
+ )
1574
+ op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
1575
+ if op_spec is None:
1576
+ raise UnsupportedOpError("Unsupported op BitShift")
1577
+ left = evaluator.values[node.inputs[0]]
1578
+ right = evaluator.values[node.inputs[1]]
1579
+ evaluator.values[node.outputs[0]] = apply_binary_op(
1580
+ op_spec, left, right
1581
+ )
1582
+ return
1583
+ if node.op_type == "Mod":
1584
+ fmod = int(node.attrs.get("fmod", 0))
1585
+ if fmod not in {0, 1}:
1586
+ raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
1587
+ function = (
1588
+ ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
1589
+ )
1590
+ else:
1591
+ try:
1592
+ function = ScalarFunction.from_onnx_op(node.op_type)
1593
+ except ScalarFunctionError as exc:
1594
+ raise UnsupportedOpError(
1595
+ f"Unsupported op {node.op_type}"
1596
+ ) from exc
1597
+ validate_unary_attrs(node.op_type, node.attrs)
1598
+ if function in COMPARE_FUNCTIONS:
1599
+ input_dtype = node_dtype(evaluator.graph, node, *node.inputs)
1600
+ output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1601
+ if output_dtype != ScalarType.BOOL:
1602
+ raise UnsupportedOpError(
1603
+ f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
1604
+ )
1605
+ op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
1606
+ if op_spec is None:
1607
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
1608
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
1609
+ raise UnsupportedOpError(
1610
+ f"{node.op_type} must have 2 inputs and 1 output"
1611
+ )
1612
+ left = evaluator.values[node.inputs[0]]
1613
+ right = evaluator.values[node.inputs[1]]
1614
+ evaluator.values[node.outputs[0]] = apply_binary_op(
1615
+ op_spec, left, right
1616
+ )
1617
+ return
1618
+ op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
1619
+ op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
1620
+ unary_symbol = unary_op_symbol(function, dtype=op_dtype)
1621
+ if op_spec is None and unary_symbol is None:
1622
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
1623
+ if op_spec is not None:
1624
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
1625
+ raise UnsupportedOpError(
1626
+ f"{node.op_type} must have 2 inputs and 1 output"
1627
+ )
1628
+ left = evaluator.values[node.inputs[0]]
1629
+ right = evaluator.values[node.inputs[1]]
1630
+ evaluator.values[node.outputs[0]] = apply_binary_op(
1631
+ op_spec, left, right
1632
+ )
1633
+ return
1634
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
1635
+ raise UnsupportedOpError(
1636
+ f"{node.op_type} must have 1 input and 1 output"
1637
+ )
1638
+ value = evaluator.values[node.inputs[0]]
1639
+ evaluator.values[node.outputs[0]] = apply_unary_op(
1640
+ function, value, dtype=op_dtype
1641
+ )
1642
+
1643
+
1644
+ def _apply_matmul(left: np.ndarray, right: np.ndarray) -> np.ndarray:
1645
+ if left.ndim < 1 or right.ndim < 1:
1646
+ raise UnsupportedOpError(
1647
+ "MatMul inputs must be at least 1D, "
1648
+ f"got {left.shape} x {right.shape}"
1649
+ )
1650
+ left_dim = left.shape[-1]
1651
+ right_dim = right.shape[0] if right.ndim == 1 else right.shape[-2]
1652
+ if left_dim != right_dim:
1653
+ raise ShapeInferenceError(
1654
+ "MatMul inner dimensions must match, "
1655
+ f"got {left_dim} and {right_dim}"
1656
+ )
1657
+ left_batch = left.shape[:-2] if left.ndim > 1 else ()
1658
+ right_batch = right.shape[:-2] if right.ndim > 1 else ()
1659
+ if not _matmul_batch_broadcastable(left_batch, right_batch):
1660
+ raise ShapeInferenceError(
1661
+ "MatMul batch dimensions must be broadcastable, "
1662
+ f"got {left_batch} x {right_batch}"
1663
+ )
1664
+ return np.matmul(left, right)
1665
+
1666
+
1667
+ def _matmul_batch_broadcastable(
1668
+ left: tuple[int, ...], right: tuple[int, ...]
1669
+ ) -> bool:
1670
+ max_rank = max(len(left), len(right))
1671
+ left_padded = (1,) * (max_rank - len(left)) + left
1672
+ right_padded = (1,) * (max_rank - len(right)) + right
1673
+ for left_dim, right_dim in zip(left_padded, right_padded):
1674
+ if left_dim == right_dim or left_dim == 1 or right_dim == 1:
1675
+ continue
1676
+ return False
1677
+ return True
1678
+
1679
+
1680
+ def _apply_softmax(values: np.ndarray, axis: int) -> np.ndarray:
1681
+ max_values = np.max(values, axis=axis, keepdims=True)
1682
+ exp_values = np.exp(values - max_values)
1683
+ sum_values = np.sum(exp_values, axis=axis, keepdims=True)
1684
+ return exp_values / sum_values
1685
+
1686
+
1687
+ def _apply_logsoftmax(values: np.ndarray, axis: int) -> np.ndarray:
1688
+ max_values = np.max(values, axis=axis, keepdims=True)
1689
+ shifted = values - max_values
1690
+ logsum = np.log(np.sum(np.exp(shifted), axis=axis, keepdims=True))
1691
+ return shifted - logsum
1692
+
1693
+
1694
+ def _apply_negative_log_likelihood_loss(
1695
+ values: np.ndarray,
1696
+ target: np.ndarray,
1697
+ weight: np.ndarray | None,
1698
+ *,
1699
+ reduction: str,
1700
+ ignore_index: int,
1701
+ ) -> np.ndarray:
1702
+ input_shape = values.shape
1703
+ if len(input_shape) < 2:
1704
+ raise UnsupportedOpError(
1705
+ "NegativeLogLikelihoodLoss input must be at least 2D"
1706
+ )
1707
+ target_shape = target.shape
1708
+ if input_shape[0] != target_shape[0]:
1709
+ raise ShapeInferenceError(
1710
+ "NegativeLogLikelihoodLoss target batch dimension must match input"
1711
+ )
1712
+ if input_shape[2:] != target_shape[1:]:
1713
+ raise ShapeInferenceError(
1714
+ "NegativeLogLikelihoodLoss target spatial dimensions must match input"
1715
+ )
1716
+ n = input_shape[0]
1717
+ c = input_shape[1]
1718
+ if weight is not None:
1719
+ gather_weight = np.take(weight, target.astype(np.int32), mode="clip")
1720
+ if ignore_index is not None:
1721
+ gather_weight = np.where(target == ignore_index, 0, gather_weight).astype(
1722
+ dtype=values.dtype
1723
+ )
1724
+ elif ignore_index != -1:
1725
+ gather_weight = np.where(target == ignore_index, 0, 1).astype(
1726
+ dtype=values.dtype
1727
+ )
1728
+ else:
1729
+ gather_weight = None
1730
+ if len(input_shape) != 3:
1731
+ values = values.reshape((n, c, -1))
1732
+ target = target.reshape((n, -1))
1733
+ d = values.shape[2]
1734
+ loss = np.zeros((n, d), dtype=values.dtype)
1735
+ for i in range(n):
1736
+ for d_index in range(d):
1737
+ if target[i][d_index] != ignore_index:
1738
+ loss[i][d_index] = -values[i][target[i][d_index]][d_index]
1739
+ if len(input_shape) != 3:
1740
+ loss = loss.reshape(target_shape)
1741
+ if gather_weight is not None:
1742
+ loss = gather_weight * loss
1743
+ if reduction == "mean":
1744
+ weight_sum = gather_weight.sum()
1745
+ if weight_sum == 0:
1746
+ return np.array(0, dtype=values.dtype)
1747
+ loss = loss.sum() / weight_sum
1748
+ return loss.astype(values.dtype)
1749
+ if reduction == "mean":
1750
+ loss = np.mean(loss)
1751
+ elif reduction == "sum":
1752
+ loss = np.sum(loss)
1753
+ return loss.astype(values.dtype)
1754
+
1755
+
1756
+ def _apply_softmax_cross_entropy_loss(
1757
+ values: np.ndarray,
1758
+ target: np.ndarray,
1759
+ weight: np.ndarray | None,
1760
+ *,
1761
+ reduction: str,
1762
+ ignore_index: int | None,
1763
+ return_log_prob: bool,
1764
+ ) -> tuple[np.ndarray, np.ndarray | None]:
1765
+ input_shape = values.shape
1766
+ if len(input_shape) < 2:
1767
+ raise UnsupportedOpError(
1768
+ "SoftmaxCrossEntropyLoss input must be at least 2D"
1769
+ )
1770
+ target_shape = target.shape
1771
+ if input_shape[0] != target_shape[0]:
1772
+ raise ShapeInferenceError(
1773
+ "SoftmaxCrossEntropyLoss target batch dimension must match input"
1774
+ )
1775
+ if input_shape[2:] != target_shape[1:]:
1776
+ raise ShapeInferenceError(
1777
+ "SoftmaxCrossEntropyLoss target spatial dimensions must match input"
1778
+ )
1779
+ log_prob = _apply_logsoftmax(values, axis=1)
1780
+ log_prob_output = log_prob if return_log_prob else None
1781
+ if weight is not None:
1782
+ gather_weight = np.take(weight, target.astype(np.int32), mode="clip")
1783
+ if ignore_index is not None:
1784
+ gather_weight = np.where(target == ignore_index, 0, gather_weight).astype(
1785
+ dtype=values.dtype
1786
+ )
1787
+ elif ignore_index is not None:
1788
+ gather_weight = np.where(target == ignore_index, 0, 1).astype(
1789
+ dtype=values.dtype
1790
+ )
1791
+ else:
1792
+ gather_weight = None
1793
+ n = input_shape[0]
1794
+ c = input_shape[1]
1795
+ if len(input_shape) != 3:
1796
+ log_prob = log_prob.reshape((n, c, -1))
1797
+ target = target.reshape((n, -1))
1798
+ d = log_prob.shape[2]
1799
+ loss = np.zeros((n, d), dtype=values.dtype)
1800
+ for i in range(n):
1801
+ for d_index in range(d):
1802
+ if ignore_index is None or target[i][d_index] != ignore_index:
1803
+ loss[i][d_index] = -log_prob[i][target[i][d_index]][d_index]
1804
+ if len(input_shape) != 3:
1805
+ loss = loss.reshape(target_shape)
1806
+ if gather_weight is not None:
1807
+ loss = gather_weight * loss
1808
+ if reduction == "mean":
1809
+ loss = loss.sum() / gather_weight.sum()
1810
+ loss = loss.astype(values.dtype)
1811
+ if return_log_prob:
1812
+ return loss, log_prob.astype(values.dtype)
1813
+ return loss, None
1814
+ if reduction == "mean":
1815
+ loss = np.mean(loss)
1816
+ elif reduction == "sum":
1817
+ loss = np.sum(loss)
1818
+ loss = loss.astype(values.dtype)
1819
+ if return_log_prob and log_prob_output is not None:
1820
+ return loss, log_prob_output.astype(values.dtype)
1821
+ return loss, None
1822
+
1823
+
1824
+ def _apply_attention(
1825
+ spec,
1826
+ query: np.ndarray,
1827
+ key: np.ndarray,
1828
+ value: np.ndarray,
1829
+ attn_mask: np.ndarray | None,
1830
+ past_key: np.ndarray | None,
1831
+ past_value: np.ndarray | None,
1832
+ nonpad_kv_seqlen: np.ndarray | None,
1833
+ ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None]:
1834
+ if spec.q_rank == 3:
1835
+ query_4d = query.reshape(
1836
+ spec.batch, spec.q_seq, spec.q_heads, spec.qk_head_size
1837
+ ).transpose(0, 2, 1, 3)
1838
+ key_4d = key.reshape(
1839
+ spec.batch, spec.kv_seq, spec.kv_heads, spec.qk_head_size
1840
+ ).transpose(0, 2, 1, 3)
1841
+ value_4d = value.reshape(
1842
+ spec.batch, spec.kv_seq, spec.kv_heads, spec.v_head_size
1843
+ ).transpose(0, 2, 1, 3)
1844
+ else:
1845
+ query_4d = query
1846
+ key_4d = key
1847
+ value_4d = value
1848
+ if past_key is not None and past_value is not None:
1849
+ key_total = np.concatenate([past_key, key_4d], axis=2)
1850
+ value_total = np.concatenate([past_value, value_4d], axis=2)
1851
+ else:
1852
+ key_total = key_4d
1853
+ value_total = value_4d
1854
+ if spec.head_group_size > 1:
1855
+ key_total_expanded = np.repeat(key_total, spec.head_group_size, axis=1)
1856
+ value_total_expanded = np.repeat(
1857
+ value_total, spec.head_group_size, axis=1
1858
+ )
1859
+ else:
1860
+ key_total_expanded = key_total
1861
+ value_total_expanded = value_total
1862
+ k_transpose = np.transpose(key_total_expanded, (0, 1, 3, 2))
1863
+ scores = np.matmul(query_4d, k_transpose) * spec.scale
1864
+ bias = np.zeros_like(scores)
1865
+ if spec.has_attn_mask and attn_mask is not None:
1866
+ if spec.mask_is_bool:
1867
+ bias_mask = np.where(attn_mask, 0.0, -np.inf)
1868
+ else:
1869
+ bias_mask = attn_mask.astype(scores.dtype)
1870
+ if spec.mask_rank == 2:
1871
+ bias_mask = bias_mask[None, None, ...]
1872
+ elif spec.mask_rank == 3:
1873
+ bias_mask = bias_mask[:, None, ...]
1874
+ bias_mask = np.broadcast_to(
1875
+ bias_mask, (spec.batch, spec.q_heads, spec.q_seq, spec.mask_kv_seq)
1876
+ )
1877
+ if spec.mask_kv_seq < spec.total_seq:
1878
+ pad_width = spec.total_seq - spec.mask_kv_seq
1879
+ bias_mask = np.pad(
1880
+ bias_mask,
1881
+ ((0, 0), (0, 0), (0, 0), (0, pad_width)),
1882
+ constant_values=-np.inf,
1883
+ )
1884
+ bias = bias + bias_mask
1885
+ if spec.has_nonpad and nonpad_kv_seqlen is not None:
1886
+ kv_range = np.arange(spec.total_seq)[None, None, None, :]
1887
+ valid = kv_range < nonpad_kv_seqlen[:, None, None, None]
1888
+ bias = bias + np.where(valid, 0.0, -np.inf)
1889
+ if spec.is_causal:
1890
+ kv_range = np.arange(spec.total_seq)[None, :]
1891
+ q_range = np.arange(spec.q_seq)[:, None] + spec.past_seq
1892
+ causal_mask = kv_range > q_range
1893
+ bias = bias + np.where(causal_mask, -np.inf, 0.0)[None, None, :, :]
1894
+ scores_with_bias = scores + bias
1895
+ if spec.softcap != 0.0:
1896
+ scores_softcap = spec.softcap * np.tanh(scores_with_bias / spec.softcap)
1897
+ else:
1898
+ scores_softcap = scores_with_bias
1899
+ max_scores = np.max(scores_softcap, axis=-1, keepdims=True)
1900
+ weights = np.exp(scores_softcap - max_scores)
1901
+ weights /= np.sum(weights, axis=-1, keepdims=True)
1902
+ output = np.matmul(weights, value_total_expanded)
1903
+ if spec.q_rank == 3:
1904
+ output = output.transpose(0, 2, 1, 3).reshape(
1905
+ spec.batch, spec.q_seq, spec.q_heads * spec.v_head_size
1906
+ )
1907
+ qk_output = None
1908
+ if spec.qk_matmul_output_mode == 0:
1909
+ qk_output = scores
1910
+ elif spec.qk_matmul_output_mode == 1:
1911
+ qk_output = scores_with_bias
1912
+ elif spec.qk_matmul_output_mode == 2:
1913
+ qk_output = scores_softcap
1914
+ else:
1915
+ qk_output = weights
1916
+ return output, key_total, value_total, qk_output
1917
+
1918
+
1919
+ def _apply_conv(spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None) -> np.ndarray:
1920
+ output = np.zeros(
1921
+ (spec.batch, spec.out_channels, *spec.out_spatial),
1922
+ dtype=data.dtype,
1923
+ )
1924
+ pad_begin = spec.pads[: spec.spatial_rank]
1925
+ group_in_channels = spec.in_channels // spec.group
1926
+ group_out_channels = spec.out_channels // spec.group
1927
+ for n in range(spec.batch):
1928
+ for g in range(spec.group):
1929
+ oc_base = g * group_out_channels
1930
+ ic_base = g * group_in_channels
1931
+ for oc in range(group_out_channels):
1932
+ oc_global = oc_base + oc
1933
+ base = bias[oc_global] if bias is not None else 0.0
1934
+ for out_index in np.ndindex(*spec.out_spatial):
1935
+ acc = base
1936
+ for ic in range(group_in_channels):
1937
+ ic_global = ic_base + ic
1938
+ for kernel_index in np.ndindex(*spec.kernel_shape):
1939
+ in_index = []
1940
+ valid = True
1941
+ for (
1942
+ out_dim,
1943
+ kernel_dim,
1944
+ stride,
1945
+ dilation,
1946
+ pad,
1947
+ in_size,
1948
+ ) in zip(
1949
+ out_index,
1950
+ kernel_index,
1951
+ spec.strides,
1952
+ spec.dilations,
1953
+ pad_begin,
1954
+ spec.in_spatial,
1955
+ ):
1956
+ in_dim = out_dim * stride + kernel_dim * dilation - pad
1957
+ if in_dim < 0 or in_dim >= in_size:
1958
+ valid = False
1959
+ break
1960
+ in_index.append(in_dim)
1961
+ if not valid:
1962
+ continue
1963
+ acc += data[(n, ic_global, *in_index)] * weights[
1964
+ (oc_global, ic, *kernel_index)
1965
+ ]
1966
+ output[(n, oc_global, *out_index)] = acc
1967
+ return output
1968
+
1969
+
1970
+ def _apply_lrn(spec, data: np.ndarray) -> np.ndarray:
1971
+ output = np.empty_like(data)
1972
+ spatial_shape = spec.shape[2:]
1973
+ spatial_indices = [()]
1974
+ if spatial_shape:
1975
+ spatial_indices = list(np.ndindex(*spatial_shape))
1976
+ for n in range(spec.shape[0]):
1977
+ for c in range(spec.channels):
1978
+ start = max(0, c - spec.half)
1979
+ end = min(spec.channels - 1, c + spec.half)
1980
+ for index in spatial_indices:
1981
+ sum_val = 0.0
1982
+ for i in range(start, end + 1):
1983
+ value = data[(n, i, *index)]
1984
+ sum_val += value * value
1985
+ scale = spec.bias + (spec.alpha / spec.size) * sum_val
1986
+ output[(n, c, *index)] = data[(n, c, *index)] / math.pow(
1987
+ scale, spec.beta
1988
+ )
1989
+ return output
1990
+
1991
+
1992
+ def _apply_average_pool(op, data: np.ndarray) -> np.ndarray:
1993
+ output = np.zeros((op.batch, op.channels, op.out_h, op.out_w), dtype=data.dtype)
1994
+ for n in range(op.batch):
1995
+ for c in range(op.channels):
1996
+ for oh in range(op.out_h):
1997
+ for ow in range(op.out_w):
1998
+ acc = 0.0
1999
+ count = 0
2000
+ for kh in range(op.kernel_h):
2001
+ ih = oh * op.stride_h + kh - op.pad_top
2002
+ if ih < 0 or ih >= op.in_h:
2003
+ if op.count_include_pad:
2004
+ count += op.kernel_w
2005
+ continue
2006
+ for kw in range(op.kernel_w):
2007
+ iw = ow * op.stride_w + kw - op.pad_left
2008
+ if iw < 0 or iw >= op.in_w:
2009
+ if op.count_include_pad:
2010
+ count += 1
2011
+ continue
2012
+ acc += data[n, c, ih, iw]
2013
+ count += 1
2014
+ output[n, c, oh, ow] = 0.0 if count == 0 else acc / float(count)
2015
+ return output
2016
+
2017
+
2018
+ def _maxpool_min_value(dtype: np.dtype) -> float | int:
2019
+ if np.issubdtype(dtype, np.floating):
2020
+ return -np.inf
2021
+ if np.issubdtype(dtype, np.integer):
2022
+ return np.iinfo(dtype).min
2023
+ raise UnsupportedOpError("MaxPool supports numeric inputs only")
2024
+
2025
+
2026
+ def _apply_maxpool(
2027
+ spec, data: np.ndarray, *, return_indices: bool = False
2028
+ ) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
2029
+ min_value = _maxpool_min_value(data.dtype)
2030
+ output = np.full(
2031
+ (spec.batch, spec.channels, *spec.out_spatial),
2032
+ min_value,
2033
+ dtype=data.dtype,
2034
+ )
2035
+ indices = (
2036
+ np.zeros((spec.batch, spec.channels, *spec.out_spatial), dtype=np.int64)
2037
+ if return_indices
2038
+ else None
2039
+ )
2040
+ pad_begin = spec.pads[: spec.spatial_rank]
2041
+ for n in range(spec.batch):
2042
+ for c in range(spec.channels):
2043
+ for out_index in np.ndindex(*spec.out_spatial):
2044
+ max_value = min_value
2045
+ max_index = 0
2046
+ has_value = False
2047
+ for kernel_index in np.ndindex(*spec.kernel_shape):
2048
+ in_index = []
2049
+ valid = True
2050
+ for out_dim, kernel_dim, stride, dilation, pad in zip(
2051
+ out_index,
2052
+ kernel_index,
2053
+ spec.strides,
2054
+ spec.dilations,
2055
+ pad_begin,
2056
+ ):
2057
+ idx = out_dim * stride + kernel_dim * dilation - pad
2058
+ if idx < 0 or idx >= spec.in_spatial[len(in_index)]:
2059
+ valid = False
2060
+ break
2061
+ in_index.append(idx)
2062
+ if not valid:
2063
+ continue
2064
+ value = data[(n, c, *in_index)]
2065
+ if value > max_value or not has_value:
2066
+ max_value = value
2067
+ has_value = True
2068
+ if return_indices:
2069
+ linear_index = n * spec.channels + c
2070
+ if spec.storage_order == 0:
2071
+ for idx, size in zip(in_index, spec.in_spatial):
2072
+ linear_index = linear_index * size + idx
2073
+ else:
2074
+ spatial_index = 0
2075
+ spatial_stride = 1
2076
+ for idx, size in zip(in_index, spec.in_spatial):
2077
+ spatial_index += idx * spatial_stride
2078
+ spatial_stride *= size
2079
+ linear_index = linear_index * spatial_stride + spatial_index
2080
+ max_index = linear_index
2081
+ output[(n, c, *out_index)] = max_value
2082
+ if return_indices and indices is not None:
2083
+ indices[(n, c, *out_index)] = max_index
2084
+ if return_indices:
2085
+ if indices is None:
2086
+ raise RuntimeError("MaxPool indices were not computed")
2087
+ return output, indices
2088
+ return output
2089
+
2090
+
2091
+ def _apply_lstm(
2092
+ spec,
2093
+ x: np.ndarray,
2094
+ w: np.ndarray,
2095
+ r: np.ndarray,
2096
+ b: np.ndarray | None,
2097
+ sequence_lens: np.ndarray | None,
2098
+ initial_h: np.ndarray | None,
2099
+ initial_c: np.ndarray | None,
2100
+ p: np.ndarray | None,
2101
+ ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
2102
+ if spec.layout == 1:
2103
+ x = np.swapaxes(x, 0, 1)
2104
+ seq_length = spec.seq_length
2105
+ batch_size = spec.batch_size
2106
+ hidden_size = spec.hidden_size
2107
+ num_directions = spec.num_directions
2108
+ if sequence_lens is None:
2109
+ sequence_lens = np.full((batch_size,), seq_length, dtype=np.int64)
2110
+ else:
2111
+ sequence_lens = sequence_lens.astype(np.int64, copy=False)
2112
+ if b is None:
2113
+ b = np.zeros((num_directions, 8 * hidden_size), dtype=x.dtype)
2114
+ if p is None:
2115
+ p = np.zeros((num_directions, 3 * hidden_size), dtype=x.dtype)
2116
+ if initial_h is None:
2117
+ initial_h = np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
2118
+ if initial_c is None:
2119
+ initial_c = np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
2120
+ if spec.layout == 1:
2121
+ initial_h = np.swapaxes(initial_h, 0, 1)
2122
+ initial_c = np.swapaxes(initial_c, 0, 1)
2123
+ output_y = None
2124
+ if spec.output_y is not None:
2125
+ output_y = np.zeros(
2126
+ (seq_length, num_directions, batch_size, hidden_size), dtype=x.dtype
2127
+ )
2128
+ output_y_h = (
2129
+ np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
2130
+ if spec.output_y_h is not None
2131
+ else None
2132
+ )
2133
+ output_y_c = (
2134
+ np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
2135
+ if spec.output_y_c is not None
2136
+ else None
2137
+ )
2138
+ directions = (
2139
+ ("forward", "reverse")
2140
+ if spec.direction == "bidirectional"
2141
+ else (spec.direction,)
2142
+ )
2143
+ for dir_index, dir_kind in enumerate(directions):
2144
+ w_dir = w[dir_index]
2145
+ r_dir = r[dir_index]
2146
+ b_dir = b[dir_index]
2147
+ bias = b_dir[: 4 * hidden_size] + b_dir[4 * hidden_size :]
2148
+ p_dir = p[dir_index]
2149
+ p_i = p_dir[:hidden_size]
2150
+ p_o = p_dir[hidden_size : 2 * hidden_size]
2151
+ p_f = p_dir[2 * hidden_size :]
2152
+ h_prev = initial_h[dir_index].copy()
2153
+ c_prev = initial_c[dir_index].copy()
2154
+ act_offset = dir_index * 3
2155
+ act_f = spec.activation_kinds[act_offset]
2156
+ act_g = spec.activation_kinds[act_offset + 1]
2157
+ act_h = spec.activation_kinds[act_offset + 2]
2158
+ alpha_f = spec.activation_alphas[act_offset]
2159
+ alpha_g = spec.activation_alphas[act_offset + 1]
2160
+ alpha_h = spec.activation_alphas[act_offset + 2]
2161
+ beta_f = spec.activation_betas[act_offset]
2162
+ beta_g = spec.activation_betas[act_offset + 1]
2163
+ beta_h = spec.activation_betas[act_offset + 2]
2164
+ for step in range(seq_length):
2165
+ t_index = step if dir_kind == "forward" else seq_length - 1 - step
2166
+ x_t = x[t_index]
2167
+ gates = x_t @ w_dir.T + h_prev @ r_dir.T + bias
2168
+ if spec.clip is not None and spec.clip > 0:
2169
+ gates = np.clip(gates, -spec.clip, spec.clip)
2170
+ i, o, f, c = np.split(gates, 4, axis=1)
2171
+ i = _apply_lstm_activation(act_f, i + p_i * c_prev, alpha_f, beta_f)
2172
+ if spec.input_forget:
2173
+ f = 1 - i
2174
+ else:
2175
+ f = _apply_lstm_activation(
2176
+ act_f, f + p_f * c_prev, alpha_f, beta_f
2177
+ )
2178
+ c_tilde = _apply_lstm_activation(act_g, c, alpha_g, beta_g)
2179
+ c_new = f * c_prev + i * c_tilde
2180
+ o = _apply_lstm_activation(act_f, o + p_o * c_new, alpha_f, beta_f)
2181
+ h_new = o * _apply_lstm_activation(act_h, c_new, alpha_h, beta_h)
2182
+ active_mask = step < sequence_lens
2183
+ if not np.all(active_mask):
2184
+ h_new = np.where(active_mask[:, None], h_new, h_prev)
2185
+ c_new = np.where(active_mask[:, None], c_new, c_prev)
2186
+ if output_y is not None:
2187
+ output_y[step, dir_index] = np.where(
2188
+ active_mask[:, None], h_new, 0
2189
+ )
2190
+ else:
2191
+ if output_y is not None:
2192
+ output_y[step, dir_index] = h_new
2193
+ h_prev = h_new
2194
+ c_prev = c_new
2195
+ if output_y_h is not None:
2196
+ output_y_h[dir_index] = h_prev
2197
+ if output_y_c is not None:
2198
+ output_y_c[dir_index] = c_prev
2199
+ if spec.layout == 1:
2200
+ if output_y is not None:
2201
+ output_y = np.transpose(output_y, (2, 0, 1, 3))
2202
+ if output_y_h is not None:
2203
+ output_y_h = np.swapaxes(output_y_h, 0, 1)
2204
+ if output_y_c is not None:
2205
+ output_y_c = np.swapaxes(output_y_c, 0, 1)
2206
+ return output_y, output_y_h, output_y_c