emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Iterable, Sequence
|
|
5
|
+
|
|
6
|
+
from shared.scalar_types import ScalarType
|
|
7
|
+
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Node
|
|
10
|
+
from .common import node_dtype, optional_name, value_dtype, value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
ACTIVATION_KIND_BY_NAME = {
|
|
14
|
+
"Relu": 0,
|
|
15
|
+
"Tanh": 1,
|
|
16
|
+
"Sigmoid": 2,
|
|
17
|
+
"Affine": 3,
|
|
18
|
+
"LeakyRelu": 4,
|
|
19
|
+
"ThresholdedRelu": 5,
|
|
20
|
+
"ScaledTanh": 6,
|
|
21
|
+
"HardSigmoid": 7,
|
|
22
|
+
"Elu": 8,
|
|
23
|
+
"Softsign": 9,
|
|
24
|
+
"Softplus": 10,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
DEFAULT_ACTIVATIONS = ("Sigmoid", "Tanh", "Tanh")
|
|
28
|
+
|
|
29
|
+
DEFAULT_ALPHA_BY_NAME = {
|
|
30
|
+
"Affine": 1.0,
|
|
31
|
+
"LeakyRelu": 0.01,
|
|
32
|
+
"ThresholdedRelu": 1.0,
|
|
33
|
+
"ScaledTanh": 1.0,
|
|
34
|
+
"HardSigmoid": 0.2,
|
|
35
|
+
"Elu": 1.0,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
DEFAULT_BETA_BY_NAME = {
|
|
39
|
+
"Affine": 0.0,
|
|
40
|
+
"ScaledTanh": 1.0,
|
|
41
|
+
"HardSigmoid": 0.5,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass(frozen=True)
|
|
46
|
+
class LstmSpec:
|
|
47
|
+
input_x: str
|
|
48
|
+
input_w: str
|
|
49
|
+
input_r: str
|
|
50
|
+
input_b: str | None
|
|
51
|
+
input_sequence_lens: str | None
|
|
52
|
+
input_initial_h: str | None
|
|
53
|
+
input_initial_c: str | None
|
|
54
|
+
input_p: str | None
|
|
55
|
+
output_y: str | None
|
|
56
|
+
output_y_h: str | None
|
|
57
|
+
output_y_c: str | None
|
|
58
|
+
seq_length: int
|
|
59
|
+
batch_size: int
|
|
60
|
+
input_size: int
|
|
61
|
+
hidden_size: int
|
|
62
|
+
num_directions: int
|
|
63
|
+
direction: str
|
|
64
|
+
layout: int
|
|
65
|
+
input_forget: int
|
|
66
|
+
clip: float | None
|
|
67
|
+
activation_kinds: tuple[int, ...]
|
|
68
|
+
activation_alphas: tuple[float, ...]
|
|
69
|
+
activation_betas: tuple[float, ...]
|
|
70
|
+
dtype: ScalarType
|
|
71
|
+
sequence_lens_dtype: ScalarType | None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _normalize_activation_names(values: Iterable[object]) -> list[str]:
|
|
75
|
+
names: list[str] = []
|
|
76
|
+
for value in values:
|
|
77
|
+
if isinstance(value, bytes):
|
|
78
|
+
value = value.decode("utf-8")
|
|
79
|
+
if not isinstance(value, str):
|
|
80
|
+
raise UnsupportedOpError("LSTM activations must be strings")
|
|
81
|
+
names.append(value)
|
|
82
|
+
return names
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _resolve_activation_params(
|
|
86
|
+
activations: Sequence[str],
|
|
87
|
+
activation_alpha: Sequence[float] | None,
|
|
88
|
+
activation_beta: Sequence[float] | None,
|
|
89
|
+
) -> tuple[tuple[int, ...], tuple[float, ...], tuple[float, ...]]:
|
|
90
|
+
if activation_alpha is None:
|
|
91
|
+
activation_alpha = []
|
|
92
|
+
if activation_beta is None:
|
|
93
|
+
activation_beta = []
|
|
94
|
+
if activation_alpha and len(activation_alpha) != len(activations):
|
|
95
|
+
raise UnsupportedOpError("LSTM activation_alpha must match activations")
|
|
96
|
+
if activation_beta and len(activation_beta) != len(activations):
|
|
97
|
+
raise UnsupportedOpError("LSTM activation_beta must match activations")
|
|
98
|
+
activation_kinds: list[int] = []
|
|
99
|
+
alphas: list[float] = []
|
|
100
|
+
betas: list[float] = []
|
|
101
|
+
for idx, name in enumerate(activations):
|
|
102
|
+
kind = ACTIVATION_KIND_BY_NAME.get(name)
|
|
103
|
+
if kind is None:
|
|
104
|
+
raise UnsupportedOpError(f"Unsupported LSTM activation {name}")
|
|
105
|
+
activation_kinds.append(kind)
|
|
106
|
+
if activation_alpha:
|
|
107
|
+
alpha = float(activation_alpha[idx])
|
|
108
|
+
else:
|
|
109
|
+
alpha = DEFAULT_ALPHA_BY_NAME.get(name, 1.0)
|
|
110
|
+
if activation_beta:
|
|
111
|
+
beta = float(activation_beta[idx])
|
|
112
|
+
else:
|
|
113
|
+
beta = DEFAULT_BETA_BY_NAME.get(name, 0.0)
|
|
114
|
+
alphas.append(alpha)
|
|
115
|
+
betas.append(beta)
|
|
116
|
+
return tuple(activation_kinds), tuple(alphas), tuple(betas)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _resolve_activations(
|
|
120
|
+
direction: str, num_directions: int, attrs: dict[str, object]
|
|
121
|
+
) -> tuple[tuple[int, ...], tuple[float, ...], tuple[float, ...]]:
|
|
122
|
+
activations_attr = attrs.get("activations")
|
|
123
|
+
if activations_attr is None:
|
|
124
|
+
activations = list(DEFAULT_ACTIVATIONS)
|
|
125
|
+
else:
|
|
126
|
+
activations = _normalize_activation_names(activations_attr)
|
|
127
|
+
if num_directions == 1:
|
|
128
|
+
if len(activations) != 3:
|
|
129
|
+
raise UnsupportedOpError("LSTM activations must have length 3")
|
|
130
|
+
else:
|
|
131
|
+
if len(activations) == 3:
|
|
132
|
+
activations = activations * 2
|
|
133
|
+
elif len(activations) != 6:
|
|
134
|
+
raise UnsupportedOpError("Bidirectional LSTM activations must be length 6")
|
|
135
|
+
activation_alpha = attrs.get("activation_alpha")
|
|
136
|
+
activation_beta = attrs.get("activation_beta")
|
|
137
|
+
return _resolve_activation_params(
|
|
138
|
+
activations,
|
|
139
|
+
activation_alpha,
|
|
140
|
+
activation_beta,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _expect_shape(
|
|
145
|
+
name: str, shape: tuple[int, ...], expected: tuple[int, ...]
|
|
146
|
+
) -> None:
|
|
147
|
+
if shape != expected:
|
|
148
|
+
raise UnsupportedOpError(
|
|
149
|
+
f"LSTM input {name} must have shape {expected}, got {shape}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _validate_direction(direction: str, num_directions: int) -> None:
|
|
154
|
+
if direction == "bidirectional" and num_directions != 2:
|
|
155
|
+
raise UnsupportedOpError(
|
|
156
|
+
"LSTM expects num_directions=2 for bidirectional models"
|
|
157
|
+
)
|
|
158
|
+
if direction in {"forward", "reverse"} and num_directions != 1:
|
|
159
|
+
raise UnsupportedOpError(
|
|
160
|
+
"LSTM expects num_directions=1 for forward/reverse models"
|
|
161
|
+
)
|
|
162
|
+
if direction not in {"forward", "reverse", "bidirectional"}:
|
|
163
|
+
raise UnsupportedOpError(f"Unsupported LSTM direction {direction}")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def resolve_lstm_spec(graph: Graph, node: Node) -> LstmSpec:
|
|
167
|
+
if len(node.inputs) < 3 or len(node.inputs) > 8:
|
|
168
|
+
raise UnsupportedOpError("LSTM expects between 3 and 8 inputs")
|
|
169
|
+
if len(node.outputs) < 1 or len(node.outputs) > 3:
|
|
170
|
+
raise UnsupportedOpError("LSTM expects between 1 and 3 outputs")
|
|
171
|
+
input_x = node.inputs[0]
|
|
172
|
+
input_w = node.inputs[1]
|
|
173
|
+
input_r = node.inputs[2]
|
|
174
|
+
input_b = optional_name(node.inputs, 3)
|
|
175
|
+
input_sequence_lens = optional_name(node.inputs, 4)
|
|
176
|
+
input_initial_h = optional_name(node.inputs, 5)
|
|
177
|
+
input_initial_c = optional_name(node.inputs, 6)
|
|
178
|
+
input_p = optional_name(node.inputs, 7)
|
|
179
|
+
output_y = optional_name(node.outputs, 0)
|
|
180
|
+
output_y_h = optional_name(node.outputs, 1)
|
|
181
|
+
output_y_c = optional_name(node.outputs, 2)
|
|
182
|
+
if output_y is None and output_y_h is None and output_y_c is None:
|
|
183
|
+
raise UnsupportedOpError("LSTM expects at least one output")
|
|
184
|
+
op_dtype = node_dtype(
|
|
185
|
+
graph,
|
|
186
|
+
node,
|
|
187
|
+
input_x,
|
|
188
|
+
input_w,
|
|
189
|
+
input_r,
|
|
190
|
+
*(name for name in (input_b, input_initial_h, input_initial_c, input_p) if name),
|
|
191
|
+
*(name for name in (output_y, output_y_h, output_y_c) if name),
|
|
192
|
+
)
|
|
193
|
+
if not op_dtype.is_float:
|
|
194
|
+
raise UnsupportedOpError(
|
|
195
|
+
"LSTM supports float16, float, and double inputs only"
|
|
196
|
+
)
|
|
197
|
+
x_shape = value_shape(graph, input_x, node)
|
|
198
|
+
if len(x_shape) != 3:
|
|
199
|
+
raise UnsupportedOpError("LSTM input X must be rank 3")
|
|
200
|
+
layout = int(node.attrs.get("layout", 0))
|
|
201
|
+
if layout not in {0, 1}:
|
|
202
|
+
raise UnsupportedOpError("LSTM layout must be 0 or 1")
|
|
203
|
+
if layout == 0:
|
|
204
|
+
seq_length, batch_size, input_size = x_shape
|
|
205
|
+
else:
|
|
206
|
+
batch_size, seq_length, input_size = x_shape
|
|
207
|
+
w_shape = value_shape(graph, input_w, node)
|
|
208
|
+
if len(w_shape) != 3:
|
|
209
|
+
raise UnsupportedOpError("LSTM input W must be rank 3")
|
|
210
|
+
num_directions = w_shape[0]
|
|
211
|
+
hidden_size_attr = node.attrs.get("hidden_size")
|
|
212
|
+
if hidden_size_attr is None:
|
|
213
|
+
if w_shape[1] % 4 != 0:
|
|
214
|
+
raise UnsupportedOpError("LSTM W shape is not divisible by 4")
|
|
215
|
+
hidden_size = w_shape[1] // 4
|
|
216
|
+
else:
|
|
217
|
+
hidden_size = int(hidden_size_attr)
|
|
218
|
+
_validate_direction(str(node.attrs.get("direction", "forward")), num_directions)
|
|
219
|
+
direction = str(node.attrs.get("direction", "forward"))
|
|
220
|
+
expected_w_shape = (num_directions, 4 * hidden_size, input_size)
|
|
221
|
+
_expect_shape(input_w, w_shape, expected_w_shape)
|
|
222
|
+
r_shape = value_shape(graph, input_r, node)
|
|
223
|
+
expected_r_shape = (num_directions, 4 * hidden_size, hidden_size)
|
|
224
|
+
_expect_shape(input_r, r_shape, expected_r_shape)
|
|
225
|
+
if input_b is not None:
|
|
226
|
+
b_shape = value_shape(graph, input_b, node)
|
|
227
|
+
_expect_shape(input_b, b_shape, (num_directions, 8 * hidden_size))
|
|
228
|
+
if input_sequence_lens is not None:
|
|
229
|
+
seq_dtype = value_dtype(graph, input_sequence_lens, node)
|
|
230
|
+
if seq_dtype not in {ScalarType.I32, ScalarType.I64}:
|
|
231
|
+
raise UnsupportedOpError("LSTM sequence_lens must be int32 or int64")
|
|
232
|
+
seq_shape = value_shape(graph, input_sequence_lens, node)
|
|
233
|
+
if seq_shape != (batch_size,):
|
|
234
|
+
raise UnsupportedOpError(
|
|
235
|
+
"LSTM sequence_lens must match batch size"
|
|
236
|
+
)
|
|
237
|
+
state_shape = (
|
|
238
|
+
(num_directions, batch_size, hidden_size)
|
|
239
|
+
if layout == 0
|
|
240
|
+
else (batch_size, num_directions, hidden_size)
|
|
241
|
+
)
|
|
242
|
+
if input_initial_h is not None:
|
|
243
|
+
_expect_shape(
|
|
244
|
+
input_initial_h,
|
|
245
|
+
value_shape(graph, input_initial_h, node),
|
|
246
|
+
state_shape,
|
|
247
|
+
)
|
|
248
|
+
if input_initial_c is not None:
|
|
249
|
+
_expect_shape(
|
|
250
|
+
input_initial_c,
|
|
251
|
+
value_shape(graph, input_initial_c, node),
|
|
252
|
+
state_shape,
|
|
253
|
+
)
|
|
254
|
+
if input_p is not None:
|
|
255
|
+
_expect_shape(
|
|
256
|
+
input_p,
|
|
257
|
+
value_shape(graph, input_p, node),
|
|
258
|
+
(num_directions, 3 * hidden_size),
|
|
259
|
+
)
|
|
260
|
+
if output_y is not None:
|
|
261
|
+
expected_y_shape = (
|
|
262
|
+
(seq_length, num_directions, batch_size, hidden_size)
|
|
263
|
+
if layout == 0
|
|
264
|
+
else (batch_size, seq_length, num_directions, hidden_size)
|
|
265
|
+
)
|
|
266
|
+
_expect_shape(output_y, value_shape(graph, output_y, node), expected_y_shape)
|
|
267
|
+
if output_y_h is not None:
|
|
268
|
+
_expect_shape(
|
|
269
|
+
output_y_h,
|
|
270
|
+
value_shape(graph, output_y_h, node),
|
|
271
|
+
state_shape,
|
|
272
|
+
)
|
|
273
|
+
if output_y_c is not None:
|
|
274
|
+
_expect_shape(
|
|
275
|
+
output_y_c,
|
|
276
|
+
value_shape(graph, output_y_c, node),
|
|
277
|
+
state_shape,
|
|
278
|
+
)
|
|
279
|
+
input_forget = int(node.attrs.get("input_forget", 0))
|
|
280
|
+
if input_forget not in {0, 1}:
|
|
281
|
+
raise UnsupportedOpError("LSTM input_forget must be 0 or 1")
|
|
282
|
+
clip = node.attrs.get("clip")
|
|
283
|
+
if clip is not None:
|
|
284
|
+
clip = float(clip)
|
|
285
|
+
if clip < 0:
|
|
286
|
+
raise UnsupportedOpError("LSTM clip must be non-negative")
|
|
287
|
+
activation_kinds, activation_alphas, activation_betas = _resolve_activations(
|
|
288
|
+
direction, num_directions, node.attrs
|
|
289
|
+
)
|
|
290
|
+
sequence_lens_dtype = (
|
|
291
|
+
value_dtype(graph, input_sequence_lens, node)
|
|
292
|
+
if input_sequence_lens is not None
|
|
293
|
+
else None
|
|
294
|
+
)
|
|
295
|
+
return LstmSpec(
|
|
296
|
+
input_x=input_x,
|
|
297
|
+
input_w=input_w,
|
|
298
|
+
input_r=input_r,
|
|
299
|
+
input_b=input_b,
|
|
300
|
+
input_sequence_lens=input_sequence_lens,
|
|
301
|
+
input_initial_h=input_initial_h,
|
|
302
|
+
input_initial_c=input_initial_c,
|
|
303
|
+
input_p=input_p,
|
|
304
|
+
output_y=output_y,
|
|
305
|
+
output_y_h=output_y_h,
|
|
306
|
+
output_y_c=output_y_c,
|
|
307
|
+
seq_length=seq_length,
|
|
308
|
+
batch_size=batch_size,
|
|
309
|
+
input_size=input_size,
|
|
310
|
+
hidden_size=hidden_size,
|
|
311
|
+
num_directions=num_directions,
|
|
312
|
+
direction=direction,
|
|
313
|
+
layout=layout,
|
|
314
|
+
input_forget=input_forget,
|
|
315
|
+
clip=clip,
|
|
316
|
+
activation_kinds=activation_kinds,
|
|
317
|
+
activation_alphas=activation_alphas,
|
|
318
|
+
activation_betas=activation_betas,
|
|
319
|
+
dtype=op_dtype,
|
|
320
|
+
sequence_lens_dtype=sequence_lens_dtype,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@register_lowering("LSTM")
|
|
325
|
+
def lower_lstm(graph: Graph, node: Node) -> "LstmOp":
|
|
326
|
+
from ..codegen.c_emitter import LstmOp
|
|
327
|
+
|
|
328
|
+
spec = resolve_lstm_spec(graph, node)
|
|
329
|
+
return LstmOp(
|
|
330
|
+
input_x=spec.input_x,
|
|
331
|
+
input_w=spec.input_w,
|
|
332
|
+
input_r=spec.input_r,
|
|
333
|
+
input_b=spec.input_b,
|
|
334
|
+
input_sequence_lens=spec.input_sequence_lens,
|
|
335
|
+
input_initial_h=spec.input_initial_h,
|
|
336
|
+
input_initial_c=spec.input_initial_c,
|
|
337
|
+
input_p=spec.input_p,
|
|
338
|
+
output_y=spec.output_y,
|
|
339
|
+
output_y_h=spec.output_y_h,
|
|
340
|
+
output_y_c=spec.output_y_c,
|
|
341
|
+
seq_length=spec.seq_length,
|
|
342
|
+
batch_size=spec.batch_size,
|
|
343
|
+
input_size=spec.input_size,
|
|
344
|
+
hidden_size=spec.hidden_size,
|
|
345
|
+
num_directions=spec.num_directions,
|
|
346
|
+
direction=spec.direction,
|
|
347
|
+
layout=spec.layout,
|
|
348
|
+
input_forget=spec.input_forget,
|
|
349
|
+
clip=spec.clip,
|
|
350
|
+
activation_kinds=spec.activation_kinds,
|
|
351
|
+
activation_alphas=spec.activation_alphas,
|
|
352
|
+
activation_betas=spec.activation_betas,
|
|
353
|
+
dtype=spec.dtype,
|
|
354
|
+
sequence_lens_dtype=spec.sequence_lens_dtype,
|
|
355
|
+
)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import MatMulOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import node_dtype as _node_dtype
|
|
9
|
+
from .common import value_shape as _value_shape
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class MatMulSpec:
|
|
15
|
+
input0_shape: tuple[int, ...]
|
|
16
|
+
input1_shape: tuple[int, ...]
|
|
17
|
+
output_shape: tuple[int, ...]
|
|
18
|
+
batch_shape: tuple[int, ...]
|
|
19
|
+
input0_batch_shape: tuple[int, ...]
|
|
20
|
+
input1_batch_shape: tuple[int, ...]
|
|
21
|
+
m: int
|
|
22
|
+
n: int
|
|
23
|
+
k: int
|
|
24
|
+
left_vector: bool
|
|
25
|
+
right_vector: bool
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def resolve_matmul_spec(graph: Graph, node: Node) -> MatMulSpec:
|
|
29
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
30
|
+
raise UnsupportedOpError("MatMul must have 2 inputs and 1 output")
|
|
31
|
+
input0_shape = _value_shape(graph, node.inputs[0], node)
|
|
32
|
+
input1_shape = _value_shape(graph, node.inputs[1], node)
|
|
33
|
+
if len(input0_shape) < 1 or len(input1_shape) < 1:
|
|
34
|
+
raise UnsupportedOpError(
|
|
35
|
+
"MatMul inputs must be at least 1D, "
|
|
36
|
+
f"got {input0_shape} x {input1_shape}"
|
|
37
|
+
)
|
|
38
|
+
left_vector = len(input0_shape) == 1
|
|
39
|
+
right_vector = len(input1_shape) == 1
|
|
40
|
+
input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
|
|
41
|
+
input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
|
|
42
|
+
m, k_left = input0_effective[-2], input0_effective[-1]
|
|
43
|
+
k_right, n = input1_effective[-2], input1_effective[-1]
|
|
44
|
+
if k_left != k_right:
|
|
45
|
+
raise ShapeInferenceError(
|
|
46
|
+
f"MatMul inner dimensions must match, got {k_left} and {k_right}"
|
|
47
|
+
)
|
|
48
|
+
batch_shape, input0_batch_shape, input1_batch_shape = (
|
|
49
|
+
_broadcast_batch_shapes(
|
|
50
|
+
input0_effective[:-2], input1_effective[:-2], node
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
if left_vector and right_vector:
|
|
54
|
+
output_shape = batch_shape
|
|
55
|
+
elif left_vector:
|
|
56
|
+
output_shape = batch_shape + (n,)
|
|
57
|
+
elif right_vector:
|
|
58
|
+
output_shape = batch_shape + (m,)
|
|
59
|
+
else:
|
|
60
|
+
output_shape = batch_shape + (m, n)
|
|
61
|
+
expected_output_shape = _value_shape(graph, node.outputs[0], node)
|
|
62
|
+
if expected_output_shape != output_shape:
|
|
63
|
+
raise ShapeInferenceError(
|
|
64
|
+
"MatMul output shape must be "
|
|
65
|
+
f"{output_shape}, got {expected_output_shape}"
|
|
66
|
+
)
|
|
67
|
+
return MatMulSpec(
|
|
68
|
+
input0_shape=input0_shape,
|
|
69
|
+
input1_shape=input1_shape,
|
|
70
|
+
output_shape=output_shape,
|
|
71
|
+
batch_shape=batch_shape,
|
|
72
|
+
input0_batch_shape=input0_batch_shape,
|
|
73
|
+
input1_batch_shape=input1_batch_shape,
|
|
74
|
+
m=m,
|
|
75
|
+
n=n,
|
|
76
|
+
k=k_left,
|
|
77
|
+
left_vector=left_vector,
|
|
78
|
+
right_vector=right_vector,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _broadcast_batch_shapes(
|
|
83
|
+
left: tuple[int, ...], right: tuple[int, ...], node: Node
|
|
84
|
+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
85
|
+
max_rank = max(len(left), len(right))
|
|
86
|
+
left_padded = (1,) * (max_rank - len(left)) + left
|
|
87
|
+
right_padded = (1,) * (max_rank - len(right)) + right
|
|
88
|
+
broadcast_shape = []
|
|
89
|
+
for left_dim, right_dim in zip(left_padded, right_padded):
|
|
90
|
+
if left_dim == right_dim or left_dim == 1 or right_dim == 1:
|
|
91
|
+
broadcast_shape.append(max(left_dim, right_dim))
|
|
92
|
+
continue
|
|
93
|
+
raise ShapeInferenceError(
|
|
94
|
+
"MatMul batch dimensions must be broadcastable, "
|
|
95
|
+
f"got {left} x {right}"
|
|
96
|
+
)
|
|
97
|
+
return tuple(broadcast_shape), left_padded, right_padded
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@register_lowering("MatMul")
|
|
101
|
+
def lower_matmul(graph: Graph, node: Node) -> MatMulOp:
|
|
102
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
103
|
+
spec = resolve_matmul_spec(graph, node)
|
|
104
|
+
return MatMulOp(
|
|
105
|
+
input0=node.inputs[0],
|
|
106
|
+
input1=node.inputs[1],
|
|
107
|
+
output=node.outputs[0],
|
|
108
|
+
input0_shape=spec.input0_shape,
|
|
109
|
+
input1_shape=spec.input1_shape,
|
|
110
|
+
output_shape=spec.output_shape,
|
|
111
|
+
batch_shape=spec.batch_shape,
|
|
112
|
+
input0_batch_shape=spec.input0_batch_shape,
|
|
113
|
+
input1_batch_shape=spec.input1_batch_shape,
|
|
114
|
+
m=spec.m,
|
|
115
|
+
n=spec.n,
|
|
116
|
+
k=spec.k,
|
|
117
|
+
left_vector=spec.left_vector,
|
|
118
|
+
right_vector=spec.right_vector,
|
|
119
|
+
dtype=op_dtype,
|
|
120
|
+
)
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from shared.scalar_types import ScalarType
|
|
7
|
+
|
|
8
|
+
from ..codegen.c_emitter import MaxPoolOp
|
|
9
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
|
+
from ..ir.model import Graph, Node
|
|
11
|
+
from .common import node_dtype as _node_dtype
|
|
12
|
+
from .common import value_dtype as _value_dtype
|
|
13
|
+
from .common import value_shape as _value_shape
|
|
14
|
+
from .registry import register_lowering
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class MaxPoolSpec:
|
|
19
|
+
batch: int
|
|
20
|
+
channels: int
|
|
21
|
+
spatial_rank: int
|
|
22
|
+
in_spatial: tuple[int, ...]
|
|
23
|
+
out_spatial: tuple[int, ...]
|
|
24
|
+
kernel_shape: tuple[int, ...]
|
|
25
|
+
strides: tuple[int, ...]
|
|
26
|
+
pads: tuple[int, ...]
|
|
27
|
+
dilations: tuple[int, ...]
|
|
28
|
+
ceil_mode: bool
|
|
29
|
+
storage_order: int
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def resolve_maxpool_spec(graph: Graph, node: Node) -> MaxPoolSpec:
|
|
33
|
+
if len(node.inputs) != 1 or len(node.outputs) not in {1, 2}:
|
|
34
|
+
raise UnsupportedOpError("MaxPool must have 1 input and 1 or 2 outputs")
|
|
35
|
+
supported_attrs = {
|
|
36
|
+
"auto_pad",
|
|
37
|
+
"ceil_mode",
|
|
38
|
+
"dilations",
|
|
39
|
+
"kernel_shape",
|
|
40
|
+
"pads",
|
|
41
|
+
"storage_order",
|
|
42
|
+
"strides",
|
|
43
|
+
}
|
|
44
|
+
if set(node.attrs) - supported_attrs:
|
|
45
|
+
raise UnsupportedOpError("MaxPool has unsupported attributes")
|
|
46
|
+
storage_order = int(node.attrs.get("storage_order", 0))
|
|
47
|
+
if storage_order not in (0, 1):
|
|
48
|
+
raise UnsupportedOpError("MaxPool supports storage_order=0 or 1 only")
|
|
49
|
+
kernel_shape = node.attrs.get("kernel_shape")
|
|
50
|
+
if kernel_shape is None:
|
|
51
|
+
raise UnsupportedOpError("MaxPool requires kernel_shape")
|
|
52
|
+
kernel_shape = tuple(int(value) for value in kernel_shape)
|
|
53
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
54
|
+
if len(input_shape) < 3:
|
|
55
|
+
raise UnsupportedOpError("MaxPool expects NCHW inputs with spatial dims")
|
|
56
|
+
spatial_rank = len(input_shape) - 2
|
|
57
|
+
if spatial_rank not in {1, 2, 3}:
|
|
58
|
+
raise UnsupportedOpError("MaxPool supports 1D/2D/3D inputs only")
|
|
59
|
+
if len(kernel_shape) != spatial_rank:
|
|
60
|
+
raise ShapeInferenceError(
|
|
61
|
+
f"MaxPool kernel_shape must have {spatial_rank} dims, got {kernel_shape}"
|
|
62
|
+
)
|
|
63
|
+
strides = tuple(
|
|
64
|
+
int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
|
|
65
|
+
)
|
|
66
|
+
if len(strides) != spatial_rank:
|
|
67
|
+
raise UnsupportedOpError("MaxPool stride rank mismatch")
|
|
68
|
+
dilations = tuple(
|
|
69
|
+
int(value) for value in node.attrs.get("dilations", (1,) * spatial_rank)
|
|
70
|
+
)
|
|
71
|
+
if len(dilations) != spatial_rank:
|
|
72
|
+
raise UnsupportedOpError("MaxPool dilation rank mismatch")
|
|
73
|
+
pads = tuple(
|
|
74
|
+
int(value)
|
|
75
|
+
for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
|
|
76
|
+
)
|
|
77
|
+
if len(pads) != 2 * spatial_rank:
|
|
78
|
+
raise UnsupportedOpError("MaxPool pads rank mismatch")
|
|
79
|
+
auto_pad = node.attrs.get("auto_pad", b"NOTSET")
|
|
80
|
+
if isinstance(auto_pad, bytes):
|
|
81
|
+
auto_pad = auto_pad.decode("utf-8", errors="ignore")
|
|
82
|
+
if auto_pad in ("", "NOTSET"):
|
|
83
|
+
pad_begin = pads[:spatial_rank]
|
|
84
|
+
pad_end = pads[spatial_rank:]
|
|
85
|
+
elif auto_pad == "VALID":
|
|
86
|
+
pad_begin = (0,) * spatial_rank
|
|
87
|
+
pad_end = (0,) * spatial_rank
|
|
88
|
+
elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
|
|
89
|
+
pad_begin = []
|
|
90
|
+
pad_end = []
|
|
91
|
+
for dim, stride, dilation, kernel in zip(
|
|
92
|
+
input_shape[2:], strides, dilations, kernel_shape
|
|
93
|
+
):
|
|
94
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
95
|
+
out_dim = math.ceil(dim / stride)
|
|
96
|
+
pad_needed = max(
|
|
97
|
+
0, (out_dim - 1) * stride + effective_kernel - dim
|
|
98
|
+
)
|
|
99
|
+
if auto_pad == "SAME_UPPER":
|
|
100
|
+
pad_start = pad_needed // 2
|
|
101
|
+
else:
|
|
102
|
+
pad_start = (pad_needed + 1) // 2
|
|
103
|
+
pad_begin.append(pad_start)
|
|
104
|
+
pad_end.append(pad_needed - pad_start)
|
|
105
|
+
pad_begin = tuple(pad_begin)
|
|
106
|
+
pad_end = tuple(pad_end)
|
|
107
|
+
else:
|
|
108
|
+
raise UnsupportedOpError("MaxPool has unsupported auto_pad mode")
|
|
109
|
+
ceil_mode = int(node.attrs.get("ceil_mode", 0))
|
|
110
|
+
if ceil_mode not in (0, 1):
|
|
111
|
+
raise UnsupportedOpError("MaxPool supports ceil_mode=0 or 1 only")
|
|
112
|
+
batch, channels = input_shape[0], input_shape[1]
|
|
113
|
+
in_spatial = input_shape[2:]
|
|
114
|
+
out_spatial = []
|
|
115
|
+
for dim, stride, dilation, kernel, pad_start, pad_finish in zip(
|
|
116
|
+
in_spatial, strides, dilations, kernel_shape, pad_begin, pad_end
|
|
117
|
+
):
|
|
118
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
119
|
+
numerator = dim + pad_start + pad_finish - effective_kernel
|
|
120
|
+
if ceil_mode:
|
|
121
|
+
out_dim = (numerator + stride - 1) // stride + 1
|
|
122
|
+
if (out_dim - 1) * stride >= dim + pad_start:
|
|
123
|
+
out_dim -= 1
|
|
124
|
+
else:
|
|
125
|
+
out_dim = numerator // stride + 1
|
|
126
|
+
if out_dim < 0:
|
|
127
|
+
raise ShapeInferenceError(
|
|
128
|
+
"MaxPool output shape must be non-negative"
|
|
129
|
+
)
|
|
130
|
+
out_spatial.append(out_dim)
|
|
131
|
+
expected_output_shape = (batch, channels, *out_spatial)
|
|
132
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
133
|
+
if output_shape != expected_output_shape:
|
|
134
|
+
raise ShapeInferenceError(
|
|
135
|
+
"MaxPool output shape must be "
|
|
136
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
137
|
+
)
|
|
138
|
+
if len(node.outputs) == 2:
|
|
139
|
+
indices_shape = _value_shape(graph, node.outputs[1], node)
|
|
140
|
+
if indices_shape != expected_output_shape:
|
|
141
|
+
raise ShapeInferenceError(
|
|
142
|
+
"MaxPool indices output shape must be "
|
|
143
|
+
f"{expected_output_shape}, got {indices_shape}"
|
|
144
|
+
)
|
|
145
|
+
indices_dtype = _value_dtype(graph, node.outputs[1], node)
|
|
146
|
+
if indices_dtype != ScalarType.I64:
|
|
147
|
+
raise UnsupportedOpError("MaxPool indices output must be int64")
|
|
148
|
+
pads = (*pad_begin, *pad_end)
|
|
149
|
+
return MaxPoolSpec(
|
|
150
|
+
batch=batch,
|
|
151
|
+
channels=channels,
|
|
152
|
+
spatial_rank=spatial_rank,
|
|
153
|
+
in_spatial=in_spatial,
|
|
154
|
+
out_spatial=tuple(out_spatial),
|
|
155
|
+
kernel_shape=kernel_shape,
|
|
156
|
+
strides=strides,
|
|
157
|
+
pads=pads,
|
|
158
|
+
dilations=dilations,
|
|
159
|
+
ceil_mode=bool(ceil_mode),
|
|
160
|
+
storage_order=storage_order,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@register_lowering("MaxPool")
|
|
165
|
+
def lower_maxpool(graph: Graph, node: Node) -> MaxPoolOp:
|
|
166
|
+
if len(node.inputs) != 1 or len(node.outputs) not in {1, 2}:
|
|
167
|
+
raise UnsupportedOpError("MaxPool must have 1 input and 1 or 2 outputs")
|
|
168
|
+
op_dtype = _node_dtype(graph, node, node.inputs[0], node.outputs[0])
|
|
169
|
+
if op_dtype == ScalarType.BOOL:
|
|
170
|
+
raise UnsupportedOpError("MaxPool supports numeric inputs only")
|
|
171
|
+
spec = resolve_maxpool_spec(graph, node)
|
|
172
|
+
indices = node.outputs[1] if len(node.outputs) == 2 else None
|
|
173
|
+
indices_dtype = (
|
|
174
|
+
_value_dtype(graph, indices, node) if indices is not None else None
|
|
175
|
+
)
|
|
176
|
+
if indices_dtype is not None and indices_dtype != ScalarType.I64:
|
|
177
|
+
raise UnsupportedOpError("MaxPool indices output must be int64")
|
|
178
|
+
return MaxPoolOp(
|
|
179
|
+
input0=node.inputs[0],
|
|
180
|
+
output=node.outputs[0],
|
|
181
|
+
indices=indices,
|
|
182
|
+
batch=spec.batch,
|
|
183
|
+
channels=spec.channels,
|
|
184
|
+
spatial_rank=spec.spatial_rank,
|
|
185
|
+
in_spatial=spec.in_spatial,
|
|
186
|
+
out_spatial=spec.out_spatial,
|
|
187
|
+
kernel_shape=spec.kernel_shape,
|
|
188
|
+
strides=spec.strides,
|
|
189
|
+
pads=spec.pads,
|
|
190
|
+
dilations=spec.dilations,
|
|
191
|
+
ceil_mode=spec.ceil_mode,
|
|
192
|
+
storage_order=spec.storage_order,
|
|
193
|
+
dtype=op_dtype,
|
|
194
|
+
indices_dtype=indices_dtype,
|
|
195
|
+
)
|