emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +372 -64
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +169 -343
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +406 -11
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +2 -4
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +6 -2
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +7 -8
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +6 -7
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +224 -52
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +6 -2
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +6 -6
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +134 -0
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +6 -6
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +785 -43
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +31 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
- emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +60 -17
- shared/ulp.py +65 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ from onnx import numpy_helper
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import ConstantOfShapeOp
|
|
8
8
|
from ..dtypes import scalar_type_from_onnx
|
|
9
9
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
10
|
from ..ir.model import Graph, Node
|
emx_onnx_cgen/lowering/conv.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import math
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
|
|
6
|
-
from ..
|
|
6
|
+
from ..ir.ops import ConvOp
|
|
7
7
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
8
|
from ..ir.model import Graph, Node
|
|
9
9
|
from .common import node_dtype as _node_dtype
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from ..ir.ops import ConvTransposeOp
|
|
7
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
+
from ..ir.model import Graph, Node
|
|
9
|
+
from .common import node_dtype as _node_dtype
|
|
10
|
+
from .common import value_shape as _value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class ConvTransposeSpec:
|
|
16
|
+
batch: int
|
|
17
|
+
in_channels: int
|
|
18
|
+
out_channels: int
|
|
19
|
+
spatial_rank: int
|
|
20
|
+
in_spatial: tuple[int, ...]
|
|
21
|
+
out_spatial: tuple[int, ...]
|
|
22
|
+
kernel_shape: tuple[int, ...]
|
|
23
|
+
strides: tuple[int, ...]
|
|
24
|
+
pads: tuple[int, ...]
|
|
25
|
+
dilations: tuple[int, ...]
|
|
26
|
+
output_padding: tuple[int, ...]
|
|
27
|
+
group: int
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _split_padding(
|
|
31
|
+
total_padding: int, auto_pad: str, *, dim: int
|
|
32
|
+
) -> tuple[int, int]:
|
|
33
|
+
if total_padding < 0:
|
|
34
|
+
raise ShapeInferenceError(
|
|
35
|
+
"ConvTranspose output shape must be fully defined and non-negative"
|
|
36
|
+
)
|
|
37
|
+
pad_end = total_padding // 2
|
|
38
|
+
pad_begin = total_padding - pad_end
|
|
39
|
+
if auto_pad == "SAME_UPPER":
|
|
40
|
+
pad_begin, pad_end = pad_end, pad_begin
|
|
41
|
+
elif auto_pad not in {"SAME_LOWER", "NOTSET", ""}:
|
|
42
|
+
raise UnsupportedOpError(
|
|
43
|
+
f"ConvTranspose has unsupported auto_pad mode '{auto_pad}'"
|
|
44
|
+
)
|
|
45
|
+
if pad_begin < 0 or pad_end < 0:
|
|
46
|
+
raise ShapeInferenceError(
|
|
47
|
+
f"ConvTranspose pads must be non-negative for dim {dim}"
|
|
48
|
+
)
|
|
49
|
+
return pad_begin, pad_end
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def resolve_conv_transpose_spec(graph: Graph, node: Node) -> ConvTransposeSpec:
|
|
53
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
54
|
+
raise UnsupportedOpError(
|
|
55
|
+
"ConvTranspose must have 2 or 3 inputs and 1 output"
|
|
56
|
+
)
|
|
57
|
+
supported_attrs = {
|
|
58
|
+
"auto_pad",
|
|
59
|
+
"dilations",
|
|
60
|
+
"group",
|
|
61
|
+
"kernel_shape",
|
|
62
|
+
"output_padding",
|
|
63
|
+
"output_shape",
|
|
64
|
+
"pads",
|
|
65
|
+
"strides",
|
|
66
|
+
}
|
|
67
|
+
if set(node.attrs) - supported_attrs:
|
|
68
|
+
raise UnsupportedOpError("ConvTranspose has unsupported attributes")
|
|
69
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
70
|
+
weight_shape = _value_shape(graph, node.inputs[1], node)
|
|
71
|
+
if len(input_shape) < 3:
|
|
72
|
+
raise UnsupportedOpError("ConvTranspose expects NCHW inputs with spatial dims")
|
|
73
|
+
spatial_rank = len(input_shape) - 2
|
|
74
|
+
if spatial_rank not in {1, 2, 3}:
|
|
75
|
+
raise UnsupportedOpError("ConvTranspose supports 1D/2D/3D inputs only")
|
|
76
|
+
if len(weight_shape) != spatial_rank + 2:
|
|
77
|
+
raise UnsupportedOpError(
|
|
78
|
+
"ConvTranspose weight rank must match spatial rank"
|
|
79
|
+
)
|
|
80
|
+
batch, in_channels = input_shape[0], input_shape[1]
|
|
81
|
+
in_spatial = input_shape[2:]
|
|
82
|
+
weight_in_channels, weight_out_channels, *kernel_shape = weight_shape
|
|
83
|
+
kernel_attr = node.attrs.get("kernel_shape")
|
|
84
|
+
if kernel_attr is not None:
|
|
85
|
+
kernel_attr = tuple(int(value) for value in kernel_attr)
|
|
86
|
+
if len(kernel_attr) != spatial_rank:
|
|
87
|
+
raise UnsupportedOpError(
|
|
88
|
+
"ConvTranspose kernel_shape rank must match input spatial rank"
|
|
89
|
+
)
|
|
90
|
+
if kernel_attr != tuple(kernel_shape):
|
|
91
|
+
raise ShapeInferenceError(
|
|
92
|
+
"ConvTranspose kernel_shape must match weights, "
|
|
93
|
+
f"got {kernel_attr} and {tuple(kernel_shape)}"
|
|
94
|
+
)
|
|
95
|
+
kernel_shape = list(kernel_attr)
|
|
96
|
+
else:
|
|
97
|
+
kernel_shape = list(kernel_shape)
|
|
98
|
+
group = int(node.attrs.get("group", 1))
|
|
99
|
+
if group <= 0:
|
|
100
|
+
raise UnsupportedOpError("ConvTranspose expects group >= 1")
|
|
101
|
+
if in_channels % group != 0:
|
|
102
|
+
raise ShapeInferenceError(
|
|
103
|
+
"ConvTranspose expects group to evenly divide in channels, "
|
|
104
|
+
f"got group={group}, in_channels={in_channels}"
|
|
105
|
+
)
|
|
106
|
+
if weight_in_channels != in_channels:
|
|
107
|
+
raise ShapeInferenceError(
|
|
108
|
+
"ConvTranspose input channels must match weight channels, "
|
|
109
|
+
f"got {in_channels} and {weight_in_channels}"
|
|
110
|
+
)
|
|
111
|
+
out_channels = weight_out_channels * group
|
|
112
|
+
if out_channels % group != 0:
|
|
113
|
+
raise ShapeInferenceError(
|
|
114
|
+
"ConvTranspose expects group to evenly divide out channels, "
|
|
115
|
+
f"got group={group}, out_channels={out_channels}"
|
|
116
|
+
)
|
|
117
|
+
if len(node.inputs) == 3:
|
|
118
|
+
bias_shape = _value_shape(graph, node.inputs[2], node)
|
|
119
|
+
if bias_shape != (out_channels,):
|
|
120
|
+
raise ShapeInferenceError(
|
|
121
|
+
f"ConvTranspose bias shape must be {(out_channels,)}, got {bias_shape}"
|
|
122
|
+
)
|
|
123
|
+
strides = tuple(
|
|
124
|
+
int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
|
|
125
|
+
)
|
|
126
|
+
if len(strides) != spatial_rank:
|
|
127
|
+
raise UnsupportedOpError("ConvTranspose stride rank mismatch")
|
|
128
|
+
dilations = tuple(
|
|
129
|
+
int(value) for value in node.attrs.get("dilations", (1,) * spatial_rank)
|
|
130
|
+
)
|
|
131
|
+
if len(dilations) != spatial_rank:
|
|
132
|
+
raise UnsupportedOpError("ConvTranspose dilation rank mismatch")
|
|
133
|
+
output_padding = tuple(
|
|
134
|
+
int(value)
|
|
135
|
+
for value in node.attrs.get("output_padding", (0,) * spatial_rank)
|
|
136
|
+
)
|
|
137
|
+
if len(output_padding) != spatial_rank:
|
|
138
|
+
raise UnsupportedOpError("ConvTranspose output_padding rank mismatch")
|
|
139
|
+
for dim, (padding, stride) in enumerate(zip(output_padding, strides)):
|
|
140
|
+
if padding < 0:
|
|
141
|
+
raise UnsupportedOpError(
|
|
142
|
+
"ConvTranspose output_padding must be non-negative"
|
|
143
|
+
)
|
|
144
|
+
if padding >= stride:
|
|
145
|
+
raise UnsupportedOpError(
|
|
146
|
+
"ConvTranspose output_padding must be smaller than stride"
|
|
147
|
+
)
|
|
148
|
+
pads = tuple(
|
|
149
|
+
int(value)
|
|
150
|
+
for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
|
|
151
|
+
)
|
|
152
|
+
if len(pads) != 2 * spatial_rank:
|
|
153
|
+
raise UnsupportedOpError("ConvTranspose pads rank mismatch")
|
|
154
|
+
auto_pad = node.attrs.get("auto_pad", b"NOTSET")
|
|
155
|
+
if isinstance(auto_pad, bytes):
|
|
156
|
+
auto_pad = auto_pad.decode("utf-8", errors="ignore")
|
|
157
|
+
if auto_pad == "":
|
|
158
|
+
auto_pad = "NOTSET"
|
|
159
|
+
output_shape_attr = node.attrs.get("output_shape")
|
|
160
|
+
output_shape: list[int] | None = None
|
|
161
|
+
if output_shape_attr is not None:
|
|
162
|
+
output_shape = [int(value) for value in output_shape_attr]
|
|
163
|
+
if len(output_shape) != spatial_rank:
|
|
164
|
+
raise UnsupportedOpError("ConvTranspose output_shape rank mismatch")
|
|
165
|
+
if output_shape is not None:
|
|
166
|
+
if auto_pad == "VALID":
|
|
167
|
+
auto_pad = "NOTSET"
|
|
168
|
+
pad_begin = []
|
|
169
|
+
pad_end = []
|
|
170
|
+
for dim, (in_dim, stride, dilation, kernel, out_dim, out_pad) in enumerate(
|
|
171
|
+
zip(
|
|
172
|
+
in_spatial,
|
|
173
|
+
strides,
|
|
174
|
+
dilations,
|
|
175
|
+
kernel_shape,
|
|
176
|
+
output_shape,
|
|
177
|
+
output_padding,
|
|
178
|
+
)
|
|
179
|
+
):
|
|
180
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
181
|
+
total_padding = (
|
|
182
|
+
stride * (in_dim - 1)
|
|
183
|
+
+ out_pad
|
|
184
|
+
+ effective_kernel
|
|
185
|
+
- out_dim
|
|
186
|
+
)
|
|
187
|
+
pad_start, pad_finish = _split_padding(
|
|
188
|
+
total_padding, auto_pad, dim=dim
|
|
189
|
+
)
|
|
190
|
+
pad_begin.append(pad_start)
|
|
191
|
+
pad_end.append(pad_finish)
|
|
192
|
+
out_spatial = output_shape
|
|
193
|
+
else:
|
|
194
|
+
if auto_pad == "VALID":
|
|
195
|
+
pad_begin = [0] * spatial_rank
|
|
196
|
+
pad_end = [0] * spatial_rank
|
|
197
|
+
elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
|
|
198
|
+
pad_begin = []
|
|
199
|
+
pad_end = []
|
|
200
|
+
for dim, (in_dim, stride, dilation, kernel, out_pad) in enumerate(
|
|
201
|
+
zip(in_spatial, strides, dilations, kernel_shape, output_padding)
|
|
202
|
+
):
|
|
203
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
204
|
+
out_dim = in_dim * stride
|
|
205
|
+
total_padding = (
|
|
206
|
+
stride * (in_dim - 1)
|
|
207
|
+
+ out_pad
|
|
208
|
+
+ effective_kernel
|
|
209
|
+
- out_dim
|
|
210
|
+
)
|
|
211
|
+
pad_start, pad_finish = _split_padding(
|
|
212
|
+
total_padding, auto_pad, dim=dim
|
|
213
|
+
)
|
|
214
|
+
pad_begin.append(pad_start)
|
|
215
|
+
pad_end.append(pad_finish)
|
|
216
|
+
elif auto_pad in {"NOTSET"}:
|
|
217
|
+
pad_begin = list(pads[:spatial_rank])
|
|
218
|
+
pad_end = list(pads[spatial_rank:])
|
|
219
|
+
else:
|
|
220
|
+
raise UnsupportedOpError(
|
|
221
|
+
f"ConvTranspose has unsupported auto_pad mode '{auto_pad}'"
|
|
222
|
+
)
|
|
223
|
+
out_spatial = []
|
|
224
|
+
for dim, (in_dim, stride, dilation, kernel, pad_start, pad_finish, out_pad) in enumerate(
|
|
225
|
+
zip(
|
|
226
|
+
in_spatial,
|
|
227
|
+
strides,
|
|
228
|
+
dilations,
|
|
229
|
+
kernel_shape,
|
|
230
|
+
pad_begin,
|
|
231
|
+
pad_end,
|
|
232
|
+
output_padding,
|
|
233
|
+
)
|
|
234
|
+
):
|
|
235
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
236
|
+
out_dim = (
|
|
237
|
+
stride * (in_dim - 1)
|
|
238
|
+
+ out_pad
|
|
239
|
+
+ effective_kernel
|
|
240
|
+
- pad_start
|
|
241
|
+
- pad_finish
|
|
242
|
+
)
|
|
243
|
+
if out_dim < 0:
|
|
244
|
+
raise ShapeInferenceError(
|
|
245
|
+
"ConvTranspose output shape must be non-negative"
|
|
246
|
+
)
|
|
247
|
+
out_spatial.append(out_dim)
|
|
248
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
249
|
+
expected_output_shape = (batch, out_channels, *out_spatial)
|
|
250
|
+
if output_shape != expected_output_shape:
|
|
251
|
+
raise ShapeInferenceError(
|
|
252
|
+
"ConvTranspose output shape must be "
|
|
253
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
254
|
+
)
|
|
255
|
+
return ConvTransposeSpec(
|
|
256
|
+
batch=batch,
|
|
257
|
+
in_channels=in_channels,
|
|
258
|
+
out_channels=out_channels,
|
|
259
|
+
spatial_rank=spatial_rank,
|
|
260
|
+
in_spatial=in_spatial,
|
|
261
|
+
out_spatial=tuple(out_spatial),
|
|
262
|
+
kernel_shape=tuple(kernel_shape),
|
|
263
|
+
strides=strides,
|
|
264
|
+
pads=(*pad_begin, *pad_end),
|
|
265
|
+
dilations=dilations,
|
|
266
|
+
output_padding=output_padding,
|
|
267
|
+
group=group,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@register_lowering("ConvTranspose")
|
|
272
|
+
def lower_conv_transpose(graph: Graph, node: Node) -> ConvTransposeOp:
|
|
273
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
274
|
+
raise UnsupportedOpError(
|
|
275
|
+
"ConvTranspose must have 2 or 3 inputs and 1 output"
|
|
276
|
+
)
|
|
277
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
278
|
+
if not op_dtype.is_float:
|
|
279
|
+
raise UnsupportedOpError(
|
|
280
|
+
"ConvTranspose supports float16, float, and double inputs only"
|
|
281
|
+
)
|
|
282
|
+
spec = resolve_conv_transpose_spec(graph, node)
|
|
283
|
+
return ConvTransposeOp(
|
|
284
|
+
input0=node.inputs[0],
|
|
285
|
+
weights=node.inputs[1],
|
|
286
|
+
bias=node.inputs[2] if len(node.inputs) == 3 else None,
|
|
287
|
+
output=node.outputs[0],
|
|
288
|
+
batch=spec.batch,
|
|
289
|
+
in_channels=spec.in_channels,
|
|
290
|
+
out_channels=spec.out_channels,
|
|
291
|
+
spatial_rank=spec.spatial_rank,
|
|
292
|
+
in_spatial=spec.in_spatial,
|
|
293
|
+
out_spatial=spec.out_spatial,
|
|
294
|
+
kernel_shape=spec.kernel_shape,
|
|
295
|
+
strides=spec.strides,
|
|
296
|
+
pads=spec.pads,
|
|
297
|
+
dilations=spec.dilations,
|
|
298
|
+
output_padding=spec.output_padding,
|
|
299
|
+
group=spec.group,
|
|
300
|
+
dtype=op_dtype,
|
|
301
|
+
)
|
emx_onnx_cgen/lowering/cumsum.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import CumSumOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import value_dtype, value_shape
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import DepthToSpaceOp, SpaceToDepthOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..lowering.common import value_dtype, value_shape
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..ir.ops import EinsumKind, EinsumOp
|
|
4
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from .common import node_dtype as _node_dtype
|
|
7
|
+
from .common import value_shape as _value_shape
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _normalize_equation(equation: str) -> str:
|
|
12
|
+
return equation.replace(" ", "")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_lowering("Einsum")
|
|
16
|
+
def lower_einsum(graph: Graph, node: Node) -> EinsumOp:
|
|
17
|
+
if not node.inputs or len(node.outputs) != 1:
|
|
18
|
+
raise UnsupportedOpError("Einsum must have 1 output and at least 1 input")
|
|
19
|
+
equation_value = node.attrs.get("equation")
|
|
20
|
+
if equation_value is None:
|
|
21
|
+
raise UnsupportedOpError("Einsum equation attribute is required")
|
|
22
|
+
equation = (
|
|
23
|
+
equation_value.decode()
|
|
24
|
+
if isinstance(equation_value, (bytes, bytearray))
|
|
25
|
+
else str(equation_value)
|
|
26
|
+
)
|
|
27
|
+
normalized = _normalize_equation(equation)
|
|
28
|
+
input_shapes = tuple(
|
|
29
|
+
_value_shape(graph, name, node) for name in node.inputs
|
|
30
|
+
)
|
|
31
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
32
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
33
|
+
if normalized == "->":
|
|
34
|
+
if len(node.inputs) != 1:
|
|
35
|
+
raise UnsupportedOpError("Einsum '->' must have 1 input")
|
|
36
|
+
if output_shape:
|
|
37
|
+
raise ShapeInferenceError(
|
|
38
|
+
"Einsum '->' output must be scalar, "
|
|
39
|
+
f"got shape {output_shape}"
|
|
40
|
+
)
|
|
41
|
+
kind = EinsumKind.REDUCE_ALL
|
|
42
|
+
elif normalized == "ij->i":
|
|
43
|
+
if len(node.inputs) != 1:
|
|
44
|
+
raise UnsupportedOpError("Einsum 'ij->i' must have 1 input")
|
|
45
|
+
input_shape = input_shapes[0]
|
|
46
|
+
if len(input_shape) != 2:
|
|
47
|
+
raise ShapeInferenceError(
|
|
48
|
+
"Einsum 'ij->i' input must be 2D, "
|
|
49
|
+
f"got shape {input_shape}"
|
|
50
|
+
)
|
|
51
|
+
expected = (input_shape[0],)
|
|
52
|
+
if output_shape != expected:
|
|
53
|
+
raise ShapeInferenceError(
|
|
54
|
+
f"Einsum 'ij->i' output must match shape {expected}, "
|
|
55
|
+
f"got {output_shape}"
|
|
56
|
+
)
|
|
57
|
+
kind = EinsumKind.SUM_J
|
|
58
|
+
elif normalized == "ij->ji":
|
|
59
|
+
if len(node.inputs) != 1:
|
|
60
|
+
raise UnsupportedOpError("Einsum 'ij->ji' must have 1 input")
|
|
61
|
+
input_shape = input_shapes[0]
|
|
62
|
+
if len(input_shape) != 2:
|
|
63
|
+
raise ShapeInferenceError(
|
|
64
|
+
"Einsum 'ij->ji' input must be 2D, "
|
|
65
|
+
f"got shape {input_shape}"
|
|
66
|
+
)
|
|
67
|
+
expected = (input_shape[1], input_shape[0])
|
|
68
|
+
if output_shape != expected:
|
|
69
|
+
raise ShapeInferenceError(
|
|
70
|
+
f"Einsum 'ij->ji' output must match shape {expected}, "
|
|
71
|
+
f"got {output_shape}"
|
|
72
|
+
)
|
|
73
|
+
kind = EinsumKind.TRANSPOSE
|
|
74
|
+
elif normalized in {"i,i", "i,i->"}:
|
|
75
|
+
if len(node.inputs) != 2:
|
|
76
|
+
raise UnsupportedOpError("Einsum 'i,i' must have 2 inputs")
|
|
77
|
+
left_shape, right_shape = input_shapes
|
|
78
|
+
if len(left_shape) != 1 or len(right_shape) != 1:
|
|
79
|
+
raise ShapeInferenceError(
|
|
80
|
+
"Einsum 'i,i' inputs must be vectors, "
|
|
81
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
82
|
+
)
|
|
83
|
+
if left_shape[0] != right_shape[0]:
|
|
84
|
+
raise ShapeInferenceError(
|
|
85
|
+
"Einsum 'i,i' inputs must have the same length, "
|
|
86
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
87
|
+
)
|
|
88
|
+
if output_shape:
|
|
89
|
+
raise ShapeInferenceError(
|
|
90
|
+
"Einsum 'i,i' output must be scalar, "
|
|
91
|
+
f"got shape {output_shape}"
|
|
92
|
+
)
|
|
93
|
+
kind = EinsumKind.DOT
|
|
94
|
+
elif normalized == "bij,bjk->bik":
|
|
95
|
+
if len(node.inputs) != 2:
|
|
96
|
+
raise UnsupportedOpError("Einsum 'bij,bjk->bik' must have 2 inputs")
|
|
97
|
+
left_shape, right_shape = input_shapes
|
|
98
|
+
if len(left_shape) != 3 or len(right_shape) != 3:
|
|
99
|
+
raise ShapeInferenceError(
|
|
100
|
+
"Einsum 'bij,bjk->bik' inputs must be 3D, "
|
|
101
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
102
|
+
)
|
|
103
|
+
if left_shape[0] != right_shape[0]:
|
|
104
|
+
raise ShapeInferenceError(
|
|
105
|
+
"Einsum 'bij,bjk->bik' batch dimensions must match, "
|
|
106
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
107
|
+
)
|
|
108
|
+
if left_shape[2] != right_shape[1]:
|
|
109
|
+
raise ShapeInferenceError(
|
|
110
|
+
"Einsum 'bij,bjk->bik' contraction dimensions must match, "
|
|
111
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
112
|
+
)
|
|
113
|
+
expected = (left_shape[0], left_shape[1], right_shape[2])
|
|
114
|
+
if output_shape != expected:
|
|
115
|
+
raise ShapeInferenceError(
|
|
116
|
+
f"Einsum 'bij,bjk->bik' output must match shape {expected}, "
|
|
117
|
+
f"got {output_shape}"
|
|
118
|
+
)
|
|
119
|
+
kind = EinsumKind.BATCH_MATMUL
|
|
120
|
+
elif normalized == "...ii->...i":
|
|
121
|
+
if len(node.inputs) != 1:
|
|
122
|
+
raise UnsupportedOpError("Einsum '...ii->...i' must have 1 input")
|
|
123
|
+
input_shape = input_shapes[0]
|
|
124
|
+
if len(input_shape) < 2:
|
|
125
|
+
raise ShapeInferenceError(
|
|
126
|
+
"Einsum '...ii->...i' input must be at least 2D, "
|
|
127
|
+
f"got shape {input_shape}"
|
|
128
|
+
)
|
|
129
|
+
if input_shape[-1] != input_shape[-2]:
|
|
130
|
+
raise ShapeInferenceError(
|
|
131
|
+
"Einsum '...ii->...i' requires last two dims to match, "
|
|
132
|
+
f"got shape {input_shape}"
|
|
133
|
+
)
|
|
134
|
+
expected = (*input_shape[:-2], input_shape[-1])
|
|
135
|
+
if output_shape != expected:
|
|
136
|
+
raise ShapeInferenceError(
|
|
137
|
+
f"Einsum '...ii->...i' output must match shape {expected}, "
|
|
138
|
+
f"got {output_shape}"
|
|
139
|
+
)
|
|
140
|
+
kind = EinsumKind.BATCH_DIAGONAL
|
|
141
|
+
else:
|
|
142
|
+
raise UnsupportedOpError(
|
|
143
|
+
f"Unsupported Einsum equation '{equation}'"
|
|
144
|
+
)
|
|
145
|
+
return EinsumOp(
|
|
146
|
+
inputs=tuple(node.inputs),
|
|
147
|
+
output=node.outputs[0],
|
|
148
|
+
kind=kind,
|
|
149
|
+
input_shapes=input_shapes,
|
|
150
|
+
output_shape=output_shape,
|
|
151
|
+
dtype=op_dtype,
|
|
152
|
+
input_dtype=op_dtype,
|
|
153
|
+
)
|
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from shared.scalar_functions import ScalarFunction
|
|
3
|
+
from shared.scalar_functions import ScalarFunction, ScalarFunctionError
|
|
4
4
|
from shared.scalar_types import ScalarType
|
|
5
5
|
|
|
6
|
-
from ..
|
|
6
|
+
from ..ir.ops import BinaryOp, ClipOp, UnaryOp
|
|
7
7
|
from ..errors import UnsupportedOpError
|
|
8
|
+
from ..ir.context import GraphContext
|
|
8
9
|
from ..ir.model import Graph, Node
|
|
9
10
|
from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
|
|
10
|
-
from ..lowering.registry import register_lowering
|
|
11
|
+
from ..lowering.registry import register_lowering, register_lowering_if_missing
|
|
12
|
+
from ..ops import (
|
|
13
|
+
BINARY_OP_TYPES,
|
|
14
|
+
COMPARE_FUNCTIONS,
|
|
15
|
+
UNARY_OP_TYPES,
|
|
16
|
+
binary_op_symbol,
|
|
17
|
+
unary_op_symbol,
|
|
18
|
+
validate_unary_attrs,
|
|
19
|
+
)
|
|
20
|
+
from ..lowering.variadic import VARIADIC_OP_FUNCTIONS
|
|
11
21
|
|
|
12
22
|
|
|
13
23
|
@register_lowering("Clip")
|
|
@@ -120,6 +130,138 @@ def lower_shrink(graph: Graph, node: Node) -> UnaryOp:
|
|
|
120
130
|
)
|
|
121
131
|
|
|
122
132
|
|
|
133
|
+
def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | UnaryOp:
|
|
134
|
+
if node.op_type == "BitShift":
|
|
135
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
136
|
+
raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
|
|
137
|
+
direction_attr = node.attrs.get("direction", "LEFT")
|
|
138
|
+
if isinstance(direction_attr, bytes):
|
|
139
|
+
direction = direction_attr.decode()
|
|
140
|
+
else:
|
|
141
|
+
direction = str(direction_attr)
|
|
142
|
+
if direction not in {"LEFT", "RIGHT"}:
|
|
143
|
+
raise UnsupportedOpError(
|
|
144
|
+
"BitShift direction must be LEFT or RIGHT"
|
|
145
|
+
)
|
|
146
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
147
|
+
if not op_dtype.is_integer:
|
|
148
|
+
raise UnsupportedOpError("BitShift expects integer inputs")
|
|
149
|
+
function = (
|
|
150
|
+
ScalarFunction.BITWISE_LEFT_SHIFT
|
|
151
|
+
if direction == "LEFT"
|
|
152
|
+
else ScalarFunction.BITWISE_RIGHT_SHIFT
|
|
153
|
+
)
|
|
154
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
155
|
+
if op_spec is None:
|
|
156
|
+
raise UnsupportedOpError("Unsupported op BitShift")
|
|
157
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
158
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
159
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
160
|
+
return BinaryOp(
|
|
161
|
+
input0=node.inputs[0],
|
|
162
|
+
input1=node.inputs[1],
|
|
163
|
+
output=node.outputs[0],
|
|
164
|
+
function=function,
|
|
165
|
+
operator_kind=op_spec.kind,
|
|
166
|
+
input0_shape=input0_shape,
|
|
167
|
+
input1_shape=input1_shape,
|
|
168
|
+
shape=output_shape,
|
|
169
|
+
dtype=op_dtype,
|
|
170
|
+
input_dtype=op_dtype,
|
|
171
|
+
)
|
|
172
|
+
if node.op_type == "Mod":
|
|
173
|
+
fmod = int(node.attrs.get("fmod", 0))
|
|
174
|
+
if fmod not in {0, 1}:
|
|
175
|
+
raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
|
|
176
|
+
function = (
|
|
177
|
+
ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
try:
|
|
181
|
+
function = ScalarFunction.from_onnx_op(node.op_type)
|
|
182
|
+
except ScalarFunctionError as exc:
|
|
183
|
+
raise UnsupportedOpError(
|
|
184
|
+
f"Unsupported op {node.op_type}"
|
|
185
|
+
) from exc
|
|
186
|
+
validate_unary_attrs(node.op_type, node.attrs)
|
|
187
|
+
if function in COMPARE_FUNCTIONS:
|
|
188
|
+
input_dtype = node_dtype(graph, node, *node.inputs)
|
|
189
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
190
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
|
|
191
|
+
if op_spec is None:
|
|
192
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
193
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
194
|
+
raise UnsupportedOpError(
|
|
195
|
+
f"{node.op_type} must have 2 inputs and 1 output"
|
|
196
|
+
)
|
|
197
|
+
if output_dtype != ScalarType.BOOL:
|
|
198
|
+
raise UnsupportedOpError(
|
|
199
|
+
f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
|
|
200
|
+
)
|
|
201
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
202
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
203
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
204
|
+
return BinaryOp(
|
|
205
|
+
input0=node.inputs[0],
|
|
206
|
+
input1=node.inputs[1],
|
|
207
|
+
output=node.outputs[0],
|
|
208
|
+
function=function,
|
|
209
|
+
operator_kind=op_spec.kind,
|
|
210
|
+
input0_shape=input0_shape,
|
|
211
|
+
input1_shape=input1_shape,
|
|
212
|
+
shape=output_shape,
|
|
213
|
+
dtype=output_dtype,
|
|
214
|
+
input_dtype=input_dtype,
|
|
215
|
+
)
|
|
216
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
217
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
218
|
+
unary_symbol = unary_op_symbol(function, dtype=op_dtype)
|
|
219
|
+
if op_spec is None and unary_symbol is None:
|
|
220
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
221
|
+
if op_spec is not None:
|
|
222
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
223
|
+
raise UnsupportedOpError(
|
|
224
|
+
f"{node.op_type} must have 2 inputs and 1 output"
|
|
225
|
+
)
|
|
226
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
227
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
228
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
229
|
+
return BinaryOp(
|
|
230
|
+
input0=node.inputs[0],
|
|
231
|
+
input1=node.inputs[1],
|
|
232
|
+
output=node.outputs[0],
|
|
233
|
+
function=function,
|
|
234
|
+
operator_kind=op_spec.kind,
|
|
235
|
+
input0_shape=input0_shape,
|
|
236
|
+
input1_shape=input1_shape,
|
|
237
|
+
shape=output_shape,
|
|
238
|
+
dtype=op_dtype,
|
|
239
|
+
input_dtype=op_dtype,
|
|
240
|
+
)
|
|
241
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
242
|
+
raise UnsupportedOpError(
|
|
243
|
+
f"{node.op_type} must have 1 input and 1 output"
|
|
244
|
+
)
|
|
245
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
246
|
+
return UnaryOp(
|
|
247
|
+
input0=node.inputs[0],
|
|
248
|
+
output=node.outputs[0],
|
|
249
|
+
function=function,
|
|
250
|
+
shape=output_shape,
|
|
251
|
+
dtype=op_dtype,
|
|
252
|
+
input_dtype=op_dtype,
|
|
253
|
+
params=(),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
_DEFAULT_ELEMENTWISE_TYPES = (
|
|
258
|
+
BINARY_OP_TYPES.union(UNARY_OP_TYPES) - set(VARIADIC_OP_FUNCTIONS.keys())
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
for _op_type in _DEFAULT_ELEMENTWISE_TYPES:
|
|
262
|
+
register_lowering_if_missing(_op_type)(_lower_binary_unary)
|
|
263
|
+
|
|
264
|
+
|
|
123
265
|
@register_lowering("IsInf")
|
|
124
266
|
def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
|
|
125
267
|
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
@@ -130,6 +272,12 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
|
|
|
130
272
|
raise UnsupportedOpError("IsInf only supports floating-point inputs")
|
|
131
273
|
if output_dtype != ScalarType.BOOL:
|
|
132
274
|
raise UnsupportedOpError("IsInf output must be bool")
|
|
275
|
+
detect_negative = int(node.attrs.get("detect_negative", 1))
|
|
276
|
+
detect_positive = int(node.attrs.get("detect_positive", 1))
|
|
277
|
+
if detect_negative not in {0, 1} or detect_positive not in {0, 1}:
|
|
278
|
+
raise UnsupportedOpError(
|
|
279
|
+
"IsInf detect_negative and detect_positive must be 0 or 1"
|
|
280
|
+
)
|
|
133
281
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
134
282
|
return UnaryOp(
|
|
135
283
|
input0=node.inputs[0],
|
|
@@ -138,7 +286,7 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
|
|
|
138
286
|
shape=output_shape,
|
|
139
287
|
dtype=output_dtype,
|
|
140
288
|
input_dtype=input_dtype,
|
|
141
|
-
params=(),
|
|
289
|
+
params=(float(detect_negative), float(detect_positive)),
|
|
142
290
|
)
|
|
143
291
|
|
|
144
292
|
|