emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,2206 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Mapping
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from shared.scalar_types import ScalarType
|
|
9
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
|
+
from ..ir.model import Graph, Node
|
|
11
|
+
from ..lowering.attention import resolve_attention_spec
|
|
12
|
+
from ..lowering.average_pool import lower_average_pool, lower_global_average_pool
|
|
13
|
+
from ..lowering.batch_normalization import lower_batch_normalization
|
|
14
|
+
from ..lowering.concat import lower_concat
|
|
15
|
+
from ..lowering.constant_of_shape import lower_constant_of_shape
|
|
16
|
+
from ..lowering.conv import resolve_conv_spec
|
|
17
|
+
from ..lowering.dropout import lower_dropout
|
|
18
|
+
from ..lowering.cumsum import lower_cumsum
|
|
19
|
+
from ..lowering.flatten import lower_flatten
|
|
20
|
+
from ..lowering.gemm import resolve_gemm_spec
|
|
21
|
+
from ..lowering.logsoftmax import lower_logsoftmax
|
|
22
|
+
from ..lowering.lp_normalization import lower_lp_normalization
|
|
23
|
+
from ..lowering.grid_sample import lower_grid_sample
|
|
24
|
+
from ..lowering.instance_normalization import lower_instance_normalization
|
|
25
|
+
from ..lowering.group_normalization import lower_group_normalization
|
|
26
|
+
from ..lowering.layer_normalization import lower_layer_normalization
|
|
27
|
+
from ..lowering.mean_variance_normalization import (
|
|
28
|
+
lower_mean_variance_normalization,
|
|
29
|
+
)
|
|
30
|
+
from ..lowering.negative_log_likelihood_loss import (
|
|
31
|
+
lower_negative_log_likelihood_loss,
|
|
32
|
+
)
|
|
33
|
+
from ..lowering.pad import lower_pad
|
|
34
|
+
from ..lowering.expand import lower_expand
|
|
35
|
+
from ..lowering.range import lower_range
|
|
36
|
+
from ..lowering.split import lower_split
|
|
37
|
+
from ..lowering.softmax_cross_entropy_loss import (
|
|
38
|
+
lower_softmax_cross_entropy_loss,
|
|
39
|
+
)
|
|
40
|
+
from ..lowering.arg_reduce import lower_arg_reduce
|
|
41
|
+
from ..lowering.lstm import ACTIVATION_KIND_BY_NAME, resolve_lstm_spec
|
|
42
|
+
from ..lowering.lrn import resolve_lrn_spec
|
|
43
|
+
from ..lowering.matmul import lower_matmul
|
|
44
|
+
from ..lowering.maxpool import resolve_maxpool_spec
|
|
45
|
+
from ..lowering.reduce import (
|
|
46
|
+
REDUCE_KIND_BY_OP,
|
|
47
|
+
REDUCE_OUTPUTS_FLOAT_ONLY,
|
|
48
|
+
normalize_reduce_axes,
|
|
49
|
+
resolve_reduce_axes,
|
|
50
|
+
)
|
|
51
|
+
from ..lowering.reshape import lower_reshape
|
|
52
|
+
from ..lowering.slice import _normalize_slices
|
|
53
|
+
from ..lowering.shape import lower_shape
|
|
54
|
+
from ..lowering.size import lower_size
|
|
55
|
+
from ..lowering.softmax import lower_softmax
|
|
56
|
+
from ..lowering.rms_normalization import lower_rms_normalization
|
|
57
|
+
from ..lowering.squeeze import lower_squeeze
|
|
58
|
+
from ..lowering.transpose import lower_transpose
|
|
59
|
+
from ..lowering.unsqueeze import lower_unsqueeze
|
|
60
|
+
from ..lowering.where import lower_where
|
|
61
|
+
from ..lowering.variadic import BINARY_ONLY_OPS, VARIADIC_OP_FUNCTIONS
|
|
62
|
+
from ..lowering.registry import resolve_dispatch
|
|
63
|
+
from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
|
|
64
|
+
from ..ops import (
|
|
65
|
+
BINARY_OP_TYPES,
|
|
66
|
+
COMPARE_FUNCTIONS,
|
|
67
|
+
UNARY_OP_TYPES,
|
|
68
|
+
apply_binary_op,
|
|
69
|
+
apply_unary_op,
|
|
70
|
+
binary_op_symbol,
|
|
71
|
+
unary_op_symbol,
|
|
72
|
+
validate_unary_attrs,
|
|
73
|
+
)
|
|
74
|
+
from shared.scalar_functions import ScalarFunction, ScalarFunctionError
|
|
75
|
+
from ..validation import normalize_axis
|
|
76
|
+
|
|
77
|
+
Handler = Callable[["Evaluator", Node], None]
|
|
78
|
+
_EVAL_REGISTRY: dict[str, Handler] = {}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def register_evaluator(op_type: str) -> Callable[[Handler], Handler]:
|
|
82
|
+
def decorator(func: Handler) -> Handler:
|
|
83
|
+
_EVAL_REGISTRY[op_type] = func
|
|
84
|
+
return func
|
|
85
|
+
|
|
86
|
+
return decorator
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class Evaluator:
|
|
90
|
+
def __init__(self, graph: Graph) -> None:
|
|
91
|
+
self._graph = graph
|
|
92
|
+
self._values: dict[str, np.ndarray] = {}
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def graph(self) -> Graph:
|
|
96
|
+
return self._graph
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def values(self) -> dict[str, np.ndarray]:
|
|
100
|
+
return self._values
|
|
101
|
+
|
|
102
|
+
def run(self, feeds: Mapping[str, np.ndarray]) -> dict[str, np.ndarray]:
|
|
103
|
+
values = {
|
|
104
|
+
initializer.name: initializer.data
|
|
105
|
+
for initializer in self._graph.initializers
|
|
106
|
+
}
|
|
107
|
+
values.update(feeds)
|
|
108
|
+
self._values = values
|
|
109
|
+
for node in self._graph.nodes:
|
|
110
|
+
self._dispatch(node)
|
|
111
|
+
return {
|
|
112
|
+
output.name: self._values[output.name]
|
|
113
|
+
for output in self._graph.outputs
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
def _dispatch(self, node: Node) -> None:
|
|
117
|
+
handler = resolve_dispatch(
|
|
118
|
+
node.op_type,
|
|
119
|
+
_EVAL_REGISTRY,
|
|
120
|
+
binary_types=BINARY_OP_TYPES,
|
|
121
|
+
unary_types=UNARY_OP_TYPES,
|
|
122
|
+
binary_fallback=lambda: _eval_binary_unary,
|
|
123
|
+
unary_fallback=lambda: _eval_binary_unary,
|
|
124
|
+
)
|
|
125
|
+
handler(self, node)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@register_evaluator("MatMul")
|
|
129
|
+
def _eval_matmul(evaluator: Evaluator, node: Node) -> None:
|
|
130
|
+
lower_matmul(evaluator.graph, node)
|
|
131
|
+
left = evaluator.values[node.inputs[0]]
|
|
132
|
+
right = evaluator.values[node.inputs[1]]
|
|
133
|
+
evaluator.values[node.outputs[0]] = _apply_matmul(left, right)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@register_evaluator("Clip")
|
|
137
|
+
def _eval_clip(evaluator: Evaluator, node: Node) -> None:
|
|
138
|
+
if not node.inputs or len(node.outputs) != 1:
|
|
139
|
+
raise UnsupportedOpError("Clip must have 1 output")
|
|
140
|
+
input_name = node.inputs[0]
|
|
141
|
+
if not input_name:
|
|
142
|
+
raise UnsupportedOpError("Clip input must be provided")
|
|
143
|
+
x = evaluator.values[input_name]
|
|
144
|
+
min_name = optional_name(node.inputs, 1)
|
|
145
|
+
max_name = optional_name(node.inputs, 2)
|
|
146
|
+
dtype = value_dtype(evaluator.graph, input_name, node)
|
|
147
|
+
if min_name is None:
|
|
148
|
+
min_val = (
|
|
149
|
+
-np.inf
|
|
150
|
+
if dtype.is_float
|
|
151
|
+
else np.iinfo(dtype.np_dtype).min
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
min_val = evaluator.values[min_name]
|
|
155
|
+
if max_name is None:
|
|
156
|
+
max_val = (
|
|
157
|
+
np.inf
|
|
158
|
+
if dtype.is_float
|
|
159
|
+
else np.iinfo(dtype.np_dtype).max
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
max_val = evaluator.values[max_name]
|
|
163
|
+
evaluator.values[node.outputs[0]] = np.clip(x, min_val, max_val)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _exclusive_cumsum(data: np.ndarray, axis: int) -> np.ndarray:
|
|
167
|
+
result = np.zeros_like(data)
|
|
168
|
+
if data.shape[axis] == 0:
|
|
169
|
+
return result
|
|
170
|
+
cumsum = np.cumsum(data, axis=axis, dtype=data.dtype)
|
|
171
|
+
src_slice = [slice(None)] * data.ndim
|
|
172
|
+
dst_slice = [slice(None)] * data.ndim
|
|
173
|
+
src_slice[axis] = slice(None, -1)
|
|
174
|
+
dst_slice[axis] = slice(1, None)
|
|
175
|
+
result[tuple(dst_slice)] = cumsum[tuple(src_slice)]
|
|
176
|
+
return result
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@register_evaluator("CumSum")
|
|
180
|
+
def _eval_cumsum(evaluator: Evaluator, node: Node) -> None:
|
|
181
|
+
op = lower_cumsum(evaluator.graph, node)
|
|
182
|
+
x = evaluator.values[op.input0]
|
|
183
|
+
axis = op.axis
|
|
184
|
+
if axis is None:
|
|
185
|
+
axis_values = evaluator.values[op.axis_input].astype(np.int64, copy=False)
|
|
186
|
+
axis_values = axis_values.reshape(-1)
|
|
187
|
+
if axis_values.size != 1:
|
|
188
|
+
raise UnsupportedOpError("CumSum axis input must be scalar")
|
|
189
|
+
axis = normalize_axis(int(axis_values[0]), op.input_shape, node)
|
|
190
|
+
data = np.flip(x, axis=axis) if op.reverse else x
|
|
191
|
+
if op.exclusive:
|
|
192
|
+
result = _exclusive_cumsum(data, axis)
|
|
193
|
+
else:
|
|
194
|
+
result = np.cumsum(data, axis=axis, dtype=data.dtype)
|
|
195
|
+
if op.reverse:
|
|
196
|
+
result = np.flip(result, axis=axis)
|
|
197
|
+
evaluator.values[op.output] = result
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@register_evaluator("Pad")
|
|
201
|
+
def _eval_pad(evaluator: Evaluator, node: Node) -> None:
|
|
202
|
+
op = lower_pad(evaluator.graph, node)
|
|
203
|
+
x = evaluator.values[op.input0]
|
|
204
|
+
if op.value_input is not None:
|
|
205
|
+
value_array = evaluator.values[op.value_input]
|
|
206
|
+
pad_value = np.array(value_array, dtype=op.dtype.np_dtype).reshape(-1)[0].item()
|
|
207
|
+
else:
|
|
208
|
+
pad_value = np.array(op.value, dtype=op.dtype.np_dtype).item()
|
|
209
|
+
rank = len(op.input_shape)
|
|
210
|
+
if op.axes_input is not None:
|
|
211
|
+
axes_values = evaluator.values[op.axes_input].astype(
|
|
212
|
+
np.int64, copy=False
|
|
213
|
+
)
|
|
214
|
+
axes_values = axes_values.reshape(-1)
|
|
215
|
+
if op.pads_input is not None:
|
|
216
|
+
pads_values = evaluator.values[op.pads_input].astype(
|
|
217
|
+
np.int64, copy=False
|
|
218
|
+
)
|
|
219
|
+
pads_values = pads_values.reshape(-1)
|
|
220
|
+
else:
|
|
221
|
+
pads_values = np.array(op.pads_values, dtype=np.int64).reshape(-1)
|
|
222
|
+
axis_count = len(axes_values)
|
|
223
|
+
pads_begin = np.zeros(rank, dtype=np.int64)
|
|
224
|
+
pads_end = np.zeros(rank, dtype=np.int64)
|
|
225
|
+
for index, axis_value in enumerate(axes_values):
|
|
226
|
+
axis = int(axis_value)
|
|
227
|
+
if axis < 0:
|
|
228
|
+
axis += rank
|
|
229
|
+
pads_begin[axis] = int(pads_values[index])
|
|
230
|
+
pads_end[axis] = int(pads_values[index + axis_count])
|
|
231
|
+
pad_width = tuple(
|
|
232
|
+
(int(pads_begin[index]), int(pads_end[index]))
|
|
233
|
+
for index in range(rank)
|
|
234
|
+
)
|
|
235
|
+
elif op.pads_input is not None:
|
|
236
|
+
pads_values = evaluator.values[op.pads_input].astype(np.int64, copy=False)
|
|
237
|
+
pads_values = pads_values.reshape(-1)
|
|
238
|
+
if op.pads_axis_map is not None:
|
|
239
|
+
axis_count = sum(
|
|
240
|
+
1 for axis_index in op.pads_axis_map if axis_index is not None
|
|
241
|
+
)
|
|
242
|
+
pads_begin = np.zeros(rank, dtype=np.int64)
|
|
243
|
+
pads_end = np.zeros(rank, dtype=np.int64)
|
|
244
|
+
for axis, pad_index in enumerate(op.pads_axis_map):
|
|
245
|
+
if pad_index is None:
|
|
246
|
+
continue
|
|
247
|
+
pads_begin[axis] = int(pads_values[pad_index])
|
|
248
|
+
pads_end[axis] = int(pads_values[pad_index + axis_count])
|
|
249
|
+
pad_width = tuple(
|
|
250
|
+
(int(pads_begin[index]), int(pads_end[index]))
|
|
251
|
+
for index in range(rank)
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
pads_begin = pads_values[:rank]
|
|
255
|
+
pads_end = pads_values[rank: rank * 2]
|
|
256
|
+
pad_width = tuple(
|
|
257
|
+
(int(pads_begin[index]), int(pads_end[index]))
|
|
258
|
+
for index in range(rank)
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
pad_width = tuple(zip(op.pads_begin or (), op.pads_end or ()))
|
|
262
|
+
pad_kwargs = {}
|
|
263
|
+
if op.mode == "constant":
|
|
264
|
+
pad_kwargs["constant_values"] = pad_value
|
|
265
|
+
evaluator.values[op.output] = np.pad(
|
|
266
|
+
x,
|
|
267
|
+
pad_width,
|
|
268
|
+
mode=op.mode,
|
|
269
|
+
**pad_kwargs,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@register_evaluator("Celu")
|
|
274
|
+
def _eval_celu(evaluator: Evaluator, node: Node) -> None:
|
|
275
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
276
|
+
raise UnsupportedOpError("Celu must have 1 input and 1 output")
|
|
277
|
+
dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
278
|
+
if not dtype.is_float:
|
|
279
|
+
raise UnsupportedOpError("Celu only supports floating-point inputs")
|
|
280
|
+
alpha = float(node.attrs.get("alpha", 1.0))
|
|
281
|
+
x = evaluator.values[node.inputs[0]]
|
|
282
|
+
evaluator.values[node.outputs[0]] = np.where(
|
|
283
|
+
x > 0,
|
|
284
|
+
x,
|
|
285
|
+
alpha * (np.exp(x / alpha) - 1.0),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@register_evaluator("Swish")
|
|
290
|
+
def _eval_swish(evaluator: Evaluator, node: Node) -> None:
|
|
291
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
292
|
+
raise UnsupportedOpError("Swish must have 1 input and 1 output")
|
|
293
|
+
dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
294
|
+
if not dtype.is_float:
|
|
295
|
+
raise UnsupportedOpError("Swish only supports floating-point inputs")
|
|
296
|
+
alpha = float(node.attrs.get("alpha", 1.0))
|
|
297
|
+
x = evaluator.values[node.inputs[0]]
|
|
298
|
+
evaluator.values[node.outputs[0]] = x / (1.0 + np.exp(-alpha * x))
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _grid_sample_denormalize(
|
|
302
|
+
value: float, length: int, *, align_corners: bool
|
|
303
|
+
) -> float:
|
|
304
|
+
if align_corners:
|
|
305
|
+
return (value + 1.0) * (length - 1) / 2.0
|
|
306
|
+
return ((value + 1.0) * length - 1.0) / 2.0
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _grid_sample_reflect(value: float, x_min: float, x_max: float) -> float:
|
|
310
|
+
rng = x_max - x_min
|
|
311
|
+
if rng == 0:
|
|
312
|
+
return x_min
|
|
313
|
+
if value < x_min:
|
|
314
|
+
dx = x_min - value
|
|
315
|
+
n = int(dx / rng)
|
|
316
|
+
r = dx - n * rng
|
|
317
|
+
return x_min + r if n % 2 == 0 else x_max - r
|
|
318
|
+
if value > x_max:
|
|
319
|
+
dx = value - x_max
|
|
320
|
+
n = int(dx / rng)
|
|
321
|
+
r = dx - n * rng
|
|
322
|
+
return x_max - r if n % 2 == 0 else x_min + r
|
|
323
|
+
return value
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _grid_sample_border(
|
|
327
|
+
dims: tuple[int, ...], *, align_corners: bool
|
|
328
|
+
) -> tuple[list[float], list[float]]:
|
|
329
|
+
min_vals: list[float] = []
|
|
330
|
+
max_vals: list[float] = []
|
|
331
|
+
for dim in dims:
|
|
332
|
+
if align_corners:
|
|
333
|
+
min_vals.append(0.0)
|
|
334
|
+
max_vals.append(dim - 1.0)
|
|
335
|
+
else:
|
|
336
|
+
min_vals.append(-0.5)
|
|
337
|
+
max_vals.append(dim - 0.5)
|
|
338
|
+
return min_vals, max_vals
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def _grid_sample_pixel_at(
|
|
342
|
+
data: np.ndarray,
|
|
343
|
+
indices: list[int],
|
|
344
|
+
border_min: list[float],
|
|
345
|
+
border_max: list[float],
|
|
346
|
+
padding_mode: str,
|
|
347
|
+
) -> float:
|
|
348
|
+
if padding_mode == "zeros":
|
|
349
|
+
for idx, dim in zip(indices, data.shape):
|
|
350
|
+
if idx < 0 or idx >= dim:
|
|
351
|
+
return data.dtype.type(0)
|
|
352
|
+
return data[tuple(indices)]
|
|
353
|
+
if padding_mode == "border":
|
|
354
|
+
clamped = [
|
|
355
|
+
0 if idx < 0 else dim - 1 if idx >= dim else idx
|
|
356
|
+
for idx, dim in zip(indices, data.shape)
|
|
357
|
+
]
|
|
358
|
+
return data[tuple(clamped)]
|
|
359
|
+
reflected = [
|
|
360
|
+
int(_grid_sample_reflect(idx, border_min[i], border_max[i]))
|
|
361
|
+
for i, idx in enumerate(indices)
|
|
362
|
+
]
|
|
363
|
+
return data[tuple(reflected)]
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _grid_sample_linear_1d(
|
|
367
|
+
data: np.ndarray,
|
|
368
|
+
coord: float,
|
|
369
|
+
border_min: float,
|
|
370
|
+
border_max: float,
|
|
371
|
+
padding_mode: str,
|
|
372
|
+
) -> float:
|
|
373
|
+
base = int(np.floor(coord))
|
|
374
|
+
weight = coord - base
|
|
375
|
+
lower = _grid_sample_pixel_at(
|
|
376
|
+
data, [base], [border_min], [border_max], padding_mode
|
|
377
|
+
)
|
|
378
|
+
upper = _grid_sample_pixel_at(
|
|
379
|
+
data, [base + 1], [border_min], [border_max], padding_mode
|
|
380
|
+
)
|
|
381
|
+
return (1.0 - weight) * lower + weight * upper
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _grid_sample_cubic_coeffs(x: float) -> np.ndarray:
|
|
385
|
+
alpha = -0.75
|
|
386
|
+
abs_x = abs(x)
|
|
387
|
+
coeffs = np.empty((4,), dtype=np.float64)
|
|
388
|
+
coeffs[0] = (
|
|
389
|
+
(alpha * (abs_x + 1.0) - 5.0 * alpha) * (abs_x + 1.0) + 8.0 * alpha
|
|
390
|
+
) * (abs_x + 1.0) - 4.0 * alpha
|
|
391
|
+
coeffs[1] = ((alpha + 2.0) * abs_x - (alpha + 3.0)) * abs_x * abs_x + 1.0
|
|
392
|
+
inv_x = 1.0 - abs_x
|
|
393
|
+
coeffs[2] = ((alpha + 2.0) * inv_x - (alpha + 3.0)) * inv_x * inv_x + 1.0
|
|
394
|
+
span = 2.0 - abs_x
|
|
395
|
+
coeffs[3] = (
|
|
396
|
+
(alpha * span - 5.0 * alpha) * span + 8.0 * alpha
|
|
397
|
+
) * span - 4.0 * alpha
|
|
398
|
+
return coeffs
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _grid_sample_cubic_1d(
|
|
402
|
+
data: np.ndarray,
|
|
403
|
+
coord: float,
|
|
404
|
+
border_min: float,
|
|
405
|
+
border_max: float,
|
|
406
|
+
padding_mode: str,
|
|
407
|
+
) -> float:
|
|
408
|
+
base = int(np.floor(coord))
|
|
409
|
+
coeffs = _grid_sample_cubic_coeffs(coord - base)
|
|
410
|
+
values = np.empty((4,), dtype=np.float64)
|
|
411
|
+
for offset in range(4):
|
|
412
|
+
values[offset] = _grid_sample_pixel_at(
|
|
413
|
+
data,
|
|
414
|
+
[base - 1 + offset],
|
|
415
|
+
[border_min],
|
|
416
|
+
[border_max],
|
|
417
|
+
padding_mode,
|
|
418
|
+
)
|
|
419
|
+
return float(coeffs @ values)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _grid_sample_linear_nd(
|
|
423
|
+
data: np.ndarray,
|
|
424
|
+
coords: np.ndarray,
|
|
425
|
+
border_min: list[float],
|
|
426
|
+
border_max: list[float],
|
|
427
|
+
padding_mode: str,
|
|
428
|
+
) -> float:
|
|
429
|
+
if data.ndim == 1:
|
|
430
|
+
return _grid_sample_linear_1d(
|
|
431
|
+
data, float(coords[0]), border_min[0], border_max[0], padding_mode
|
|
432
|
+
)
|
|
433
|
+
reduced = np.array(
|
|
434
|
+
[
|
|
435
|
+
_grid_sample_linear_nd(
|
|
436
|
+
data[index],
|
|
437
|
+
coords[1:],
|
|
438
|
+
border_min[1:],
|
|
439
|
+
border_max[1:],
|
|
440
|
+
padding_mode,
|
|
441
|
+
)
|
|
442
|
+
for index in range(data.shape[0])
|
|
443
|
+
],
|
|
444
|
+
dtype=np.float64,
|
|
445
|
+
)
|
|
446
|
+
return _grid_sample_linear_1d(
|
|
447
|
+
reduced, float(coords[0]), border_min[0], border_max[0], padding_mode
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _grid_sample_cubic_nd(
|
|
452
|
+
data: np.ndarray,
|
|
453
|
+
coords: np.ndarray,
|
|
454
|
+
border_min: list[float],
|
|
455
|
+
border_max: list[float],
|
|
456
|
+
padding_mode: str,
|
|
457
|
+
) -> float:
|
|
458
|
+
if data.ndim == 1:
|
|
459
|
+
return _grid_sample_cubic_1d(
|
|
460
|
+
data, float(coords[0]), border_min[0], border_max[0], padding_mode
|
|
461
|
+
)
|
|
462
|
+
reduced = np.array(
|
|
463
|
+
[
|
|
464
|
+
_grid_sample_cubic_nd(
|
|
465
|
+
data[index],
|
|
466
|
+
coords[1:],
|
|
467
|
+
border_min[1:],
|
|
468
|
+
border_max[1:],
|
|
469
|
+
padding_mode,
|
|
470
|
+
)
|
|
471
|
+
for index in range(data.shape[0])
|
|
472
|
+
],
|
|
473
|
+
dtype=np.float64,
|
|
474
|
+
)
|
|
475
|
+
return _grid_sample_cubic_1d(
|
|
476
|
+
reduced, float(coords[0]), border_min[0], border_max[0], padding_mode
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
@register_evaluator("GridSample")
|
|
481
|
+
def _eval_grid_sample(evaluator: Evaluator, node: Node) -> None:
|
|
482
|
+
op = lower_grid_sample(evaluator.graph, node)
|
|
483
|
+
input_data = evaluator.values[op.input0]
|
|
484
|
+
grid_data = evaluator.values[op.grid]
|
|
485
|
+
output = np.empty(op.output_shape, dtype=input_data.dtype)
|
|
486
|
+
if output.size == 0:
|
|
487
|
+
evaluator.values[op.output] = output
|
|
488
|
+
return
|
|
489
|
+
dims = op.input_spatial
|
|
490
|
+
border_min, border_max = _grid_sample_border(
|
|
491
|
+
dims, align_corners=op.align_corners
|
|
492
|
+
)
|
|
493
|
+
for n in range(op.output_shape[0]):
|
|
494
|
+
grid_batch = grid_data[n]
|
|
495
|
+
for c in range(op.output_shape[1]):
|
|
496
|
+
input_slice = input_data[n, c]
|
|
497
|
+
for out_idx in np.ndindex(*op.output_spatial):
|
|
498
|
+
coords = np.array(
|
|
499
|
+
grid_batch[out_idx][::-1], dtype=np.float64
|
|
500
|
+
)
|
|
501
|
+
for i, dim in enumerate(dims):
|
|
502
|
+
coords[i] = _grid_sample_denormalize(
|
|
503
|
+
float(coords[i]), dim, align_corners=op.align_corners
|
|
504
|
+
)
|
|
505
|
+
if op.mode == "nearest":
|
|
506
|
+
rounded = np.rint(coords).astype(int)
|
|
507
|
+
if op.padding_mode != "zeros":
|
|
508
|
+
for i, dim in enumerate(dims):
|
|
509
|
+
if (
|
|
510
|
+
rounded[i] < border_min[i]
|
|
511
|
+
or rounded[i] > border_max[i]
|
|
512
|
+
):
|
|
513
|
+
if op.padding_mode == "border":
|
|
514
|
+
rounded[i] = min(
|
|
515
|
+
max(rounded[i], 0), dim - 1
|
|
516
|
+
)
|
|
517
|
+
else:
|
|
518
|
+
rounded[i] = int(
|
|
519
|
+
_grid_sample_reflect(
|
|
520
|
+
rounded[i],
|
|
521
|
+
border_min[i],
|
|
522
|
+
border_max[i],
|
|
523
|
+
)
|
|
524
|
+
)
|
|
525
|
+
value = _grid_sample_pixel_at(
|
|
526
|
+
input_slice,
|
|
527
|
+
rounded.tolist(),
|
|
528
|
+
border_min,
|
|
529
|
+
border_max,
|
|
530
|
+
op.padding_mode,
|
|
531
|
+
)
|
|
532
|
+
else:
|
|
533
|
+
if op.padding_mode != "zeros":
|
|
534
|
+
for i, dim in enumerate(dims):
|
|
535
|
+
if (
|
|
536
|
+
coords[i] < border_min[i]
|
|
537
|
+
or coords[i] > border_max[i]
|
|
538
|
+
):
|
|
539
|
+
if op.padding_mode == "border":
|
|
540
|
+
coords[i] = min(
|
|
541
|
+
max(coords[i], 0.0), dim - 1.0
|
|
542
|
+
)
|
|
543
|
+
else:
|
|
544
|
+
coords[i] = _grid_sample_reflect(
|
|
545
|
+
coords[i],
|
|
546
|
+
border_min[i],
|
|
547
|
+
border_max[i],
|
|
548
|
+
)
|
|
549
|
+
if op.mode == "linear":
|
|
550
|
+
value = _grid_sample_linear_nd(
|
|
551
|
+
input_slice,
|
|
552
|
+
coords,
|
|
553
|
+
border_min,
|
|
554
|
+
border_max,
|
|
555
|
+
op.padding_mode,
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
value = _grid_sample_cubic_nd(
|
|
559
|
+
input_slice,
|
|
560
|
+
coords,
|
|
561
|
+
border_min,
|
|
562
|
+
border_max,
|
|
563
|
+
op.padding_mode,
|
|
564
|
+
)
|
|
565
|
+
output[(n, c, *out_idx)] = value
|
|
566
|
+
evaluator.values[op.output] = output
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
_VARIADIC_COMBINE_FUNCS: dict[
|
|
570
|
+
ScalarFunction, Callable[[np.ndarray, np.ndarray], np.ndarray]
|
|
571
|
+
] = {
|
|
572
|
+
ScalarFunction.ADD: np.add,
|
|
573
|
+
ScalarFunction.MAXIMUM: np.maximum,
|
|
574
|
+
ScalarFunction.MINIMUM: np.minimum,
|
|
575
|
+
ScalarFunction.LOGICAL_AND: np.logical_and,
|
|
576
|
+
ScalarFunction.LOGICAL_OR: np.logical_or,
|
|
577
|
+
ScalarFunction.LOGICAL_XOR: np.logical_xor,
|
|
578
|
+
ScalarFunction.BITWISE_AND: np.bitwise_and,
|
|
579
|
+
ScalarFunction.BITWISE_OR: np.bitwise_or,
|
|
580
|
+
ScalarFunction.BITWISE_XOR: np.bitwise_xor,
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def _validate_variadic_inputs(
|
|
585
|
+
evaluator: Evaluator, node: Node, *, function: ScalarFunction
|
|
586
|
+
) -> tuple[ScalarType, tuple[int, ...]]:
|
|
587
|
+
if len(node.outputs) != 1:
|
|
588
|
+
raise UnsupportedOpError(f"{node.op_type} must have 1 output")
|
|
589
|
+
if node.op_type in BINARY_ONLY_OPS:
|
|
590
|
+
if len(node.inputs) != 2:
|
|
591
|
+
raise UnsupportedOpError(
|
|
592
|
+
f"{node.op_type} must have exactly 2 inputs"
|
|
593
|
+
)
|
|
594
|
+
elif len(node.inputs) < 2:
|
|
595
|
+
raise UnsupportedOpError(
|
|
596
|
+
f"{node.op_type} must have at least 2 inputs"
|
|
597
|
+
)
|
|
598
|
+
for name in node.inputs:
|
|
599
|
+
if not name:
|
|
600
|
+
raise UnsupportedOpError(f"{node.op_type} input must be provided")
|
|
601
|
+
op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
|
|
602
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
603
|
+
if op_dtype != output_dtype:
|
|
604
|
+
raise UnsupportedOpError(
|
|
605
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
606
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
607
|
+
)
|
|
608
|
+
output_shape = value_shape(evaluator.graph, node.outputs[0], node)
|
|
609
|
+
for name in node.inputs:
|
|
610
|
+
input_shape = value_shape(evaluator.graph, name, node)
|
|
611
|
+
if input_shape != output_shape:
|
|
612
|
+
raise UnsupportedOpError(
|
|
613
|
+
f"{node.op_type} expects identical input/output shapes"
|
|
614
|
+
)
|
|
615
|
+
if function in {
|
|
616
|
+
ScalarFunction.LOGICAL_AND,
|
|
617
|
+
ScalarFunction.LOGICAL_OR,
|
|
618
|
+
ScalarFunction.LOGICAL_XOR,
|
|
619
|
+
} and op_dtype != ScalarType.BOOL:
|
|
620
|
+
raise UnsupportedOpError(f"{node.op_type} expects bool inputs")
|
|
621
|
+
if function in {
|
|
622
|
+
ScalarFunction.BITWISE_AND,
|
|
623
|
+
ScalarFunction.BITWISE_OR,
|
|
624
|
+
ScalarFunction.BITWISE_XOR,
|
|
625
|
+
} and not op_dtype.is_integer:
|
|
626
|
+
raise UnsupportedOpError(f"{node.op_type} expects integer inputs")
|
|
627
|
+
if function == ScalarFunction.MEAN and not op_dtype.is_float:
|
|
628
|
+
raise UnsupportedOpError(f"{node.op_type} expects floating-point inputs")
|
|
629
|
+
return op_dtype, output_shape
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def _eval_variadic(evaluator: Evaluator, node: Node) -> None:
|
|
633
|
+
function = VARIADIC_OP_FUNCTIONS[node.op_type]
|
|
634
|
+
_validate_variadic_inputs(evaluator, node, function=function)
|
|
635
|
+
values = [evaluator.values[name] for name in node.inputs]
|
|
636
|
+
if function == ScalarFunction.MEAN:
|
|
637
|
+
combine_func = _VARIADIC_COMBINE_FUNCS[ScalarFunction.ADD]
|
|
638
|
+
else:
|
|
639
|
+
combine_func = _VARIADIC_COMBINE_FUNCS[function]
|
|
640
|
+
result = values[0]
|
|
641
|
+
for value in values[1:]:
|
|
642
|
+
result = combine_func(result, value)
|
|
643
|
+
if function == ScalarFunction.MEAN:
|
|
644
|
+
result = result / len(values)
|
|
645
|
+
evaluator.values[node.outputs[0]] = result
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
for _op_type in VARIADIC_OP_FUNCTIONS:
|
|
649
|
+
register_evaluator(_op_type)(_eval_variadic)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
@register_evaluator("Shrink")
|
|
653
|
+
def _eval_shrink(evaluator: Evaluator, node: Node) -> None:
|
|
654
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
655
|
+
raise UnsupportedOpError("Shrink must have 1 input and 1 output")
|
|
656
|
+
bias = float(node.attrs.get("bias", 0.0))
|
|
657
|
+
lambd = float(node.attrs.get("lambd", 0.5))
|
|
658
|
+
x = evaluator.values[node.inputs[0]]
|
|
659
|
+
result = np.where(
|
|
660
|
+
x < -lambd,
|
|
661
|
+
x + bias,
|
|
662
|
+
np.where(x > lambd, x - bias, 0.0),
|
|
663
|
+
)
|
|
664
|
+
if result.dtype != x.dtype:
|
|
665
|
+
result = result.astype(x.dtype)
|
|
666
|
+
evaluator.values[node.outputs[0]] = result
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
@register_evaluator("IsInf")
|
|
670
|
+
def _eval_isinf(evaluator: Evaluator, node: Node) -> None:
|
|
671
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
672
|
+
raise UnsupportedOpError("IsInf must have 1 input and 1 output")
|
|
673
|
+
input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
674
|
+
if not input_dtype.is_float:
|
|
675
|
+
raise UnsupportedOpError("IsInf only supports floating-point inputs")
|
|
676
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
677
|
+
if output_dtype != ScalarType.BOOL:
|
|
678
|
+
raise UnsupportedOpError("IsInf output must be bool")
|
|
679
|
+
x = evaluator.values[node.inputs[0]]
|
|
680
|
+
evaluator.values[node.outputs[0]] = np.isinf(x)
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
@register_evaluator("IsNaN")
|
|
684
|
+
def _eval_isnan(evaluator: Evaluator, node: Node) -> None:
|
|
685
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
686
|
+
raise UnsupportedOpError("IsNaN must have 1 input and 1 output")
|
|
687
|
+
input_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
688
|
+
if not input_dtype.is_float:
|
|
689
|
+
raise UnsupportedOpError("IsNaN only supports floating-point inputs")
|
|
690
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
691
|
+
if output_dtype != ScalarType.BOOL:
|
|
692
|
+
raise UnsupportedOpError("IsNaN output must be bool")
|
|
693
|
+
x = evaluator.values[node.inputs[0]]
|
|
694
|
+
evaluator.values[node.outputs[0]] = np.isnan(x)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
@register_evaluator("Gemm")
|
|
698
|
+
def _eval_gemm(evaluator: Evaluator, node: Node) -> None:
|
|
699
|
+
op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
700
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
701
|
+
if op_dtype != output_dtype:
|
|
702
|
+
raise UnsupportedOpError(
|
|
703
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
704
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
705
|
+
)
|
|
706
|
+
spec = resolve_gemm_spec(evaluator.graph, node, op_dtype)
|
|
707
|
+
left = evaluator.values[node.inputs[0]]
|
|
708
|
+
right = evaluator.values[node.inputs[1]]
|
|
709
|
+
if spec.trans_a:
|
|
710
|
+
left = left.T
|
|
711
|
+
if spec.trans_b:
|
|
712
|
+
right = right.T
|
|
713
|
+
result = _apply_matmul(left, right)
|
|
714
|
+
if op_dtype.is_float:
|
|
715
|
+
alpha = float(spec.alpha)
|
|
716
|
+
beta = float(spec.beta)
|
|
717
|
+
else:
|
|
718
|
+
alpha = int(spec.alpha)
|
|
719
|
+
beta = int(spec.beta)
|
|
720
|
+
if alpha != 1:
|
|
721
|
+
result = result * alpha
|
|
722
|
+
if len(node.inputs) == 3:
|
|
723
|
+
bias = evaluator.values[node.inputs[2]]
|
|
724
|
+
if beta != 1:
|
|
725
|
+
bias = bias * beta
|
|
726
|
+
result = result + bias
|
|
727
|
+
evaluator.values[node.outputs[0]] = result
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
@register_evaluator("Cast")
|
|
731
|
+
def _eval_cast(evaluator: Evaluator, node: Node) -> None:
|
|
732
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
733
|
+
raise UnsupportedOpError("Cast must have 1 input and 1 output")
|
|
734
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
735
|
+
input_value = evaluator.values[node.inputs[0]]
|
|
736
|
+
evaluator.values[node.outputs[0]] = input_value.astype(
|
|
737
|
+
output_dtype.np_dtype, copy=False
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
@register_evaluator("CastLike")
|
|
742
|
+
def _eval_castlike(evaluator: Evaluator, node: Node) -> None:
|
|
743
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
744
|
+
raise UnsupportedOpError("CastLike must have 2 inputs and 1 output")
|
|
745
|
+
like_dtype = value_dtype(evaluator.graph, node.inputs[1], node)
|
|
746
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
747
|
+
if output_dtype != like_dtype:
|
|
748
|
+
raise UnsupportedOpError(
|
|
749
|
+
"CastLike output dtype must match like input dtype, "
|
|
750
|
+
f"got {output_dtype.onnx_name} and {like_dtype.onnx_name}"
|
|
751
|
+
)
|
|
752
|
+
input_value = evaluator.values[node.inputs[0]]
|
|
753
|
+
evaluator.values[node.outputs[0]] = input_value.astype(
|
|
754
|
+
output_dtype.np_dtype, copy=False
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
@register_evaluator("Identity")
|
|
759
|
+
def _eval_identity(evaluator: Evaluator, node: Node) -> None:
|
|
760
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
761
|
+
raise UnsupportedOpError("Identity must have 1 input and 1 output")
|
|
762
|
+
value = evaluator.values[node.inputs[0]]
|
|
763
|
+
evaluator.values[node.outputs[0]] = np.array(value, copy=True)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
@register_evaluator("EyeLike")
|
|
767
|
+
def _eval_eye_like(evaluator: Evaluator, node: Node) -> None:
|
|
768
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
769
|
+
raise UnsupportedOpError("EyeLike must have 1 input and 1 output")
|
|
770
|
+
output_shape = value_shape(evaluator.graph, node.outputs[0], node)
|
|
771
|
+
if len(output_shape) < 2:
|
|
772
|
+
raise UnsupportedOpError("EyeLike expects input rank >= 2")
|
|
773
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
774
|
+
k = int(node.attrs.get("k", 0))
|
|
775
|
+
output = np.zeros(output_shape, dtype=output_dtype.np_dtype)
|
|
776
|
+
rows, cols = output_shape[-2], output_shape[-1]
|
|
777
|
+
row_start = 0 if k >= 0 else -k
|
|
778
|
+
col_start = k if k >= 0 else 0
|
|
779
|
+
if row_start < rows and col_start < cols:
|
|
780
|
+
diag_len = min(rows - row_start, cols - col_start)
|
|
781
|
+
batch_size = int(np.prod(output_shape[:-2])) if output_shape[:-2] else 1
|
|
782
|
+
view = output.reshape(batch_size, rows, cols)
|
|
783
|
+
diag_idx = np.arange(diag_len, dtype=np.int64)
|
|
784
|
+
one = output_dtype.np_dtype.type(1)
|
|
785
|
+
view[:, row_start + diag_idx, col_start + diag_idx] = one
|
|
786
|
+
evaluator.values[node.outputs[0]] = output
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
@register_evaluator("Tile")
|
|
790
|
+
def _eval_tile(evaluator: Evaluator, node: Node) -> None:
|
|
791
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
792
|
+
raise UnsupportedOpError("Tile must have 2 inputs and 1 output")
|
|
793
|
+
value = evaluator.values[node.inputs[0]]
|
|
794
|
+
repeats = evaluator.values[node.inputs[1]]
|
|
795
|
+
repeats = np.array(repeats, dtype=np.int64).reshape(-1)
|
|
796
|
+
if repeats.size != value.ndim:
|
|
797
|
+
raise UnsupportedOpError(
|
|
798
|
+
"Tile repeats must have the same rank as input shape"
|
|
799
|
+
)
|
|
800
|
+
evaluator.values[node.outputs[0]] = np.tile(value, repeats)
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
@register_evaluator("DepthToSpace")
|
|
804
|
+
def _eval_depth_to_space(evaluator: Evaluator, node: Node) -> None:
|
|
805
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
806
|
+
raise UnsupportedOpError("DepthToSpace must have 1 input and 1 output")
|
|
807
|
+
data = evaluator.values[node.inputs[0]]
|
|
808
|
+
if data.ndim != 4:
|
|
809
|
+
raise UnsupportedOpError("DepthToSpace only supports 4D inputs")
|
|
810
|
+
blocksize = int(node.attrs.get("blocksize", 0))
|
|
811
|
+
if blocksize <= 0:
|
|
812
|
+
raise UnsupportedOpError(
|
|
813
|
+
f"DepthToSpace blocksize must be > 0, got {blocksize}"
|
|
814
|
+
)
|
|
815
|
+
mode_attr = node.attrs.get("mode", "DCR")
|
|
816
|
+
if isinstance(mode_attr, bytes):
|
|
817
|
+
mode = mode_attr.decode()
|
|
818
|
+
else:
|
|
819
|
+
mode = str(mode_attr)
|
|
820
|
+
if mode not in {"DCR", "CRD"}:
|
|
821
|
+
raise UnsupportedOpError("DepthToSpace only supports mode DCR or CRD")
|
|
822
|
+
b, c, h, w = data.shape
|
|
823
|
+
if mode == "DCR":
|
|
824
|
+
tmpshape = (
|
|
825
|
+
b,
|
|
826
|
+
blocksize,
|
|
827
|
+
blocksize,
|
|
828
|
+
c // (blocksize * blocksize),
|
|
829
|
+
h,
|
|
830
|
+
w,
|
|
831
|
+
)
|
|
832
|
+
reshaped = data.reshape(tmpshape)
|
|
833
|
+
transposed = np.transpose(reshaped, [0, 3, 4, 1, 5, 2])
|
|
834
|
+
else:
|
|
835
|
+
tmpshape = (
|
|
836
|
+
b,
|
|
837
|
+
c // (blocksize * blocksize),
|
|
838
|
+
blocksize,
|
|
839
|
+
blocksize,
|
|
840
|
+
h,
|
|
841
|
+
w,
|
|
842
|
+
)
|
|
843
|
+
reshaped = data.reshape(tmpshape)
|
|
844
|
+
transposed = np.transpose(reshaped, [0, 1, 4, 2, 5, 3])
|
|
845
|
+
finalshape = (
|
|
846
|
+
b,
|
|
847
|
+
c // (blocksize * blocksize),
|
|
848
|
+
h * blocksize,
|
|
849
|
+
w * blocksize,
|
|
850
|
+
)
|
|
851
|
+
evaluator.values[node.outputs[0]] = np.reshape(transposed, finalshape)
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
@register_evaluator("SpaceToDepth")
|
|
855
|
+
def _eval_space_to_depth(evaluator: Evaluator, node: Node) -> None:
|
|
856
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
857
|
+
raise UnsupportedOpError("SpaceToDepth must have 1 input and 1 output")
|
|
858
|
+
data = evaluator.values[node.inputs[0]]
|
|
859
|
+
if data.ndim != 4:
|
|
860
|
+
raise UnsupportedOpError("SpaceToDepth only supports 4D inputs")
|
|
861
|
+
blocksize = int(node.attrs.get("blocksize", 0))
|
|
862
|
+
if blocksize <= 0:
|
|
863
|
+
raise UnsupportedOpError(
|
|
864
|
+
f"SpaceToDepth blocksize must be > 0, got {blocksize}"
|
|
865
|
+
)
|
|
866
|
+
b, c, h, w = data.shape
|
|
867
|
+
tmpshape = (
|
|
868
|
+
b,
|
|
869
|
+
c,
|
|
870
|
+
h // blocksize,
|
|
871
|
+
blocksize,
|
|
872
|
+
w // blocksize,
|
|
873
|
+
blocksize,
|
|
874
|
+
)
|
|
875
|
+
reshaped = np.reshape(data, tmpshape)
|
|
876
|
+
transposed = np.transpose(reshaped, [0, 3, 5, 1, 2, 4])
|
|
877
|
+
finalshape = (
|
|
878
|
+
b,
|
|
879
|
+
c * blocksize * blocksize,
|
|
880
|
+
h // blocksize,
|
|
881
|
+
w // blocksize,
|
|
882
|
+
)
|
|
883
|
+
evaluator.values[node.outputs[0]] = np.reshape(transposed, finalshape)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
@register_evaluator("Where")
|
|
887
|
+
def _eval_where(evaluator: Evaluator, node: Node) -> None:
|
|
888
|
+
lower_where(evaluator.graph, node)
|
|
889
|
+
condition = evaluator.values[node.inputs[0]]
|
|
890
|
+
x_value = evaluator.values[node.inputs[1]]
|
|
891
|
+
y_value = evaluator.values[node.inputs[2]]
|
|
892
|
+
evaluator.values[node.outputs[0]] = np.where(condition, x_value, y_value)
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
@register_evaluator("GatherElements")
|
|
896
|
+
def _eval_gather_elements(evaluator: Evaluator, node: Node) -> None:
|
|
897
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
898
|
+
raise UnsupportedOpError("GatherElements must have 2 inputs and 1 output")
|
|
899
|
+
data = evaluator.values[node.inputs[0]]
|
|
900
|
+
indices = evaluator.values[node.inputs[1]]
|
|
901
|
+
if indices.dtype.type not in {np.int32, np.int64}:
|
|
902
|
+
raise UnsupportedOpError(
|
|
903
|
+
f"GatherElements indices must be int32 or int64, got {indices.dtype}"
|
|
904
|
+
)
|
|
905
|
+
axis = normalize_axis(int(node.attrs.get("axis", 0)), data.shape, node)
|
|
906
|
+
evaluator.values[node.outputs[0]] = np.take_along_axis(
|
|
907
|
+
data, indices, axis=axis
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
@register_evaluator("Gather")
|
|
912
|
+
def _eval_gather(evaluator: Evaluator, node: Node) -> None:
|
|
913
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
914
|
+
raise UnsupportedOpError("Gather must have 2 inputs and 1 output")
|
|
915
|
+
data = evaluator.values[node.inputs[0]]
|
|
916
|
+
indices = evaluator.values[node.inputs[1]]
|
|
917
|
+
if indices.dtype.type not in {np.int32, np.int64}:
|
|
918
|
+
raise UnsupportedOpError(
|
|
919
|
+
f"Gather indices must be int32 or int64, got {indices.dtype}"
|
|
920
|
+
)
|
|
921
|
+
axis = normalize_axis(int(node.attrs.get("axis", 0)), data.shape, node)
|
|
922
|
+
evaluator.values[node.outputs[0]] = np.take(data, indices, axis=axis)
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
@register_evaluator("Slice")
|
|
926
|
+
def _eval_slice(evaluator: Evaluator, node: Node) -> None:
|
|
927
|
+
input_value = evaluator.values[node.inputs[0]]
|
|
928
|
+
if "starts" in node.attrs or "ends" in node.attrs:
|
|
929
|
+
starts = [int(value) for value in node.attrs.get("starts", [])]
|
|
930
|
+
ends = [int(value) for value in node.attrs.get("ends", [])]
|
|
931
|
+
axes_attr = node.attrs.get("axes")
|
|
932
|
+
axes = [int(value) for value in axes_attr] if axes_attr else None
|
|
933
|
+
steps = None
|
|
934
|
+
else:
|
|
935
|
+
if len(node.inputs) < 3:
|
|
936
|
+
raise UnsupportedOpError(
|
|
937
|
+
f"{node.op_type} expects at least 3 inputs"
|
|
938
|
+
)
|
|
939
|
+
starts_value = evaluator.values[node.inputs[1]]
|
|
940
|
+
ends_value = evaluator.values[node.inputs[2]]
|
|
941
|
+
if starts_value.dtype.type not in {np.int32, np.int64}:
|
|
942
|
+
raise UnsupportedOpError(
|
|
943
|
+
f"{node.op_type} starts input must be int64 or int32"
|
|
944
|
+
)
|
|
945
|
+
if ends_value.dtype.type not in {np.int32, np.int64}:
|
|
946
|
+
raise UnsupportedOpError(
|
|
947
|
+
f"{node.op_type} ends input must be int64 or int32"
|
|
948
|
+
)
|
|
949
|
+
starts = [int(value) for value in starts_value.reshape(-1)]
|
|
950
|
+
ends = [int(value) for value in ends_value.reshape(-1)]
|
|
951
|
+
axes = None
|
|
952
|
+
steps = None
|
|
953
|
+
if len(node.inputs) >= 4 and node.inputs[3]:
|
|
954
|
+
axes_value = evaluator.values[node.inputs[3]]
|
|
955
|
+
if axes_value.dtype.type not in {np.int32, np.int64}:
|
|
956
|
+
raise UnsupportedOpError(
|
|
957
|
+
f"{node.op_type} axes input must be int64 or int32"
|
|
958
|
+
)
|
|
959
|
+
axes = [int(value) for value in axes_value.reshape(-1)]
|
|
960
|
+
if len(node.inputs) >= 5 and node.inputs[4]:
|
|
961
|
+
steps_value = evaluator.values[node.inputs[4]]
|
|
962
|
+
if steps_value.dtype.type not in {np.int32, np.int64}:
|
|
963
|
+
raise UnsupportedOpError(
|
|
964
|
+
f"{node.op_type} steps input must be int64 or int32"
|
|
965
|
+
)
|
|
966
|
+
steps = [int(value) for value in steps_value.reshape(-1)]
|
|
967
|
+
normalized_starts, normalized_steps, output_shape = _normalize_slices(
|
|
968
|
+
input_value.shape, starts, ends, axes, steps, node
|
|
969
|
+
)
|
|
970
|
+
slices = tuple(
|
|
971
|
+
slice(start, start + step * size, step)
|
|
972
|
+
for start, step, size in zip(
|
|
973
|
+
normalized_starts, normalized_steps, output_shape
|
|
974
|
+
)
|
|
975
|
+
)
|
|
976
|
+
evaluator.values[node.outputs[0]] = input_value[slices]
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
@register_evaluator("Attention")
|
|
980
|
+
def _eval_attention(evaluator: Evaluator, node: Node) -> None:
|
|
981
|
+
input_q = node.inputs[0]
|
|
982
|
+
input_k = node.inputs[1]
|
|
983
|
+
input_v = node.inputs[2]
|
|
984
|
+
output_y = node.outputs[0]
|
|
985
|
+
op_dtype = node_dtype(evaluator.graph, node, input_q, input_k, input_v, output_y)
|
|
986
|
+
spec = resolve_attention_spec(evaluator.graph, node, op_dtype)
|
|
987
|
+
attn_mask_name = optional_name(node.inputs, 3)
|
|
988
|
+
past_key_name = optional_name(node.inputs, 4)
|
|
989
|
+
past_value_name = optional_name(node.inputs, 5)
|
|
990
|
+
nonpad_name = optional_name(node.inputs, 6)
|
|
991
|
+
present_key_name = optional_name(node.outputs, 1)
|
|
992
|
+
present_value_name = optional_name(node.outputs, 2)
|
|
993
|
+
qk_matmul_output_name = optional_name(node.outputs, 3)
|
|
994
|
+
output, present_key, present_value, qk_output = _apply_attention(
|
|
995
|
+
spec,
|
|
996
|
+
evaluator.values[input_q],
|
|
997
|
+
evaluator.values[input_k],
|
|
998
|
+
evaluator.values[input_v],
|
|
999
|
+
evaluator.values[attn_mask_name] if attn_mask_name else None,
|
|
1000
|
+
evaluator.values[past_key_name] if past_key_name else None,
|
|
1001
|
+
evaluator.values[past_value_name] if past_value_name else None,
|
|
1002
|
+
evaluator.values[nonpad_name] if nonpad_name else None,
|
|
1003
|
+
)
|
|
1004
|
+
evaluator.values[output_y] = output
|
|
1005
|
+
if present_key_name is not None:
|
|
1006
|
+
evaluator.values[present_key_name] = present_key
|
|
1007
|
+
if present_value_name is not None:
|
|
1008
|
+
evaluator.values[present_value_name] = present_value
|
|
1009
|
+
if qk_matmul_output_name is not None:
|
|
1010
|
+
evaluator.values[qk_matmul_output_name] = qk_output
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
def _apply_lstm_activation(
|
|
1014
|
+
kind: int, value: np.ndarray, alpha: float, beta: float
|
|
1015
|
+
) -> np.ndarray:
|
|
1016
|
+
if kind == ACTIVATION_KIND_BY_NAME["Relu"]:
|
|
1017
|
+
return np.maximum(value, 0)
|
|
1018
|
+
if kind == ACTIVATION_KIND_BY_NAME["Tanh"]:
|
|
1019
|
+
return np.tanh(value)
|
|
1020
|
+
if kind == ACTIVATION_KIND_BY_NAME["Sigmoid"]:
|
|
1021
|
+
return 1 / (1 + np.exp(-value))
|
|
1022
|
+
if kind == ACTIVATION_KIND_BY_NAME["Affine"]:
|
|
1023
|
+
return alpha * value + beta
|
|
1024
|
+
if kind == ACTIVATION_KIND_BY_NAME["LeakyRelu"]:
|
|
1025
|
+
return np.where(value < 0, alpha * value, value)
|
|
1026
|
+
if kind == ACTIVATION_KIND_BY_NAME["ThresholdedRelu"]:
|
|
1027
|
+
return np.where(value > alpha, value, 0)
|
|
1028
|
+
if kind == ACTIVATION_KIND_BY_NAME["ScaledTanh"]:
|
|
1029
|
+
return alpha * np.tanh(beta * value)
|
|
1030
|
+
if kind == ACTIVATION_KIND_BY_NAME["HardSigmoid"]:
|
|
1031
|
+
return np.clip(alpha * value + beta, 0, 1)
|
|
1032
|
+
if kind == ACTIVATION_KIND_BY_NAME["Elu"]:
|
|
1033
|
+
return np.where(value >= 0, value, alpha * (np.exp(value) - 1))
|
|
1034
|
+
if kind == ACTIVATION_KIND_BY_NAME["Softsign"]:
|
|
1035
|
+
return value / (1 + np.abs(value))
|
|
1036
|
+
if kind == ACTIVATION_KIND_BY_NAME["Softplus"]:
|
|
1037
|
+
return np.log1p(np.exp(value))
|
|
1038
|
+
raise UnsupportedOpError(f"Unsupported LSTM activation kind {kind}")
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
@register_evaluator("LSTM")
|
|
1042
|
+
def _eval_lstm(evaluator: Evaluator, node: Node) -> None:
|
|
1043
|
+
spec = resolve_lstm_spec(evaluator.graph, node)
|
|
1044
|
+
inputs = evaluator.values
|
|
1045
|
+
x = inputs[spec.input_x]
|
|
1046
|
+
w = inputs[spec.input_w]
|
|
1047
|
+
r = inputs[spec.input_r]
|
|
1048
|
+
b = inputs[spec.input_b] if spec.input_b is not None else None
|
|
1049
|
+
sequence_lens = (
|
|
1050
|
+
inputs[spec.input_sequence_lens]
|
|
1051
|
+
if spec.input_sequence_lens is not None
|
|
1052
|
+
else None
|
|
1053
|
+
)
|
|
1054
|
+
initial_h = (
|
|
1055
|
+
inputs[spec.input_initial_h]
|
|
1056
|
+
if spec.input_initial_h is not None
|
|
1057
|
+
else None
|
|
1058
|
+
)
|
|
1059
|
+
initial_c = (
|
|
1060
|
+
inputs[spec.input_initial_c]
|
|
1061
|
+
if spec.input_initial_c is not None
|
|
1062
|
+
else None
|
|
1063
|
+
)
|
|
1064
|
+
p = inputs[spec.input_p] if spec.input_p is not None else None
|
|
1065
|
+
output_y, output_y_h, output_y_c = _apply_lstm(
|
|
1066
|
+
spec,
|
|
1067
|
+
x,
|
|
1068
|
+
w,
|
|
1069
|
+
r,
|
|
1070
|
+
b,
|
|
1071
|
+
sequence_lens,
|
|
1072
|
+
initial_h,
|
|
1073
|
+
initial_c,
|
|
1074
|
+
p,
|
|
1075
|
+
)
|
|
1076
|
+
if spec.output_y is not None:
|
|
1077
|
+
evaluator.values[spec.output_y] = output_y
|
|
1078
|
+
if spec.output_y_h is not None:
|
|
1079
|
+
evaluator.values[spec.output_y_h] = output_y_h
|
|
1080
|
+
if spec.output_y_c is not None:
|
|
1081
|
+
evaluator.values[spec.output_y_c] = output_y_c
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
@register_evaluator("Conv")
|
|
1085
|
+
def _eval_conv(evaluator: Evaluator, node: Node) -> None:
|
|
1086
|
+
op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
1087
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
1088
|
+
if op_dtype != output_dtype:
|
|
1089
|
+
raise UnsupportedOpError(
|
|
1090
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
1091
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
1092
|
+
)
|
|
1093
|
+
if not op_dtype.is_float:
|
|
1094
|
+
raise UnsupportedOpError(
|
|
1095
|
+
"Conv supports float16, float, and double inputs only"
|
|
1096
|
+
)
|
|
1097
|
+
spec = resolve_conv_spec(evaluator.graph, node)
|
|
1098
|
+
data = evaluator.values[node.inputs[0]]
|
|
1099
|
+
weights = evaluator.values[node.inputs[1]]
|
|
1100
|
+
bias = evaluator.values[node.inputs[2]] if len(node.inputs) > 2 else None
|
|
1101
|
+
evaluator.values[node.outputs[0]] = _apply_conv(spec, data, weights, bias)
|
|
1102
|
+
|
|
1103
|
+
|
|
1104
|
+
@register_evaluator("BatchNormalization")
|
|
1105
|
+
def _eval_batch_norm(evaluator: Evaluator, node: Node) -> None:
|
|
1106
|
+
op = lower_batch_normalization(evaluator.graph, node)
|
|
1107
|
+
data = evaluator.values[op.input0]
|
|
1108
|
+
scale = evaluator.values[op.scale].reshape(
|
|
1109
|
+
(1, op.channels) + (1,) * (data.ndim - 2)
|
|
1110
|
+
)
|
|
1111
|
+
bias = evaluator.values[op.bias].reshape(
|
|
1112
|
+
(1, op.channels) + (1,) * (data.ndim - 2)
|
|
1113
|
+
)
|
|
1114
|
+
mean = evaluator.values[op.mean].reshape(
|
|
1115
|
+
(1, op.channels) + (1,) * (data.ndim - 2)
|
|
1116
|
+
)
|
|
1117
|
+
variance = evaluator.values[op.variance].reshape(
|
|
1118
|
+
(1, op.channels) + (1,) * (data.ndim - 2)
|
|
1119
|
+
)
|
|
1120
|
+
evaluator.values[op.output] = (
|
|
1121
|
+
(data - mean) / np.sqrt(variance + op.epsilon) * scale + bias
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
@register_evaluator("LpNormalization")
|
|
1126
|
+
def _eval_lp_normalization(evaluator: Evaluator, node: Node) -> None:
|
|
1127
|
+
op = lower_lp_normalization(evaluator.graph, node)
|
|
1128
|
+
data = evaluator.values[op.input0]
|
|
1129
|
+
if op.p == 1:
|
|
1130
|
+
denom = np.sum(np.abs(data), axis=op.axis, keepdims=True)
|
|
1131
|
+
else:
|
|
1132
|
+
denom = np.sqrt(np.sum(data * data, axis=op.axis, keepdims=True))
|
|
1133
|
+
evaluator.values[op.output] = data / denom
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
@register_evaluator("InstanceNormalization")
|
|
1137
|
+
def _eval_instance_normalization(evaluator: Evaluator, node: Node) -> None:
|
|
1138
|
+
op = lower_instance_normalization(evaluator.graph, node)
|
|
1139
|
+
data = evaluator.values[op.input0]
|
|
1140
|
+
axes = tuple(range(2, data.ndim))
|
|
1141
|
+
mean = np.mean(data, axis=axes, keepdims=True)
|
|
1142
|
+
var = np.mean((data - mean) ** 2, axis=axes, keepdims=True)
|
|
1143
|
+
scale = evaluator.values[op.scale].reshape(
|
|
1144
|
+
(1, op.channels) + (1,) * (data.ndim - 2)
|
|
1145
|
+
)
|
|
1146
|
+
bias = evaluator.values[op.bias].reshape(
|
|
1147
|
+
(1, op.channels) + (1,) * (data.ndim - 2)
|
|
1148
|
+
)
|
|
1149
|
+
evaluator.values[op.output] = (
|
|
1150
|
+
(data - mean) / np.sqrt(var + op.epsilon) * scale + bias
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
@register_evaluator("GroupNormalization")
|
|
1155
|
+
def _eval_group_normalization(evaluator: Evaluator, node: Node) -> None:
|
|
1156
|
+
op = lower_group_normalization(evaluator.graph, node)
|
|
1157
|
+
data = evaluator.values[op.input0]
|
|
1158
|
+
batch = data.shape[0]
|
|
1159
|
+
spatial_shape = data.shape[2:]
|
|
1160
|
+
grouped = data.reshape(
|
|
1161
|
+
(batch, op.num_groups, op.group_size) + spatial_shape
|
|
1162
|
+
)
|
|
1163
|
+
axes = tuple(range(2, grouped.ndim))
|
|
1164
|
+
mean = np.mean(grouped, axis=axes, keepdims=True)
|
|
1165
|
+
var = np.mean((grouped - mean) ** 2, axis=axes, keepdims=True)
|
|
1166
|
+
normalized = (grouped - mean) / np.sqrt(var + op.epsilon)
|
|
1167
|
+
normalized = normalized.reshape(data.shape)
|
|
1168
|
+
scale = evaluator.values[op.scale].reshape(
|
|
1169
|
+
(1, op.channels) + (1,) * (data.ndim - 2)
|
|
1170
|
+
)
|
|
1171
|
+
bias = evaluator.values[op.bias].reshape(
|
|
1172
|
+
(1, op.channels) + (1,) * (data.ndim - 2)
|
|
1173
|
+
)
|
|
1174
|
+
evaluator.values[op.output] = normalized * scale + bias
|
|
1175
|
+
|
|
1176
|
+
|
|
1177
|
+
@register_evaluator("LayerNormalization")
|
|
1178
|
+
def _eval_layer_normalization(evaluator: Evaluator, node: Node) -> None:
|
|
1179
|
+
op = lower_layer_normalization(evaluator.graph, node)
|
|
1180
|
+
data = evaluator.values[op.input0]
|
|
1181
|
+
axes = tuple(range(op.axis, data.ndim))
|
|
1182
|
+
mean = np.mean(data, axis=axes, keepdims=True)
|
|
1183
|
+
var = np.mean((data - mean) ** 2, axis=axes, keepdims=True)
|
|
1184
|
+
inv_std = 1.0 / np.sqrt(var + op.epsilon)
|
|
1185
|
+
normalized = (data - mean) * inv_std
|
|
1186
|
+
scale = evaluator.values[op.scale].reshape(
|
|
1187
|
+
(1,) * op.axis + evaluator.values[op.scale].shape
|
|
1188
|
+
)
|
|
1189
|
+
normalized = normalized * scale
|
|
1190
|
+
if op.bias is not None:
|
|
1191
|
+
bias = evaluator.values[op.bias].reshape(
|
|
1192
|
+
(1,) * op.axis + evaluator.values[op.bias].shape
|
|
1193
|
+
)
|
|
1194
|
+
normalized = normalized + bias
|
|
1195
|
+
evaluator.values[op.output] = normalized
|
|
1196
|
+
if op.mean_output is not None:
|
|
1197
|
+
evaluator.values[op.mean_output] = mean
|
|
1198
|
+
if op.invstd_output is not None:
|
|
1199
|
+
evaluator.values[op.invstd_output] = inv_std
|
|
1200
|
+
|
|
1201
|
+
|
|
1202
|
+
@register_evaluator("MeanVarianceNormalization")
|
|
1203
|
+
def _eval_mean_variance_normalization(
|
|
1204
|
+
evaluator: Evaluator, node: Node
|
|
1205
|
+
) -> None:
|
|
1206
|
+
op = lower_mean_variance_normalization(evaluator.graph, node)
|
|
1207
|
+
data = evaluator.values[op.input0]
|
|
1208
|
+
mean = np.mean(data, axis=op.axes, keepdims=True)
|
|
1209
|
+
variance = np.mean((data - mean) ** 2, axis=op.axes, keepdims=True)
|
|
1210
|
+
evaluator.values[op.output] = (data - mean) / np.sqrt(
|
|
1211
|
+
variance + op.epsilon
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1214
|
+
|
|
1215
|
+
@register_evaluator("RMSNormalization")
|
|
1216
|
+
def _eval_rms_normalization(evaluator: Evaluator, node: Node) -> None:
|
|
1217
|
+
op = lower_rms_normalization(evaluator.graph, node)
|
|
1218
|
+
data = evaluator.values[op.input0]
|
|
1219
|
+
axes = tuple(range(op.axis, data.ndim))
|
|
1220
|
+
mean_square = np.mean(data * data, axis=axes, keepdims=True)
|
|
1221
|
+
rms = np.sqrt(mean_square + op.epsilon)
|
|
1222
|
+
normalized = data / rms
|
|
1223
|
+
scale = evaluator.values[op.scale].reshape(
|
|
1224
|
+
(1,) * op.axis + evaluator.values[op.scale].shape
|
|
1225
|
+
)
|
|
1226
|
+
evaluator.values[op.output] = normalized * scale
|
|
1227
|
+
|
|
1228
|
+
|
|
1229
|
+
@register_evaluator("LRN")
|
|
1230
|
+
def _eval_lrn(evaluator: Evaluator, node: Node) -> None:
|
|
1231
|
+
op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
1232
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
1233
|
+
if op_dtype != output_dtype:
|
|
1234
|
+
raise UnsupportedOpError(
|
|
1235
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
1236
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
1237
|
+
)
|
|
1238
|
+
if not op_dtype.is_float:
|
|
1239
|
+
raise UnsupportedOpError(
|
|
1240
|
+
"LRN supports float16, float, and double inputs only"
|
|
1241
|
+
)
|
|
1242
|
+
spec = resolve_lrn_spec(evaluator.graph, node)
|
|
1243
|
+
data = evaluator.values[node.inputs[0]]
|
|
1244
|
+
evaluator.values[node.outputs[0]] = _apply_lrn(spec, data)
|
|
1245
|
+
|
|
1246
|
+
|
|
1247
|
+
@register_evaluator("AveragePool")
|
|
1248
|
+
def _eval_average_pool(evaluator: Evaluator, node: Node) -> None:
|
|
1249
|
+
op = lower_average_pool(evaluator.graph, node)
|
|
1250
|
+
data = evaluator.values[node.inputs[0]]
|
|
1251
|
+
evaluator.values[node.outputs[0]] = _apply_average_pool(op, data)
|
|
1252
|
+
|
|
1253
|
+
|
|
1254
|
+
@register_evaluator("GlobalAveragePool")
|
|
1255
|
+
def _eval_global_average_pool(evaluator: Evaluator, node: Node) -> None:
|
|
1256
|
+
op = lower_global_average_pool(evaluator.graph, node)
|
|
1257
|
+
data = evaluator.values[node.inputs[0]]
|
|
1258
|
+
evaluator.values[node.outputs[0]] = _apply_average_pool(op, data)
|
|
1259
|
+
|
|
1260
|
+
|
|
1261
|
+
@register_evaluator("MaxPool")
|
|
1262
|
+
def _eval_maxpool(evaluator: Evaluator, node: Node) -> None:
|
|
1263
|
+
op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
1264
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
1265
|
+
if op_dtype != output_dtype:
|
|
1266
|
+
raise UnsupportedOpError(
|
|
1267
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
1268
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
1269
|
+
)
|
|
1270
|
+
indices_output = node.outputs[1] if len(node.outputs) > 1 else None
|
|
1271
|
+
if indices_output is not None:
|
|
1272
|
+
indices_dtype = value_dtype(evaluator.graph, indices_output, node)
|
|
1273
|
+
if indices_dtype != ScalarType.I64:
|
|
1274
|
+
raise UnsupportedOpError("MaxPool indices output must be int64")
|
|
1275
|
+
if op_dtype == ScalarType.BOOL:
|
|
1276
|
+
raise UnsupportedOpError("MaxPool supports numeric inputs only")
|
|
1277
|
+
spec = resolve_maxpool_spec(evaluator.graph, node)
|
|
1278
|
+
data = evaluator.values[node.inputs[0]]
|
|
1279
|
+
if indices_output is None:
|
|
1280
|
+
evaluator.values[node.outputs[0]] = _apply_maxpool(spec, data)
|
|
1281
|
+
else:
|
|
1282
|
+
values, indices = _apply_maxpool(spec, data, return_indices=True)
|
|
1283
|
+
evaluator.values[node.outputs[0]] = values
|
|
1284
|
+
evaluator.values[indices_output] = indices
|
|
1285
|
+
|
|
1286
|
+
|
|
1287
|
+
@register_evaluator("Softmax")
|
|
1288
|
+
def _eval_softmax(evaluator: Evaluator, node: Node) -> None:
|
|
1289
|
+
op = lower_softmax(evaluator.graph, node)
|
|
1290
|
+
value = evaluator.values[node.inputs[0]]
|
|
1291
|
+
evaluator.values[node.outputs[0]] = _apply_softmax(value, op.axis)
|
|
1292
|
+
|
|
1293
|
+
|
|
1294
|
+
@register_evaluator("LogSoftmax")
|
|
1295
|
+
def _eval_logsoftmax(evaluator: Evaluator, node: Node) -> None:
|
|
1296
|
+
op = lower_logsoftmax(evaluator.graph, node)
|
|
1297
|
+
value = evaluator.values[node.inputs[0]]
|
|
1298
|
+
evaluator.values[node.outputs[0]] = _apply_logsoftmax(value, op.axis)
|
|
1299
|
+
|
|
1300
|
+
|
|
1301
|
+
@register_evaluator("NegativeLogLikelihoodLoss")
|
|
1302
|
+
def _eval_negative_log_likelihood_loss(
|
|
1303
|
+
evaluator: Evaluator, node: Node
|
|
1304
|
+
) -> None:
|
|
1305
|
+
op = lower_negative_log_likelihood_loss(evaluator.graph, node)
|
|
1306
|
+
input_value = evaluator.values[op.input0]
|
|
1307
|
+
target_value = evaluator.values[op.target]
|
|
1308
|
+
weight_value = evaluator.values[op.weight] if op.weight is not None else None
|
|
1309
|
+
evaluator.values[op.output] = _apply_negative_log_likelihood_loss(
|
|
1310
|
+
input_value,
|
|
1311
|
+
target_value,
|
|
1312
|
+
weight_value,
|
|
1313
|
+
reduction=op.reduction,
|
|
1314
|
+
ignore_index=op.ignore_index,
|
|
1315
|
+
)
|
|
1316
|
+
|
|
1317
|
+
|
|
1318
|
+
@register_evaluator("SoftmaxCrossEntropyLoss")
|
|
1319
|
+
def _eval_softmax_cross_entropy_loss(
|
|
1320
|
+
evaluator: Evaluator, node: Node
|
|
1321
|
+
) -> None:
|
|
1322
|
+
op = lower_softmax_cross_entropy_loss(evaluator.graph, node)
|
|
1323
|
+
input_value = evaluator.values[op.input0]
|
|
1324
|
+
target_value = evaluator.values[op.target]
|
|
1325
|
+
weight_value = evaluator.values[op.weight] if op.weight is not None else None
|
|
1326
|
+
loss, log_prob = _apply_softmax_cross_entropy_loss(
|
|
1327
|
+
input_value,
|
|
1328
|
+
target_value,
|
|
1329
|
+
weight_value,
|
|
1330
|
+
reduction=op.reduction,
|
|
1331
|
+
ignore_index=op.ignore_index,
|
|
1332
|
+
return_log_prob=op.log_prob is not None,
|
|
1333
|
+
)
|
|
1334
|
+
evaluator.values[op.output] = loss
|
|
1335
|
+
if op.log_prob is not None and log_prob is not None:
|
|
1336
|
+
evaluator.values[op.log_prob] = log_prob
|
|
1337
|
+
|
|
1338
|
+
|
|
1339
|
+
@register_evaluator("Dropout")
|
|
1340
|
+
def _eval_dropout(evaluator: Evaluator, node: Node) -> None:
|
|
1341
|
+
op = lower_dropout(evaluator.graph, node)
|
|
1342
|
+
evaluator.values[op.output] = evaluator.values[op.input0].copy()
|
|
1343
|
+
|
|
1344
|
+
|
|
1345
|
+
@register_evaluator("Concat")
|
|
1346
|
+
def _eval_concat(evaluator: Evaluator, node: Node) -> None:
|
|
1347
|
+
op = lower_concat(evaluator.graph, node)
|
|
1348
|
+
tensors = [evaluator.values[name] for name in node.inputs]
|
|
1349
|
+
evaluator.values[op.output] = np.concatenate(tensors, axis=op.axis)
|
|
1350
|
+
|
|
1351
|
+
|
|
1352
|
+
@register_evaluator("Transpose")
|
|
1353
|
+
def _eval_transpose(evaluator: Evaluator, node: Node) -> None:
|
|
1354
|
+
op = lower_transpose(evaluator.graph, node)
|
|
1355
|
+
evaluator.values[op.output] = np.transpose(
|
|
1356
|
+
evaluator.values[op.input0], axes=tuple(op.perm)
|
|
1357
|
+
)
|
|
1358
|
+
|
|
1359
|
+
|
|
1360
|
+
@register_evaluator("Unsqueeze")
|
|
1361
|
+
def _eval_unsqueeze(evaluator: Evaluator, node: Node) -> None:
|
|
1362
|
+
op = lower_unsqueeze(evaluator.graph, node)
|
|
1363
|
+
evaluator.values[op.output] = evaluator.values[op.input0].reshape(
|
|
1364
|
+
op.output_shape
|
|
1365
|
+
)
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
@register_evaluator("Squeeze")
|
|
1369
|
+
def _eval_squeeze(evaluator: Evaluator, node: Node) -> None:
|
|
1370
|
+
op = lower_squeeze(evaluator.graph, node)
|
|
1371
|
+
evaluator.values[op.output] = evaluator.values[op.input0].reshape(
|
|
1372
|
+
op.output_shape
|
|
1373
|
+
)
|
|
1374
|
+
|
|
1375
|
+
|
|
1376
|
+
@register_evaluator("Reshape")
|
|
1377
|
+
def _eval_reshape(evaluator: Evaluator, node: Node) -> None:
|
|
1378
|
+
op = lower_reshape(evaluator.graph, node)
|
|
1379
|
+
evaluator.values[op.output] = evaluator.values[op.input0].reshape(
|
|
1380
|
+
op.output_shape
|
|
1381
|
+
)
|
|
1382
|
+
|
|
1383
|
+
|
|
1384
|
+
@register_evaluator("Flatten")
|
|
1385
|
+
def _eval_flatten(evaluator: Evaluator, node: Node) -> None:
|
|
1386
|
+
op = lower_flatten(evaluator.graph, node)
|
|
1387
|
+
evaluator.values[op.output] = evaluator.values[op.input0].reshape(
|
|
1388
|
+
op.output_shape
|
|
1389
|
+
)
|
|
1390
|
+
|
|
1391
|
+
|
|
1392
|
+
@register_evaluator("ConstantOfShape")
|
|
1393
|
+
def _eval_constant_of_shape(evaluator: Evaluator, node: Node) -> None:
|
|
1394
|
+
op = lower_constant_of_shape(evaluator.graph, node)
|
|
1395
|
+
evaluator.values[op.output] = np.full(
|
|
1396
|
+
op.shape, op.value, dtype=op.dtype.np_dtype
|
|
1397
|
+
)
|
|
1398
|
+
|
|
1399
|
+
|
|
1400
|
+
@register_evaluator("Shape")
|
|
1401
|
+
def _eval_shape(evaluator: Evaluator, node: Node) -> None:
|
|
1402
|
+
op = lower_shape(evaluator.graph, node)
|
|
1403
|
+
evaluator.values[op.output] = np.array(op.values, dtype=np.int64)
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
@register_evaluator("Size")
|
|
1407
|
+
def _eval_size(evaluator: Evaluator, node: Node) -> None:
|
|
1408
|
+
op = lower_size(evaluator.graph, node)
|
|
1409
|
+
evaluator.values[op.output] = np.array(op.value, dtype=np.int64)
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
@register_evaluator("Expand")
|
|
1413
|
+
def _eval_expand(evaluator: Evaluator, node: Node) -> None:
|
|
1414
|
+
op = lower_expand(evaluator.graph, node)
|
|
1415
|
+
value = evaluator.values[op.input0]
|
|
1416
|
+
evaluator.values[op.output] = np.broadcast_to(
|
|
1417
|
+
value, op.output_shape
|
|
1418
|
+
).copy()
|
|
1419
|
+
|
|
1420
|
+
|
|
1421
|
+
@register_evaluator("Range")
|
|
1422
|
+
def _eval_range(evaluator: Evaluator, node: Node) -> None:
|
|
1423
|
+
op = lower_range(evaluator.graph, node)
|
|
1424
|
+
start_value = evaluator.values[op.start].reshape(-1)[0]
|
|
1425
|
+
delta_value = evaluator.values[op.delta].reshape(-1)[0]
|
|
1426
|
+
indices = np.arange(op.length, dtype=op.dtype.np_dtype)
|
|
1427
|
+
output = start_value + indices * delta_value
|
|
1428
|
+
evaluator.values[op.output] = output
|
|
1429
|
+
|
|
1430
|
+
|
|
1431
|
+
@register_evaluator("Split")
|
|
1432
|
+
def _eval_split(evaluator: Evaluator, node: Node) -> None:
|
|
1433
|
+
op = lower_split(evaluator.graph, node)
|
|
1434
|
+
data = evaluator.values[op.input0]
|
|
1435
|
+
split_points = np.cumsum(op.split_sizes)[:-1]
|
|
1436
|
+
outputs = np.split(data, split_points, axis=op.axis)
|
|
1437
|
+
for output_name, output_value in zip(op.outputs, outputs):
|
|
1438
|
+
evaluator.values[output_name] = output_value
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
@register_evaluator("ReduceMean")
|
|
1442
|
+
@register_evaluator("ReduceSum")
|
|
1443
|
+
@register_evaluator("ReduceMax")
|
|
1444
|
+
@register_evaluator("ReduceMin")
|
|
1445
|
+
@register_evaluator("ReduceProd")
|
|
1446
|
+
@register_evaluator("ReduceL1")
|
|
1447
|
+
@register_evaluator("ReduceL2")
|
|
1448
|
+
@register_evaluator("ReduceLogSum")
|
|
1449
|
+
@register_evaluator("ReduceLogSumExp")
|
|
1450
|
+
@register_evaluator("ReduceSumSquare")
|
|
1451
|
+
def _eval_reduce(evaluator: Evaluator, node: Node) -> None:
|
|
1452
|
+
if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
|
|
1453
|
+
raise UnsupportedOpError(
|
|
1454
|
+
f"{node.op_type} must have 1 or 2 inputs and 1 output"
|
|
1455
|
+
)
|
|
1456
|
+
op_dtype = value_dtype(evaluator.graph, node.inputs[0], node)
|
|
1457
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
1458
|
+
if op_dtype != output_dtype:
|
|
1459
|
+
raise UnsupportedOpError(
|
|
1460
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
1461
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
1462
|
+
)
|
|
1463
|
+
if (
|
|
1464
|
+
node.op_type in REDUCE_OUTPUTS_FLOAT_ONLY
|
|
1465
|
+
and not op_dtype.is_float
|
|
1466
|
+
):
|
|
1467
|
+
raise UnsupportedOpError(
|
|
1468
|
+
f"{node.op_type} supports float16, float, and double inputs only"
|
|
1469
|
+
)
|
|
1470
|
+
value = evaluator.values[node.inputs[0]]
|
|
1471
|
+
input_shape = value.shape
|
|
1472
|
+
if len(node.inputs) > 1 and node.inputs[1]:
|
|
1473
|
+
axes_value = evaluator.values[node.inputs[1]]
|
|
1474
|
+
if axes_value.dtype.type not in {np.int32, np.int64}:
|
|
1475
|
+
raise UnsupportedOpError(
|
|
1476
|
+
f"{node.op_type} axes input must be int64 or int32"
|
|
1477
|
+
)
|
|
1478
|
+
axes = tuple(int(axis) for axis in axes_value.ravel())
|
|
1479
|
+
noop_with_empty_axes = bool(int(node.attrs.get("noop_with_empty_axes", 0)))
|
|
1480
|
+
if not axes:
|
|
1481
|
+
if noop_with_empty_axes:
|
|
1482
|
+
evaluator.values[node.outputs[0]] = value.copy()
|
|
1483
|
+
return
|
|
1484
|
+
axes = tuple(range(len(input_shape)))
|
|
1485
|
+
axes = normalize_reduce_axes(axes, input_shape, node)
|
|
1486
|
+
else:
|
|
1487
|
+
axes_spec, noop = resolve_reduce_axes(evaluator.graph, node, input_shape)
|
|
1488
|
+
if noop:
|
|
1489
|
+
evaluator.values[node.outputs[0]] = value.copy()
|
|
1490
|
+
return
|
|
1491
|
+
if axes_spec is None or axes_spec.axes is None:
|
|
1492
|
+
raise UnsupportedOpError(
|
|
1493
|
+
f"{node.op_type} axes input must be constant for evaluator"
|
|
1494
|
+
)
|
|
1495
|
+
axes = axes_spec.axes
|
|
1496
|
+
keepdims = bool(int(node.attrs.get("keepdims", 1)))
|
|
1497
|
+
reduce_kind = REDUCE_KIND_BY_OP[node.op_type]
|
|
1498
|
+
if reduce_kind == "sum":
|
|
1499
|
+
result = np.sum(value, axis=axes, keepdims=keepdims)
|
|
1500
|
+
elif reduce_kind == "mean":
|
|
1501
|
+
result = np.mean(value, axis=axes, keepdims=keepdims)
|
|
1502
|
+
elif reduce_kind == "max":
|
|
1503
|
+
result = np.max(value, axis=axes, keepdims=keepdims)
|
|
1504
|
+
elif reduce_kind == "min":
|
|
1505
|
+
result = np.min(value, axis=axes, keepdims=keepdims)
|
|
1506
|
+
elif reduce_kind == "prod":
|
|
1507
|
+
result = np.prod(value, axis=axes, keepdims=keepdims)
|
|
1508
|
+
elif reduce_kind == "l1":
|
|
1509
|
+
result = np.sum(np.abs(value), axis=axes, keepdims=keepdims)
|
|
1510
|
+
elif reduce_kind == "l2":
|
|
1511
|
+
result = np.sqrt(np.sum(value * value, axis=axes, keepdims=keepdims))
|
|
1512
|
+
elif reduce_kind == "logsum":
|
|
1513
|
+
result = np.log(np.sum(value, axis=axes, keepdims=keepdims))
|
|
1514
|
+
elif reduce_kind == "logsumexp":
|
|
1515
|
+
result = np.log(np.sum(np.exp(value), axis=axes, keepdims=keepdims))
|
|
1516
|
+
elif reduce_kind == "sumsquare":
|
|
1517
|
+
result = np.sum(value * value, axis=axes, keepdims=keepdims)
|
|
1518
|
+
else:
|
|
1519
|
+
raise UnsupportedOpError(f"Unsupported reduce kind {reduce_kind}")
|
|
1520
|
+
evaluator.values[node.outputs[0]] = result
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
@register_evaluator("ArgMax")
|
|
1524
|
+
@register_evaluator("ArgMin")
|
|
1525
|
+
def _eval_arg_reduce(evaluator: Evaluator, node: Node) -> None:
|
|
1526
|
+
op = lower_arg_reduce(evaluator.graph, node)
|
|
1527
|
+
value = evaluator.values[op.input0]
|
|
1528
|
+
if op.select_last_index:
|
|
1529
|
+
flipped = np.flip(value, axis=op.axis)
|
|
1530
|
+
if op.reduce_kind == "max":
|
|
1531
|
+
indices = np.argmax(flipped, axis=op.axis)
|
|
1532
|
+
elif op.reduce_kind == "min":
|
|
1533
|
+
indices = np.argmin(flipped, axis=op.axis)
|
|
1534
|
+
else:
|
|
1535
|
+
raise UnsupportedOpError(
|
|
1536
|
+
f"Unsupported arg reduce kind {op.reduce_kind}"
|
|
1537
|
+
)
|
|
1538
|
+
indices = value.shape[op.axis] - 1 - indices
|
|
1539
|
+
else:
|
|
1540
|
+
if op.reduce_kind == "max":
|
|
1541
|
+
indices = np.argmax(value, axis=op.axis)
|
|
1542
|
+
elif op.reduce_kind == "min":
|
|
1543
|
+
indices = np.argmin(value, axis=op.axis)
|
|
1544
|
+
else:
|
|
1545
|
+
raise UnsupportedOpError(
|
|
1546
|
+
f"Unsupported arg reduce kind {op.reduce_kind}"
|
|
1547
|
+
)
|
|
1548
|
+
if op.keepdims:
|
|
1549
|
+
indices = np.expand_dims(indices, axis=op.axis)
|
|
1550
|
+
evaluator.values[op.output] = indices.astype(op.output_dtype.np_dtype)
|
|
1551
|
+
|
|
1552
|
+
|
|
1553
|
+
def _eval_binary_unary(evaluator: Evaluator, node: Node) -> None:
|
|
1554
|
+
if node.op_type == "BitShift":
|
|
1555
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
1556
|
+
raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
|
|
1557
|
+
direction_attr = node.attrs.get("direction", "LEFT")
|
|
1558
|
+
if isinstance(direction_attr, bytes):
|
|
1559
|
+
direction = direction_attr.decode()
|
|
1560
|
+
else:
|
|
1561
|
+
direction = str(direction_attr)
|
|
1562
|
+
if direction not in {"LEFT", "RIGHT"}:
|
|
1563
|
+
raise UnsupportedOpError(
|
|
1564
|
+
"BitShift direction must be LEFT or RIGHT"
|
|
1565
|
+
)
|
|
1566
|
+
op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
|
|
1567
|
+
if not op_dtype.is_integer:
|
|
1568
|
+
raise UnsupportedOpError("BitShift expects integer inputs")
|
|
1569
|
+
function = (
|
|
1570
|
+
ScalarFunction.BITWISE_LEFT_SHIFT
|
|
1571
|
+
if direction == "LEFT"
|
|
1572
|
+
else ScalarFunction.BITWISE_RIGHT_SHIFT
|
|
1573
|
+
)
|
|
1574
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
1575
|
+
if op_spec is None:
|
|
1576
|
+
raise UnsupportedOpError("Unsupported op BitShift")
|
|
1577
|
+
left = evaluator.values[node.inputs[0]]
|
|
1578
|
+
right = evaluator.values[node.inputs[1]]
|
|
1579
|
+
evaluator.values[node.outputs[0]] = apply_binary_op(
|
|
1580
|
+
op_spec, left, right
|
|
1581
|
+
)
|
|
1582
|
+
return
|
|
1583
|
+
if node.op_type == "Mod":
|
|
1584
|
+
fmod = int(node.attrs.get("fmod", 0))
|
|
1585
|
+
if fmod not in {0, 1}:
|
|
1586
|
+
raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
|
|
1587
|
+
function = (
|
|
1588
|
+
ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
|
|
1589
|
+
)
|
|
1590
|
+
else:
|
|
1591
|
+
try:
|
|
1592
|
+
function = ScalarFunction.from_onnx_op(node.op_type)
|
|
1593
|
+
except ScalarFunctionError as exc:
|
|
1594
|
+
raise UnsupportedOpError(
|
|
1595
|
+
f"Unsupported op {node.op_type}"
|
|
1596
|
+
) from exc
|
|
1597
|
+
validate_unary_attrs(node.op_type, node.attrs)
|
|
1598
|
+
if function in COMPARE_FUNCTIONS:
|
|
1599
|
+
input_dtype = node_dtype(evaluator.graph, node, *node.inputs)
|
|
1600
|
+
output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
|
|
1601
|
+
if output_dtype != ScalarType.BOOL:
|
|
1602
|
+
raise UnsupportedOpError(
|
|
1603
|
+
f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
|
|
1604
|
+
)
|
|
1605
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
|
|
1606
|
+
if op_spec is None:
|
|
1607
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
1608
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
1609
|
+
raise UnsupportedOpError(
|
|
1610
|
+
f"{node.op_type} must have 2 inputs and 1 output"
|
|
1611
|
+
)
|
|
1612
|
+
left = evaluator.values[node.inputs[0]]
|
|
1613
|
+
right = evaluator.values[node.inputs[1]]
|
|
1614
|
+
evaluator.values[node.outputs[0]] = apply_binary_op(
|
|
1615
|
+
op_spec, left, right
|
|
1616
|
+
)
|
|
1617
|
+
return
|
|
1618
|
+
op_dtype = node_dtype(evaluator.graph, node, *node.inputs, *node.outputs)
|
|
1619
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
1620
|
+
unary_symbol = unary_op_symbol(function, dtype=op_dtype)
|
|
1621
|
+
if op_spec is None and unary_symbol is None:
|
|
1622
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
1623
|
+
if op_spec is not None:
|
|
1624
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
1625
|
+
raise UnsupportedOpError(
|
|
1626
|
+
f"{node.op_type} must have 2 inputs and 1 output"
|
|
1627
|
+
)
|
|
1628
|
+
left = evaluator.values[node.inputs[0]]
|
|
1629
|
+
right = evaluator.values[node.inputs[1]]
|
|
1630
|
+
evaluator.values[node.outputs[0]] = apply_binary_op(
|
|
1631
|
+
op_spec, left, right
|
|
1632
|
+
)
|
|
1633
|
+
return
|
|
1634
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
1635
|
+
raise UnsupportedOpError(
|
|
1636
|
+
f"{node.op_type} must have 1 input and 1 output"
|
|
1637
|
+
)
|
|
1638
|
+
value = evaluator.values[node.inputs[0]]
|
|
1639
|
+
evaluator.values[node.outputs[0]] = apply_unary_op(
|
|
1640
|
+
function, value, dtype=op_dtype
|
|
1641
|
+
)
|
|
1642
|
+
|
|
1643
|
+
|
|
1644
|
+
def _apply_matmul(left: np.ndarray, right: np.ndarray) -> np.ndarray:
|
|
1645
|
+
if left.ndim < 1 or right.ndim < 1:
|
|
1646
|
+
raise UnsupportedOpError(
|
|
1647
|
+
"MatMul inputs must be at least 1D, "
|
|
1648
|
+
f"got {left.shape} x {right.shape}"
|
|
1649
|
+
)
|
|
1650
|
+
left_dim = left.shape[-1]
|
|
1651
|
+
right_dim = right.shape[0] if right.ndim == 1 else right.shape[-2]
|
|
1652
|
+
if left_dim != right_dim:
|
|
1653
|
+
raise ShapeInferenceError(
|
|
1654
|
+
"MatMul inner dimensions must match, "
|
|
1655
|
+
f"got {left_dim} and {right_dim}"
|
|
1656
|
+
)
|
|
1657
|
+
left_batch = left.shape[:-2] if left.ndim > 1 else ()
|
|
1658
|
+
right_batch = right.shape[:-2] if right.ndim > 1 else ()
|
|
1659
|
+
if not _matmul_batch_broadcastable(left_batch, right_batch):
|
|
1660
|
+
raise ShapeInferenceError(
|
|
1661
|
+
"MatMul batch dimensions must be broadcastable, "
|
|
1662
|
+
f"got {left_batch} x {right_batch}"
|
|
1663
|
+
)
|
|
1664
|
+
return np.matmul(left, right)
|
|
1665
|
+
|
|
1666
|
+
|
|
1667
|
+
def _matmul_batch_broadcastable(
|
|
1668
|
+
left: tuple[int, ...], right: tuple[int, ...]
|
|
1669
|
+
) -> bool:
|
|
1670
|
+
max_rank = max(len(left), len(right))
|
|
1671
|
+
left_padded = (1,) * (max_rank - len(left)) + left
|
|
1672
|
+
right_padded = (1,) * (max_rank - len(right)) + right
|
|
1673
|
+
for left_dim, right_dim in zip(left_padded, right_padded):
|
|
1674
|
+
if left_dim == right_dim or left_dim == 1 or right_dim == 1:
|
|
1675
|
+
continue
|
|
1676
|
+
return False
|
|
1677
|
+
return True
|
|
1678
|
+
|
|
1679
|
+
|
|
1680
|
+
def _apply_softmax(values: np.ndarray, axis: int) -> np.ndarray:
|
|
1681
|
+
max_values = np.max(values, axis=axis, keepdims=True)
|
|
1682
|
+
exp_values = np.exp(values - max_values)
|
|
1683
|
+
sum_values = np.sum(exp_values, axis=axis, keepdims=True)
|
|
1684
|
+
return exp_values / sum_values
|
|
1685
|
+
|
|
1686
|
+
|
|
1687
|
+
def _apply_logsoftmax(values: np.ndarray, axis: int) -> np.ndarray:
|
|
1688
|
+
max_values = np.max(values, axis=axis, keepdims=True)
|
|
1689
|
+
shifted = values - max_values
|
|
1690
|
+
logsum = np.log(np.sum(np.exp(shifted), axis=axis, keepdims=True))
|
|
1691
|
+
return shifted - logsum
|
|
1692
|
+
|
|
1693
|
+
|
|
1694
|
+
def _apply_negative_log_likelihood_loss(
|
|
1695
|
+
values: np.ndarray,
|
|
1696
|
+
target: np.ndarray,
|
|
1697
|
+
weight: np.ndarray | None,
|
|
1698
|
+
*,
|
|
1699
|
+
reduction: str,
|
|
1700
|
+
ignore_index: int,
|
|
1701
|
+
) -> np.ndarray:
|
|
1702
|
+
input_shape = values.shape
|
|
1703
|
+
if len(input_shape) < 2:
|
|
1704
|
+
raise UnsupportedOpError(
|
|
1705
|
+
"NegativeLogLikelihoodLoss input must be at least 2D"
|
|
1706
|
+
)
|
|
1707
|
+
target_shape = target.shape
|
|
1708
|
+
if input_shape[0] != target_shape[0]:
|
|
1709
|
+
raise ShapeInferenceError(
|
|
1710
|
+
"NegativeLogLikelihoodLoss target batch dimension must match input"
|
|
1711
|
+
)
|
|
1712
|
+
if input_shape[2:] != target_shape[1:]:
|
|
1713
|
+
raise ShapeInferenceError(
|
|
1714
|
+
"NegativeLogLikelihoodLoss target spatial dimensions must match input"
|
|
1715
|
+
)
|
|
1716
|
+
n = input_shape[0]
|
|
1717
|
+
c = input_shape[1]
|
|
1718
|
+
if weight is not None:
|
|
1719
|
+
gather_weight = np.take(weight, target.astype(np.int32), mode="clip")
|
|
1720
|
+
if ignore_index is not None:
|
|
1721
|
+
gather_weight = np.where(target == ignore_index, 0, gather_weight).astype(
|
|
1722
|
+
dtype=values.dtype
|
|
1723
|
+
)
|
|
1724
|
+
elif ignore_index != -1:
|
|
1725
|
+
gather_weight = np.where(target == ignore_index, 0, 1).astype(
|
|
1726
|
+
dtype=values.dtype
|
|
1727
|
+
)
|
|
1728
|
+
else:
|
|
1729
|
+
gather_weight = None
|
|
1730
|
+
if len(input_shape) != 3:
|
|
1731
|
+
values = values.reshape((n, c, -1))
|
|
1732
|
+
target = target.reshape((n, -1))
|
|
1733
|
+
d = values.shape[2]
|
|
1734
|
+
loss = np.zeros((n, d), dtype=values.dtype)
|
|
1735
|
+
for i in range(n):
|
|
1736
|
+
for d_index in range(d):
|
|
1737
|
+
if target[i][d_index] != ignore_index:
|
|
1738
|
+
loss[i][d_index] = -values[i][target[i][d_index]][d_index]
|
|
1739
|
+
if len(input_shape) != 3:
|
|
1740
|
+
loss = loss.reshape(target_shape)
|
|
1741
|
+
if gather_weight is not None:
|
|
1742
|
+
loss = gather_weight * loss
|
|
1743
|
+
if reduction == "mean":
|
|
1744
|
+
weight_sum = gather_weight.sum()
|
|
1745
|
+
if weight_sum == 0:
|
|
1746
|
+
return np.array(0, dtype=values.dtype)
|
|
1747
|
+
loss = loss.sum() / weight_sum
|
|
1748
|
+
return loss.astype(values.dtype)
|
|
1749
|
+
if reduction == "mean":
|
|
1750
|
+
loss = np.mean(loss)
|
|
1751
|
+
elif reduction == "sum":
|
|
1752
|
+
loss = np.sum(loss)
|
|
1753
|
+
return loss.astype(values.dtype)
|
|
1754
|
+
|
|
1755
|
+
|
|
1756
|
+
def _apply_softmax_cross_entropy_loss(
|
|
1757
|
+
values: np.ndarray,
|
|
1758
|
+
target: np.ndarray,
|
|
1759
|
+
weight: np.ndarray | None,
|
|
1760
|
+
*,
|
|
1761
|
+
reduction: str,
|
|
1762
|
+
ignore_index: int | None,
|
|
1763
|
+
return_log_prob: bool,
|
|
1764
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
1765
|
+
input_shape = values.shape
|
|
1766
|
+
if len(input_shape) < 2:
|
|
1767
|
+
raise UnsupportedOpError(
|
|
1768
|
+
"SoftmaxCrossEntropyLoss input must be at least 2D"
|
|
1769
|
+
)
|
|
1770
|
+
target_shape = target.shape
|
|
1771
|
+
if input_shape[0] != target_shape[0]:
|
|
1772
|
+
raise ShapeInferenceError(
|
|
1773
|
+
"SoftmaxCrossEntropyLoss target batch dimension must match input"
|
|
1774
|
+
)
|
|
1775
|
+
if input_shape[2:] != target_shape[1:]:
|
|
1776
|
+
raise ShapeInferenceError(
|
|
1777
|
+
"SoftmaxCrossEntropyLoss target spatial dimensions must match input"
|
|
1778
|
+
)
|
|
1779
|
+
log_prob = _apply_logsoftmax(values, axis=1)
|
|
1780
|
+
log_prob_output = log_prob if return_log_prob else None
|
|
1781
|
+
if weight is not None:
|
|
1782
|
+
gather_weight = np.take(weight, target.astype(np.int32), mode="clip")
|
|
1783
|
+
if ignore_index is not None:
|
|
1784
|
+
gather_weight = np.where(target == ignore_index, 0, gather_weight).astype(
|
|
1785
|
+
dtype=values.dtype
|
|
1786
|
+
)
|
|
1787
|
+
elif ignore_index is not None:
|
|
1788
|
+
gather_weight = np.where(target == ignore_index, 0, 1).astype(
|
|
1789
|
+
dtype=values.dtype
|
|
1790
|
+
)
|
|
1791
|
+
else:
|
|
1792
|
+
gather_weight = None
|
|
1793
|
+
n = input_shape[0]
|
|
1794
|
+
c = input_shape[1]
|
|
1795
|
+
if len(input_shape) != 3:
|
|
1796
|
+
log_prob = log_prob.reshape((n, c, -1))
|
|
1797
|
+
target = target.reshape((n, -1))
|
|
1798
|
+
d = log_prob.shape[2]
|
|
1799
|
+
loss = np.zeros((n, d), dtype=values.dtype)
|
|
1800
|
+
for i in range(n):
|
|
1801
|
+
for d_index in range(d):
|
|
1802
|
+
if ignore_index is None or target[i][d_index] != ignore_index:
|
|
1803
|
+
loss[i][d_index] = -log_prob[i][target[i][d_index]][d_index]
|
|
1804
|
+
if len(input_shape) != 3:
|
|
1805
|
+
loss = loss.reshape(target_shape)
|
|
1806
|
+
if gather_weight is not None:
|
|
1807
|
+
loss = gather_weight * loss
|
|
1808
|
+
if reduction == "mean":
|
|
1809
|
+
loss = loss.sum() / gather_weight.sum()
|
|
1810
|
+
loss = loss.astype(values.dtype)
|
|
1811
|
+
if return_log_prob:
|
|
1812
|
+
return loss, log_prob.astype(values.dtype)
|
|
1813
|
+
return loss, None
|
|
1814
|
+
if reduction == "mean":
|
|
1815
|
+
loss = np.mean(loss)
|
|
1816
|
+
elif reduction == "sum":
|
|
1817
|
+
loss = np.sum(loss)
|
|
1818
|
+
loss = loss.astype(values.dtype)
|
|
1819
|
+
if return_log_prob and log_prob_output is not None:
|
|
1820
|
+
return loss, log_prob_output.astype(values.dtype)
|
|
1821
|
+
return loss, None
|
|
1822
|
+
|
|
1823
|
+
|
|
1824
|
+
def _apply_attention(
|
|
1825
|
+
spec,
|
|
1826
|
+
query: np.ndarray,
|
|
1827
|
+
key: np.ndarray,
|
|
1828
|
+
value: np.ndarray,
|
|
1829
|
+
attn_mask: np.ndarray | None,
|
|
1830
|
+
past_key: np.ndarray | None,
|
|
1831
|
+
past_value: np.ndarray | None,
|
|
1832
|
+
nonpad_kv_seqlen: np.ndarray | None,
|
|
1833
|
+
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None]:
|
|
1834
|
+
if spec.q_rank == 3:
|
|
1835
|
+
query_4d = query.reshape(
|
|
1836
|
+
spec.batch, spec.q_seq, spec.q_heads, spec.qk_head_size
|
|
1837
|
+
).transpose(0, 2, 1, 3)
|
|
1838
|
+
key_4d = key.reshape(
|
|
1839
|
+
spec.batch, spec.kv_seq, spec.kv_heads, spec.qk_head_size
|
|
1840
|
+
).transpose(0, 2, 1, 3)
|
|
1841
|
+
value_4d = value.reshape(
|
|
1842
|
+
spec.batch, spec.kv_seq, spec.kv_heads, spec.v_head_size
|
|
1843
|
+
).transpose(0, 2, 1, 3)
|
|
1844
|
+
else:
|
|
1845
|
+
query_4d = query
|
|
1846
|
+
key_4d = key
|
|
1847
|
+
value_4d = value
|
|
1848
|
+
if past_key is not None and past_value is not None:
|
|
1849
|
+
key_total = np.concatenate([past_key, key_4d], axis=2)
|
|
1850
|
+
value_total = np.concatenate([past_value, value_4d], axis=2)
|
|
1851
|
+
else:
|
|
1852
|
+
key_total = key_4d
|
|
1853
|
+
value_total = value_4d
|
|
1854
|
+
if spec.head_group_size > 1:
|
|
1855
|
+
key_total_expanded = np.repeat(key_total, spec.head_group_size, axis=1)
|
|
1856
|
+
value_total_expanded = np.repeat(
|
|
1857
|
+
value_total, spec.head_group_size, axis=1
|
|
1858
|
+
)
|
|
1859
|
+
else:
|
|
1860
|
+
key_total_expanded = key_total
|
|
1861
|
+
value_total_expanded = value_total
|
|
1862
|
+
k_transpose = np.transpose(key_total_expanded, (0, 1, 3, 2))
|
|
1863
|
+
scores = np.matmul(query_4d, k_transpose) * spec.scale
|
|
1864
|
+
bias = np.zeros_like(scores)
|
|
1865
|
+
if spec.has_attn_mask and attn_mask is not None:
|
|
1866
|
+
if spec.mask_is_bool:
|
|
1867
|
+
bias_mask = np.where(attn_mask, 0.0, -np.inf)
|
|
1868
|
+
else:
|
|
1869
|
+
bias_mask = attn_mask.astype(scores.dtype)
|
|
1870
|
+
if spec.mask_rank == 2:
|
|
1871
|
+
bias_mask = bias_mask[None, None, ...]
|
|
1872
|
+
elif spec.mask_rank == 3:
|
|
1873
|
+
bias_mask = bias_mask[:, None, ...]
|
|
1874
|
+
bias_mask = np.broadcast_to(
|
|
1875
|
+
bias_mask, (spec.batch, spec.q_heads, spec.q_seq, spec.mask_kv_seq)
|
|
1876
|
+
)
|
|
1877
|
+
if spec.mask_kv_seq < spec.total_seq:
|
|
1878
|
+
pad_width = spec.total_seq - spec.mask_kv_seq
|
|
1879
|
+
bias_mask = np.pad(
|
|
1880
|
+
bias_mask,
|
|
1881
|
+
((0, 0), (0, 0), (0, 0), (0, pad_width)),
|
|
1882
|
+
constant_values=-np.inf,
|
|
1883
|
+
)
|
|
1884
|
+
bias = bias + bias_mask
|
|
1885
|
+
if spec.has_nonpad and nonpad_kv_seqlen is not None:
|
|
1886
|
+
kv_range = np.arange(spec.total_seq)[None, None, None, :]
|
|
1887
|
+
valid = kv_range < nonpad_kv_seqlen[:, None, None, None]
|
|
1888
|
+
bias = bias + np.where(valid, 0.0, -np.inf)
|
|
1889
|
+
if spec.is_causal:
|
|
1890
|
+
kv_range = np.arange(spec.total_seq)[None, :]
|
|
1891
|
+
q_range = np.arange(spec.q_seq)[:, None] + spec.past_seq
|
|
1892
|
+
causal_mask = kv_range > q_range
|
|
1893
|
+
bias = bias + np.where(causal_mask, -np.inf, 0.0)[None, None, :, :]
|
|
1894
|
+
scores_with_bias = scores + bias
|
|
1895
|
+
if spec.softcap != 0.0:
|
|
1896
|
+
scores_softcap = spec.softcap * np.tanh(scores_with_bias / spec.softcap)
|
|
1897
|
+
else:
|
|
1898
|
+
scores_softcap = scores_with_bias
|
|
1899
|
+
max_scores = np.max(scores_softcap, axis=-1, keepdims=True)
|
|
1900
|
+
weights = np.exp(scores_softcap - max_scores)
|
|
1901
|
+
weights /= np.sum(weights, axis=-1, keepdims=True)
|
|
1902
|
+
output = np.matmul(weights, value_total_expanded)
|
|
1903
|
+
if spec.q_rank == 3:
|
|
1904
|
+
output = output.transpose(0, 2, 1, 3).reshape(
|
|
1905
|
+
spec.batch, spec.q_seq, spec.q_heads * spec.v_head_size
|
|
1906
|
+
)
|
|
1907
|
+
qk_output = None
|
|
1908
|
+
if spec.qk_matmul_output_mode == 0:
|
|
1909
|
+
qk_output = scores
|
|
1910
|
+
elif spec.qk_matmul_output_mode == 1:
|
|
1911
|
+
qk_output = scores_with_bias
|
|
1912
|
+
elif spec.qk_matmul_output_mode == 2:
|
|
1913
|
+
qk_output = scores_softcap
|
|
1914
|
+
else:
|
|
1915
|
+
qk_output = weights
|
|
1916
|
+
return output, key_total, value_total, qk_output
|
|
1917
|
+
|
|
1918
|
+
|
|
1919
|
+
def _apply_conv(spec, data: np.ndarray, weights: np.ndarray, bias: np.ndarray | None) -> np.ndarray:
|
|
1920
|
+
output = np.zeros(
|
|
1921
|
+
(spec.batch, spec.out_channels, *spec.out_spatial),
|
|
1922
|
+
dtype=data.dtype,
|
|
1923
|
+
)
|
|
1924
|
+
pad_begin = spec.pads[: spec.spatial_rank]
|
|
1925
|
+
group_in_channels = spec.in_channels // spec.group
|
|
1926
|
+
group_out_channels = spec.out_channels // spec.group
|
|
1927
|
+
for n in range(spec.batch):
|
|
1928
|
+
for g in range(spec.group):
|
|
1929
|
+
oc_base = g * group_out_channels
|
|
1930
|
+
ic_base = g * group_in_channels
|
|
1931
|
+
for oc in range(group_out_channels):
|
|
1932
|
+
oc_global = oc_base + oc
|
|
1933
|
+
base = bias[oc_global] if bias is not None else 0.0
|
|
1934
|
+
for out_index in np.ndindex(*spec.out_spatial):
|
|
1935
|
+
acc = base
|
|
1936
|
+
for ic in range(group_in_channels):
|
|
1937
|
+
ic_global = ic_base + ic
|
|
1938
|
+
for kernel_index in np.ndindex(*spec.kernel_shape):
|
|
1939
|
+
in_index = []
|
|
1940
|
+
valid = True
|
|
1941
|
+
for (
|
|
1942
|
+
out_dim,
|
|
1943
|
+
kernel_dim,
|
|
1944
|
+
stride,
|
|
1945
|
+
dilation,
|
|
1946
|
+
pad,
|
|
1947
|
+
in_size,
|
|
1948
|
+
) in zip(
|
|
1949
|
+
out_index,
|
|
1950
|
+
kernel_index,
|
|
1951
|
+
spec.strides,
|
|
1952
|
+
spec.dilations,
|
|
1953
|
+
pad_begin,
|
|
1954
|
+
spec.in_spatial,
|
|
1955
|
+
):
|
|
1956
|
+
in_dim = out_dim * stride + kernel_dim * dilation - pad
|
|
1957
|
+
if in_dim < 0 or in_dim >= in_size:
|
|
1958
|
+
valid = False
|
|
1959
|
+
break
|
|
1960
|
+
in_index.append(in_dim)
|
|
1961
|
+
if not valid:
|
|
1962
|
+
continue
|
|
1963
|
+
acc += data[(n, ic_global, *in_index)] * weights[
|
|
1964
|
+
(oc_global, ic, *kernel_index)
|
|
1965
|
+
]
|
|
1966
|
+
output[(n, oc_global, *out_index)] = acc
|
|
1967
|
+
return output
|
|
1968
|
+
|
|
1969
|
+
|
|
1970
|
+
def _apply_lrn(spec, data: np.ndarray) -> np.ndarray:
|
|
1971
|
+
output = np.empty_like(data)
|
|
1972
|
+
spatial_shape = spec.shape[2:]
|
|
1973
|
+
spatial_indices = [()]
|
|
1974
|
+
if spatial_shape:
|
|
1975
|
+
spatial_indices = list(np.ndindex(*spatial_shape))
|
|
1976
|
+
for n in range(spec.shape[0]):
|
|
1977
|
+
for c in range(spec.channels):
|
|
1978
|
+
start = max(0, c - spec.half)
|
|
1979
|
+
end = min(spec.channels - 1, c + spec.half)
|
|
1980
|
+
for index in spatial_indices:
|
|
1981
|
+
sum_val = 0.0
|
|
1982
|
+
for i in range(start, end + 1):
|
|
1983
|
+
value = data[(n, i, *index)]
|
|
1984
|
+
sum_val += value * value
|
|
1985
|
+
scale = spec.bias + (spec.alpha / spec.size) * sum_val
|
|
1986
|
+
output[(n, c, *index)] = data[(n, c, *index)] / math.pow(
|
|
1987
|
+
scale, spec.beta
|
|
1988
|
+
)
|
|
1989
|
+
return output
|
|
1990
|
+
|
|
1991
|
+
|
|
1992
|
+
def _apply_average_pool(op, data: np.ndarray) -> np.ndarray:
|
|
1993
|
+
output = np.zeros((op.batch, op.channels, op.out_h, op.out_w), dtype=data.dtype)
|
|
1994
|
+
for n in range(op.batch):
|
|
1995
|
+
for c in range(op.channels):
|
|
1996
|
+
for oh in range(op.out_h):
|
|
1997
|
+
for ow in range(op.out_w):
|
|
1998
|
+
acc = 0.0
|
|
1999
|
+
count = 0
|
|
2000
|
+
for kh in range(op.kernel_h):
|
|
2001
|
+
ih = oh * op.stride_h + kh - op.pad_top
|
|
2002
|
+
if ih < 0 or ih >= op.in_h:
|
|
2003
|
+
if op.count_include_pad:
|
|
2004
|
+
count += op.kernel_w
|
|
2005
|
+
continue
|
|
2006
|
+
for kw in range(op.kernel_w):
|
|
2007
|
+
iw = ow * op.stride_w + kw - op.pad_left
|
|
2008
|
+
if iw < 0 or iw >= op.in_w:
|
|
2009
|
+
if op.count_include_pad:
|
|
2010
|
+
count += 1
|
|
2011
|
+
continue
|
|
2012
|
+
acc += data[n, c, ih, iw]
|
|
2013
|
+
count += 1
|
|
2014
|
+
output[n, c, oh, ow] = 0.0 if count == 0 else acc / float(count)
|
|
2015
|
+
return output
|
|
2016
|
+
|
|
2017
|
+
|
|
2018
|
+
def _maxpool_min_value(dtype: np.dtype) -> float | int:
|
|
2019
|
+
if np.issubdtype(dtype, np.floating):
|
|
2020
|
+
return -np.inf
|
|
2021
|
+
if np.issubdtype(dtype, np.integer):
|
|
2022
|
+
return np.iinfo(dtype).min
|
|
2023
|
+
raise UnsupportedOpError("MaxPool supports numeric inputs only")
|
|
2024
|
+
|
|
2025
|
+
|
|
2026
|
+
def _apply_maxpool(
|
|
2027
|
+
spec, data: np.ndarray, *, return_indices: bool = False
|
|
2028
|
+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
|
2029
|
+
min_value = _maxpool_min_value(data.dtype)
|
|
2030
|
+
output = np.full(
|
|
2031
|
+
(spec.batch, spec.channels, *spec.out_spatial),
|
|
2032
|
+
min_value,
|
|
2033
|
+
dtype=data.dtype,
|
|
2034
|
+
)
|
|
2035
|
+
indices = (
|
|
2036
|
+
np.zeros((spec.batch, spec.channels, *spec.out_spatial), dtype=np.int64)
|
|
2037
|
+
if return_indices
|
|
2038
|
+
else None
|
|
2039
|
+
)
|
|
2040
|
+
pad_begin = spec.pads[: spec.spatial_rank]
|
|
2041
|
+
for n in range(spec.batch):
|
|
2042
|
+
for c in range(spec.channels):
|
|
2043
|
+
for out_index in np.ndindex(*spec.out_spatial):
|
|
2044
|
+
max_value = min_value
|
|
2045
|
+
max_index = 0
|
|
2046
|
+
has_value = False
|
|
2047
|
+
for kernel_index in np.ndindex(*spec.kernel_shape):
|
|
2048
|
+
in_index = []
|
|
2049
|
+
valid = True
|
|
2050
|
+
for out_dim, kernel_dim, stride, dilation, pad in zip(
|
|
2051
|
+
out_index,
|
|
2052
|
+
kernel_index,
|
|
2053
|
+
spec.strides,
|
|
2054
|
+
spec.dilations,
|
|
2055
|
+
pad_begin,
|
|
2056
|
+
):
|
|
2057
|
+
idx = out_dim * stride + kernel_dim * dilation - pad
|
|
2058
|
+
if idx < 0 or idx >= spec.in_spatial[len(in_index)]:
|
|
2059
|
+
valid = False
|
|
2060
|
+
break
|
|
2061
|
+
in_index.append(idx)
|
|
2062
|
+
if not valid:
|
|
2063
|
+
continue
|
|
2064
|
+
value = data[(n, c, *in_index)]
|
|
2065
|
+
if value > max_value or not has_value:
|
|
2066
|
+
max_value = value
|
|
2067
|
+
has_value = True
|
|
2068
|
+
if return_indices:
|
|
2069
|
+
linear_index = n * spec.channels + c
|
|
2070
|
+
if spec.storage_order == 0:
|
|
2071
|
+
for idx, size in zip(in_index, spec.in_spatial):
|
|
2072
|
+
linear_index = linear_index * size + idx
|
|
2073
|
+
else:
|
|
2074
|
+
spatial_index = 0
|
|
2075
|
+
spatial_stride = 1
|
|
2076
|
+
for idx, size in zip(in_index, spec.in_spatial):
|
|
2077
|
+
spatial_index += idx * spatial_stride
|
|
2078
|
+
spatial_stride *= size
|
|
2079
|
+
linear_index = linear_index * spatial_stride + spatial_index
|
|
2080
|
+
max_index = linear_index
|
|
2081
|
+
output[(n, c, *out_index)] = max_value
|
|
2082
|
+
if return_indices and indices is not None:
|
|
2083
|
+
indices[(n, c, *out_index)] = max_index
|
|
2084
|
+
if return_indices:
|
|
2085
|
+
if indices is None:
|
|
2086
|
+
raise RuntimeError("MaxPool indices were not computed")
|
|
2087
|
+
return output, indices
|
|
2088
|
+
return output
|
|
2089
|
+
|
|
2090
|
+
|
|
2091
|
+
def _apply_lstm(
|
|
2092
|
+
spec,
|
|
2093
|
+
x: np.ndarray,
|
|
2094
|
+
w: np.ndarray,
|
|
2095
|
+
r: np.ndarray,
|
|
2096
|
+
b: np.ndarray | None,
|
|
2097
|
+
sequence_lens: np.ndarray | None,
|
|
2098
|
+
initial_h: np.ndarray | None,
|
|
2099
|
+
initial_c: np.ndarray | None,
|
|
2100
|
+
p: np.ndarray | None,
|
|
2101
|
+
) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
|
|
2102
|
+
if spec.layout == 1:
|
|
2103
|
+
x = np.swapaxes(x, 0, 1)
|
|
2104
|
+
seq_length = spec.seq_length
|
|
2105
|
+
batch_size = spec.batch_size
|
|
2106
|
+
hidden_size = spec.hidden_size
|
|
2107
|
+
num_directions = spec.num_directions
|
|
2108
|
+
if sequence_lens is None:
|
|
2109
|
+
sequence_lens = np.full((batch_size,), seq_length, dtype=np.int64)
|
|
2110
|
+
else:
|
|
2111
|
+
sequence_lens = sequence_lens.astype(np.int64, copy=False)
|
|
2112
|
+
if b is None:
|
|
2113
|
+
b = np.zeros((num_directions, 8 * hidden_size), dtype=x.dtype)
|
|
2114
|
+
if p is None:
|
|
2115
|
+
p = np.zeros((num_directions, 3 * hidden_size), dtype=x.dtype)
|
|
2116
|
+
if initial_h is None:
|
|
2117
|
+
initial_h = np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
|
|
2118
|
+
if initial_c is None:
|
|
2119
|
+
initial_c = np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
|
|
2120
|
+
if spec.layout == 1:
|
|
2121
|
+
initial_h = np.swapaxes(initial_h, 0, 1)
|
|
2122
|
+
initial_c = np.swapaxes(initial_c, 0, 1)
|
|
2123
|
+
output_y = None
|
|
2124
|
+
if spec.output_y is not None:
|
|
2125
|
+
output_y = np.zeros(
|
|
2126
|
+
(seq_length, num_directions, batch_size, hidden_size), dtype=x.dtype
|
|
2127
|
+
)
|
|
2128
|
+
output_y_h = (
|
|
2129
|
+
np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
|
|
2130
|
+
if spec.output_y_h is not None
|
|
2131
|
+
else None
|
|
2132
|
+
)
|
|
2133
|
+
output_y_c = (
|
|
2134
|
+
np.zeros((num_directions, batch_size, hidden_size), dtype=x.dtype)
|
|
2135
|
+
if spec.output_y_c is not None
|
|
2136
|
+
else None
|
|
2137
|
+
)
|
|
2138
|
+
directions = (
|
|
2139
|
+
("forward", "reverse")
|
|
2140
|
+
if spec.direction == "bidirectional"
|
|
2141
|
+
else (spec.direction,)
|
|
2142
|
+
)
|
|
2143
|
+
for dir_index, dir_kind in enumerate(directions):
|
|
2144
|
+
w_dir = w[dir_index]
|
|
2145
|
+
r_dir = r[dir_index]
|
|
2146
|
+
b_dir = b[dir_index]
|
|
2147
|
+
bias = b_dir[: 4 * hidden_size] + b_dir[4 * hidden_size :]
|
|
2148
|
+
p_dir = p[dir_index]
|
|
2149
|
+
p_i = p_dir[:hidden_size]
|
|
2150
|
+
p_o = p_dir[hidden_size : 2 * hidden_size]
|
|
2151
|
+
p_f = p_dir[2 * hidden_size :]
|
|
2152
|
+
h_prev = initial_h[dir_index].copy()
|
|
2153
|
+
c_prev = initial_c[dir_index].copy()
|
|
2154
|
+
act_offset = dir_index * 3
|
|
2155
|
+
act_f = spec.activation_kinds[act_offset]
|
|
2156
|
+
act_g = spec.activation_kinds[act_offset + 1]
|
|
2157
|
+
act_h = spec.activation_kinds[act_offset + 2]
|
|
2158
|
+
alpha_f = spec.activation_alphas[act_offset]
|
|
2159
|
+
alpha_g = spec.activation_alphas[act_offset + 1]
|
|
2160
|
+
alpha_h = spec.activation_alphas[act_offset + 2]
|
|
2161
|
+
beta_f = spec.activation_betas[act_offset]
|
|
2162
|
+
beta_g = spec.activation_betas[act_offset + 1]
|
|
2163
|
+
beta_h = spec.activation_betas[act_offset + 2]
|
|
2164
|
+
for step in range(seq_length):
|
|
2165
|
+
t_index = step if dir_kind == "forward" else seq_length - 1 - step
|
|
2166
|
+
x_t = x[t_index]
|
|
2167
|
+
gates = x_t @ w_dir.T + h_prev @ r_dir.T + bias
|
|
2168
|
+
if spec.clip is not None and spec.clip > 0:
|
|
2169
|
+
gates = np.clip(gates, -spec.clip, spec.clip)
|
|
2170
|
+
i, o, f, c = np.split(gates, 4, axis=1)
|
|
2171
|
+
i = _apply_lstm_activation(act_f, i + p_i * c_prev, alpha_f, beta_f)
|
|
2172
|
+
if spec.input_forget:
|
|
2173
|
+
f = 1 - i
|
|
2174
|
+
else:
|
|
2175
|
+
f = _apply_lstm_activation(
|
|
2176
|
+
act_f, f + p_f * c_prev, alpha_f, beta_f
|
|
2177
|
+
)
|
|
2178
|
+
c_tilde = _apply_lstm_activation(act_g, c, alpha_g, beta_g)
|
|
2179
|
+
c_new = f * c_prev + i * c_tilde
|
|
2180
|
+
o = _apply_lstm_activation(act_f, o + p_o * c_new, alpha_f, beta_f)
|
|
2181
|
+
h_new = o * _apply_lstm_activation(act_h, c_new, alpha_h, beta_h)
|
|
2182
|
+
active_mask = step < sequence_lens
|
|
2183
|
+
if not np.all(active_mask):
|
|
2184
|
+
h_new = np.where(active_mask[:, None], h_new, h_prev)
|
|
2185
|
+
c_new = np.where(active_mask[:, None], c_new, c_prev)
|
|
2186
|
+
if output_y is not None:
|
|
2187
|
+
output_y[step, dir_index] = np.where(
|
|
2188
|
+
active_mask[:, None], h_new, 0
|
|
2189
|
+
)
|
|
2190
|
+
else:
|
|
2191
|
+
if output_y is not None:
|
|
2192
|
+
output_y[step, dir_index] = h_new
|
|
2193
|
+
h_prev = h_new
|
|
2194
|
+
c_prev = c_new
|
|
2195
|
+
if output_y_h is not None:
|
|
2196
|
+
output_y_h[dir_index] = h_prev
|
|
2197
|
+
if output_y_c is not None:
|
|
2198
|
+
output_y_c[dir_index] = c_prev
|
|
2199
|
+
if spec.layout == 1:
|
|
2200
|
+
if output_y is not None:
|
|
2201
|
+
output_y = np.transpose(output_y, (2, 0, 1, 3))
|
|
2202
|
+
if output_y_h is not None:
|
|
2203
|
+
output_y_h = np.swapaxes(output_y_h, 0, 1)
|
|
2204
|
+
if output_y_c is not None:
|
|
2205
|
+
output_y_c = np.swapaxes(output_y_c, 0, 1)
|
|
2206
|
+
return output_y, output_y_h, output_y_c
|