emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
@@ -1,2955 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections.abc import Callable, Mapping
4
- import math
5
-
6
- import numpy as np
7
-
8
- from shared.scalar_types import ScalarType
9
- from ..errors import ShapeInferenceError, UnsupportedOpError
10
- from ..ir.context import GraphContext
11
- from ..ir.model import Graph, Node
12
- from ..ir.op_context import OpContext
13
- from ..lowering.attention import resolve_attention_spec
14
- from ..lowering.average_pool import lower_average_pool, lower_global_average_pool
15
- from ..lowering.adagrad import lower_adagrad
16
- from ..lowering.batch_normalization import lower_batch_normalization
17
- from ..lowering.concat import lower_concat
18
- from ..lowering.constant_of_shape import lower_constant_of_shape
19
- from ..lowering.conv import resolve_conv_spec
20
- from ..lowering.conv_transpose import resolve_conv_transpose_spec
21
- from ..lowering.dropout import lower_dropout
22
- from ..lowering.cumsum import lower_cumsum
23
- from ..lowering.einsum import lower_einsum
24
- from ..lowering.flatten import lower_flatten
25
- from ..lowering.gemm import resolve_gemm_spec
26
- from ..lowering.logsoftmax import lower_logsoftmax
27
- from ..lowering.hardmax import lower_hardmax
28
- from ..lowering.lp_normalization import lower_lp_normalization
29
- from ..lowering.lp_pool import lower_lp_pool
30
- from ..lowering.grid_sample import lower_grid_sample
31
- from ..lowering.instance_normalization import lower_instance_normalization
32
- from ..lowering.group_normalization import lower_group_normalization
33
- from ..lowering.layer_normalization import lower_layer_normalization
34
- from ..lowering.non_max_suppression import lower_non_max_suppression
35
- from ..lowering.mean_variance_normalization import (
36
- lower_mean_variance_normalization,
37
- )
38
- from ..lowering.global_max_pool import lower_global_max_pool
39
- from ..lowering.negative_log_likelihood_loss import (
40
- lower_negative_log_likelihood_loss,
41
- )
42
- from ..lowering.nonzero import lower_nonzero
43
- from ..lowering.pad import lower_pad
44
- from ..lowering.expand import lower_expand
45
- from ..lowering.range import lower_range
46
- from ..lowering.one_hot import lower_onehot
47
- from ..lowering.split import lower_split
48
- from ..lowering.softmax_cross_entropy_loss import (
49
- lower_softmax_cross_entropy_loss,
50
- )
51
- from ..lowering.arg_reduce import lower_arg_reduce
52
- from ..lowering.topk import lower_topk
53
- from ..lowering.lstm import ACTIVATION_KIND_BY_NAME, resolve_lstm_spec
54
- from ..lowering.lrn import resolve_lrn_spec
55
- from ..lowering.matmul import lower_matmul
56
- from ..lowering.qlinear_matmul import lower_qlinear_matmul
57
- from ..lowering.maxpool import resolve_maxpool_spec
58
- from ..lowering.reduce import (
59
- REDUCE_KIND_BY_OP,
60
- REDUCE_OUTPUTS_FLOAT_ONLY,
61
- normalize_reduce_axes,
62
- resolve_reduce_axes,
63
- )
64
- from ..lowering.reshape import lower_reshape
65
- from ..lowering.scatter_nd import lower_scatternd
66
- from ..lowering.tensor_scatter import lower_tensor_scatter
67
- from ..lowering.slice import _normalize_slices
68
- from ..lowering.shape import lower_shape
69
- from ..lowering.size import lower_size
70
- from ..lowering.softmax import lower_softmax
71
- from ..lowering.rms_normalization import lower_rms_normalization
72
- from ..lowering.rotary_embedding import lower_rotary_embedding
73
- from ..lowering.squeeze import lower_squeeze
74
- from ..lowering.transpose import lower_transpose
75
- from ..lowering.unsqueeze import lower_unsqueeze
76
- from ..lowering.where import lower_where
77
- from ..lowering.quantize_linear import resolve_quantize_spec
78
- from ..lowering.variadic import BINARY_ONLY_OPS, VARIADIC_OP_FUNCTIONS
79
- from ..lowering.registry import resolve_dispatch
80
- from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
81
- from ..ops import (
82
- BINARY_OP_TYPES,
83
- COMPARE_FUNCTIONS,
84
- UNARY_OP_TYPES,
85
- apply_binary_op,
86
- apply_unary_op,
87
- binary_op_symbol,
88
- unary_op_symbol,
89
- validate_unary_attrs,
90
- )
91
- from shared.scalar_functions import ScalarFunction, ScalarFunctionError
92
- from ..validation import normalize_axis
93
-
94
- Handler = Callable[["Evaluator", Node], None]
95
- _EVAL_REGISTRY: dict[str, Handler] = {}
96
-
97
-
98
- def register_evaluator(op_type: str) -> Callable[[Handler], Handler]:
99
- def decorator(func: Handler) -> Handler:
100
- _EVAL_REGISTRY[op_type] = func
101
- return func
102
-
103
- return decorator
104
-
105
-
106
- class Evaluator:
107
- def __init__(self, graph: Graph) -> None:
108
- self._graph = graph
109
- self._values: dict[str, np.ndarray] = {}
110
-
111
- @property
112
- def graph(self) -> Graph:
113
- return self._graph
114
-
115
- @property
116
- def values(self) -> dict[str, np.ndarray]:
117
- return self._values
118
-
119
- def run(self, feeds: Mapping[str, np.ndarray]) -> dict[str, np.ndarray]:
120
- values = {
121
- initializer.name: initializer.data
122
- for initializer in self._graph.initializers
123
- }
124
- values.update(feeds)
125
- self._values = values
126
- for node in self._graph.nodes:
127
- self._dispatch(node)
128
- return {
129
- output.name: self._values[output.name]
130
- for output in self._graph.outputs
131
- }
132
-
133
- def _dispatch(self, node: Node) -> None:
134
- handler = resolve_dispatch(
135
- node.op_type,
136
- _EVAL_REGISTRY,
137
- binary_types=BINARY_OP_TYPES,
138
- unary_types=UNARY_OP_TYPES,
139
- binary_fallback=lambda: _eval_binary_unary,
140
- unary_fallback=lambda: _eval_binary_unary,
141
- )
142
- handler(self, node)
143
-
144
-
145
- @register_evaluator("MatMul")
146
- def _eval_matmul(evaluator: Evaluator, node: Node) -> None:
147
- lower_matmul(evaluator.graph, node)
148
- left = evaluator.values[node.inputs[0]]
149
- right = evaluator.values[node.inputs[1]]
150
- evaluator.values[node.outputs[0]] = _apply_matmul(left, right)
151
-
152
-
153
- @register_evaluator("Einsum")
154
- def _eval_einsum(evaluator: Evaluator, node: Node) -> None:
155
- lower_einsum(evaluator.graph, node)
156
- equation_value = node.attrs.get("equation")
157
- if equation_value is None:
158
- raise UnsupportedOpError("Einsum equation attribute is required")
159
- equation = (
160
- equation_value.decode()
161
- if isinstance(equation_value, (bytes, bytearray))
162
- else str(equation_value)
163
- )
164
- inputs = [evaluator.values[name] for name in node.inputs]
165
- evaluator.values[node.outputs[0]] = np.einsum(equation, *inputs)
166
-
167
-
168
- @register_evaluator("Adagrad")
169
- def _eval_adagrad(evaluator: Evaluator, node: Node) -> None:
170
- op = lower_adagrad(evaluator.graph, node)
171
- rate = evaluator.values[op.rate]
172
- timestep = evaluator.values[op.timestep]
173
- rate_value = (
174
- np.array(rate, dtype=op.dtype.np_dtype).reshape(-1)[0].item()
175
- )
176
- timestep_value = (
177
- np.array(timestep, dtype=np.int64).reshape(-1)[0].item()
178
- )
179
- r = op.dtype.np_dtype.type(
180
- rate_value / (1.0 + float(timestep_value) * op.decay_factor)
181
- )
182
- for x_name, g_name, h_name, out_name, h_out_name in zip(
183
- op.inputs,
184
- op.gradients,
185
- op.accumulators,
186
- op.outputs,
187
- op.accumulator_outputs,
188
- ):
189
- x = evaluator.values[x_name]
190
- g = evaluator.values[g_name]
191
- h = evaluator.values[h_name]
192
- g_regularized = op.norm_coefficient * x + g
193
- h_new = h + g_regularized * g_regularized
194
- h_adaptive = np.sqrt(h_new) + op.epsilon
195
- evaluator.values[out_name] = x - r * g_regularized / h_adaptive
196
- evaluator.values[h_out_name] = h_new
197
-
198
-
199
- @register_evaluator("Clip")
200
- def _eval_clip(evaluator: Evaluator, node: Node) -> None:
201
- if not node.inputs or len(node.outputs) != 1:
202
- raise UnsupportedOpError("Clip must have 1 output")
203
- input_name = node.inputs[0]
204
- if not input_name:
205
- raise UnsupportedOpError("Clip input must be provided")
206
- x = evaluator.values[input_name]
207
- min_name = optional_name(node.inputs, 1)
208
- max_name = optional_name(node.inputs, 2)
209
- dtype = value_dtype(evaluator.graph, input_name, node)
210
- if min_name is None:
211
- min_val = (
212
- -np.inf
213
- if dtype.is_float
214
- else np.iinfo(dtype.np_dtype).min
215
- )
216
- else:
217
- min_val = evaluator.values[min_name]
218
- if max_name is None:
219
- max_val = (
220
- np.inf
221
- if dtype.is_float
222
- else np.iinfo(dtype.np_dtype).max
223
- )
224
- else:
225
- max_val = evaluator.values[max_name]
226
- evaluator.values[node.outputs[0]] = np.clip(x, min_val, max_val)
227
-
228
-
229
- def _max_min(lhs: float, rhs: float) -> tuple[float, float]:
230
- if lhs >= rhs:
231
- return rhs, lhs
232
- return lhs, rhs
233
-
234
-
235
- def _suppress_by_iou(
236
- boxes: np.ndarray,
237
- box_index1: int,
238
- box_index2: int,
239
- *,
240
- center_point_box: int,
241
- iou_threshold: float,
242
- ) -> bool:
243
- box1 = boxes[box_index1]
244
- box2 = boxes[box_index2]
245
- if center_point_box == 0:
246
- x1_min, x1_max = _max_min(float(box1[1]), float(box1[3]))
247
- x2_min, x2_max = _max_min(float(box2[1]), float(box2[3]))
248
- intersection_x_min = max(x1_min, x2_min)
249
- intersection_x_max = min(x1_max, x2_max)
250
- if intersection_x_max <= intersection_x_min:
251
- return False
252
-
253
- y1_min, y1_max = _max_min(float(box1[0]), float(box1[2]))
254
- y2_min, y2_max = _max_min(float(box2[0]), float(box2[2]))
255
- intersection_y_min = max(y1_min, y2_min)
256
- intersection_y_max = min(y1_max, y2_max)
257
- if intersection_y_max <= intersection_y_min:
258
- return False
259
- else:
260
- box1_width_half = float(box1[2]) / 2.0
261
- box1_height_half = float(box1[3]) / 2.0
262
- box2_width_half = float(box2[2]) / 2.0
263
- box2_height_half = float(box2[3]) / 2.0
264
-
265
- x1_min = float(box1[0]) - box1_width_half
266
- x1_max = float(box1[0]) + box1_width_half
267
- x2_min = float(box2[0]) - box2_width_half
268
- x2_max = float(box2[0]) + box2_width_half
269
-
270
- y1_min = float(box1[1]) - box1_height_half
271
- y1_max = float(box1[1]) + box1_height_half
272
- y2_min = float(box2[1]) - box2_height_half
273
- y2_max = float(box2[1]) + box2_height_half
274
-
275
- intersection_x_min = max(x1_min, x2_min)
276
- intersection_x_max = min(x1_max, x2_max)
277
- if intersection_x_max <= intersection_x_min:
278
- return False
279
-
280
- intersection_y_min = max(y1_min, y2_min)
281
- intersection_y_max = min(y1_max, y2_max)
282
- if intersection_y_max <= intersection_y_min:
283
- return False
284
-
285
- intersection_area = (intersection_x_max - intersection_x_min) * (
286
- intersection_y_max - intersection_y_min
287
- )
288
- if intersection_area <= 0:
289
- return False
290
-
291
- area1 = (x1_max - x1_min) * (y1_max - y1_min)
292
- area2 = (x2_max - x2_min) * (y2_max - y2_min)
293
- union_area = area1 + area2 - intersection_area
294
-
295
- if area1 <= 0 or area2 <= 0 or union_area <= 0:
296
- return False
297
-
298
- intersection_over_union = intersection_area / union_area
299
- return intersection_over_union > iou_threshold
300
-
301
-
302
- def _exclusive_cumsum(data: np.ndarray, axis: int) -> np.ndarray:
303
- result = np.zeros_like(data)
304
- if data.shape[axis] == 0:
305
- return result
306
- cumsum = np.cumsum(data, axis=axis, dtype=data.dtype)
307
- src_slice = [slice(None)] * data.ndim
308
- dst_slice = [slice(None)] * data.ndim
309
- src_slice[axis] = slice(None, -1)
310
- dst_slice[axis] = slice(1, None)
311
- result[tuple(dst_slice)] = cumsum[tuple(src_slice)]
312
- return result
313
-
314
-
315
- @register_evaluator("CumSum")
316
- def _eval_cumsum(evaluator: Evaluator, node: Node) -> None:
317
- op = lower_cumsum(evaluator.graph, node)
318
- x = evaluator.values[op.input0]
319
- axis = op.axis
320
- if axis is None:
321
- axis_values = evaluator.values[op.axis_input].astype(np.int64, copy=False)
322
- axis_values = axis_values.reshape(-1)
323
- if axis_values.size != 1:
324
- raise UnsupportedOpError("CumSum axis input must be scalar")
325
- axis = normalize_axis(int(axis_values[0]), op.input_shape, node)
326
- data = np.flip(x, axis=axis) if op.reverse else x
327
- if op.exclusive:
328
- result = _exclusive_cumsum(data, axis)
329
- else:
330
- result = np.cumsum(data, axis=axis, dtype=data.dtype)
331
- if op.reverse:
332
- result = np.flip(result, axis=axis)
333
- evaluator.values[op.output] = result
334
-
335
-
336
- @register_evaluator("NonMaxSuppression")
337
- def _eval_nonmax_suppression(evaluator: Evaluator, node: Node) -> None:
338
- op = lower_non_max_suppression(evaluator.graph, node)
339
- boxes = evaluator.values[op.boxes]
340
- scores = evaluator.values[op.scores]
341
-
342
- max_output_boxes_per_class = 0
343
- if op.max_output_boxes_per_class is not None:
344
- max_output_values = evaluator.values[
345
- op.max_output_boxes_per_class
346
- ].astype(np.int64, copy=False)
347
- max_output_values = max_output_values.reshape(-1)
348
- if max_output_values.size != 1:
349
- raise UnsupportedOpError(
350
- "NonMaxSuppression max_output_boxes_per_class must be scalar"
351
- )
352
- max_output_boxes_per_class = max(int(max_output_values[0]), 0)
353
-
354
- iou_threshold = 0.0
355
- if op.iou_threshold is not None:
356
- iou_values = evaluator.values[op.iou_threshold].reshape(-1)
357
- if iou_values.size != 1:
358
- raise UnsupportedOpError(
359
- "NonMaxSuppression iou_threshold must be scalar"
360
- )
361
- iou_threshold = float(iou_values[0])
362
-
363
- score_threshold = 0.0
364
- score_threshold_enabled = op.score_threshold is not None
365
- if op.score_threshold is not None:
366
- score_values = evaluator.values[op.score_threshold].reshape(-1)
367
- if score_values.size != 1:
368
- raise UnsupportedOpError(
369
- "NonMaxSuppression score_threshold must be scalar"
370
- )
371
- score_threshold = float(score_values[0])
372
-
373
- if max_output_boxes_per_class == 0:
374
- evaluator.values[op.output] = np.empty((0, 3), dtype=np.int64)
375
- return
376
-
377
- num_batches = boxes.shape[0]
378
- num_boxes = boxes.shape[1]
379
- num_classes = scores.shape[1]
380
-
381
- selected_indices: list[tuple[int, int, int]] = []
382
- for batch_index in range(num_batches):
383
- batch_boxes = boxes[batch_index]
384
- for class_index in range(num_classes):
385
- class_scores = scores[batch_index, class_index]
386
- candidates: list[tuple[float, int]] = []
387
- if score_threshold_enabled:
388
- for box_index in range(num_boxes):
389
- score = float(class_scores[box_index])
390
- if score > score_threshold:
391
- candidates.append((score, box_index))
392
- else:
393
- for box_index in range(num_boxes):
394
- candidates.append(
395
- (float(class_scores[box_index]), box_index)
396
- )
397
- candidates.sort(key=lambda item: (item[0], -item[1]))
398
- selected_boxes: list[int] = []
399
- while (
400
- candidates
401
- and len(selected_boxes) < max_output_boxes_per_class
402
- ):
403
- _, box_index = candidates.pop()
404
- if any(
405
- _suppress_by_iou(
406
- batch_boxes,
407
- box_index,
408
- selected_index,
409
- center_point_box=op.center_point_box,
410
- iou_threshold=iou_threshold,
411
- )
412
- for selected_index in selected_boxes
413
- ):
414
- continue
415
- selected_boxes.append(box_index)
416
- selected_indices.append(
417
- (batch_index, class_index, box_index)
418
- )
419
-
420
- result = np.empty((len(selected_indices), 3), dtype=np.int64)
421
- for idx, (batch_index, class_index, box_index) in enumerate(
422
- selected_indices
423
- ):
424
- result[idx, 0] = batch_index
425
- result[idx, 1] = class_index
426
- result[idx, 2] = box_index
427
- evaluator.values[op.output] = result
428
-
429
-
430
- @register_evaluator("Pad")
431
- def _eval_pad(evaluator: Evaluator, node: Node) -> None:
432
- op = lower_pad(evaluator.graph, node)
433
- x = evaluator.values[op.input0]
434
- if op.value_input is not None:
435
- value_array = evaluator.values[op.value_input]
436
- pad_value = np.array(value_array, dtype=op.dtype.np_dtype).reshape(-1)[0].item()
437
- else:
438
- pad_value = np.array(op.value, dtype=op.dtype.np_dtype).item()
439
- rank = len(op.input_shape)
440
- if op.axes_input is not None:
441
- axes_values = evaluator.values[op.axes_input].astype(
442
- np.int64, copy=False
443
- )
444
- axes_values = axes_values.reshape(-1)
445
- if op.pads_input is not None:
446
- pads_values = evaluator.values[op.pads_input].astype(
447
- np.int64, copy=False
448
- )
449
- pads_values = pads_values.reshape(-1)
450
- else:
451
- pads_values = np.array(op.pads_values, dtype=np.int64).reshape(-1)
452
- axis_count = len(axes_values)
453
- pads_begin = np.zeros(rank, dtype=np.int64)
454
- pads_end = np.zeros(rank, dtype=np.int64)
455
- for index, axis_value in enumerate(axes_values):
456
- axis = int(axis_value)
457
- if axis < 0:
458
- axis += rank
459
- pads_begin[axis] = int(pads_values[index])
460
- pads_end[axis] = int(pads_values[index + axis_count])
461
- pad_width = tuple(
462
- (int(pads_begin[index]), int(pads_end[index]))
463
- for index in range(rank)
464
- )
465
- elif op.pads_input is not None:
466
- pads_values = evaluator.values[op.pads_input].astype(np.int64, copy=False)
467
- pads_values = pads_values.reshape(-1)
468
- if op.pads_axis_map is not None:
469
- axis_count = sum(
470
- 1 for axis_index in op.pads_axis_map if axis_index is not None
471
- )
472
- pads_begin = np.zeros(rank, dtype=np.int64)
473
- pads_end = np.zeros(rank, dtype=np.int64)
474
- for axis, pad_index in enumerate(op.pads_axis_map):
475
- if pad_index is not None:
476
- pads_begin[axis] = int(pads_values[pad_index])
477
- pads_end[axis] = int(
478
- pads_values[pad_index + axis_count]
479
- )
480
- pad_width = tuple(
481
- (int(pads_begin[index]), int(pads_end[index]))
482
- for index in range(rank)
483
- )
484
- else:
485
- pads_begin = pads_values[:rank]
486
- pads_end = pads_values[rank: rank * 2]
487
- pad_width = tuple(
488
- (int(pads_begin[index]), int(pads_end[index]))
489
- for index in range(rank)
490
- )
491
- else:
492
- pad_width = tuple(zip(op.pads_begin or (), op.pads_end or ()))
493
- pad_kwargs = {}
494
- if op.mode == "constant":
495
- pad_kwargs["constant_values"] = pad_value
496
- evaluator.values[op.output] = np.pad(
497
- x,
498
- pad_width,
499
- mode=op.mode,
500
- **pad_kwargs,
501
- )
502
-
503
-
504
- @register_evaluator("ScatterND")
505
- def _eval_scatternd(evaluator: Evaluator, node: Node) -> None:
506
- op = lower_scatternd(evaluator.graph, node)
507
- data = evaluator.values[op.data]
508
- indices = evaluator.values[op.indices]
509
- updates = evaluator.values[op.updates]
510
- output = np.array(data, copy=True)
511
- index_depth = op.indices_shape[-1]
512
- update_indices_shape = op.indices_shape[:-1]
513
- update_count = int(np.prod(update_indices_shape)) if update_indices_shape else 1
514
- flat_indices = indices.astype(np.int64, copy=False).reshape(
515
- update_count, index_depth
516
- )
517
- tail_shape = op.data_shape[index_depth:]
518
- updates_reshaped = updates.reshape((update_count,) + tail_shape)
519
- for index, index_values in enumerate(flat_indices):
520
- output_index: list[int | slice] = []
521
- for axis, value in enumerate(index_values):
522
- axis_size = op.data_shape[axis]
523
- idx = int(value)
524
- if idx < 0:
525
- idx += axis_size
526
- if idx < 0 or idx >= axis_size:
527
- raise UnsupportedOpError(
528
- "ScatterND indices must be within data bounds"
529
- )
530
- output_index.append(idx)
531
- output_index.extend([slice(None)] * len(tail_shape))
532
- target = tuple(output_index)
533
- update_value = updates_reshaped[index]
534
- if op.reduction == "none":
535
- output[target] = update_value
536
- elif op.reduction == "add":
537
- output[target] = output[target] + update_value
538
- elif op.reduction == "mul":
539
- output[target] = output[target] * update_value
540
- elif op.reduction == "min":
541
- output[target] = np.minimum(output[target], update_value)
542
- elif op.reduction == "max":
543
- output[target] = np.maximum(output[target], update_value)
544
- else:
545
- raise UnsupportedOpError(
546
- f"Unsupported ScatterND reduction {op.reduction}"
547
- )
548
- evaluator.values[op.output] = output
549
-
550
-
551
- @register_evaluator("TensorScatter")
552
- def _eval_tensor_scatter(evaluator: Evaluator, node: Node) -> None:
553
- op = lower_tensor_scatter(evaluator.graph, node)
554
- past_cache = evaluator.values[op.past_cache]
555
- update = evaluator.values[op.update]
556
- if op.write_indices is None:
557
- write_indices = np.zeros((past_cache.shape[0],), dtype=np.int64)
558
- else:
559
- write_indices = evaluator.values[op.write_indices].astype(
560
- np.int64, copy=False
561
- )
562
- axis = op.axis
563
- max_sequence_length = past_cache.shape[axis]
564
- sequence_length = update.shape[axis]
565
- output = np.array(past_cache, copy=True)
566
- for prefix_idx in np.ndindex(past_cache.shape[:axis]):
567
- batch_idx = prefix_idx[0]
568
- base_index = int(write_indices[batch_idx])
569
- for sequence_idx in range(sequence_length):
570
- cache_idx = (*prefix_idx, base_index + sequence_idx)
571
- if op.mode == "circular":
572
- cache_idx = tuple(
573
- np.mod(np.asarray(cache_idx), max_sequence_length)
574
- )
575
- update_idx = (*prefix_idx, sequence_idx)
576
- output[cache_idx] = update[update_idx]
577
- evaluator.values[op.output] = output
578
-
579
-
580
- @register_evaluator("Celu")
581
- def _eval_celu(evaluator: Evaluator, node: Node) -> None:
582
- if len(node.inputs) != 1 or len(node.outputs) != 1:
583
- raise UnsupportedOpError("Celu must have 1 input and 1 output")
584
- dtype = value_dtype(evaluator.graph, node.inputs[0], node)
585
- if not dtype.is_float:
586
- raise UnsupportedOpError("Celu only supports floating-point inputs")
587
- alpha = float(node.attrs.get("alpha", 1.0))
588
- x = evaluator.values[node.inputs[0]]
589
- evaluator.values[node.outputs[0]] = np.where(
590
- x > 0,
591
- x,
592
- alpha * (np.exp(x / alpha) - 1.0),
593
- )
594
-
595
-
596
- @register_evaluator("Swish")
597
- def _eval_swish(evaluator: Evaluator, node: Node) -> None:
598
- if len(node.inputs) != 1 or len(node.outputs) != 1:
599
- raise UnsupportedOpError("Swish must have 1 input and 1 output")
600
- dtype = value_dtype(evaluator.graph, node.inputs[0], node)
601
- if not dtype.is_float:
602
- raise UnsupportedOpError("Swish only supports floating-point inputs")
603
- alpha = float(node.attrs.get("alpha", 1.0))
604
- x = evaluator.values[node.inputs[0]]
605
- evaluator.values[node.outputs[0]] = x / (1.0 + np.exp(-alpha * x))
606
-
607
-
608
- def _grid_sample_denormalize(
609
- value: float, length: int, *, align_corners: bool
610
- ) -> float:
611
- if align_corners:
612
- return (value + 1.0) * (length - 1) / 2.0
613
- return ((value + 1.0) * length - 1.0) / 2.0
614
-
615
-
616
- def _grid_sample_reflect(value: float, x_min: float, x_max: float) -> float:
617
- rng = x_max - x_min
618
- if rng == 0:
619
- return x_min
620
- if value < x_min:
621
- dx = x_min - value
622
- n = int(dx / rng)
623
- r = dx - n * rng
624
- return x_min + r if n % 2 == 0 else x_max - r
625
- if value > x_max:
626
- dx = value - x_max
627
- n = int(dx / rng)
628
- r = dx - n * rng
629
- return x_max - r if n % 2 == 0 else x_min + r
630
- return value
631
-
632
-
633
- def _grid_sample_border(
634
- dims: tuple[int, ...], *, align_corners: bool
635
- ) -> tuple[list[float], list[float]]:
636
- min_vals: list[float] = []
637
- max_vals: list[float] = []
638
- for dim in dims:
639
- if align_corners:
640
- min_vals.append(0.0)
641
- max_vals.append(dim - 1.0)
642
- else:
643
- min_vals.append(-0.5)
644
- max_vals.append(dim - 0.5)
645
- return min_vals, max_vals
646
-
647
-
648
- def _grid_sample_pixel_at(
649
- data: np.ndarray,
650
- indices: list[int],
651
- border_min: list[float],
652
- border_max: list[float],
653
- padding_mode: str,
654
- ) -> float:
655
- if padding_mode == "zeros":
656
- for idx, dim in zip(indices, data.shape):
657
- if idx < 0 or idx >= dim:
658
- return data.dtype.type(0)
659
- return data[tuple(indices)]
660
- if padding_mode == "border":
661
- clamped = [
662
- 0 if idx < 0 else dim - 1 if idx >= dim else idx
663
- for idx, dim in zip(indices, data.shape)
664
- ]
665
- return data[tuple(clamped)]
666
- reflected = [
667
- int(_grid_sample_reflect(idx, border_min[i], border_max[i]))
668
- for i, idx in enumerate(indices)
669
- ]
670
- return data[tuple(reflected)]
671
-
672
-
673
- def _grid_sample_linear_1d(
674
- data: np.ndarray,
675
- coord: float,
676
- border_min: float,
677
- border_max: float,
678
- padding_mode: str,
679
- ) -> float:
680
- base = int(np.floor(coord))
681
- weight = coord - base
682
- lower = _grid_sample_pixel_at(
683
- data, [base], [border_min], [border_max], padding_mode
684
- )
685
- upper = _grid_sample_pixel_at(
686
- data, [base + 1], [border_min], [border_max], padding_mode
687
- )
688
- return (1.0 - weight) * lower + weight * upper
689
-
690
-
691
- def _grid_sample_cubic_coeffs(x: float) -> np.ndarray:
692
- alpha = -0.75
693
- abs_x = abs(x)
694
- coeffs = np.empty((4,), dtype=np.float64)
695
- coeffs[0] = (
696
- (alpha * (abs_x + 1.0) - 5.0 * alpha) * (abs_x + 1.0) + 8.0 * alpha
697
- ) * (abs_x + 1.0) - 4.0 * alpha
698
- coeffs[1] = ((alpha + 2.0) * abs_x - (alpha + 3.0)) * abs_x * abs_x + 1.0
699
- inv_x = 1.0 - abs_x
700
- coeffs[2] = ((alpha + 2.0) * inv_x - (alpha + 3.0)) * inv_x * inv_x + 1.0
701
- span = 2.0 - abs_x
702
- coeffs[3] = (
703
- (alpha * span - 5.0 * alpha) * span + 8.0 * alpha
704
- ) * span - 4.0 * alpha
705
- return coeffs
706
-
707
-
708
- def _grid_sample_cubic_1d(
709
- data: np.ndarray,
710
- coord: float,
711
- border_min: float,
712
- border_max: float,
713
- padding_mode: str,
714
- ) -> float:
715
- base = int(np.floor(coord))
716
- coeffs = _grid_sample_cubic_coeffs(coord - base)
717
- values = np.empty((4,), dtype=np.float64)
718
- for offset in range(4):
719
- values[offset] = _grid_sample_pixel_at(
720
- data,
721
- [base - 1 + offset],
722
- [border_min],
723
- [border_max],
724
- padding_mode,
725
- )
726
- return float(coeffs @ values)
727
-
728
-
729
- def _grid_sample_linear_nd(
730
- data: np.ndarray,
731
- coords: np.ndarray,
732
- border_min: list[float],
733
- border_max: list[float],
734
- padding_mode: str,
735
- ) -> float:
736
- if data.ndim == 1:
737
- return _grid_sample_linear_1d(
738
- data, float(coords[0]), border_min[0], border_max[0], padding_mode
739
- )
740
- reduced = np.array(
741
- [
742
- _grid_sample_linear_nd(
743
- data[index],
744
- coords[1:],
745
- border_min[1:],
746
- border_max[1:],
747
- padding_mode,
748
- )
749
- for index in range(data.shape[0])
750
- ],
751
- dtype=np.float64,
752
- )
753
- return _grid_sample_linear_1d(
754
- reduced, float(coords[0]), border_min[0], border_max[0], padding_mode
755
- )
756
-
757
-
758
- def _grid_sample_cubic_nd(
759
- data: np.ndarray,
760
- coords: np.ndarray,
761
- border_min: list[float],
762
- border_max: list[float],
763
- padding_mode: str,
764
- ) -> float:
765
- if data.ndim == 1:
766
- return _grid_sample_cubic_1d(
767
- data, float(coords[0]), border_min[0], border_max[0], padding_mode
768
- )
769
- reduced = np.array(
770
- [
771
- _grid_sample_cubic_nd(
772
- data[index],
773
- coords[1:],
774
- border_min[1:],
775
- border_max[1:],
776
- padding_mode,
777
- )
778
- for index in range(data.shape[0])
779
- ],
780
- dtype=np.float64,
781
- )
782
- return _grid_sample_cubic_1d(
783
- reduced, float(coords[0]), border_min[0], border_max[0], padding_mode
784
- )
785
-
786
-
787
- @register_evaluator("GridSample")
788
- def _eval_grid_sample(evaluator: Evaluator, node: Node) -> None:
789
- op = lower_grid_sample(evaluator.graph, node)
790
- input_data = evaluator.values[op.input0]
791
- grid_data = evaluator.values[op.grid]
792
- output = np.empty(op.output_shape, dtype=input_data.dtype)
793
- if output.size == 0:
794
- evaluator.values[op.output] = output
795
- return
796
- dims = op.input_spatial
797
- border_min, border_max = _grid_sample_border(
798
- dims, align_corners=op.align_corners
799
- )
800
- for n in range(op.output_shape[0]):
801
- grid_batch = grid_data[n]
802
- for c in range(op.output_shape[1]):
803
- input_slice = input_data[n, c]
804
- for out_idx in np.ndindex(*op.output_spatial):
805
- coords = np.array(
806
- grid_batch[out_idx][::-1], dtype=np.float64
807
- )
808
- for i, dim in enumerate(dims):
809
- coords[i] = _grid_sample_denormalize(
810
- float(coords[i]), dim, align_corners=op.align_corners
811
- )
812
- if op.mode == "nearest":
813
- rounded = np.rint(coords).astype(int)
814
- if op.padding_mode != "zeros":
815
- for i, dim in enumerate(dims):
816
- if (
817
- rounded[i] < border_min[i]
818
- or rounded[i] > border_max[i]
819
- ):
820
- if op.padding_mode == "border":
821
- rounded[i] = min(
822
- max(rounded[i], 0), dim - 1
823
- )
824
- else:
825
- rounded[i] = int(
826
- _grid_sample_reflect(
827
- rounded[i],
828
- border_min[i],
829
- border_max[i],
830
- )
831
- )
832
- value = _grid_sample_pixel_at(
833
- input_slice,
834
- rounded.tolist(),
835
- border_min,
836
- border_max,
837
- op.padding_mode,
838
- )
839
- else:
840
- if op.padding_mode != "zeros":
841
- for i, dim in enumerate(dims):
842
- if (
843
- coords[i] < border_min[i]
844
- or coords[i] > border_max[i]
845
- ):
846
- if op.padding_mode == "border":
847
- coords[i] = min(
848
- max(coords[i], 0.0), dim - 1.0
849
- )
850
- else:
851
- coords[i] = _grid_sample_reflect(
852
- coords[i],
853
- border_min[i],
854
- border_max[i],
855
- )
856
- if op.mode == "linear":
857
- value = _grid_sample_linear_nd(
858
- input_slice,
859
- coords,
860
- border_min,
861
- border_max,
862
- op.padding_mode,
863
- )
864
- else:
865
- value = _grid_sample_cubic_nd(
866
- input_slice,
867
- coords,
868
- border_min,
869
- border_max,
870
- op.padding_mode,
871
- )
872
- output[(n, c, *out_idx)] = value
873
- evaluator.values[op.output] = output
874
-
875
-
876
- _VARIADIC_COMBINE_FUNCS: dict[
877
- ScalarFunction, Callable[[np.ndarray, np.ndarray], np.ndarray]
878
- ] = {
879
- ScalarFunction.ADD: np.add,
880
- ScalarFunction.MAXIMUM: np.maximum,
881
- ScalarFunction.MINIMUM: np.minimum,
882
- ScalarFunction.LOGICAL_AND: np.logical_and,
883
- ScalarFunction.LOGICAL_OR: np.logical_or,
884
- ScalarFunction.LOGICAL_XOR: np.logical_xor,
885
- ScalarFunction.BITWISE_AND: np.bitwise_and,
886
- ScalarFunction.BITWISE_OR: np.bitwise_or,
887
- ScalarFunction.BITWISE_XOR: np.bitwise_xor,
888
- }
889
-
890
-
891
- def _validate_variadic_inputs(
892
- evaluator: Evaluator, node: Node, *, function: ScalarFunction
893
- ) -> tuple[ScalarType, tuple[int, ...]]:
894
- if len(node.outputs) != 1:
895
- raise UnsupportedOpError(f"{node.op_type} must have 1 output")
896
- if node.op_type in BINARY_ONLY_OPS:
897
- if len(node.inputs) != 2:
898
- raise UnsupportedOpError(
899
- f"{node.op_type} must have exactly 2 inputs"
900
- )
901
- elif len(node.inputs) < 2:
902
- raise UnsupportedOpError(
903
- f"{node.op_type} must have at least 2 inputs"
904
- )
905
- for name in node.inputs:
906
- if not name:
907
- raise UnsupportedOpError(f"{node.op_type} input must be provided")
908
- op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
909
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
910
- if op_dtype != output_dtype:
911
- raise UnsupportedOpError(
912
- f"{node.op_type} expects matching input/output dtypes, "
913
- f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
914
- )
915
- output_shape = value_shape(evaluator.graph, node.outputs[0], node)
916
- for name in node.inputs:
917
- input_shape = value_shape(evaluator.graph, name, node)
918
- if input_shape != output_shape:
919
- raise UnsupportedOpError(
920
- f"{node.op_type} expects identical input/output shapes"
921
- )
922
- if function in {
923
- ScalarFunction.LOGICAL_AND,
924
- ScalarFunction.LOGICAL_OR,
925
- ScalarFunction.LOGICAL_XOR,
926
- } and op_dtype != ScalarType.BOOL:
927
- raise UnsupportedOpError(f"{node.op_type} expects bool inputs")
928
- if function in {
929
- ScalarFunction.BITWISE_AND,
930
- ScalarFunction.BITWISE_OR,
931
- ScalarFunction.BITWISE_XOR,
932
- } and not op_dtype.is_integer:
933
- raise UnsupportedOpError(f"{node.op_type} expects integer inputs")
934
- if function == ScalarFunction.MEAN and not op_dtype.is_float:
935
- raise UnsupportedOpError(f"{node.op_type} expects floating-point inputs")
936
- return op_dtype, output_shape
937
-
938
-
939
- def _eval_variadic(evaluator: Evaluator, node: Node) -> None:
940
- function = VARIADIC_OP_FUNCTIONS[node.op_type]
941
- _validate_variadic_inputs(evaluator, node, function=function)
942
- values = [evaluator.values[name] for name in node.inputs]
943
- if function == ScalarFunction.MEAN:
944
- combine_func = _VARIADIC_COMBINE_FUNCS[ScalarFunction.ADD]
945
- else:
946
- combine_func = _VARIADIC_COMBINE_FUNCS[function]
947
- result = values[0]
948
- for value in values[1:]:
949
- result = combine_func(result, value)
950
- if function == ScalarFunction.MEAN:
951
- result = result / len(values)
952
- evaluator.values[node.outputs[0]] = result
953
-
954
-
955
- for _op_type in VARIADIC_OP_FUNCTIONS:
956
- register_evaluator(_op_type)(_eval_variadic)
957
-
958
-
959
- @register_evaluator("Shrink")
960
- def _eval_shrink(evaluator: Evaluator, node: Node) -> None:
961
- if len(node.inputs) != 1 or len(node.outputs) != 1:
962
- raise UnsupportedOpError("Shrink must have 1 input and 1 output")
963
- bias = float(node.attrs.get("bias", 0.0))
964
- lambd = float(node.attrs.get("lambd", 0.5))
965
- x = evaluator.values[node.inputs[0]]
966
- result = np.where(
967
- x < -lambd,
968
- x + bias,
969
- np.where(x > lambd, x - bias, 0.0),
970
- )
971
- if result.dtype != x.dtype:
972
- result = result.astype(x.dtype)
973
- evaluator.values[node.outputs[0]] = result
974
-
975
-
976
- @register_evaluator("IsInf")
977
- def _eval_isinf(evaluator: Evaluator, node: Node) -> None:
978
- if len(node.inputs) != 1 or len(node.outputs) != 1:
979
- raise UnsupportedOpError("IsInf must have 1 input and 1 output")
980
- input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
981
- if not input_dtype.is_float:
982
- raise UnsupportedOpError("IsInf only supports floating-point inputs")
983
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
984
- if output_dtype != ScalarType.BOOL:
985
- raise UnsupportedOpError("IsInf output must be bool")
986
- detect_negative = int(node.attrs.get("detect_negative", 1))
987
- detect_positive = int(node.attrs.get("detect_positive", 1))
988
- if detect_negative not in {0, 1} or detect_positive not in {0, 1}:
989
- raise UnsupportedOpError(
990
- "IsInf detect_negative and detect_positive must be 0 or 1"
991
- )
992
- x = evaluator.values[node.inputs[0]]
993
- if detect_negative and detect_positive:
994
- result = np.isinf(x)
995
- elif detect_negative:
996
- result = np.isneginf(x)
997
- elif detect_positive:
998
- result = np.isposinf(x)
999
- else:
1000
- result = np.zeros(x.shape, dtype=bool)
1001
- evaluator.values[node.outputs[0]] = result
1002
-
1003
-
1004
- @register_evaluator("IsNaN")
1005
- def _eval_isnan(evaluator: Evaluator, node: Node) -> None:
1006
- if len(node.inputs) != 1 or len(node.outputs) != 1:
1007
- raise UnsupportedOpError("IsNaN must have 1 input and 1 output")
1008
- input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1009
- if not input_dtype.is_float:
1010
- raise UnsupportedOpError("IsNaN only supports floating-point inputs")
1011
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1012
- if output_dtype != ScalarType.BOOL:
1013
- raise UnsupportedOpError("IsNaN output must be bool")
1014
- x = evaluator.values[node.inputs[0]]
1015
- evaluator.values[node.outputs[0]] = np.isnan(x)
1016
-
1017
-
1018
- @register_evaluator("Gemm")
1019
- def _eval_gemm(evaluator: Evaluator, node: Node) -> None:
1020
- op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1021
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1022
- if op_dtype != output_dtype:
1023
- raise UnsupportedOpError(
1024
- f"{node.op_type} expects matching input/output dtypes, "
1025
- f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1026
- )
1027
- spec = resolve_gemm_spec(evaluator.graph, node, op_dtype)
1028
- left = evaluator.values[node.inputs[0]]
1029
- right = evaluator.values[node.inputs[1]]
1030
- if spec.trans_a:
1031
- left = left.T
1032
- if spec.trans_b:
1033
- right = right.T
1034
- result = _apply_matmul(left, right)
1035
- if op_dtype.is_float:
1036
- alpha = float(spec.alpha)
1037
- beta = float(spec.beta)
1038
- else:
1039
- alpha = int(spec.alpha)
1040
- beta = int(spec.beta)
1041
- if alpha != 1:
1042
- result = result * alpha
1043
- if len(node.inputs) == 3:
1044
- bias = evaluator.values[node.inputs[2]]
1045
- if beta != 1:
1046
- bias = bias * beta
1047
- result = result + bias
1048
- evaluator.values[node.outputs[0]] = result
1049
-
1050
-
1051
- @register_evaluator("Cast")
1052
- def _eval_cast(evaluator: Evaluator, node: Node) -> None:
1053
- if len(node.inputs) != 1 or len(node.outputs) != 1:
1054
- raise UnsupportedOpError("Cast must have 1 input and 1 output")
1055
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1056
- input_value = evaluator.values[node.inputs[0]]
1057
- evaluator.values[node.outputs[0]] = input_value.astype(
1058
- output_dtype.np_dtype, copy=False
1059
- )
1060
-
1061
-
1062
- @register_evaluator("CastLike")
1063
- def _eval_castlike(evaluator: Evaluator, node: Node) -> None:
1064
- if len(node.inputs) != 2 or len(node.outputs) != 1:
1065
- raise UnsupportedOpError("CastLike must have 2 inputs and 1 output")
1066
- like_dtype = value_dtype(evaluator.graph, node.inputs[1], node)
1067
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1068
- if output_dtype != like_dtype:
1069
- raise UnsupportedOpError(
1070
- "CastLike output dtype must match like input dtype, "
1071
- f"got {output_dtype.onnx_name} and {like_dtype.onnx_name}"
1072
- )
1073
- input_value = evaluator.values[node.inputs[0]]
1074
- evaluator.values[node.outputs[0]] = input_value.astype(
1075
- output_dtype.np_dtype, copy=False
1076
- )
1077
-
1078
-
1079
- @register_evaluator("Identity")
1080
- def _eval_identity(evaluator: Evaluator, node: Node) -> None:
1081
- if len(node.inputs) != 1 or len(node.outputs) != 1:
1082
- raise UnsupportedOpError("Identity must have 1 input and 1 output")
1083
- value = evaluator.values[node.inputs[0]]
1084
- evaluator.values[node.outputs[0]] = np.array(value, copy=True)
1085
-
1086
-
1087
- @register_evaluator("EyeLike")
1088
- def _eval_eye_like(evaluator: Evaluator, node: Node) -> None:
1089
- if len(node.inputs) != 1 or len(node.outputs) != 1:
1090
- raise UnsupportedOpError("EyeLike must have 1 input and 1 output")
1091
- output_shape = value_shape(evaluator.graph, node.outputs[0], node)
1092
- if len(output_shape) < 2:
1093
- raise UnsupportedOpError("EyeLike expects input rank >= 2")
1094
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1095
- k = int(node.attrs.get("k", 0))
1096
- output = np.zeros(output_shape, dtype=output_dtype.np_dtype)
1097
- rows, cols = output_shape[-2], output_shape[-1]
1098
- row_start = 0 if k >= 0 else -k
1099
- col_start = k if k >= 0 else 0
1100
- if row_start < rows and col_start < cols:
1101
- diag_len = min(rows - row_start, cols - col_start)
1102
- batch_size = int(np.prod(output_shape[:-2])) if output_shape[:-2] else 1
1103
- view = output.reshape(batch_size, rows, cols)
1104
- diag_idx = np.arange(diag_len, dtype=np.int64)
1105
- one = output_dtype.np_dtype.type(1)
1106
- view[:, row_start + diag_idx, col_start + diag_idx] = one
1107
- evaluator.values[node.outputs[0]] = output
1108
-
1109
-
1110
- @register_evaluator("Trilu")
1111
- def _eval_trilu(evaluator: Evaluator, node: Node) -> None:
1112
- if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
1113
- raise UnsupportedOpError("Trilu must have 1 or 2 inputs and 1 output")
1114
- value = evaluator.values[node.inputs[0]]
1115
- if value.ndim < 2:
1116
- raise UnsupportedOpError("Trilu expects input rank >= 2")
1117
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1118
- input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1119
- if output_dtype != input_dtype:
1120
- raise UnsupportedOpError(
1121
- "Trilu expects matching input/output dtypes, "
1122
- f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
1123
- )
1124
- k = 0
1125
- if len(node.inputs) == 2 and node.inputs[1]:
1126
- k_value = np.array(evaluator.values[node.inputs[1]], dtype=np.int64)
1127
- if k_value.size != 1:
1128
- raise UnsupportedOpError("Trilu k input must be scalar")
1129
- k = int(k_value.reshape(-1)[0])
1130
- upper_attr = node.attrs.get("upper", 1)
1131
- upper = bool(int(upper_attr))
1132
- rows, cols = value.shape[-2], value.shape[-1]
1133
- batch_shape = value.shape[:-2]
1134
- batch_size = int(np.prod(batch_shape)) if batch_shape else 1
1135
- view = value.reshape(batch_size, rows, cols)
1136
- if upper:
1137
- mask = np.triu(np.ones((rows, cols), dtype=bool), k=k)
1138
- else:
1139
- mask = np.tril(np.ones((rows, cols), dtype=bool), k=k)
1140
- output = np.where(mask, view, np.zeros_like(view))
1141
- evaluator.values[node.outputs[0]] = output.reshape(value.shape)
1142
-
1143
-
1144
- @register_evaluator("Tile")
1145
- def _eval_tile(evaluator: Evaluator, node: Node) -> None:
1146
- if len(node.inputs) != 2 or len(node.outputs) != 1:
1147
- raise UnsupportedOpError("Tile must have 2 inputs and 1 output")
1148
- value = evaluator.values[node.inputs[0]]
1149
- repeats = evaluator.values[node.inputs[1]]
1150
- repeats = np.array(repeats, dtype=np.int64).reshape(-1)
1151
- if repeats.size != value.ndim:
1152
- raise UnsupportedOpError(
1153
- "Tile repeats must have the same rank as input shape"
1154
- )
1155
- evaluator.values[node.outputs[0]] = np.tile(value, repeats)
1156
-
1157
-
1158
- @register_evaluator("DepthToSpace")
1159
- def _eval_depth_to_space(evaluator: Evaluator, node: Node) -> None:
1160
- if len(node.inputs) != 1 or len(node.outputs) != 1:
1161
- raise UnsupportedOpError("DepthToSpace must have 1 input and 1 output")
1162
- data = evaluator.values[node.inputs[0]]
1163
- if data.ndim != 4:
1164
- raise UnsupportedOpError("DepthToSpace only supports 4D inputs")
1165
- blocksize = int(node.attrs.get("blocksize", 0))
1166
- if blocksize <= 0:
1167
- raise UnsupportedOpError(
1168
- f"DepthToSpace blocksize must be > 0, got {blocksize}"
1169
- )
1170
- mode_attr = node.attrs.get("mode", "DCR")
1171
- if isinstance(mode_attr, bytes):
1172
- mode = mode_attr.decode()
1173
- else:
1174
- mode = str(mode_attr)
1175
- if mode not in {"DCR", "CRD"}:
1176
- raise UnsupportedOpError("DepthToSpace only supports mode DCR or CRD")
1177
- b, c, h, w = data.shape
1178
- if mode == "DCR":
1179
- tmpshape = (
1180
- b,
1181
- blocksize,
1182
- blocksize,
1183
- c // (blocksize * blocksize),
1184
- h,
1185
- w,
1186
- )
1187
- reshaped = data.reshape(tmpshape)
1188
- transposed = np.transpose(reshaped, [0, 3, 4, 1, 5, 2])
1189
- else:
1190
- tmpshape = (
1191
- b,
1192
- c // (blocksize * blocksize),
1193
- blocksize,
1194
- blocksize,
1195
- h,
1196
- w,
1197
- )
1198
- reshaped = data.reshape(tmpshape)
1199
- transposed = np.transpose(reshaped, [0, 1, 4, 2, 5, 3])
1200
- finalshape = (
1201
- b,
1202
- c // (blocksize * blocksize),
1203
- h * blocksize,
1204
- w * blocksize,
1205
- )
1206
- evaluator.values[node.outputs[0]] = np.reshape(transposed, finalshape)
1207
-
1208
-
1209
- @register_evaluator("SpaceToDepth")
1210
- def _eval_space_to_depth(evaluator: Evaluator, node: Node) -> None:
1211
- if len(node.inputs) != 1 or len(node.outputs) != 1:
1212
- raise UnsupportedOpError("SpaceToDepth must have 1 input and 1 output")
1213
- data = evaluator.values[node.inputs[0]]
1214
- if data.ndim != 4:
1215
- raise UnsupportedOpError("SpaceToDepth only supports 4D inputs")
1216
- blocksize = int(node.attrs.get("blocksize", 0))
1217
- if blocksize <= 0:
1218
- raise UnsupportedOpError(
1219
- f"SpaceToDepth blocksize must be > 0, got {blocksize}"
1220
- )
1221
- b, c, h, w = data.shape
1222
- tmpshape = (
1223
- b,
1224
- c,
1225
- h // blocksize,
1226
- blocksize,
1227
- w // blocksize,
1228
- blocksize,
1229
- )
1230
- reshaped = np.reshape(data, tmpshape)
1231
- transposed = np.transpose(reshaped, [0, 3, 5, 1, 2, 4])
1232
- finalshape = (
1233
- b,
1234
- c * blocksize * blocksize,
1235
- h // blocksize,
1236
- w // blocksize,
1237
- )
1238
- evaluator.values[node.outputs[0]] = np.reshape(transposed, finalshape)
1239
-
1240
-
1241
- @register_evaluator("Where")
1242
- def _eval_where(evaluator: Evaluator, node: Node) -> None:
1243
- lower_where(evaluator.graph, node)
1244
- condition = evaluator.values[node.inputs[0]]
1245
- x_value = evaluator.values[node.inputs[1]]
1246
- y_value = evaluator.values[node.inputs[2]]
1247
- evaluator.values[node.outputs[0]] = np.where(condition, x_value, y_value)
1248
-
1249
-
1250
- @register_evaluator("GatherElements")
1251
- def _eval_gather_elements(evaluator: Evaluator, node: Node) -> None:
1252
- if len(node.inputs) != 2 or len(node.outputs) != 1:
1253
- raise UnsupportedOpError("GatherElements must have 2 inputs and 1 output")
1254
- data = evaluator.values[node.inputs[0]]
1255
- indices = evaluator.values[node.inputs[1]]
1256
- if indices.dtype.type not in {np.int32, np.int64}:
1257
- raise UnsupportedOpError(
1258
- f"GatherElements indices must be int32 or int64, got {indices.dtype}"
1259
- )
1260
- axis = normalize_axis(int(node.attrs.get("axis", 0)), data.shape, node)
1261
- evaluator.values[node.outputs[0]] = np.take_along_axis(
1262
- data, indices, axis=axis
1263
- )
1264
-
1265
-
1266
- @register_evaluator("Gather")
1267
- def _eval_gather(evaluator: Evaluator, node: Node) -> None:
1268
- if len(node.inputs) != 2 or len(node.outputs) != 1:
1269
- raise UnsupportedOpError("Gather must have 2 inputs and 1 output")
1270
- data = evaluator.values[node.inputs[0]]
1271
- indices = evaluator.values[node.inputs[1]]
1272
- if indices.dtype.type not in {np.int32, np.int64}:
1273
- raise UnsupportedOpError(
1274
- f"Gather indices must be int32 or int64, got {indices.dtype}"
1275
- )
1276
- axis = normalize_axis(int(node.attrs.get("axis", 0)), data.shape, node)
1277
- evaluator.values[node.outputs[0]] = np.take(data, indices, axis=axis)
1278
-
1279
-
1280
- @register_evaluator("GatherND")
1281
- def _eval_gather_nd(evaluator: Evaluator, node: Node) -> None:
1282
- if len(node.inputs) != 2 or len(node.outputs) != 1:
1283
- raise UnsupportedOpError("GatherND must have 2 inputs and 1 output")
1284
- data = evaluator.values[node.inputs[0]]
1285
- indices = evaluator.values[node.inputs[1]]
1286
- if indices.dtype.type not in {np.int32, np.int64}:
1287
- raise UnsupportedOpError(
1288
- f"GatherND indices must be int32 or int64, got {indices.dtype}"
1289
- )
1290
- if indices.ndim < 1:
1291
- raise UnsupportedOpError("GatherND indices must have rank >= 1")
1292
- batch_dims = int(node.attrs.get("batch_dims", 0))
1293
- if batch_dims < 0:
1294
- raise UnsupportedOpError(
1295
- f"GatherND batch_dims must be >= 0, got {batch_dims}"
1296
- )
1297
- if batch_dims > indices.ndim - 1:
1298
- raise UnsupportedOpError(
1299
- "GatherND batch_dims must be <= indices rank - 1, "
1300
- f"got {batch_dims} vs {indices.ndim - 1}"
1301
- )
1302
- if batch_dims > data.ndim:
1303
- raise UnsupportedOpError(
1304
- "GatherND batch_dims must be <= data rank, "
1305
- f"got {batch_dims} vs {data.ndim}"
1306
- )
1307
- if tuple(data.shape[:batch_dims]) != tuple(indices.shape[:batch_dims]):
1308
- raise UnsupportedOpError(
1309
- "GatherND batch_dims must match on data/indices, "
1310
- f"got {data.shape} vs {indices.shape}"
1311
- )
1312
- index_depth = indices.shape[-1]
1313
- if index_depth <= 0:
1314
- raise UnsupportedOpError(
1315
- "GatherND indices final dimension must be >= 1"
1316
- )
1317
- if index_depth > data.ndim - batch_dims:
1318
- raise UnsupportedOpError(
1319
- "GatherND indices final dimension must be <= data rank - "
1320
- f"batch_dims, got {index_depth} vs {data.ndim - batch_dims}"
1321
- )
1322
- tail_shape = data.shape[batch_dims + index_depth :]
1323
- output_shape = indices.shape[:-1] + tail_shape
1324
- output = np.empty(output_shape, dtype=data.dtype)
1325
- indices_prefix_shape = indices.shape[:-1]
1326
- prefix_iter = (
1327
- np.ndindex(*indices_prefix_shape) if indices_prefix_shape else [()]
1328
- )
1329
- for prefix in prefix_iter:
1330
- raw_index = indices[prefix]
1331
- if index_depth == 1:
1332
- index_values = [int(np.asarray(raw_index).item())]
1333
- else:
1334
- index_values = [int(value) for value in raw_index]
1335
- for dim_index, value in enumerate(index_values):
1336
- if value < 0:
1337
- index_values[dim_index] = value + data.shape[
1338
- batch_dims + dim_index
1339
- ]
1340
- data_index = list(prefix[:batch_dims]) + index_values
1341
- data_index.extend([slice(None)] * len(tail_shape))
1342
- output_index = prefix + (slice(None),) * len(tail_shape)
1343
- output[output_index] = data[tuple(data_index)]
1344
- evaluator.values[node.outputs[0]] = output
1345
-
1346
-
1347
- @register_evaluator("Slice")
1348
- def _eval_slice(evaluator: Evaluator, node: Node) -> None:
1349
- input_value = evaluator.values[node.inputs[0]]
1350
- if "starts" in node.attrs or "ends" in node.attrs:
1351
- starts = [int(value) for value in node.attrs.get("starts", [])]
1352
- ends = [int(value) for value in node.attrs.get("ends", [])]
1353
- axes_attr = node.attrs.get("axes")
1354
- axes = [int(value) for value in axes_attr] if axes_attr else None
1355
- steps = None
1356
- else:
1357
- if len(node.inputs) < 3:
1358
- raise UnsupportedOpError(
1359
- f"{node.op_type} expects at least 3 inputs"
1360
- )
1361
- starts_value = evaluator.values[node.inputs[1]]
1362
- ends_value = evaluator.values[node.inputs[2]]
1363
- if starts_value.dtype.type not in {np.int32, np.int64}:
1364
- raise UnsupportedOpError(
1365
- f"{node.op_type} starts input must be int64 or int32"
1366
- )
1367
- if ends_value.dtype.type not in {np.int32, np.int64}:
1368
- raise UnsupportedOpError(
1369
- f"{node.op_type} ends input must be int64 or int32"
1370
- )
1371
- starts = [int(value) for value in starts_value.reshape(-1)]
1372
- ends = [int(value) for value in ends_value.reshape(-1)]
1373
- axes = None
1374
- steps = None
1375
- if len(node.inputs) >= 4 and node.inputs[3]:
1376
- axes_value = evaluator.values[node.inputs[3]]
1377
- if axes_value.dtype.type not in {np.int32, np.int64}:
1378
- raise UnsupportedOpError(
1379
- f"{node.op_type} axes input must be int64 or int32"
1380
- )
1381
- axes = [int(value) for value in axes_value.reshape(-1)]
1382
- if len(node.inputs) >= 5 and node.inputs[4]:
1383
- steps_value = evaluator.values[node.inputs[4]]
1384
- if steps_value.dtype.type not in {np.int32, np.int64}:
1385
- raise UnsupportedOpError(
1386
- f"{node.op_type} steps input must be int64 or int32"
1387
- )
1388
- steps = [int(value) for value in steps_value.reshape(-1)]
1389
- normalized_starts, normalized_steps, output_shape = _normalize_slices(
1390
- input_value.shape, starts, ends, axes, steps, node
1391
- )
1392
- slices = tuple(
1393
- slice(start, start + step * size, step)
1394
- for start, step, size in zip(
1395
- normalized_starts, normalized_steps, output_shape
1396
- )
1397
- )
1398
- evaluator.values[node.outputs[0]] = input_value[slices]
1399
-
1400
-
1401
- @register_evaluator("Attention")
1402
- def _eval_attention(evaluator: Evaluator, node: Node) -> None:
1403
- input_q = node.inputs[0]
1404
- input_k = node.inputs[1]
1405
- input_v = node.inputs[2]
1406
- output_y = node.outputs[0]
1407
- op_dtype = node_dtype(evaluator.graph, node, input_q, input_k, input_v, output_y)
1408
- spec = resolve_attention_spec(evaluator.graph, node, op_dtype)
1409
- attn_mask_name = optional_name(node.inputs, 3)
1410
- past_key_name = optional_name(node.inputs, 4)
1411
- past_value_name = optional_name(node.inputs, 5)
1412
- nonpad_name = optional_name(node.inputs, 6)
1413
- present_key_name = optional_name(node.outputs, 1)
1414
- present_value_name = optional_name(node.outputs, 2)
1415
- qk_matmul_output_name = optional_name(node.outputs, 3)
1416
- output, present_key, present_value, qk_output = _apply_attention(
1417
- spec,
1418
- evaluator.values[input_q],
1419
- evaluator.values[input_k],
1420
- evaluator.values[input_v],
1421
- evaluator.values[attn_mask_name] if attn_mask_name else None,
1422
- evaluator.values[past_key_name] if past_key_name else None,
1423
- evaluator.values[past_value_name] if past_value_name else None,
1424
- evaluator.values[nonpad_name] if nonpad_name else None,
1425
- )
1426
- evaluator.values[output_y] = output
1427
- if present_key_name is not None:
1428
- evaluator.values[present_key_name] = present_key
1429
- if present_value_name is not None:
1430
- evaluator.values[present_value_name] = present_value
1431
- if qk_matmul_output_name is not None:
1432
- evaluator.values[qk_matmul_output_name] = qk_output
1433
-
1434
-
1435
- @register_evaluator("RotaryEmbedding")
1436
- def _eval_rotary_embedding(evaluator: Evaluator, node: Node) -> None:
1437
- op = lower_rotary_embedding(evaluator.graph, node)
1438
- x = evaluator.values[op.input0]
1439
- cos_cache = evaluator.values[op.cos_cache]
1440
- sin_cache = evaluator.values[op.sin_cache]
1441
- position_ids = (
1442
- evaluator.values[op.position_ids] if op.position_ids else None
1443
- )
1444
- original_shape = x.shape
1445
- if op.input_rank == 4:
1446
- x = np.transpose(x, (0, 2, 1, 3))
1447
- else:
1448
- x = x.reshape(op.batch, op.seq_len, op.num_heads, op.head_size)
1449
- x_rotate = x[..., : op.rotary_dim]
1450
- x_not_rotate = x[..., op.rotary_dim :]
1451
- if position_ids is not None:
1452
- cos_cache = cos_cache[position_ids]
1453
- sin_cache = sin_cache[position_ids]
1454
- cos_cache = np.expand_dims(cos_cache, axis=2)
1455
- sin_cache = np.expand_dims(sin_cache, axis=2)
1456
- if op.interleaved:
1457
- x1 = x_rotate[..., 0::2]
1458
- x2 = x_rotate[..., 1::2]
1459
- else:
1460
- x1, x2 = np.split(x_rotate, 2, axis=-1)
1461
- real = (cos_cache * x1) - (sin_cache * x2)
1462
- imag = (sin_cache * x1) + (cos_cache * x2)
1463
- if op.interleaved:
1464
- real = np.expand_dims(real, axis=-1)
1465
- imag = np.expand_dims(imag, axis=-1)
1466
- x_rotate_concat = np.concatenate((real, imag), axis=-1)
1467
- x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
1468
- else:
1469
- x_rotate = np.concatenate((real, imag), axis=-1)
1470
- output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
1471
- if op.input_rank == 4:
1472
- output = np.transpose(output, (0, 2, 1, 3))
1473
- else:
1474
- output = output.reshape(original_shape)
1475
- evaluator.values[node.outputs[0]] = output
1476
-
1477
-
1478
- def _apply_lstm_activation(
1479
- kind: int, value: np.ndarray, alpha: float, beta: float
1480
- ) -> np.ndarray:
1481
- if kind == ACTIVATION_KIND_BY_NAME["Relu"]:
1482
- return np.maximum(value, 0)
1483
- if kind == ACTIVATION_KIND_BY_NAME["Tanh"]:
1484
- return np.tanh(value)
1485
- if kind == ACTIVATION_KIND_BY_NAME["Sigmoid"]:
1486
- return 1 / (1 + np.exp(-value))
1487
- if kind == ACTIVATION_KIND_BY_NAME["Affine"]:
1488
- return alpha * value + beta
1489
- if kind == ACTIVATION_KIND_BY_NAME["LeakyRelu"]:
1490
- return np.where(value < 0, alpha * value, value)
1491
- if kind == ACTIVATION_KIND_BY_NAME["ThresholdedRelu"]:
1492
- return np.where(value > alpha, value, 0)
1493
- if kind == ACTIVATION_KIND_BY_NAME["ScaledTanh"]:
1494
- return alpha * np.tanh(beta * value)
1495
- if kind == ACTIVATION_KIND_BY_NAME["HardSigmoid"]:
1496
- return np.clip(alpha * value + beta, 0, 1)
1497
- if kind == ACTIVATION_KIND_BY_NAME["Elu"]:
1498
- return np.where(value >= 0, value, alpha * (np.exp(value) - 1))
1499
- if kind == ACTIVATION_KIND_BY_NAME["Softsign"]:
1500
- return value / (1 + np.abs(value))
1501
- if kind == ACTIVATION_KIND_BY_NAME["Softplus"]:
1502
- return np.log1p(np.exp(value))
1503
- raise UnsupportedOpError(f"Unsupported LSTM activation kind {kind}")
1504
-
1505
-
1506
- @register_evaluator("LSTM")
1507
- def _eval_lstm(evaluator: Evaluator, node: Node) -> None:
1508
- spec = resolve_lstm_spec(evaluator.graph, node)
1509
- inputs = evaluator.values
1510
- x = inputs[spec.input_x]
1511
- w = inputs[spec.input_w]
1512
- r = inputs[spec.input_r]
1513
- b = inputs[spec.input_b] if spec.input_b is not None else None
1514
- sequence_lens = (
1515
- inputs[spec.input_sequence_lens]
1516
- if spec.input_sequence_lens is not None
1517
- else None
1518
- )
1519
- initial_h = (
1520
- inputs[spec.input_initial_h]
1521
- if spec.input_initial_h is not None
1522
- else None
1523
- )
1524
- initial_c = (
1525
- inputs[spec.input_initial_c]
1526
- if spec.input_initial_c is not None
1527
- else None
1528
- )
1529
- p = inputs[spec.input_p] if spec.input_p is not None else None
1530
- output_y, output_y_h, output_y_c = _apply_lstm(
1531
- spec,
1532
- x,
1533
- w,
1534
- r,
1535
- b,
1536
- sequence_lens,
1537
- initial_h,
1538
- initial_c,
1539
- p,
1540
- )
1541
- if spec.output_y is not None:
1542
- evaluator.values[spec.output_y] = output_y
1543
- if spec.output_y_h is not None:
1544
- evaluator.values[spec.output_y_h] = output_y_h
1545
- if spec.output_y_c is not None:
1546
- evaluator.values[spec.output_y_c] = output_y_c
1547
-
1548
-
1549
- @register_evaluator("Conv")
1550
- def _eval_conv(evaluator: Evaluator, node: Node) -> None:
1551
- op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1552
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1553
- if op_dtype != output_dtype:
1554
- raise UnsupportedOpError(
1555
- f"{node.op_type} expects matching input/output dtypes, "
1556
- f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1557
- )
1558
- if not op_dtype.is_float:
1559
- raise UnsupportedOpError(
1560
- "Conv supports float16, float, and double inputs only"
1561
- )
1562
- spec = resolve_conv_spec(evaluator.graph, node)
1563
- data = evaluator.values[node.inputs[0]]
1564
- weights = evaluator.values[node.inputs[1]]
1565
- bias = evaluator.values[node.inputs[2]] if len(node.inputs) > 2 else None
1566
- evaluator.values[node.outputs[0]] = _apply_conv(spec, data, weights, bias)
1567
-
1568
-
1569
- @register_evaluator("ConvTranspose")
1570
- def _eval_conv_transpose(evaluator: Evaluator, node: Node) -> None:
1571
- op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1572
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1573
- if op_dtype != output_dtype:
1574
- raise UnsupportedOpError(
1575
- f"{node.op_type} expects matching input/output dtypes, "
1576
- f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1577
- )
1578
- if not op_dtype.is_float:
1579
- raise UnsupportedOpError(
1580
- "ConvTranspose supports float16, float, and double inputs only"
1581
- )
1582
- spec = resolve_conv_transpose_spec(evaluator.graph, node)
1583
- data = evaluator.values[node.inputs[0]]
1584
- weights = evaluator.values[node.inputs[1]]
1585
- bias = evaluator.values[node.inputs[2]] if len(node.inputs) > 2 else None
1586
- evaluator.values[node.outputs[0]] = _apply_conv_transpose(
1587
- spec, data, weights, bias
1588
- )
1589
-
1590
-
1591
- @register_evaluator("BatchNormalization")
1592
- def _eval_batch_norm(evaluator: Evaluator, node: Node) -> None:
1593
- op = lower_batch_normalization(evaluator.graph, node)
1594
- data = evaluator.values[op.input0]
1595
- scale = evaluator.values[op.scale].reshape(
1596
- (1, op.channels) + (1,) * (data.ndim - 2)
1597
- )
1598
- bias = evaluator.values[op.bias].reshape(
1599
- (1, op.channels) + (1,) * (data.ndim - 2)
1600
- )
1601
- mean = evaluator.values[op.mean].reshape(
1602
- (1, op.channels) + (1,) * (data.ndim - 2)
1603
- )
1604
- variance = evaluator.values[op.variance].reshape(
1605
- (1, op.channels) + (1,) * (data.ndim - 2)
1606
- )
1607
- evaluator.values[op.output] = (
1608
- (data - mean) / np.sqrt(variance + op.epsilon) * scale + bias
1609
- )
1610
-
1611
-
1612
- @register_evaluator("LpNormalization")
1613
- def _eval_lp_normalization(evaluator: Evaluator, node: Node) -> None:
1614
- op = lower_lp_normalization(evaluator.graph, node)
1615
- data = evaluator.values[op.input0]
1616
- if op.p == 1:
1617
- denom = np.sum(np.abs(data), axis=op.axis, keepdims=True)
1618
- else:
1619
- denom = np.sqrt(np.sum(data * data, axis=op.axis, keepdims=True))
1620
- evaluator.values[op.output] = data / denom
1621
-
1622
-
1623
- @register_evaluator("LpPool")
1624
- def _eval_lp_pool(evaluator: Evaluator, node: Node) -> None:
1625
- op = lower_lp_pool(evaluator.graph, node)
1626
- data = evaluator.values[op.input0]
1627
- output = np.zeros(
1628
- (op.batch, op.channels, op.out_h, op.out_w), dtype=data.dtype
1629
- )
1630
- for n in range(op.batch):
1631
- for c in range(op.channels):
1632
- for out_h in range(op.out_h):
1633
- for out_w in range(op.out_w):
1634
- h_start = out_h * op.stride_h - op.pad_top
1635
- w_start = out_w * op.stride_w - op.pad_left
1636
- acc = 0.0
1637
- for kh in range(op.kernel_h):
1638
- for kw in range(op.kernel_w):
1639
- in_h = h_start + kh
1640
- in_w = w_start + kw
1641
- if (
1642
- 0 <= in_h < op.in_h
1643
- and 0 <= in_w < op.in_w
1644
- ):
1645
- value = data[(n, c, in_h, in_w)]
1646
- acc += abs(value) ** op.p
1647
- output[(n, c, out_h, out_w)] = acc ** (1.0 / op.p)
1648
- evaluator.values[op.output] = output
1649
-
1650
-
1651
- @register_evaluator("QuantizeLinear")
1652
- def _eval_quantize_linear(evaluator: Evaluator, node: Node) -> None:
1653
- spec = resolve_quantize_spec(evaluator.graph, node)
1654
- data = evaluator.values[node.inputs[0]]
1655
- scale = evaluator.values[node.inputs[1]]
1656
- zero_point_name = optional_name(node.inputs, 2)
1657
- if zero_point_name is None:
1658
- zero_point = 0
1659
- else:
1660
- zero_point = evaluator.values[zero_point_name]
1661
- if spec.axis is None:
1662
- scaled = data / scale + zero_point
1663
- else:
1664
- shape = [1] * data.ndim
1665
- shape[spec.axis] = scale.shape[0]
1666
- scaled = data / scale.reshape(shape) + np.asarray(zero_point).reshape(
1667
- shape
1668
- )
1669
- rounded = np.rint(scaled)
1670
- info = np.iinfo(spec.output_dtype.np_dtype)
1671
- clipped = np.clip(rounded, info.min, info.max)
1672
- evaluator.values[node.outputs[0]] = clipped.astype(
1673
- spec.output_dtype.np_dtype, copy=False
1674
- )
1675
-
1676
-
1677
- @register_evaluator("QLinearMatMul")
1678
- def _eval_qlinear_matmul(evaluator: Evaluator, node: Node) -> None:
1679
- op = lower_qlinear_matmul(evaluator.graph, node)
1680
- input0 = evaluator.values[op.input0]
1681
- input1 = evaluator.values[op.input1]
1682
- input0_scale = evaluator.values[op.input0_scale]
1683
- input1_scale = evaluator.values[op.input1_scale]
1684
- output_scale = evaluator.values[op.output_scale]
1685
- input0_zero_point = evaluator.values[op.input0_zero_point]
1686
- input1_zero_point = evaluator.values[op.input1_zero_point]
1687
- output_zero_point = evaluator.values[op.output_zero_point]
1688
-
1689
- def _scalar_value(array: np.ndarray) -> float:
1690
- return float(np.asarray(array).reshape(-1)[0])
1691
-
1692
- def _scalar_int(array: np.ndarray) -> int:
1693
- return int(np.asarray(array).reshape(-1)[0])
1694
-
1695
- input0_zero = _scalar_int(input0_zero_point)
1696
- input1_zero = _scalar_int(input1_zero_point)
1697
- output_zero = _scalar_int(output_zero_point)
1698
- scale = _scalar_value(input0_scale) * _scalar_value(
1699
- input1_scale
1700
- ) / _scalar_value(output_scale)
1701
- acc = _apply_matmul(
1702
- input0.astype(np.int32) - input0_zero,
1703
- input1.astype(np.int32) - input1_zero,
1704
- )
1705
- scaled = acc.astype(np.float64) * scale + output_zero
1706
- rounded = np.rint(scaled)
1707
- info = np.iinfo(op.dtype.np_dtype)
1708
- clipped = np.clip(rounded, info.min, info.max)
1709
- evaluator.values[op.output] = clipped.astype(op.dtype.np_dtype)
1710
-
1711
- @register_evaluator("InstanceNormalization")
1712
- def _eval_instance_normalization(evaluator: Evaluator, node: Node) -> None:
1713
- op = lower_instance_normalization(evaluator.graph, node)
1714
- data = evaluator.values[op.input0]
1715
- axes = tuple(range(2, data.ndim))
1716
- mean = np.mean(data, axis=axes, keepdims=True)
1717
- var = np.mean((data - mean) ** 2, axis=axes, keepdims=True)
1718
- scale = evaluator.values[op.scale].reshape(
1719
- (1, op.channels) + (1,) * (data.ndim - 2)
1720
- )
1721
- bias = evaluator.values[op.bias].reshape(
1722
- (1, op.channels) + (1,) * (data.ndim - 2)
1723
- )
1724
- evaluator.values[op.output] = (
1725
- (data - mean) / np.sqrt(var + op.epsilon) * scale + bias
1726
- )
1727
-
1728
-
1729
- @register_evaluator("GroupNormalization")
1730
- def _eval_group_normalization(evaluator: Evaluator, node: Node) -> None:
1731
- op = lower_group_normalization(evaluator.graph, node)
1732
- data = evaluator.values[op.input0]
1733
- batch = data.shape[0]
1734
- spatial_shape = data.shape[2:]
1735
- grouped = data.reshape(
1736
- (batch, op.num_groups, op.group_size) + spatial_shape
1737
- )
1738
- axes = tuple(range(2, grouped.ndim))
1739
- mean = np.mean(grouped, axis=axes, keepdims=True)
1740
- var = np.mean((grouped - mean) ** 2, axis=axes, keepdims=True)
1741
- normalized = (grouped - mean) / np.sqrt(var + op.epsilon)
1742
- normalized = normalized.reshape(data.shape)
1743
- scale = evaluator.values[op.scale].reshape(
1744
- (1, op.channels) + (1,) * (data.ndim - 2)
1745
- )
1746
- bias = evaluator.values[op.bias].reshape(
1747
- (1, op.channels) + (1,) * (data.ndim - 2)
1748
- )
1749
- evaluator.values[op.output] = normalized * scale + bias
1750
-
1751
-
1752
- @register_evaluator("LayerNormalization")
1753
- def _eval_layer_normalization(evaluator: Evaluator, node: Node) -> None:
1754
- op = lower_layer_normalization(evaluator.graph, node)
1755
- data = evaluator.values[op.input0]
1756
- axes = tuple(range(op.axis, data.ndim))
1757
- mean = np.mean(data, axis=axes, keepdims=True)
1758
- var = np.mean((data - mean) ** 2, axis=axes, keepdims=True)
1759
- inv_std = 1.0 / np.sqrt(var + op.epsilon)
1760
- normalized = (data - mean) * inv_std
1761
- scale = evaluator.values[op.scale].reshape(
1762
- (1,) * op.axis + evaluator.values[op.scale].shape
1763
- )
1764
- normalized = normalized * scale
1765
- if op.bias is not None:
1766
- bias = evaluator.values[op.bias].reshape(
1767
- (1,) * op.axis + evaluator.values[op.bias].shape
1768
- )
1769
- normalized = normalized + bias
1770
- evaluator.values[op.output] = normalized
1771
- if op.mean_output is not None:
1772
- evaluator.values[op.mean_output] = mean
1773
- if op.invstd_output is not None:
1774
- evaluator.values[op.invstd_output] = inv_std
1775
-
1776
-
1777
- @register_evaluator("MeanVarianceNormalization")
1778
- def _eval_mean_variance_normalization(
1779
- evaluator: Evaluator, node: Node
1780
- ) -> None:
1781
- op = lower_mean_variance_normalization(evaluator.graph, node)
1782
- data = evaluator.values[op.input0]
1783
- mean = np.mean(data, axis=op.axes, keepdims=True)
1784
- variance = np.mean((data - mean) ** 2, axis=op.axes, keepdims=True)
1785
- evaluator.values[op.output] = (data - mean) / np.sqrt(
1786
- variance + op.epsilon
1787
- )
1788
-
1789
-
1790
- @register_evaluator("RMSNormalization")
1791
- def _eval_rms_normalization(evaluator: Evaluator, node: Node) -> None:
1792
- op = lower_rms_normalization(evaluator.graph, node)
1793
- data = evaluator.values[op.input0]
1794
- axes = tuple(range(op.axis, data.ndim))
1795
- mean_square = np.mean(data * data, axis=axes, keepdims=True)
1796
- rms = np.sqrt(mean_square + op.epsilon)
1797
- normalized = data / rms
1798
- scale = evaluator.values[op.scale].reshape(
1799
- (1,) * op.axis + evaluator.values[op.scale].shape
1800
- )
1801
- evaluator.values[op.output] = normalized * scale
1802
-
1803
-
1804
- @register_evaluator("LRN")
1805
- def _eval_lrn(evaluator: Evaluator, node: Node) -> None:
1806
- op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1807
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1808
- if op_dtype != output_dtype:
1809
- raise UnsupportedOpError(
1810
- f"{node.op_type} expects matching input/output dtypes, "
1811
- f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1812
- )
1813
- if not op_dtype.is_float:
1814
- raise UnsupportedOpError(
1815
- "LRN supports float16, float, and double inputs only"
1816
- )
1817
- spec = resolve_lrn_spec(evaluator.graph, node)
1818
- data = evaluator.values[node.inputs[0]]
1819
- evaluator.values[node.outputs[0]] = _apply_lrn(spec, data)
1820
-
1821
-
1822
- @register_evaluator("AveragePool")
1823
- def _eval_average_pool(evaluator: Evaluator, node: Node) -> None:
1824
- op = lower_average_pool(evaluator.graph, node)
1825
- data = evaluator.values[node.inputs[0]]
1826
- evaluator.values[node.outputs[0]] = _apply_average_pool(op, data)
1827
-
1828
-
1829
- @register_evaluator("GlobalAveragePool")
1830
- def _eval_global_average_pool(evaluator: Evaluator, node: Node) -> None:
1831
- op = lower_global_average_pool(evaluator.graph, node)
1832
- data = evaluator.values[node.inputs[0]]
1833
- evaluator.values[node.outputs[0]] = _apply_average_pool(op, data)
1834
-
1835
-
1836
- @register_evaluator("MaxPool")
1837
- def _eval_maxpool(evaluator: Evaluator, node: Node) -> None:
1838
- op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
1839
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
1840
- if op_dtype != output_dtype:
1841
- raise UnsupportedOpError(
1842
- f"{node.op_type} expects matching input/output dtypes, "
1843
- f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
1844
- )
1845
- indices_output = node.outputs[1] if len(node.outputs) > 1 else None
1846
- if indices_output is not None:
1847
- indices_dtype = value_dtype(evaluator.graph, indices_output, node)
1848
- if indices_dtype != ScalarType.I64:
1849
- raise UnsupportedOpError("MaxPool indices output must be int64")
1850
- if op_dtype == ScalarType.BOOL:
1851
- raise UnsupportedOpError("MaxPool supports numeric inputs only")
1852
- spec = resolve_maxpool_spec(evaluator.graph, node)
1853
- data = evaluator.values[node.inputs[0]]
1854
- if indices_output is None:
1855
- evaluator.values[node.outputs[0]] = _apply_maxpool(spec, data)
1856
- else:
1857
- values, indices = _apply_maxpool(spec, data, return_indices=True)
1858
- evaluator.values[node.outputs[0]] = values
1859
- evaluator.values[indices_output] = indices
1860
-
1861
-
1862
- @register_evaluator("GlobalMaxPool")
1863
- def _eval_global_max_pool(evaluator: Evaluator, node: Node) -> None:
1864
- op = lower_global_max_pool(evaluator.graph, node)
1865
- value = evaluator.values[node.inputs[0]]
1866
- if not op.axes:
1867
- evaluator.values[node.outputs[0]] = value.copy()
1868
- return
1869
- evaluator.values[node.outputs[0]] = np.max(
1870
- value, axis=op.axes, keepdims=op.keepdims
1871
- )
1872
-
1873
-
1874
- @register_evaluator("Softmax")
1875
- def _eval_softmax(evaluator: Evaluator, node: Node) -> None:
1876
- op = lower_softmax(evaluator.graph, node)
1877
- value = evaluator.values[node.inputs[0]]
1878
- evaluator.values[node.outputs[0]] = _apply_softmax(value, op.axis)
1879
-
1880
-
1881
- @register_evaluator("LogSoftmax")
1882
- def _eval_logsoftmax(evaluator: Evaluator, node: Node) -> None:
1883
- op = lower_logsoftmax(evaluator.graph, node)
1884
- value = evaluator.values[node.inputs[0]]
1885
- evaluator.values[node.outputs[0]] = _apply_logsoftmax(value, op.axis)
1886
-
1887
-
1888
- @register_evaluator("Hardmax")
1889
- def _eval_hardmax(evaluator: Evaluator, node: Node) -> None:
1890
- op = lower_hardmax(evaluator.graph, node)
1891
- value = evaluator.values[node.inputs[0]]
1892
- max_values = np.max(value, axis=op.axis, keepdims=True)
1893
- is_max = value == max_values
1894
- max_index = np.argmax(is_max, axis=op.axis)
1895
- output = np.zeros_like(value)
1896
- ones = np.array(1.0, dtype=value.dtype)
1897
- np.put_along_axis(output, np.expand_dims(max_index, axis=op.axis), ones, axis=op.axis)
1898
- evaluator.values[node.outputs[0]] = output
1899
-
1900
-
1901
- @register_evaluator("NegativeLogLikelihoodLoss")
1902
- def _eval_negative_log_likelihood_loss(
1903
- evaluator: Evaluator, node: Node
1904
- ) -> None:
1905
- op = lower_negative_log_likelihood_loss(evaluator.graph, node)
1906
- input_value = evaluator.values[op.input0]
1907
- target_value = evaluator.values[op.target]
1908
- weight_value = evaluator.values[op.weight] if op.weight is not None else None
1909
- evaluator.values[op.output] = _apply_negative_log_likelihood_loss(
1910
- input_value,
1911
- target_value,
1912
- weight_value,
1913
- reduction=op.reduction,
1914
- ignore_index=op.ignore_index,
1915
- )
1916
-
1917
-
1918
- @register_evaluator("SoftmaxCrossEntropyLoss")
1919
- def _eval_softmax_cross_entropy_loss(
1920
- evaluator: Evaluator, node: Node
1921
- ) -> None:
1922
- op = lower_softmax_cross_entropy_loss(evaluator.graph, node)
1923
- input_value = evaluator.values[op.input0]
1924
- target_value = evaluator.values[op.target]
1925
- weight_value = evaluator.values[op.weight] if op.weight is not None else None
1926
- loss, log_prob = _apply_softmax_cross_entropy_loss(
1927
- input_value,
1928
- target_value,
1929
- weight_value,
1930
- reduction=op.reduction,
1931
- ignore_index=op.ignore_index,
1932
- return_log_prob=op.log_prob is not None,
1933
- )
1934
- evaluator.values[op.output] = loss
1935
- if op.log_prob is not None and log_prob is not None:
1936
- evaluator.values[op.log_prob] = log_prob
1937
-
1938
-
1939
- @register_evaluator("Dropout")
1940
- def _eval_dropout(evaluator: Evaluator, node: Node) -> None:
1941
- op = lower_dropout(evaluator.graph, node)
1942
- evaluator.values[op.output] = evaluator.values[op.input0].copy()
1943
-
1944
-
1945
- @register_evaluator("Concat")
1946
- def _eval_concat(evaluator: Evaluator, node: Node) -> None:
1947
- op = lower_concat(evaluator.graph, node)
1948
- tensors = [evaluator.values[name] for name in node.inputs]
1949
- evaluator.values[op.output] = np.concatenate(tensors, axis=op.axis)
1950
-
1951
-
1952
- @register_evaluator("Transpose")
1953
- def _eval_transpose(evaluator: Evaluator, node: Node) -> None:
1954
- op = lower_transpose(evaluator.graph, node)
1955
- evaluator.values[op.output] = np.transpose(
1956
- evaluator.values[op.input0], axes=tuple(op.perm)
1957
- )
1958
-
1959
-
1960
- @register_evaluator("Unsqueeze")
1961
- def _eval_unsqueeze(evaluator: Evaluator, node: Node) -> None:
1962
- op = lower_unsqueeze(evaluator.graph, node)
1963
- evaluator.values[op.output] = evaluator.values[op.input0].reshape(
1964
- op.output_shape
1965
- )
1966
-
1967
-
1968
- @register_evaluator("Squeeze")
1969
- def _eval_squeeze(evaluator: Evaluator, node: Node) -> None:
1970
- op = lower_squeeze(evaluator.graph, node)
1971
- evaluator.values[op.output] = evaluator.values[op.input0].reshape(
1972
- op.output_shape
1973
- )
1974
-
1975
-
1976
- @register_evaluator("Reshape")
1977
- def _eval_reshape(evaluator: Evaluator, node: Node) -> None:
1978
- op = lower_reshape(evaluator.graph, node)
1979
- evaluator.values[op.output] = evaluator.values[op.input0].reshape(
1980
- op.output_shape
1981
- )
1982
-
1983
-
1984
- @register_evaluator("Flatten")
1985
- def _eval_flatten(evaluator: Evaluator, node: Node) -> None:
1986
- op = lower_flatten(evaluator.graph, node)
1987
- evaluator.values[op.output] = evaluator.values[op.input0].reshape(
1988
- op.output_shape
1989
- )
1990
-
1991
-
1992
- @register_evaluator("ConstantOfShape")
1993
- def _eval_constant_of_shape(evaluator: Evaluator, node: Node) -> None:
1994
- op = lower_constant_of_shape(evaluator.graph, node)
1995
- evaluator.values[op.output] = np.full(
1996
- op.shape, op.value, dtype=op.dtype.np_dtype
1997
- )
1998
-
1999
-
2000
- @register_evaluator("Shape")
2001
- def _eval_shape(evaluator: Evaluator, node: Node) -> None:
2002
- op = lower_shape(evaluator.graph, node)
2003
- evaluator.values[op.output] = np.array(op.values, dtype=np.int64)
2004
-
2005
-
2006
- @register_evaluator("Size")
2007
- def _eval_size(evaluator: Evaluator, node: Node) -> None:
2008
- op = lower_size(evaluator.graph, node)
2009
- evaluator.values[op.output] = np.array(op.value, dtype=np.int64)
2010
-
2011
-
2012
- @register_evaluator("NonZero")
2013
- def _eval_nonzero(evaluator: Evaluator, node: Node) -> None:
2014
- op = lower_nonzero(evaluator.graph, node)
2015
- values = evaluator.values[op.input0]
2016
- indices = np.nonzero(values)
2017
- evaluator.values[op.output] = np.stack(indices, axis=0).astype(
2018
- np.int64, copy=False
2019
- )
2020
-
2021
-
2022
- @register_evaluator("Expand")
2023
- def _eval_expand(evaluator: Evaluator, node: Node) -> None:
2024
- op = lower_expand(evaluator.graph, node)
2025
- value = evaluator.values[op.input0]
2026
- op_ctx = OpContext(GraphContext(evaluator.graph))
2027
- op.validate(op_ctx)
2028
- op.infer_types(op_ctx)
2029
- op.infer_shapes(op_ctx)
2030
- output_shape = op_ctx.shape(op.output)
2031
- evaluator.values[op.output] = np.broadcast_to(
2032
- value, output_shape
2033
- ).copy()
2034
-
2035
-
2036
- @register_evaluator("Range")
2037
- def _eval_range(evaluator: Evaluator, node: Node) -> None:
2038
- op = lower_range(evaluator.graph, node)
2039
- start_value = evaluator.values[op.start].reshape(-1)[0]
2040
- delta_value = evaluator.values[op.delta].reshape(-1)[0]
2041
- indices = np.arange(op.length, dtype=op.dtype.np_dtype)
2042
- output = start_value + indices * delta_value
2043
- evaluator.values[op.output] = output
2044
-
2045
-
2046
- @register_evaluator("OneHot")
2047
- def _eval_onehot(evaluator: Evaluator, node: Node) -> None:
2048
- op = lower_onehot(evaluator.graph, node)
2049
- indices = evaluator.values[op.indices].astype(np.int64, copy=False)
2050
- depth_values = evaluator.values[op.depth].reshape(-1)
2051
- if depth_values.size != 1:
2052
- raise UnsupportedOpError("OneHot depth input must be a scalar")
2053
- depth_value = int(depth_values[0])
2054
- if depth_value < 0:
2055
- raise UnsupportedOpError("OneHot depth must be non-negative")
2056
- values = evaluator.values[op.values].reshape(-1)
2057
- if values.size != 2:
2058
- raise UnsupportedOpError("OneHot values input must have 2 elements")
2059
- off_value, on_value = values[0], values[1]
2060
- if depth_value == 0:
2061
- evaluator.values[op.output] = np.full(
2062
- op.output_shape, off_value, dtype=values.dtype
2063
- )
2064
- return
2065
- axis = op.axis
2066
- rank = indices.ndim
2067
- if axis < 0:
2068
- axis += rank + 1
2069
- depth_range = np.arange(depth_value, dtype=np.int64)
2070
- new_shape = (1,) * axis + (depth_value,) + (1,) * (rank - axis)
2071
- targets = depth_range.reshape(new_shape)
2072
- adjusted = np.mod(indices, depth_value) if depth_value > 0 else indices
2073
- values_reshaped = np.reshape(
2074
- adjusted, indices.shape[:axis] + (1,) + indices.shape[axis:]
2075
- )
2076
- valid_mask = (indices >= -depth_value) & (indices < depth_value)
2077
- valid_mask = np.reshape(
2078
- valid_mask, indices.shape[:axis] + (1,) + indices.shape[axis:]
2079
- )
2080
- one_hot = (targets == values_reshaped) & valid_mask
2081
- output = np.where(one_hot, on_value, off_value).astype(values.dtype)
2082
- evaluator.values[op.output] = output
2083
-
2084
-
2085
- @register_evaluator("Split")
2086
- def _eval_split(evaluator: Evaluator, node: Node) -> None:
2087
- op = lower_split(evaluator.graph, node)
2088
- data = evaluator.values[op.input0]
2089
- split_points = np.cumsum(op.split_sizes)[:-1]
2090
- outputs = np.split(data, split_points, axis=op.axis)
2091
- for output_name, output_value in zip(op.outputs, outputs):
2092
- evaluator.values[output_name] = output_value
2093
-
2094
-
2095
- @register_evaluator("ReduceMean")
2096
- @register_evaluator("ReduceSum")
2097
- @register_evaluator("ReduceMax")
2098
- @register_evaluator("ReduceMin")
2099
- @register_evaluator("ReduceProd")
2100
- @register_evaluator("ReduceL1")
2101
- @register_evaluator("ReduceL2")
2102
- @register_evaluator("ReduceLogSum")
2103
- @register_evaluator("ReduceLogSumExp")
2104
- @register_evaluator("ReduceSumSquare")
2105
- def _eval_reduce(evaluator: Evaluator, node: Node) -> None:
2106
- if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
2107
- raise UnsupportedOpError(
2108
- f"{node.op_type} must have 1 or 2 inputs and 1 output"
2109
- )
2110
- op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
2111
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
2112
- if op_dtype != output_dtype:
2113
- raise UnsupportedOpError(
2114
- f"{node.op_type} expects matching input/output dtypes, "
2115
- f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
2116
- )
2117
- if (
2118
- node.op_type in REDUCE_OUTPUTS_FLOAT_ONLY
2119
- and not op_dtype.is_float
2120
- ):
2121
- raise UnsupportedOpError(
2122
- f"{node.op_type} supports float16, float, and double inputs only"
2123
- )
2124
- value = evaluator.values[node.inputs[0]]
2125
- input_shape = value.shape
2126
- if len(node.inputs) > 1 and node.inputs[1]:
2127
- axes_value = evaluator.values[node.inputs[1]]
2128
- if axes_value.dtype.type not in {np.int32, np.int64}:
2129
- raise UnsupportedOpError(
2130
- f"{node.op_type} axes input must be int64 or int32"
2131
- )
2132
- axes = tuple(int(axis) for axis in axes_value.ravel())
2133
- noop_with_empty_axes = bool(int(node.attrs.get("noop_with_empty_axes", 0)))
2134
- if not axes:
2135
- if noop_with_empty_axes:
2136
- evaluator.values[node.outputs[0]] = value.copy()
2137
- return
2138
- axes = tuple(range(len(input_shape)))
2139
- axes = normalize_reduce_axes(axes, input_shape, node)
2140
- else:
2141
- axes_spec, noop = resolve_reduce_axes(evaluator.graph, node, input_shape)
2142
- if noop:
2143
- evaluator.values[node.outputs[0]] = value.copy()
2144
- return
2145
- if axes_spec is None or axes_spec.axes is None:
2146
- raise UnsupportedOpError(
2147
- f"{node.op_type} axes input must be constant for evaluator"
2148
- )
2149
- axes = axes_spec.axes
2150
- keepdims = bool(int(node.attrs.get("keepdims", 1)))
2151
- reduce_kind = REDUCE_KIND_BY_OP[node.op_type]
2152
- if reduce_kind == "sum":
2153
- result = np.sum(value, axis=axes, keepdims=keepdims)
2154
- elif reduce_kind == "mean":
2155
- result = np.mean(value, axis=axes, keepdims=keepdims)
2156
- elif reduce_kind == "max":
2157
- result = np.max(value, axis=axes, keepdims=keepdims)
2158
- elif reduce_kind == "min":
2159
- result = np.min(value, axis=axes, keepdims=keepdims)
2160
- elif reduce_kind == "prod":
2161
- result = np.prod(value, axis=axes, keepdims=keepdims)
2162
- elif reduce_kind == "l1":
2163
- result = np.sum(np.abs(value), axis=axes, keepdims=keepdims)
2164
- elif reduce_kind == "l2":
2165
- result = np.sqrt(np.sum(value * value, axis=axes, keepdims=keepdims))
2166
- elif reduce_kind == "logsum":
2167
- result = np.log(np.sum(value, axis=axes, keepdims=keepdims))
2168
- elif reduce_kind == "logsumexp":
2169
- result = np.log(np.sum(np.exp(value), axis=axes, keepdims=keepdims))
2170
- elif reduce_kind == "sumsquare":
2171
- result = np.sum(value * value, axis=axes, keepdims=keepdims)
2172
- else:
2173
- raise UnsupportedOpError(f"Unsupported reduce kind {reduce_kind}")
2174
- evaluator.values[node.outputs[0]] = result
2175
-
2176
-
2177
- @register_evaluator("ArgMax")
2178
- @register_evaluator("ArgMin")
2179
- def _eval_arg_reduce(evaluator: Evaluator, node: Node) -> None:
2180
- op = lower_arg_reduce(evaluator.graph, node)
2181
- value = evaluator.values[op.input0]
2182
- if op.select_last_index:
2183
- flipped = np.flip(value, axis=op.axis)
2184
- if op.reduce_kind == "max":
2185
- indices = np.argmax(flipped, axis=op.axis)
2186
- elif op.reduce_kind == "min":
2187
- indices = np.argmin(flipped, axis=op.axis)
2188
- else:
2189
- raise UnsupportedOpError(
2190
- f"Unsupported arg reduce kind {op.reduce_kind}"
2191
- )
2192
- indices = value.shape[op.axis] - 1 - indices
2193
- else:
2194
- if op.reduce_kind == "max":
2195
- indices = np.argmax(value, axis=op.axis)
2196
- elif op.reduce_kind == "min":
2197
- indices = np.argmin(value, axis=op.axis)
2198
- else:
2199
- raise UnsupportedOpError(
2200
- f"Unsupported arg reduce kind {op.reduce_kind}"
2201
- )
2202
- if op.keepdims:
2203
- indices = np.expand_dims(indices, axis=op.axis)
2204
- evaluator.values[op.output] = indices.astype(op.output_dtype.np_dtype)
2205
-
2206
-
2207
- @register_evaluator("TopK")
2208
- def _eval_topk(evaluator: Evaluator, node: Node) -> None:
2209
- op = lower_topk(evaluator.graph, node)
2210
- value = evaluator.values[op.input0]
2211
- moved = np.moveaxis(value, op.axis, -1)
2212
- axis_dim = moved.shape[-1]
2213
- flat = moved.reshape(-1, axis_dim)
2214
- values_out = np.empty((flat.shape[0], op.k), dtype=value.dtype)
2215
- indices_out = np.empty((flat.shape[0], op.k), dtype=np.int64)
2216
- for row_index in range(flat.shape[0]):
2217
- row = flat[row_index]
2218
- order = sorted(
2219
- range(axis_dim),
2220
- key=lambda idx: (
2221
- -row[idx].item() if op.largest else row[idx].item(),
2222
- idx,
2223
- ),
2224
- )
2225
- topk = order[: op.k]
2226
- indices_out[row_index] = topk
2227
- values_out[row_index] = row[topk]
2228
- values_out = values_out.reshape(moved.shape[:-1] + (op.k,))
2229
- indices_out = indices_out.reshape(moved.shape[:-1] + (op.k,))
2230
- values_out = np.moveaxis(values_out, -1, op.axis)
2231
- indices_out = np.moveaxis(indices_out, -1, op.axis)
2232
- evaluator.values[op.output_values] = values_out.astype(
2233
- op.output_values_dtype.np_dtype
2234
- )
2235
- evaluator.values[op.output_indices] = indices_out.astype(
2236
- op.output_indices_dtype.np_dtype
2237
- )
2238
-
2239
-
2240
- def _eval_binary_unary(evaluator: Evaluator, node: Node) -> None:
2241
- if node.op_type == "BitShift":
2242
- if len(node.inputs) != 2 or len(node.outputs) != 1:
2243
- raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
2244
- direction_attr = node.attrs.get("direction", "LEFT")
2245
- if isinstance(direction_attr, bytes):
2246
- direction = direction_attr.decode()
2247
- else:
2248
- direction = str(direction_attr)
2249
- if direction not in {"LEFT", "RIGHT"}:
2250
- raise UnsupportedOpError(
2251
- "BitShift direction must be LEFT or RIGHT"
2252
- )
2253
- op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
2254
- if not op_dtype.is_integer:
2255
- raise UnsupportedOpError("BitShift expects integer inputs")
2256
- function = (
2257
- ScalarFunction.BITWISE_LEFT_SHIFT
2258
- if direction == "LEFT"
2259
- else ScalarFunction.BITWISE_RIGHT_SHIFT
2260
- )
2261
- op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
2262
- if op_spec is None:
2263
- raise UnsupportedOpError("Unsupported op BitShift")
2264
- left = evaluator.values[node.inputs[0]]
2265
- right = evaluator.values[node.inputs[1]]
2266
- evaluator.values[node.outputs[0]] = apply_binary_op(
2267
- op_spec, left, right
2268
- )
2269
- return
2270
- if node.op_type == "Mod":
2271
- fmod = int(node.attrs.get("fmod", 0))
2272
- if fmod not in {0, 1}:
2273
- raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
2274
- function = (
2275
- ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
2276
- )
2277
- else:
2278
- try:
2279
- function = ScalarFunction.from_onnx_op(node.op_type)
2280
- except ScalarFunctionError as exc:
2281
- raise UnsupportedOpError(
2282
- f"Unsupported op {node.op_type}"
2283
- ) from exc
2284
- validate_unary_attrs(node.op_type, node.attrs)
2285
- if function in COMPARE_FUNCTIONS:
2286
- input_dtype = node_dtype(evaluator.graph, node, *node.inputs)
2287
- output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
2288
- if output_dtype != ScalarType.BOOL:
2289
- raise UnsupportedOpError(
2290
- f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
2291
- )
2292
- op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
2293
- if op_spec is None:
2294
- raise UnsupportedOpError(f"Unsupported op {node.op_type}")
2295
- if len(node.inputs) != 2 or len(node.outputs) != 1:
2296
- raise UnsupportedOpError(
2297
- f"{node.op_type} must have 2 inputs and 1 output"
2298
- )
2299
- left = evaluator.values[node.inputs[0]]
2300
- right = evaluator.values[node.inputs[1]]
2301
- evaluator.values[node.outputs[0]] = apply_binary_op(
2302
- op_spec, left, right
2303
- )
2304
- return
2305
- op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
2306
- op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
2307
- unary_symbol = unary_op_symbol(function, dtype=op_dtype)
2308
- if op_spec is None and unary_symbol is None:
2309
- raise UnsupportedOpError(f"Unsupported op {node.op_type}")
2310
- if op_spec is not None:
2311
- if len(node.inputs) != 2 or len(node.outputs) != 1:
2312
- raise UnsupportedOpError(
2313
- f"{node.op_type} must have 2 inputs and 1 output"
2314
- )
2315
- left = evaluator.values[node.inputs[0]]
2316
- right = evaluator.values[node.inputs[1]]
2317
- evaluator.values[node.outputs[0]] = apply_binary_op(
2318
- op_spec, left, right
2319
- )
2320
- return
2321
- if len(node.inputs) != 1 or len(node.outputs) != 1:
2322
- raise UnsupportedOpError(
2323
- f"{node.op_type} must have 1 input and 1 output"
2324
- )
2325
- value = evaluator.values[node.inputs[0]]
2326
- evaluator.values[node.outputs[0]] = apply_unary_op(
2327
- function, value, dtype=op_dtype
2328
- )
2329
-
2330
-
2331
- def _apply_matmul(left: np.ndarray, right: np.ndarray) -> np.ndarray:
2332
- if left.ndim < 1 or right.ndim < 1:
2333
- raise UnsupportedOpError(
2334
- "MatMul inputs must be at least 1D, "
2335
- f"got {left.shape} x {right.shape}"
2336
- )
2337
- left_dim = left.shape[-1]
2338
- right_dim = right.shape[0] if right.ndim == 1 else right.shape[-2]
2339
- if left_dim != right_dim:
2340
- raise ShapeInferenceError(
2341
- "MatMul inner dimensions must match, "
2342
- f"got {left_dim} and {right_dim}"
2343
- )
2344
- left_batch = left.shape[:-2] if left.ndim > 1 else ()
2345
- right_batch = right.shape[:-2] if right.ndim > 1 else ()
2346
- if not _matmul_batch_broadcastable(left_batch, right_batch):
2347
- raise ShapeInferenceError(
2348
- "MatMul batch dimensions must be broadcastable, "
2349
- f"got {left_batch} x {right_batch}"
2350
- )
2351
- return np.matmul(left, right)
2352
-
2353
-
2354
- def _matmul_batch_broadcastable(
2355
- left: tuple[int, ...], right: tuple[int, ...]
2356
- ) -> bool:
2357
- max_rank = max(len(left), len(right))
2358
- left_padded = (1,) * (max_rank - len(left)) + left
2359
- right_padded = (1,) * (max_rank - len(right)) + right
2360
- for left_dim, right_dim in zip(left_padded, right_padded):
2361
- if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
2362
- return False
2363
- return True
2364
-
2365
-
2366
- def _apply_softmax(values: np.ndarray, axis: int) -> np.ndarray:
2367
- max_values = np.max(values, axis=axis, keepdims=True)
2368
- exp_values = np.exp(values - max_values)
2369
- sum_values = np.sum(exp_values, axis=axis, keepdims=True)
2370
- return exp_values / sum_values
2371
-
2372
-
2373
- def _apply_logsoftmax(values: np.ndarray, axis: int) -> np.ndarray:
2374
- max_values = np.max(values, axis=axis, keepdims=True)
2375
- shifted = values - max_values
2376
- logsum = np.log(np.sum(np.exp(shifted), axis=axis, keepdims=True))
2377
- return shifted - logsum
2378
-
2379
-
2380
- def _apply_negative_log_likelihood_loss(
2381
- values: np.ndarray,
2382
- target: np.ndarray,
2383
- weight: np.ndarray | None,
2384
- *,
2385
- reduction: str,
2386
- ignore_index: int,
2387
- ) -> np.ndarray:
2388
- input_shape = values.shape
2389
- if len(input_shape) < 2:
2390
- raise UnsupportedOpError(
2391
- "NegativeLogLikelihoodLoss input must be at least 2D"
2392
- )
2393
- target_shape = target.shape
2394
- if input_shape[0] != target_shape[0]:
2395
- raise ShapeInferenceError(
2396
- "NegativeLogLikelihoodLoss target batch dimension must match input"
2397
- )
2398
- if input_shape[2:] != target_shape[1:]:
2399
- raise ShapeInferenceError(
2400
- "NegativeLogLikelihoodLoss target spatial dimensions must match input"
2401
- )
2402
- n = input_shape[0]
2403
- c = input_shape[1]
2404
- if weight is not None:
2405
- gather_weight = np.take(weight, target.astype(np.int32), mode="clip")
2406
- if ignore_index is not None:
2407
- gather_weight = np.where(target == ignore_index, 0, gather_weight).astype(
2408
- dtype=values.dtype
2409
- )
2410
- elif ignore_index != -1:
2411
- gather_weight = np.where(target == ignore_index, 0, 1).astype(
2412
- dtype=values.dtype
2413
- )
2414
- else:
2415
- gather_weight = None
2416
- if len(input_shape) != 3:
2417
- values = values.reshape((n, c, -1))
2418
- target = target.reshape((n, -1))
2419
- d = values.shape[2]
2420
- loss = np.zeros((n, d), dtype=values.dtype)
2421
- for i in range(n):
2422
- for d_index in range(d):
2423
- if target[i][d_index] != ignore_index:
2424
- loss[i][d_index] = -values[i][target[i][d_index]][d_index]
2425
- if len(input_shape) != 3:
2426
- loss = loss.reshape(target_shape)
2427
- if gather_weight is not None:
2428
- loss = gather_weight * loss
2429
- if reduction == "mean":
2430
- weight_sum = gather_weight.sum()
2431
- if weight_sum == 0:
2432
- return np.array(0, dtype=values.dtype)
2433
- loss = loss.sum() / weight_sum
2434
- return loss.astype(values.dtype)
2435
- if reduction == "mean":
2436
- loss = np.mean(loss)
2437
- elif reduction == "sum":
2438
- loss = np.sum(loss)
2439
- return loss.astype(values.dtype)
2440
-
2441
-
2442
- def _apply_softmax_cross_entropy_loss(
2443
- values: np.ndarray,
2444
- target: np.ndarray,
2445
- weight: np.ndarray | None,
2446
- *,
2447
- reduction: str,
2448
- ignore_index: int | None,
2449
- return_log_prob: bool,
2450
- ) -> tuple[np.ndarray, np.ndarray | None]:
2451
- input_shape = values.shape
2452
- if len(input_shape) < 2:
2453
- raise UnsupportedOpError(
2454
- "SoftmaxCrossEntropyLoss input must be at least 2D"
2455
- )
2456
- target_shape = target.shape
2457
- if input_shape[0] != target_shape[0]:
2458
- raise ShapeInferenceError(
2459
- "SoftmaxCrossEntropyLoss target batch dimension must match input"
2460
- )
2461
- if input_shape[2:] != target_shape[1:]:
2462
- raise ShapeInferenceError(
2463
- "SoftmaxCrossEntropyLoss target spatial dimensions must match input"
2464
- )
2465
- log_prob = _apply_logsoftmax(values, axis=1)
2466
- log_prob_output = log_prob if return_log_prob else None
2467
- if weight is not None:
2468
- gather_weight = np.take(weight, target.astype(np.int32), mode="clip")
2469
- if ignore_index is not None:
2470
- gather_weight = np.where(target == ignore_index, 0, gather_weight).astype(
2471
- dtype=values.dtype
2472
- )
2473
- elif ignore_index is not None:
2474
- gather_weight = np.where(target == ignore_index, 0, 1).astype(
2475
- dtype=values.dtype
2476
- )
2477
- else:
2478
- gather_weight = None
2479
- n = input_shape[0]
2480
- c = input_shape[1]
2481
- if len(input_shape) != 3:
2482
- log_prob = log_prob.reshape((n, c, -1))
2483
- target = target.reshape((n, -1))
2484
- d = log_prob.shape[2]
2485
- loss = np.zeros((n, d), dtype=values.dtype)
2486
- for i in range(n):
2487
- for d_index in range(d):
2488
- if ignore_index is None or target[i][d_index] != ignore_index:
2489
- loss[i][d_index] = -log_prob[i][target[i][d_index]][d_index]
2490
- if len(input_shape) != 3:
2491
- loss = loss.reshape(target_shape)
2492
- if gather_weight is not None:
2493
- loss = gather_weight * loss
2494
- if reduction == "mean":
2495
- loss = loss.sum() / gather_weight.sum()
2496
- loss = loss.astype(values.dtype)
2497
- if return_log_prob:
2498
- return loss, log_prob.astype(values.dtype)
2499
- return loss, None
2500
- if reduction == "mean":
2501
- loss = np.mean(loss)
2502
- elif reduction == "sum":
2503
- loss = np.sum(loss)
2504
- loss = loss.astype(values.dtype)
2505
- if return_log_prob and log_prob_output is not None:
2506
- return loss, log_prob_output.astype(values.dtype)
2507
- return loss, None
2508
-
2509
-
2510
- def _apply_attention(
2511
- spec,
2512
- query: np.ndarray,
2513
- key: np.ndarray,
2514
- value: np.ndarray,
2515
- attn_mask: np.ndarray | None,
2516
- past_key: np.ndarray | None,
2517
- past_value: np.ndarray | None,
2518
- nonpad_kv_seqlen: np.ndarray | None,
2519
- ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None]:
2520
- if spec.q_rank == 3:
2521
- query_4d = query.reshape(
2522
- spec.batch, spec.q_seq, spec.q_heads, spec.qk_head_size
2523
- ).transpose(0, 2, 1, 3)
2524
- key_4d = key.reshape(
2525
- spec.batch, spec.kv_seq, spec.kv_heads, spec.qk_head_size
2526
- ).transpose(0, 2, 1, 3)
2527
- value_4d = value.reshape(
2528
- spec.batch, spec.kv_seq, spec.kv_heads, spec.v_head_size
2529
- ).transpose(0, 2, 1, 3)
2530
- else:
2531
- query_4d = query
2532
- key_4d = key
2533
- value_4d = value
2534
- if past_key is not None and past_value is not None:
2535
- key_total = np.concatenate([past_key, key_4d], axis=2)
2536
- value_total = np.concatenate([past_value, value_4d], axis=2)
2537
- else:
2538
- key_total = key_4d
2539
- value_total = value_4d
2540
- if spec.head_group_size > 1:
2541
- key_total_expanded = np.repeat(key_total, spec.head_group_size, axis=1)
2542
- value_total_expanded = np.repeat(
2543
- value_total, spec.head_group_size, axis=1
2544
- )
2545
- else:
2546
- key_total_expanded = key_total
2547
- value_total_expanded = value_total
2548
- k_transpose = np.transpose(key_total_expanded, (0, 1, 3, 2))
2549
- scores = np.matmul(query_4d, k_transpose) * spec.scale
2550
- bias = np.zeros_like(scores)
2551
- if spec.has_attn_mask and attn_mask is not None:
2552
- if spec.mask_is_bool:
2553
- bias_mask = np.where(attn_mask, 0.0, -np.inf)
2554
- else:
2555
- bias_mask = attn_mask.astype(scores.dtype)
2556
- if spec.mask_rank == 2:
2557
- bias_mask = bias_mask[None, None, ...]
2558
- elif spec.mask_rank == 3:
2559
- bias_mask = bias_mask[:, None, ...]
2560
- bias_mask = np.broadcast_to(
2561
- bias_mask, (spec.batch, spec.q_heads, spec.q_seq, spec.mask_kv_seq)
2562
- )
2563
- if spec.mask_kv_seq < spec.total_seq:
2564
- pad_width = spec.total_seq - spec.mask_kv_seq
2565
- bias_mask = np.pad(
2566
- bias_mask,
2567
- ((0, 0), (0, 0), (0, 0), (0, pad_width)),
2568
- constant_values=-np.inf,
2569
- )
2570
- bias = bias + bias_mask
2571
- if spec.has_nonpad and nonpad_kv_seqlen is not None:
2572
- kv_range = np.arange(spec.total_seq)[None, None, None, :]
2573
- valid = kv_range < nonpad_kv_seqlen[:, None, None, None]
2574
- bias = bias + np.where(valid, 0.0, -np.inf)
2575
- if spec.is_causal:
2576
- kv_range = np.arange(spec.total_seq)[None, :]
2577
- q_range = np.arange(spec.q_seq)[:, None] + spec.past_seq
2578
- causal_mask = kv_range > q_range
2579
- bias = bias + np.where(causal_mask, -np.inf, 0.0)[None, None, :, :]
2580
- scores_with_bias = scores + bias
2581
- if spec.softcap != 0.0:
2582
- scores_softcap = spec.softcap * np.tanh(scores_with_bias / spec.softcap)
2583
- else:
2584
- scores_softcap = scores_with_bias
2585
- max_scores = np.max(scores_softcap, axis=-1, keepdims=True)
2586
- weights = np.exp(scores_softcap - max_scores)
2587
- weights /= np.sum(weights, axis=-1, keepdims=True)
2588
- output = np.matmul(weights, value_total_expanded)
2589
- if spec.q_rank == 3:
2590
- output = output.transpose(0, 2, 1, 3).reshape(
2591
- spec.batch, spec.q_seq, spec.q_heads * spec.v_head_size
2592
- )
2593
- qk_output = None
2594
- if spec.qk_matmul_output_mode == 0:
2595
- qk_output = scores
2596
- elif spec.qk_matmul_output_mode == 1:
2597
- qk_output = scores_with_bias
2598
- elif spec.qk_matmul_output_mode == 2:
2599
- qk_output = scores_softcap
2600
- else:
2601
- qk_output = weights
2602
- return output, key_total, value_total, qk_output
2603
-
2604
-
2605
- def _apply_conv(
2606
- spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None
2607
- ) -> np.ndarray:
2608
- output = np.zeros(
2609
- (spec.batch, spec.out_channels, *spec.out_spatial),
2610
- dtype=data.dtype,
2611
- )
2612
- pad_begin = spec.pads[: spec.spatial_rank]
2613
- group_in_channels = spec.in_channels // spec.group
2614
- group_out_channels = spec.out_channels // spec.group
2615
- for n in range(spec.batch):
2616
- for g in range(spec.group):
2617
- oc_base = g * group_out_channels
2618
- ic_base = g * group_in_channels
2619
- for oc in range(group_out_channels):
2620
- oc_global = oc_base + oc
2621
- base = bias[oc_global] if bias is not None else 0.0
2622
- for out_index in np.ndindex(*spec.out_spatial):
2623
- acc = base
2624
- for ic in range(group_in_channels):
2625
- ic_global = ic_base + ic
2626
- for kernel_index in np.ndindex(*spec.kernel_shape):
2627
- in_index = []
2628
- valid = True
2629
- for (
2630
- out_dim,
2631
- kernel_dim,
2632
- stride,
2633
- dilation,
2634
- pad,
2635
- in_size,
2636
- ) in zip(
2637
- out_index,
2638
- kernel_index,
2639
- spec.strides,
2640
- spec.dilations,
2641
- pad_begin,
2642
- spec.in_spatial,
2643
- ):
2644
- in_dim = out_dim * stride + kernel_dim * dilation - pad
2645
- if in_dim < 0 or in_dim >= in_size:
2646
- valid = False
2647
- break
2648
- in_index.append(in_dim)
2649
- if valid:
2650
- acc += data[(n, ic_global, *in_index)] * weights[
2651
- (oc_global, ic, *kernel_index)
2652
- ]
2653
- output[(n, oc_global, *out_index)] = acc
2654
- return output
2655
-
2656
-
2657
- def _apply_conv_transpose(
2658
- spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None
2659
- ) -> np.ndarray:
2660
- output = np.zeros(
2661
- (spec.batch, spec.out_channels, *spec.out_spatial), dtype=data.dtype
2662
- )
2663
- if bias is not None:
2664
- output += bias.reshape((1, spec.out_channels) + (1,) * spec.spatial_rank)
2665
- pad_begin = spec.pads[: spec.spatial_rank]
2666
- group_in_channels = spec.in_channels // spec.group
2667
- group_out_channels = spec.out_channels // spec.group
2668
- for n in range(spec.batch):
2669
- for g in range(spec.group):
2670
- oc_base = g * group_out_channels
2671
- ic_base = g * group_in_channels
2672
- for ic in range(group_in_channels):
2673
- ic_global = ic_base + ic
2674
- for in_index in np.ndindex(*spec.in_spatial):
2675
- value = data[(n, ic_global, *in_index)]
2676
- for oc in range(group_out_channels):
2677
- oc_global = oc_base + oc
2678
- for kernel_index in np.ndindex(*spec.kernel_shape):
2679
- out_index = []
2680
- valid = True
2681
- for (
2682
- in_dim,
2683
- kernel_dim,
2684
- stride,
2685
- dilation,
2686
- pad,
2687
- out_size,
2688
- ) in zip(
2689
- in_index,
2690
- kernel_index,
2691
- spec.strides,
2692
- spec.dilations,
2693
- pad_begin,
2694
- spec.out_spatial,
2695
- ):
2696
- out_dim = (
2697
- in_dim * stride + kernel_dim * dilation - pad
2698
- )
2699
- if out_dim < 0 or out_dim >= out_size:
2700
- valid = False
2701
- break
2702
- out_index.append(out_dim)
2703
- if valid:
2704
- output[(n, oc_global, *out_index)] += (
2705
- value * weights[(ic_global, oc, *kernel_index)]
2706
- )
2707
- return output
2708
-
2709
-
2710
- def _apply_lrn(spec, data: np.ndarray) -> np.ndarray:
2711
- output = np.empty_like(data)
2712
- spatial_shape = spec.shape[2:]
2713
- spatial_indices = [()]
2714
- if spatial_shape:
2715
- spatial_indices = list(np.ndindex(*spatial_shape))
2716
- for n in range(spec.shape[0]):
2717
- for c in range(spec.channels):
2718
- start = max(0, c - spec.half)
2719
- end = min(spec.channels - 1, c + spec.half)
2720
- for index in spatial_indices:
2721
- sum_val = 0.0
2722
- for i in range(start, end + 1):
2723
- value = data[(n, i, *index)]
2724
- sum_val += value * value
2725
- scale = spec.bias + (spec.alpha / spec.size) * sum_val
2726
- output[(n, c, *index)] = data[(n, c, *index)] / math.pow(
2727
- scale, spec.beta
2728
- )
2729
- return output
2730
-
2731
-
2732
- def _apply_average_pool(op, data: np.ndarray) -> np.ndarray:
2733
- output = np.zeros((op.batch, op.channels, op.out_h, op.out_w), dtype=data.dtype)
2734
- for n in range(op.batch):
2735
- for c in range(op.channels):
2736
- for oh in range(op.out_h):
2737
- for ow in range(op.out_w):
2738
- acc = 0.0
2739
- count = 0
2740
- for kh in range(op.kernel_h):
2741
- ih = oh * op.stride_h + kh - op.pad_top
2742
- if ih < 0 or ih >= op.in_h:
2743
- if op.count_include_pad:
2744
- count += op.kernel_w
2745
- else:
2746
- for kw in range(op.kernel_w):
2747
- iw = ow * op.stride_w + kw - op.pad_left
2748
- if iw < 0 or iw >= op.in_w:
2749
- if op.count_include_pad:
2750
- count += 1
2751
- else:
2752
- acc += data[n, c, ih, iw]
2753
- count += 1
2754
- output[n, c, oh, ow] = 0.0 if count == 0 else acc / float(count)
2755
- return output
2756
-
2757
-
2758
- def _maxpool_min_value(dtype: np.dtype) -> float | int:
2759
- if np.issubdtype(dtype, np.floating):
2760
- return -np.inf
2761
- if np.issubdtype(dtype, np.integer):
2762
- return np.iinfo(dtype).min
2763
- raise UnsupportedOpError("MaxPool supports numeric inputs only")
2764
-
2765
-
2766
- def _apply_maxpool(
2767
- spec, data: np.ndarray, *, return_indices: bool = False
2768
- ) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
2769
- min_value = _maxpool_min_value(data.dtype)
2770
- output = np.full(
2771
- (spec.batch, spec.channels, *spec.out_spatial),
2772
- min_value,
2773
- dtype=data.dtype,
2774
- )
2775
- indices = (
2776
- np.zeros((spec.batch, spec.channels, *spec.out_spatial), dtype=np.int64)
2777
- if return_indices
2778
- else None
2779
- )
2780
- pad_begin = spec.pads[: spec.spatial_rank]
2781
- for n in range(spec.batch):
2782
- for c in range(spec.channels):
2783
- for out_index in np.ndindex(*spec.out_spatial):
2784
- max_value = min_value
2785
- max_index = 0
2786
- has_value = False
2787
- for kernel_index in np.ndindex(*spec.kernel_shape):
2788
- in_index = []
2789
- valid = True
2790
- for out_dim, kernel_dim, stride, dilation, pad in zip(
2791
- out_index,
2792
- kernel_index,
2793
- spec.strides,
2794
- spec.dilations,
2795
- pad_begin,
2796
- ):
2797
- idx = out_dim * stride + kernel_dim * dilation - pad
2798
- if idx < 0 or idx >= spec.in_spatial[len(in_index)]:
2799
- valid = False
2800
- break
2801
- in_index.append(idx)
2802
- if valid:
2803
- value = data[(n, c, *in_index)]
2804
- if value > max_value or not has_value:
2805
- max_value = value
2806
- has_value = True
2807
- if return_indices:
2808
- linear_index = n * spec.channels + c
2809
- if spec.storage_order == 0:
2810
- for idx, size in zip(
2811
- in_index, spec.in_spatial
2812
- ):
2813
- linear_index = linear_index * size + idx
2814
- else:
2815
- spatial_index = 0
2816
- spatial_stride = 1
2817
- for idx, size in zip(
2818
- in_index, spec.in_spatial
2819
- ):
2820
- spatial_index += idx * spatial_stride
2821
- spatial_stride *= size
2822
- linear_index = (
2823
- linear_index * spatial_stride + spatial_index
2824
- )
2825
- max_index = linear_index
2826
- output[(n, c, *out_index)] = max_value
2827
- if return_indices and indices is not None:
2828
- indices[(n, c, *out_index)] = max_index
2829
- if return_indices:
2830
- if indices is None:
2831
- raise RuntimeError("MaxPool indices were not computed")
2832
- return output, indices
2833
- return output
2834
-
2835
-
2836
- def _apply_lstm(
2837
- spec,
2838
- x: np.ndarray,
2839
- w: np.ndarray,
2840
- r: np.ndarray,
2841
- b: np.ndarray | None,
2842
- sequence_lens: np.ndarray | None,
2843
- initial_h: np.ndarray | None,
2844
- initial_c: np.ndarray | None,
2845
- p: np.ndarray | None,
2846
- ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
2847
- if spec.layout == 1:
2848
- x = np.swapaxes(x, 0, 1)
2849
- seq_length = spec.seq_length
2850
- batch_size = spec.batch_size
2851
- hidden_size = spec.hidden_size
2852
- num_directions = spec.num_directions
2853
- if sequence_lens is None:
2854
- sequence_lens = np.full((batch_size,), seq_length, dtype=np.int64)
2855
- else:
2856
- sequence_lens = sequence_lens.astype(np.int64, copy=False)
2857
- if b is None:
2858
- b = np.zeros((num_directions, 8 * hidden_size), dtype=x.dtype)
2859
- if p is None:
2860
- p = np.zeros((num_directions, 3 * hidden_size), dtype=x.dtype)
2861
- if initial_h is None:
2862
- initial_h = np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
2863
- if initial_c is None:
2864
- initial_c = np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
2865
- if spec.layout == 1:
2866
- initial_h = np.swapaxes(initial_h, 0, 1)
2867
- initial_c = np.swapaxes(initial_c, 0, 1)
2868
- output_y = None
2869
- if spec.output_y is not None:
2870
- output_y = np.zeros(
2871
- (seq_length, num_directions, batch_size, hidden_size), dtype=x.dtype
2872
- )
2873
- output_y_h = (
2874
- np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
2875
- if spec.output_y_h is not None
2876
- else None
2877
- )
2878
- output_y_c = (
2879
- np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
2880
- if spec.output_y_c is not None
2881
- else None
2882
- )
2883
- directions = (
2884
- ("forward", "reverse")
2885
- if spec.direction == "bidirectional"
2886
- else (spec.direction,)
2887
- )
2888
- for dir_index, dir_kind in enumerate(directions):
2889
- w_dir = w[dir_index]
2890
- r_dir = r[dir_index]
2891
- b_dir = b[dir_index]
2892
- bias = b_dir[: 4 * hidden_size] + b_dir[4 * hidden_size :]
2893
- p_dir = p[dir_index]
2894
- p_i = p_dir[:hidden_size]
2895
- p_o = p_dir[hidden_size : 2 * hidden_size]
2896
- p_f = p_dir[2 * hidden_size :]
2897
- h_prev = initial_h[dir_index].copy()
2898
- c_prev = initial_c[dir_index].copy()
2899
- act_offset = dir_index * 3
2900
- act_f = spec.activation_kinds[act_offset]
2901
- act_g = spec.activation_kinds[act_offset + 1]
2902
- act_h = spec.activation_kinds[act_offset + 2]
2903
- alpha_f = spec.activation_alphas[act_offset]
2904
- alpha_g = spec.activation_alphas[act_offset + 1]
2905
- alpha_h = spec.activation_alphas[act_offset + 2]
2906
- beta_f = spec.activation_betas[act_offset]
2907
- beta_g = spec.activation_betas[act_offset + 1]
2908
- beta_h = spec.activation_betas[act_offset + 2]
2909
- for step in range(seq_length):
2910
- if dir_kind == "forward":
2911
- x_t = x[step]
2912
- else:
2913
- t_indices = sequence_lens - 1 - step
2914
- t_indices = np.clip(t_indices, 0, seq_length - 1)
2915
- x_t = x[t_indices, np.arange(batch_size)]
2916
- gates = x_t @ w_dir.T + h_prev @ r_dir.T + bias
2917
- if spec.clip is not None and spec.clip > 0:
2918
- gates = np.clip(gates, -spec.clip, spec.clip)
2919
- i, o, f, c = np.split(gates, 4, axis=1)
2920
- i = _apply_lstm_activation(act_f, i + p_i * c_prev, alpha_f, beta_f)
2921
- if spec.input_forget:
2922
- f = 1 - i
2923
- else:
2924
- f = _apply_lstm_activation(
2925
- act_f, f + p_f * c_prev, alpha_f, beta_f
2926
- )
2927
- c_tilde = _apply_lstm_activation(act_g, c, alpha_g, beta_g)
2928
- c_new = f * c_prev + i * c_tilde
2929
- o = _apply_lstm_activation(act_f, o + p_o * c_new, alpha_f, beta_f)
2930
- h_new = o * _apply_lstm_activation(act_h, c_new, alpha_h, beta_h)
2931
- active_mask = step < sequence_lens
2932
- if not np.all(active_mask):
2933
- h_new = np.where(active_mask[:, None], h_new, h_prev)
2934
- c_new = np.where(active_mask[:, None], c_new, c_prev)
2935
- if output_y is not None:
2936
- output_y[step, dir_index] = np.where(
2937
- active_mask[:, None], h_new, 0
2938
- )
2939
- else:
2940
- if output_y is not None:
2941
- output_y[step, dir_index] = h_new
2942
- h_prev = h_new
2943
- c_prev = c_new
2944
- if output_y_h is not None:
2945
- output_y_h[dir_index] = h_prev
2946
- if output_y_c is not None:
2947
- output_y_c[dir_index] = c_prev
2948
- if spec.layout == 1:
2949
- if output_y is not None:
2950
- output_y = np.transpose(output_y, (2, 0, 1, 3))
2951
- if output_y_h is not None:
2952
- output_y_h = np.swapaxes(output_y_h, 0, 1)
2953
- if output_y_c is not None:
2954
- output_y_c = np.swapaxes(output_y_c, 0, 1)
2955
- return output_y, output_y_h, output_y_c