emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Iterable, Sequence
|
|
5
|
+
|
|
6
|
+
from shared.scalar_types import ScalarType
|
|
7
|
+
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Node
|
|
10
|
+
from .common import node_dtype, optional_name, value_dtype, value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
ACTIVATION_KIND_BY_NAME = {
|
|
14
|
+
"Relu": 0,
|
|
15
|
+
"Tanh": 1,
|
|
16
|
+
"Sigmoid": 2,
|
|
17
|
+
"Affine": 3,
|
|
18
|
+
"LeakyRelu": 4,
|
|
19
|
+
"ThresholdedRelu": 5,
|
|
20
|
+
"ScaledTanh": 6,
|
|
21
|
+
"HardSigmoid": 7,
|
|
22
|
+
"Elu": 8,
|
|
23
|
+
"Softsign": 9,
|
|
24
|
+
"Softplus": 10,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
DEFAULT_ACTIVATIONS = ("Sigmoid", "Tanh")
|
|
28
|
+
|
|
29
|
+
DEFAULT_ALPHA_BY_NAME = {
|
|
30
|
+
"Affine": 1.0,
|
|
31
|
+
"LeakyRelu": 0.01,
|
|
32
|
+
"ThresholdedRelu": 1.0,
|
|
33
|
+
"ScaledTanh": 1.0,
|
|
34
|
+
"HardSigmoid": 0.2,
|
|
35
|
+
"Elu": 1.0,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
DEFAULT_BETA_BY_NAME = {
|
|
39
|
+
"Affine": 0.0,
|
|
40
|
+
"ScaledTanh": 1.0,
|
|
41
|
+
"HardSigmoid": 0.5,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass(frozen=True)
|
|
46
|
+
class GruSpec:
|
|
47
|
+
input_x: str
|
|
48
|
+
input_w: str
|
|
49
|
+
input_r: str
|
|
50
|
+
input_b: str | None
|
|
51
|
+
input_sequence_lens: str | None
|
|
52
|
+
input_initial_h: str | None
|
|
53
|
+
output_y: str | None
|
|
54
|
+
output_y_h: str | None
|
|
55
|
+
seq_length: int
|
|
56
|
+
batch_size: int
|
|
57
|
+
input_size: int
|
|
58
|
+
hidden_size: int
|
|
59
|
+
num_directions: int
|
|
60
|
+
direction: str
|
|
61
|
+
layout: int
|
|
62
|
+
linear_before_reset: int
|
|
63
|
+
clip: float | None
|
|
64
|
+
activation_kinds: tuple[int, ...]
|
|
65
|
+
activation_alphas: tuple[float, ...]
|
|
66
|
+
activation_betas: tuple[float, ...]
|
|
67
|
+
dtype: ScalarType
|
|
68
|
+
sequence_lens_dtype: ScalarType | None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _normalize_activation_names(values: Iterable[object]) -> list[str]:
|
|
72
|
+
names: list[str] = []
|
|
73
|
+
for value in values:
|
|
74
|
+
if isinstance(value, bytes):
|
|
75
|
+
value = value.decode("utf-8")
|
|
76
|
+
if not isinstance(value, str):
|
|
77
|
+
raise UnsupportedOpError("GRU activations must be strings")
|
|
78
|
+
names.append(value)
|
|
79
|
+
return names
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _resolve_activation_params(
|
|
83
|
+
activations: Sequence[str],
|
|
84
|
+
activation_alpha: Sequence[float] | None,
|
|
85
|
+
activation_beta: Sequence[float] | None,
|
|
86
|
+
) -> tuple[tuple[int, ...], tuple[float, ...], tuple[float, ...]]:
|
|
87
|
+
if activation_alpha is None:
|
|
88
|
+
activation_alpha = []
|
|
89
|
+
if activation_beta is None:
|
|
90
|
+
activation_beta = []
|
|
91
|
+
if activation_alpha and len(activation_alpha) != len(activations):
|
|
92
|
+
raise UnsupportedOpError("GRU activation_alpha must match activations")
|
|
93
|
+
if activation_beta and len(activation_beta) != len(activations):
|
|
94
|
+
raise UnsupportedOpError("GRU activation_beta must match activations")
|
|
95
|
+
activation_kinds: list[int] = []
|
|
96
|
+
alphas: list[float] = []
|
|
97
|
+
betas: list[float] = []
|
|
98
|
+
for idx, name in enumerate(activations):
|
|
99
|
+
kind = ACTIVATION_KIND_BY_NAME.get(name)
|
|
100
|
+
if kind is None:
|
|
101
|
+
raise UnsupportedOpError(f"Unsupported GRU activation {name}")
|
|
102
|
+
activation_kinds.append(kind)
|
|
103
|
+
if activation_alpha:
|
|
104
|
+
alpha = float(activation_alpha[idx])
|
|
105
|
+
else:
|
|
106
|
+
alpha = DEFAULT_ALPHA_BY_NAME.get(name, 1.0)
|
|
107
|
+
if activation_beta:
|
|
108
|
+
beta = float(activation_beta[idx])
|
|
109
|
+
else:
|
|
110
|
+
beta = DEFAULT_BETA_BY_NAME.get(name, 0.0)
|
|
111
|
+
alphas.append(alpha)
|
|
112
|
+
betas.append(beta)
|
|
113
|
+
return tuple(activation_kinds), tuple(alphas), tuple(betas)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _resolve_activations(
|
|
117
|
+
num_directions: int, attrs: dict[str, object]
|
|
118
|
+
) -> tuple[tuple[int, ...], tuple[float, ...], tuple[float, ...]]:
|
|
119
|
+
activations_attr = attrs.get("activations")
|
|
120
|
+
if activations_attr is None:
|
|
121
|
+
activations = list(DEFAULT_ACTIVATIONS)
|
|
122
|
+
else:
|
|
123
|
+
activations = _normalize_activation_names(activations_attr)
|
|
124
|
+
if num_directions == 1:
|
|
125
|
+
if len(activations) != 2:
|
|
126
|
+
raise UnsupportedOpError("GRU activations must have length 2")
|
|
127
|
+
else:
|
|
128
|
+
if len(activations) == 2:
|
|
129
|
+
activations = activations * 2
|
|
130
|
+
elif len(activations) != 4:
|
|
131
|
+
raise UnsupportedOpError("Bidirectional GRU activations must be length 4")
|
|
132
|
+
activation_alpha = attrs.get("activation_alpha")
|
|
133
|
+
activation_beta = attrs.get("activation_beta")
|
|
134
|
+
return _resolve_activation_params(
|
|
135
|
+
activations,
|
|
136
|
+
activation_alpha,
|
|
137
|
+
activation_beta,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _expect_shape(
|
|
142
|
+
name: str, shape: tuple[int, ...], expected: tuple[int, ...]
|
|
143
|
+
) -> None:
|
|
144
|
+
if shape != expected:
|
|
145
|
+
raise UnsupportedOpError(
|
|
146
|
+
f"GRU input {name} must have shape {expected}, got {shape}"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _validate_direction(direction: str, num_directions: int) -> None:
|
|
151
|
+
if direction == "bidirectional" and num_directions != 2:
|
|
152
|
+
raise UnsupportedOpError(
|
|
153
|
+
"GRU expects num_directions=2 for bidirectional models"
|
|
154
|
+
)
|
|
155
|
+
if direction in {"forward", "reverse"} and num_directions != 1:
|
|
156
|
+
raise UnsupportedOpError(
|
|
157
|
+
"GRU expects num_directions=1 for forward/reverse models"
|
|
158
|
+
)
|
|
159
|
+
if direction not in {"forward", "reverse", "bidirectional"}:
|
|
160
|
+
raise UnsupportedOpError(f"Unsupported GRU direction {direction}")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def resolve_gru_spec(graph: Graph, node: Node) -> GruSpec:
|
|
164
|
+
if len(node.inputs) < 3 or len(node.inputs) > 6:
|
|
165
|
+
raise UnsupportedOpError("GRU expects between 3 and 6 inputs")
|
|
166
|
+
if len(node.outputs) < 1 or len(node.outputs) > 2:
|
|
167
|
+
raise UnsupportedOpError("GRU expects between 1 and 2 outputs")
|
|
168
|
+
input_x = node.inputs[0]
|
|
169
|
+
input_w = node.inputs[1]
|
|
170
|
+
input_r = node.inputs[2]
|
|
171
|
+
input_b = optional_name(node.inputs, 3)
|
|
172
|
+
input_sequence_lens = optional_name(node.inputs, 4)
|
|
173
|
+
input_initial_h = optional_name(node.inputs, 5)
|
|
174
|
+
output_y = optional_name(node.outputs, 0)
|
|
175
|
+
output_y_h = optional_name(node.outputs, 1)
|
|
176
|
+
if output_y is None and output_y_h is None:
|
|
177
|
+
raise UnsupportedOpError("GRU expects at least one output")
|
|
178
|
+
op_dtype = node_dtype(
|
|
179
|
+
graph,
|
|
180
|
+
node,
|
|
181
|
+
input_x,
|
|
182
|
+
input_w,
|
|
183
|
+
input_r,
|
|
184
|
+
*(name for name in (input_b, input_initial_h) if name),
|
|
185
|
+
*(name for name in (output_y, output_y_h) if name),
|
|
186
|
+
)
|
|
187
|
+
if not op_dtype.is_float:
|
|
188
|
+
raise UnsupportedOpError(
|
|
189
|
+
"GRU supports float16, float, and double inputs only"
|
|
190
|
+
)
|
|
191
|
+
x_shape = value_shape(graph, input_x, node)
|
|
192
|
+
if len(x_shape) != 3:
|
|
193
|
+
raise UnsupportedOpError("GRU input X must be rank 3")
|
|
194
|
+
layout = int(node.attrs.get("layout", 0))
|
|
195
|
+
if layout not in {0, 1}:
|
|
196
|
+
raise UnsupportedOpError("GRU layout must be 0 or 1")
|
|
197
|
+
if layout == 0:
|
|
198
|
+
seq_length, batch_size, input_size = x_shape
|
|
199
|
+
else:
|
|
200
|
+
batch_size, seq_length, input_size = x_shape
|
|
201
|
+
w_shape = value_shape(graph, input_w, node)
|
|
202
|
+
if len(w_shape) != 3:
|
|
203
|
+
raise UnsupportedOpError("GRU input W must be rank 3")
|
|
204
|
+
num_directions = w_shape[0]
|
|
205
|
+
hidden_size_attr = node.attrs.get("hidden_size")
|
|
206
|
+
if hidden_size_attr is None:
|
|
207
|
+
if w_shape[1] % 3 != 0:
|
|
208
|
+
raise UnsupportedOpError("GRU W shape is not divisible by 3")
|
|
209
|
+
hidden_size = w_shape[1] // 3
|
|
210
|
+
else:
|
|
211
|
+
hidden_size = int(hidden_size_attr)
|
|
212
|
+
direction = str(node.attrs.get("direction", "forward"))
|
|
213
|
+
_validate_direction(direction, num_directions)
|
|
214
|
+
expected_w_shape = (num_directions, 3 * hidden_size, input_size)
|
|
215
|
+
_expect_shape(input_w, w_shape, expected_w_shape)
|
|
216
|
+
r_shape = value_shape(graph, input_r, node)
|
|
217
|
+
expected_r_shape = (num_directions, 3 * hidden_size, hidden_size)
|
|
218
|
+
_expect_shape(input_r, r_shape, expected_r_shape)
|
|
219
|
+
if input_b is not None:
|
|
220
|
+
b_shape = value_shape(graph, input_b, node)
|
|
221
|
+
_expect_shape(input_b, b_shape, (num_directions, 6 * hidden_size))
|
|
222
|
+
if input_sequence_lens is not None:
|
|
223
|
+
seq_dtype = value_dtype(graph, input_sequence_lens, node)
|
|
224
|
+
if seq_dtype not in {ScalarType.I32, ScalarType.I64}:
|
|
225
|
+
raise UnsupportedOpError("GRU sequence_lens must be int32 or int64")
|
|
226
|
+
seq_shape = value_shape(graph, input_sequence_lens, node)
|
|
227
|
+
if seq_shape != (batch_size,):
|
|
228
|
+
raise UnsupportedOpError("GRU sequence_lens must match batch size")
|
|
229
|
+
state_shape = (
|
|
230
|
+
(num_directions, batch_size, hidden_size)
|
|
231
|
+
if layout == 0
|
|
232
|
+
else (batch_size, num_directions, hidden_size)
|
|
233
|
+
)
|
|
234
|
+
if input_initial_h is not None:
|
|
235
|
+
_expect_shape(
|
|
236
|
+
input_initial_h,
|
|
237
|
+
value_shape(graph, input_initial_h, node),
|
|
238
|
+
state_shape,
|
|
239
|
+
)
|
|
240
|
+
if output_y is not None:
|
|
241
|
+
expected_y_shape = (
|
|
242
|
+
(seq_length, num_directions, batch_size, hidden_size)
|
|
243
|
+
if layout == 0
|
|
244
|
+
else (batch_size, seq_length, num_directions, hidden_size)
|
|
245
|
+
)
|
|
246
|
+
_expect_shape(output_y, value_shape(graph, output_y, node), expected_y_shape)
|
|
247
|
+
if output_y_h is not None:
|
|
248
|
+
_expect_shape(
|
|
249
|
+
output_y_h,
|
|
250
|
+
value_shape(graph, output_y_h, node),
|
|
251
|
+
state_shape,
|
|
252
|
+
)
|
|
253
|
+
linear_before_reset = int(node.attrs.get("linear_before_reset", 0))
|
|
254
|
+
if linear_before_reset not in {0, 1}:
|
|
255
|
+
raise UnsupportedOpError("GRU linear_before_reset must be 0 or 1")
|
|
256
|
+
clip = node.attrs.get("clip")
|
|
257
|
+
if clip is not None:
|
|
258
|
+
clip = float(clip)
|
|
259
|
+
if clip < 0:
|
|
260
|
+
raise UnsupportedOpError("GRU clip must be non-negative")
|
|
261
|
+
activation_kinds, activation_alphas, activation_betas = _resolve_activations(
|
|
262
|
+
num_directions, node.attrs
|
|
263
|
+
)
|
|
264
|
+
sequence_lens_dtype = (
|
|
265
|
+
value_dtype(graph, input_sequence_lens, node)
|
|
266
|
+
if input_sequence_lens is not None
|
|
267
|
+
else None
|
|
268
|
+
)
|
|
269
|
+
return GruSpec(
|
|
270
|
+
input_x=input_x,
|
|
271
|
+
input_w=input_w,
|
|
272
|
+
input_r=input_r,
|
|
273
|
+
input_b=input_b,
|
|
274
|
+
input_sequence_lens=input_sequence_lens,
|
|
275
|
+
input_initial_h=input_initial_h,
|
|
276
|
+
output_y=output_y,
|
|
277
|
+
output_y_h=output_y_h,
|
|
278
|
+
seq_length=seq_length,
|
|
279
|
+
batch_size=batch_size,
|
|
280
|
+
input_size=input_size,
|
|
281
|
+
hidden_size=hidden_size,
|
|
282
|
+
num_directions=num_directions,
|
|
283
|
+
direction=direction,
|
|
284
|
+
layout=layout,
|
|
285
|
+
linear_before_reset=linear_before_reset,
|
|
286
|
+
clip=clip,
|
|
287
|
+
activation_kinds=activation_kinds,
|
|
288
|
+
activation_alphas=activation_alphas,
|
|
289
|
+
activation_betas=activation_betas,
|
|
290
|
+
dtype=op_dtype,
|
|
291
|
+
sequence_lens_dtype=sequence_lens_dtype,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@register_lowering("GRU")
|
|
296
|
+
def lower_gru(graph: Graph, node: Node) -> "GruOp":
|
|
297
|
+
from ..ir.ops import GruOp
|
|
298
|
+
|
|
299
|
+
spec = resolve_gru_spec(graph, node)
|
|
300
|
+
return GruOp(
|
|
301
|
+
input_x=spec.input_x,
|
|
302
|
+
input_w=spec.input_w,
|
|
303
|
+
input_r=spec.input_r,
|
|
304
|
+
input_b=spec.input_b,
|
|
305
|
+
input_sequence_lens=spec.input_sequence_lens,
|
|
306
|
+
input_initial_h=spec.input_initial_h,
|
|
307
|
+
output_y=spec.output_y,
|
|
308
|
+
output_y_h=spec.output_y_h,
|
|
309
|
+
seq_length=spec.seq_length,
|
|
310
|
+
batch_size=spec.batch_size,
|
|
311
|
+
input_size=spec.input_size,
|
|
312
|
+
hidden_size=spec.hidden_size,
|
|
313
|
+
num_directions=spec.num_directions,
|
|
314
|
+
direction=spec.direction,
|
|
315
|
+
layout=spec.layout,
|
|
316
|
+
linear_before_reset=spec.linear_before_reset,
|
|
317
|
+
clip=spec.clip,
|
|
318
|
+
activation_kinds=spec.activation_kinds,
|
|
319
|
+
activation_alphas=spec.activation_alphas,
|
|
320
|
+
activation_betas=spec.activation_betas,
|
|
321
|
+
dtype=spec.dtype,
|
|
322
|
+
sequence_lens_dtype=spec.sequence_lens_dtype,
|
|
323
|
+
)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..dtypes import dtype_info
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..ir.ops import HammingWindowOp
|
|
11
|
+
from ..lowering.common import value_dtype, value_shape
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_SUPPORTED_INPUT_DTYPES = {ScalarType.I32, ScalarType.I64}
|
|
16
|
+
_SUPPORTED_OUTPUT_DTYPES = {
|
|
17
|
+
ScalarType.U8,
|
|
18
|
+
ScalarType.U16,
|
|
19
|
+
ScalarType.U32,
|
|
20
|
+
ScalarType.U64,
|
|
21
|
+
ScalarType.I8,
|
|
22
|
+
ScalarType.I16,
|
|
23
|
+
ScalarType.I32,
|
|
24
|
+
ScalarType.I64,
|
|
25
|
+
ScalarType.F16,
|
|
26
|
+
ScalarType.F32,
|
|
27
|
+
ScalarType.F64,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
32
|
+
for initializer in graph.initializers:
|
|
33
|
+
if initializer.name == name:
|
|
34
|
+
return initializer
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _read_scalar_initializer(
|
|
39
|
+
graph: Graph, name: str, node: Node
|
|
40
|
+
) -> int | None:
|
|
41
|
+
initializer = _find_initializer(graph, name)
|
|
42
|
+
if initializer is None:
|
|
43
|
+
return None
|
|
44
|
+
data = np.array(initializer.data)
|
|
45
|
+
if data.size != 1:
|
|
46
|
+
raise UnsupportedOpError(
|
|
47
|
+
f"{node.op_type} size input must be a scalar"
|
|
48
|
+
)
|
|
49
|
+
return int(data.reshape(-1)[0].item())
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
|
|
53
|
+
return shape == () or shape == (1,)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@register_lowering("HammingWindow")
|
|
57
|
+
def lower_hamming_window(graph: Graph, node: Node) -> HammingWindowOp:
|
|
58
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
59
|
+
raise UnsupportedOpError("HammingWindow must have 1 input and 1 output")
|
|
60
|
+
size_shape = value_shape(graph, node.inputs[0], node)
|
|
61
|
+
if not _is_scalar_shape(size_shape):
|
|
62
|
+
raise UnsupportedOpError("HammingWindow size input must be a scalar")
|
|
63
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
64
|
+
if input_dtype not in _SUPPORTED_INPUT_DTYPES:
|
|
65
|
+
raise UnsupportedOpError(
|
|
66
|
+
f"HammingWindow size input must be int32 or int64, got {input_dtype.onnx_name}"
|
|
67
|
+
)
|
|
68
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
69
|
+
if len(output_shape) != 1:
|
|
70
|
+
raise ShapeInferenceError("HammingWindow output must be 1D")
|
|
71
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
72
|
+
if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
|
|
73
|
+
raise UnsupportedOpError(
|
|
74
|
+
"HammingWindow output dtype must be numeric, "
|
|
75
|
+
f"got {output_dtype.onnx_name}"
|
|
76
|
+
)
|
|
77
|
+
output_datatype = node.attrs.get("output_datatype")
|
|
78
|
+
if output_datatype is not None:
|
|
79
|
+
attr_dtype = dtype_info(int(output_datatype))
|
|
80
|
+
if attr_dtype != output_dtype:
|
|
81
|
+
raise UnsupportedOpError(
|
|
82
|
+
"HammingWindow output_datatype does not match output dtype"
|
|
83
|
+
)
|
|
84
|
+
periodic = int(node.attrs.get("periodic", 1))
|
|
85
|
+
if periodic not in {0, 1}:
|
|
86
|
+
raise UnsupportedOpError("HammingWindow periodic must be 0 or 1")
|
|
87
|
+
size_value = _read_scalar_initializer(graph, node.inputs[0], node)
|
|
88
|
+
if size_value is not None:
|
|
89
|
+
if size_value < 0:
|
|
90
|
+
raise ShapeInferenceError(
|
|
91
|
+
"HammingWindow size must be non-negative"
|
|
92
|
+
)
|
|
93
|
+
if output_shape[0] != size_value:
|
|
94
|
+
raise ShapeInferenceError(
|
|
95
|
+
"HammingWindow output length does not match size input"
|
|
96
|
+
)
|
|
97
|
+
return HammingWindowOp(
|
|
98
|
+
size=node.inputs[0],
|
|
99
|
+
output=node.outputs[0],
|
|
100
|
+
output_shape=output_shape,
|
|
101
|
+
periodic=periodic == 1,
|
|
102
|
+
dtype=output_dtype,
|
|
103
|
+
input_dtype=input_dtype,
|
|
104
|
+
)
|
|
@@ -1,53 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from shared.scalar_types import ScalarType
|
|
4
|
-
|
|
5
3
|
from ..ir.ops import HardmaxOp
|
|
6
4
|
from ..errors import UnsupportedOpError
|
|
7
5
|
from ..ir.model import Graph, Node
|
|
8
|
-
from .common import node_dtype as _node_dtype
|
|
9
|
-
from .common import onnx_opset_version as _onnx_opset_version
|
|
10
|
-
from .common import shape_product as _shape_product
|
|
11
|
-
from .common import value_shape as _value_shape
|
|
12
6
|
from .registry import register_lowering
|
|
13
|
-
from ..validation import ensure_output_shape_matches_input
|
|
14
|
-
from ..validation import normalize_axis as _normalize_axis
|
|
15
7
|
|
|
16
8
|
|
|
17
9
|
@register_lowering("Hardmax")
|
|
18
10
|
def lower_hardmax(graph: Graph, node: Node) -> HardmaxOp:
|
|
19
11
|
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
20
12
|
raise UnsupportedOpError("Hardmax must have 1 input and 1 output")
|
|
21
|
-
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
22
|
-
if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
|
|
23
|
-
raise UnsupportedOpError(
|
|
24
|
-
"Hardmax supports float16, float, and double inputs only"
|
|
25
|
-
)
|
|
26
|
-
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
27
|
-
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
28
|
-
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
29
|
-
opset_version = _onnx_opset_version(graph)
|
|
30
|
-
default_axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
31
|
-
axis_attr = node.attrs.get("axis", default_axis)
|
|
32
|
-
axis = _normalize_axis(
|
|
33
|
-
int(axis_attr),
|
|
34
|
-
input_shape,
|
|
35
|
-
node,
|
|
36
|
-
)
|
|
37
|
-
outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
|
|
38
|
-
axis_size = input_shape[axis]
|
|
39
|
-
inner = (
|
|
40
|
-
_shape_product(input_shape[axis + 1 :])
|
|
41
|
-
if axis + 1 < len(input_shape)
|
|
42
|
-
else 1
|
|
43
|
-
)
|
|
44
13
|
return HardmaxOp(
|
|
45
14
|
input0=node.inputs[0],
|
|
46
15
|
output=node.outputs[0],
|
|
47
|
-
|
|
48
|
-
axis_size=axis_size,
|
|
49
|
-
inner=inner,
|
|
50
|
-
axis=axis,
|
|
51
|
-
shape=input_shape,
|
|
52
|
-
dtype=op_dtype,
|
|
16
|
+
axis=int(node.attrs["axis"]) if "axis" in node.attrs else None,
|
|
53
17
|
)
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..ir.ops import IdentityOp
|
|
4
3
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
4
|
+
from ..ir.context import GraphContext
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
|
-
from .
|
|
6
|
+
from ..ir.ops import IdentityOp
|
|
7
|
+
from .common import value_dtype, value_has_dim_params, value_shape
|
|
7
8
|
from .registry import register_lowering
|
|
8
9
|
|
|
9
10
|
|
|
@@ -13,9 +14,10 @@ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
|
|
|
13
14
|
raise UnsupportedOpError("Identity must have 1 input and 1 output")
|
|
14
15
|
input_shape = value_shape(graph, node.inputs[0], node)
|
|
15
16
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
17
|
+
if value_has_dim_params(graph, node.outputs[0]) or not output_shape:
|
|
18
|
+
output_shape = ()
|
|
16
19
|
input_dim_params = graph.find_value(node.inputs[0]).type.dim_params
|
|
17
20
|
output_dim_params = graph.find_value(node.outputs[0]).type.dim_params
|
|
18
|
-
resolved_shape = output_shape or input_shape
|
|
19
21
|
if input_shape and output_shape:
|
|
20
22
|
if len(input_shape) != len(output_shape):
|
|
21
23
|
raise ShapeInferenceError("Identity input and output shapes must match")
|
|
@@ -35,10 +37,9 @@ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
|
|
|
35
37
|
"Identity expects matching input/output dtypes, "
|
|
36
38
|
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
37
39
|
)
|
|
40
|
+
if isinstance(graph, GraphContext):
|
|
41
|
+
graph.set_shape(node.outputs[0], input_shape)
|
|
38
42
|
return IdentityOp(
|
|
39
43
|
input0=node.inputs[0],
|
|
40
44
|
output=node.outputs[0],
|
|
41
|
-
shape=resolved_shape,
|
|
42
|
-
dtype=output_dtype,
|
|
43
|
-
input_dtype=input_dtype,
|
|
44
45
|
)
|
|
@@ -3,49 +3,15 @@ from __future__ import annotations
|
|
|
3
3
|
from ..ir.ops import LogSoftmaxOp
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
|
-
from .common import node_dtype as _node_dtype
|
|
7
|
-
from .common import onnx_opset_version as _onnx_opset_version
|
|
8
|
-
from .common import shape_product as _shape_product
|
|
9
|
-
from .common import value_shape as _value_shape
|
|
10
6
|
from .registry import register_lowering
|
|
11
|
-
from ..validation import ensure_output_shape_matches_input
|
|
12
|
-
from ..validation import normalize_axis as _normalize_axis
|
|
13
7
|
|
|
14
8
|
|
|
15
9
|
@register_lowering("LogSoftmax")
|
|
16
10
|
def lower_logsoftmax(graph: Graph, node: Node) -> LogSoftmaxOp:
|
|
17
11
|
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
18
12
|
raise UnsupportedOpError("LogSoftmax must have 1 input and 1 output")
|
|
19
|
-
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
20
|
-
if not op_dtype.is_float:
|
|
21
|
-
raise UnsupportedOpError(
|
|
22
|
-
"LogSoftmax supports float16, float, and double inputs only"
|
|
23
|
-
)
|
|
24
|
-
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
25
|
-
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
26
|
-
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
27
|
-
opset_version = _onnx_opset_version(graph)
|
|
28
|
-
default_axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
29
|
-
axis_attr = node.attrs.get("axis", default_axis)
|
|
30
|
-
axis = _normalize_axis(
|
|
31
|
-
int(axis_attr),
|
|
32
|
-
input_shape,
|
|
33
|
-
node,
|
|
34
|
-
)
|
|
35
|
-
outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
|
|
36
|
-
axis_size = input_shape[axis]
|
|
37
|
-
inner = (
|
|
38
|
-
_shape_product(input_shape[axis + 1 :])
|
|
39
|
-
if axis + 1 < len(input_shape)
|
|
40
|
-
else 1
|
|
41
|
-
)
|
|
42
13
|
return LogSoftmaxOp(
|
|
43
14
|
input0=node.inputs[0],
|
|
44
15
|
output=node.outputs[0],
|
|
45
|
-
|
|
46
|
-
axis_size=axis_size,
|
|
47
|
-
inner=inner,
|
|
48
|
-
axis=axis,
|
|
49
|
-
shape=input_shape,
|
|
50
|
-
dtype=op_dtype,
|
|
16
|
+
axis=int(node.attrs["axis"]) if "axis" in node.attrs else None,
|
|
51
17
|
)
|
|
@@ -19,6 +19,8 @@ class LpPoolSpec:
|
|
|
19
19
|
out_w: int
|
|
20
20
|
kernel_h: int
|
|
21
21
|
kernel_w: int
|
|
22
|
+
dilation_h: int
|
|
23
|
+
dilation_w: int
|
|
22
24
|
stride_h: int
|
|
23
25
|
stride_w: int
|
|
24
26
|
pad_top: int
|
|
@@ -51,8 +53,10 @@ def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
|
|
|
51
53
|
if ceil_mode != 0:
|
|
52
54
|
raise UnsupportedOpError("LpPool supports ceil_mode=0 only")
|
|
53
55
|
dilations = tuple(int(value) for value in node.attrs.get("dilations", (1, 1)))
|
|
54
|
-
if
|
|
55
|
-
raise UnsupportedOpError("LpPool
|
|
56
|
+
if len(dilations) != 2:
|
|
57
|
+
raise UnsupportedOpError("LpPool expects 2D dilations")
|
|
58
|
+
if any(value < 1 for value in dilations):
|
|
59
|
+
raise UnsupportedOpError("LpPool requires dilations >= 1")
|
|
56
60
|
kernel_shape = node.attrs.get("kernel_shape")
|
|
57
61
|
if kernel_shape is None:
|
|
58
62
|
raise UnsupportedOpError("LpPool requires kernel_shape")
|
|
@@ -75,8 +79,11 @@ def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
|
|
|
75
79
|
raise UnsupportedOpError("LpPool supports NCHW 2D inputs only")
|
|
76
80
|
batch, channels, in_h, in_w = input_shape
|
|
77
81
|
stride_h, stride_w = strides
|
|
78
|
-
|
|
79
|
-
|
|
82
|
+
dilation_h, dilation_w = dilations
|
|
83
|
+
effective_kernel_h = dilation_h * (kernel_h - 1) + 1
|
|
84
|
+
effective_kernel_w = dilation_w * (kernel_w - 1) + 1
|
|
85
|
+
out_h = (in_h + pad_top + pad_bottom - effective_kernel_h) // stride_h + 1
|
|
86
|
+
out_w = (in_w + pad_left + pad_right - effective_kernel_w) // stride_w + 1
|
|
80
87
|
if out_h < 0 or out_w < 0:
|
|
81
88
|
raise ShapeInferenceError("LpPool output shape must be non-negative")
|
|
82
89
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
@@ -95,6 +102,8 @@ def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
|
|
|
95
102
|
out_w=out_w,
|
|
96
103
|
kernel_h=kernel_h,
|
|
97
104
|
kernel_w=kernel_w,
|
|
105
|
+
dilation_h=dilation_h,
|
|
106
|
+
dilation_w=dilation_w,
|
|
98
107
|
stride_h=stride_h,
|
|
99
108
|
stride_w=stride_w,
|
|
100
109
|
pad_top=pad_top,
|
|
@@ -130,6 +139,8 @@ def lower_lp_pool(graph: Graph, node: Node) -> LpPoolOp:
|
|
|
130
139
|
out_w=spec.out_w,
|
|
131
140
|
kernel_h=spec.kernel_h,
|
|
132
141
|
kernel_w=spec.kernel_w,
|
|
142
|
+
dilation_h=spec.dilation_h,
|
|
143
|
+
dilation_w=spec.dilation_w,
|
|
133
144
|
stride_h=spec.stride_h,
|
|
134
145
|
stride_w=spec.stride_w,
|
|
135
146
|
pad_top=spec.pad_top,
|