emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from ..codegen.c_emitter import ConvOp
|
|
7
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
+
from ..ir.model import Graph, Node
|
|
9
|
+
from .common import node_dtype as _node_dtype
|
|
10
|
+
from .common import value_shape as _value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class ConvSpec:
|
|
16
|
+
batch: int
|
|
17
|
+
in_channels: int
|
|
18
|
+
out_channels: int
|
|
19
|
+
spatial_rank: int
|
|
20
|
+
in_spatial: tuple[int, ...]
|
|
21
|
+
out_spatial: tuple[int, ...]
|
|
22
|
+
kernel_shape: tuple[int, ...]
|
|
23
|
+
strides: tuple[int, ...]
|
|
24
|
+
pads: tuple[int, ...]
|
|
25
|
+
dilations: tuple[int, ...]
|
|
26
|
+
group: int
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def resolve_conv_spec(graph: Graph, node: Node) -> ConvSpec:
|
|
30
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
31
|
+
raise UnsupportedOpError("Conv must have 2 or 3 inputs and 1 output")
|
|
32
|
+
supported_attrs = {
|
|
33
|
+
"auto_pad",
|
|
34
|
+
"dilations",
|
|
35
|
+
"group",
|
|
36
|
+
"kernel_shape",
|
|
37
|
+
"pads",
|
|
38
|
+
"strides",
|
|
39
|
+
}
|
|
40
|
+
if set(node.attrs) - supported_attrs:
|
|
41
|
+
raise UnsupportedOpError("Conv has unsupported attributes")
|
|
42
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
43
|
+
weight_shape = _value_shape(graph, node.inputs[1], node)
|
|
44
|
+
if len(input_shape) < 3:
|
|
45
|
+
raise UnsupportedOpError("Conv expects NCHW inputs with spatial dims")
|
|
46
|
+
spatial_rank = len(input_shape) - 2
|
|
47
|
+
if spatial_rank not in {1, 2, 3}:
|
|
48
|
+
raise UnsupportedOpError("Conv supports 1D/2D/3D inputs only")
|
|
49
|
+
if len(weight_shape) != spatial_rank + 2:
|
|
50
|
+
raise UnsupportedOpError("Conv weight rank must match spatial rank")
|
|
51
|
+
batch, in_channels = input_shape[0], input_shape[1]
|
|
52
|
+
in_spatial = input_shape[2:]
|
|
53
|
+
out_channels, weight_in_channels, *kernel_shape = weight_shape
|
|
54
|
+
kernel_shape = node.attrs.get("kernel_shape")
|
|
55
|
+
if kernel_shape is not None:
|
|
56
|
+
kernel_shape = tuple(int(value) for value in kernel_shape)
|
|
57
|
+
if len(kernel_shape) != spatial_rank:
|
|
58
|
+
raise UnsupportedOpError(
|
|
59
|
+
"Conv kernel_shape rank must match input spatial rank"
|
|
60
|
+
)
|
|
61
|
+
if kernel_shape != tuple(weight_shape[2:]):
|
|
62
|
+
raise ShapeInferenceError(
|
|
63
|
+
"Conv kernel_shape must match weights, "
|
|
64
|
+
f"got {kernel_shape} and {tuple(weight_shape[2:])}"
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
kernel_shape = tuple(weight_shape[2:])
|
|
68
|
+
group = int(node.attrs.get("group", 1))
|
|
69
|
+
if group <= 0:
|
|
70
|
+
raise UnsupportedOpError("Conv expects group >= 1")
|
|
71
|
+
if in_channels % group != 0 or out_channels % group != 0:
|
|
72
|
+
raise ShapeInferenceError(
|
|
73
|
+
"Conv expects group to evenly divide in/out channels, "
|
|
74
|
+
f"got group={group}, in_channels={in_channels}, "
|
|
75
|
+
f"out_channels={out_channels}"
|
|
76
|
+
)
|
|
77
|
+
if weight_in_channels != in_channels // group:
|
|
78
|
+
raise ShapeInferenceError(
|
|
79
|
+
"Conv input channels must match weight channels, "
|
|
80
|
+
f"got {in_channels} and {weight_in_channels * group}"
|
|
81
|
+
)
|
|
82
|
+
if len(node.inputs) == 3:
|
|
83
|
+
bias_shape = _value_shape(graph, node.inputs[2], node)
|
|
84
|
+
if bias_shape != (out_channels,):
|
|
85
|
+
raise ShapeInferenceError(
|
|
86
|
+
f"Conv bias shape must be {(out_channels,)}, got {bias_shape}"
|
|
87
|
+
)
|
|
88
|
+
strides = tuple(
|
|
89
|
+
int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
|
|
90
|
+
)
|
|
91
|
+
if len(strides) != spatial_rank:
|
|
92
|
+
raise UnsupportedOpError("Conv stride rank mismatch")
|
|
93
|
+
dilations = tuple(
|
|
94
|
+
int(value) for value in node.attrs.get("dilations", (1,) * spatial_rank)
|
|
95
|
+
)
|
|
96
|
+
if len(dilations) != spatial_rank:
|
|
97
|
+
raise UnsupportedOpError("Conv dilation rank mismatch")
|
|
98
|
+
pads = tuple(
|
|
99
|
+
int(value)
|
|
100
|
+
for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
|
|
101
|
+
)
|
|
102
|
+
if len(pads) != 2 * spatial_rank:
|
|
103
|
+
raise UnsupportedOpError("Conv pads rank mismatch")
|
|
104
|
+
auto_pad = node.attrs.get("auto_pad", b"NOTSET")
|
|
105
|
+
if isinstance(auto_pad, bytes):
|
|
106
|
+
auto_pad = auto_pad.decode("utf-8", errors="ignore")
|
|
107
|
+
if auto_pad in ("", "NOTSET"):
|
|
108
|
+
pad_begin = pads[:spatial_rank]
|
|
109
|
+
pad_end = pads[spatial_rank:]
|
|
110
|
+
elif auto_pad == "VALID":
|
|
111
|
+
pad_begin = (0,) * spatial_rank
|
|
112
|
+
pad_end = (0,) * spatial_rank
|
|
113
|
+
elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
|
|
114
|
+
pad_begin = []
|
|
115
|
+
pad_end = []
|
|
116
|
+
for dim, stride, dilation, kernel in zip(
|
|
117
|
+
in_spatial, strides, dilations, kernel_shape
|
|
118
|
+
):
|
|
119
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
120
|
+
out_dim = math.ceil(dim / stride)
|
|
121
|
+
pad_needed = max(
|
|
122
|
+
0, (out_dim - 1) * stride + effective_kernel - dim
|
|
123
|
+
)
|
|
124
|
+
if auto_pad == "SAME_UPPER":
|
|
125
|
+
pad_start = pad_needed // 2
|
|
126
|
+
else:
|
|
127
|
+
pad_start = (pad_needed + 1) // 2
|
|
128
|
+
pad_begin.append(pad_start)
|
|
129
|
+
pad_end.append(pad_needed - pad_start)
|
|
130
|
+
pad_begin = tuple(pad_begin)
|
|
131
|
+
pad_end = tuple(pad_end)
|
|
132
|
+
else:
|
|
133
|
+
raise UnsupportedOpError("Conv has unsupported auto_pad mode")
|
|
134
|
+
out_spatial = []
|
|
135
|
+
for dim, stride, dilation, kernel, pad_start, pad_finish in zip(
|
|
136
|
+
in_spatial, strides, dilations, kernel_shape, pad_begin, pad_end
|
|
137
|
+
):
|
|
138
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
139
|
+
out_dim = (dim + pad_start + pad_finish - effective_kernel) // stride + 1
|
|
140
|
+
if out_dim < 0:
|
|
141
|
+
raise ShapeInferenceError("Conv output shape must be non-negative")
|
|
142
|
+
out_spatial.append(out_dim)
|
|
143
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
144
|
+
expected_output_shape = (batch, out_channels, *out_spatial)
|
|
145
|
+
if output_shape != expected_output_shape:
|
|
146
|
+
raise ShapeInferenceError(
|
|
147
|
+
"Conv output shape must be "
|
|
148
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
149
|
+
)
|
|
150
|
+
return ConvSpec(
|
|
151
|
+
batch=batch,
|
|
152
|
+
in_channels=in_channels,
|
|
153
|
+
out_channels=out_channels,
|
|
154
|
+
spatial_rank=spatial_rank,
|
|
155
|
+
in_spatial=in_spatial,
|
|
156
|
+
out_spatial=tuple(out_spatial),
|
|
157
|
+
kernel_shape=kernel_shape,
|
|
158
|
+
strides=strides,
|
|
159
|
+
pads=(*pad_begin, *pad_end),
|
|
160
|
+
dilations=dilations,
|
|
161
|
+
group=group,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@register_lowering("Conv")
|
|
166
|
+
def lower_conv(graph: Graph, node: Node) -> ConvOp:
|
|
167
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
168
|
+
raise UnsupportedOpError("Conv must have 2 or 3 inputs and 1 output")
|
|
169
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
170
|
+
if not op_dtype.is_float:
|
|
171
|
+
raise UnsupportedOpError(
|
|
172
|
+
"Conv supports float16, float, and double inputs only"
|
|
173
|
+
)
|
|
174
|
+
spec = resolve_conv_spec(graph, node)
|
|
175
|
+
return ConvOp(
|
|
176
|
+
input0=node.inputs[0],
|
|
177
|
+
weights=node.inputs[1],
|
|
178
|
+
bias=node.inputs[2] if len(node.inputs) == 3 else None,
|
|
179
|
+
output=node.outputs[0],
|
|
180
|
+
batch=spec.batch,
|
|
181
|
+
in_channels=spec.in_channels,
|
|
182
|
+
out_channels=spec.out_channels,
|
|
183
|
+
spatial_rank=spec.spatial_rank,
|
|
184
|
+
in_spatial=spec.in_spatial,
|
|
185
|
+
out_spatial=spec.out_spatial,
|
|
186
|
+
kernel_shape=spec.kernel_shape,
|
|
187
|
+
strides=spec.strides,
|
|
188
|
+
pads=spec.pads,
|
|
189
|
+
dilations=spec.dilations,
|
|
190
|
+
group=spec.group,
|
|
191
|
+
dtype=op_dtype,
|
|
192
|
+
)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..codegen.c_emitter import CumSumOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..lowering.common import value_dtype, value_shape
|
|
11
|
+
from ..validation import ensure_output_shape_matches_input, normalize_axis
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_SUPPORTED_CUMSUM_DTYPES = {
|
|
16
|
+
ScalarType.F16,
|
|
17
|
+
ScalarType.F32,
|
|
18
|
+
ScalarType.F64,
|
|
19
|
+
ScalarType.I32,
|
|
20
|
+
ScalarType.I64,
|
|
21
|
+
ScalarType.U32,
|
|
22
|
+
ScalarType.U64,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
27
|
+
for initializer in graph.initializers:
|
|
28
|
+
if initializer.name == name:
|
|
29
|
+
return initializer
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
|
|
34
|
+
return shape == () or shape == (1,)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _validate_static_shape(shape: tuple[int, ...], node: Node) -> None:
|
|
38
|
+
for dim in shape:
|
|
39
|
+
if dim < 0:
|
|
40
|
+
raise ShapeInferenceError(
|
|
41
|
+
f"{node.op_type} does not support dynamic dims"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _read_axis_initializer(
|
|
46
|
+
initializer: Initializer, node: Node
|
|
47
|
+
) -> int:
|
|
48
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
49
|
+
raise UnsupportedOpError(
|
|
50
|
+
f"{node.op_type} axis input must be int64 or int32"
|
|
51
|
+
)
|
|
52
|
+
axis_data = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
53
|
+
if axis_data.size != 1:
|
|
54
|
+
raise UnsupportedOpError(f"{node.op_type} axis input must be scalar")
|
|
55
|
+
return int(axis_data[0])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@register_lowering("CumSum")
|
|
59
|
+
def lower_cumsum(graph: Graph, node: Node) -> CumSumOp:
|
|
60
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
61
|
+
raise UnsupportedOpError("CumSum must have 2 inputs and 1 output")
|
|
62
|
+
input_name = node.inputs[0]
|
|
63
|
+
axis_name = node.inputs[1]
|
|
64
|
+
if not input_name or not axis_name:
|
|
65
|
+
raise UnsupportedOpError("CumSum requires input and axis values")
|
|
66
|
+
input_shape = value_shape(graph, input_name, node)
|
|
67
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
68
|
+
_validate_static_shape(input_shape, node)
|
|
69
|
+
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
70
|
+
input_dtype = value_dtype(graph, input_name, node)
|
|
71
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
72
|
+
if input_dtype != output_dtype:
|
|
73
|
+
raise UnsupportedOpError(
|
|
74
|
+
"CumSum expects matching input/output dtypes, "
|
|
75
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
76
|
+
)
|
|
77
|
+
if input_dtype not in _SUPPORTED_CUMSUM_DTYPES:
|
|
78
|
+
raise UnsupportedOpError(
|
|
79
|
+
f"CumSum does not support dtype {input_dtype.onnx_name}"
|
|
80
|
+
)
|
|
81
|
+
axis_initializer = _find_initializer(graph, axis_name)
|
|
82
|
+
axis_value = None
|
|
83
|
+
axis_input = None
|
|
84
|
+
axis_input_dtype = None
|
|
85
|
+
if axis_initializer is not None:
|
|
86
|
+
axis_value = normalize_axis(
|
|
87
|
+
_read_axis_initializer(axis_initializer, node),
|
|
88
|
+
input_shape,
|
|
89
|
+
node,
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
axis_shape = value_shape(graph, axis_name, node)
|
|
93
|
+
if not _is_scalar_shape(axis_shape):
|
|
94
|
+
raise UnsupportedOpError("CumSum axis input must be scalar")
|
|
95
|
+
axis_input_dtype = value_dtype(graph, axis_name, node)
|
|
96
|
+
if axis_input_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
97
|
+
raise UnsupportedOpError(
|
|
98
|
+
"CumSum axis input must be int64 or int32"
|
|
99
|
+
)
|
|
100
|
+
axis_input = axis_name
|
|
101
|
+
exclusive = int(node.attrs.get("exclusive", 0))
|
|
102
|
+
reverse = int(node.attrs.get("reverse", 0))
|
|
103
|
+
if exclusive not in {0, 1}:
|
|
104
|
+
raise UnsupportedOpError("CumSum exclusive must be 0 or 1")
|
|
105
|
+
if reverse not in {0, 1}:
|
|
106
|
+
raise UnsupportedOpError("CumSum reverse must be 0 or 1")
|
|
107
|
+
return CumSumOp(
|
|
108
|
+
input0=input_name,
|
|
109
|
+
axis_input=axis_input,
|
|
110
|
+
axis_input_dtype=axis_input_dtype,
|
|
111
|
+
axis=axis_value,
|
|
112
|
+
output=node.outputs[0],
|
|
113
|
+
input_shape=input_shape,
|
|
114
|
+
dtype=input_dtype,
|
|
115
|
+
input_dtype=input_dtype,
|
|
116
|
+
exclusive=bool(exclusive),
|
|
117
|
+
reverse=bool(reverse),
|
|
118
|
+
)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..codegen.c_emitter import DepthToSpaceOp, SpaceToDepthOp
|
|
4
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from ..lowering.common import value_dtype, value_shape
|
|
7
|
+
from .registry import register_lowering
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _blocksize(node: Node) -> int:
|
|
11
|
+
if "blocksize" not in node.attrs:
|
|
12
|
+
raise UnsupportedOpError(f"{node.op_type} requires blocksize attribute")
|
|
13
|
+
blocksize = int(node.attrs["blocksize"])
|
|
14
|
+
if blocksize <= 0:
|
|
15
|
+
raise UnsupportedOpError(
|
|
16
|
+
f"{node.op_type} blocksize must be > 0, got {blocksize}"
|
|
17
|
+
)
|
|
18
|
+
return blocksize
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_lowering("DepthToSpace")
|
|
22
|
+
def lower_depth_to_space(graph: Graph, node: Node) -> DepthToSpaceOp:
|
|
23
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
24
|
+
raise UnsupportedOpError("DepthToSpace must have 1 input and 1 output")
|
|
25
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
26
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
27
|
+
if len(input_shape) != 4 or len(output_shape) != 4:
|
|
28
|
+
raise UnsupportedOpError("DepthToSpace only supports 4D inputs")
|
|
29
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
30
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
31
|
+
if input_dtype != output_dtype:
|
|
32
|
+
raise UnsupportedOpError(
|
|
33
|
+
"DepthToSpace expects matching input/output dtypes, "
|
|
34
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
35
|
+
)
|
|
36
|
+
blocksize = _blocksize(node)
|
|
37
|
+
mode_attr = node.attrs.get("mode", "DCR")
|
|
38
|
+
if isinstance(mode_attr, bytes):
|
|
39
|
+
mode = mode_attr.decode()
|
|
40
|
+
else:
|
|
41
|
+
mode = str(mode_attr)
|
|
42
|
+
if mode not in {"DCR", "CRD"}:
|
|
43
|
+
raise UnsupportedOpError(
|
|
44
|
+
"DepthToSpace only supports mode DCR or CRD"
|
|
45
|
+
)
|
|
46
|
+
n, c, h, w = input_shape
|
|
47
|
+
if c % (blocksize * blocksize) != 0:
|
|
48
|
+
raise ShapeInferenceError(
|
|
49
|
+
"DepthToSpace input channels must be divisible by blocksize^2"
|
|
50
|
+
)
|
|
51
|
+
expected_shape = (
|
|
52
|
+
n,
|
|
53
|
+
c // (blocksize * blocksize),
|
|
54
|
+
h * blocksize,
|
|
55
|
+
w * blocksize,
|
|
56
|
+
)
|
|
57
|
+
if output_shape != expected_shape:
|
|
58
|
+
raise ShapeInferenceError(
|
|
59
|
+
"DepthToSpace output shape mismatch: "
|
|
60
|
+
f"expected {expected_shape}, got {output_shape}"
|
|
61
|
+
)
|
|
62
|
+
return DepthToSpaceOp(
|
|
63
|
+
input0=node.inputs[0],
|
|
64
|
+
output=node.outputs[0],
|
|
65
|
+
input_shape=input_shape,
|
|
66
|
+
output_shape=output_shape,
|
|
67
|
+
blocksize=blocksize,
|
|
68
|
+
mode=mode,
|
|
69
|
+
dtype=output_dtype,
|
|
70
|
+
input_dtype=input_dtype,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@register_lowering("SpaceToDepth")
|
|
75
|
+
def lower_space_to_depth(graph: Graph, node: Node) -> SpaceToDepthOp:
|
|
76
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
77
|
+
raise UnsupportedOpError("SpaceToDepth must have 1 input and 1 output")
|
|
78
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
79
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
80
|
+
if len(input_shape) != 4 or len(output_shape) != 4:
|
|
81
|
+
raise UnsupportedOpError("SpaceToDepth only supports 4D inputs")
|
|
82
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
83
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
84
|
+
if input_dtype != output_dtype:
|
|
85
|
+
raise UnsupportedOpError(
|
|
86
|
+
"SpaceToDepth expects matching input/output dtypes, "
|
|
87
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
88
|
+
)
|
|
89
|
+
blocksize = _blocksize(node)
|
|
90
|
+
n, c, h, w = input_shape
|
|
91
|
+
if h % blocksize != 0 or w % blocksize != 0:
|
|
92
|
+
raise ShapeInferenceError(
|
|
93
|
+
"SpaceToDepth spatial dims must be divisible by blocksize"
|
|
94
|
+
)
|
|
95
|
+
expected_shape = (
|
|
96
|
+
n,
|
|
97
|
+
c * blocksize * blocksize,
|
|
98
|
+
h // blocksize,
|
|
99
|
+
w // blocksize,
|
|
100
|
+
)
|
|
101
|
+
if output_shape != expected_shape:
|
|
102
|
+
raise ShapeInferenceError(
|
|
103
|
+
"SpaceToDepth output shape mismatch: "
|
|
104
|
+
f"expected {expected_shape}, got {output_shape}"
|
|
105
|
+
)
|
|
106
|
+
return SpaceToDepthOp(
|
|
107
|
+
input0=node.inputs[0],
|
|
108
|
+
output=node.outputs[0],
|
|
109
|
+
input_shape=input_shape,
|
|
110
|
+
output_shape=output_shape,
|
|
111
|
+
blocksize=blocksize,
|
|
112
|
+
dtype=output_dtype,
|
|
113
|
+
input_dtype=input_dtype,
|
|
114
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..codegen.c_emitter import ReshapeOp
|
|
4
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from .common import value_dtype as _value_dtype
|
|
7
|
+
from .common import value_shape as _value_shape
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _is_value_used(graph: Graph, name: str) -> bool:
|
|
12
|
+
if any(value.name == name for value in graph.outputs):
|
|
13
|
+
return True
|
|
14
|
+
return any(name in node.inputs for node in graph.nodes)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_lowering("Dropout")
|
|
18
|
+
def lower_dropout(graph: Graph, node: Node) -> ReshapeOp:
|
|
19
|
+
if len(node.outputs) not in {1, 2} or len(node.inputs) != 1:
|
|
20
|
+
raise UnsupportedOpError(
|
|
21
|
+
"Dropout supports only the data input and 1 or 2 outputs"
|
|
22
|
+
)
|
|
23
|
+
if len(node.outputs) == 2 and _is_value_used(graph, node.outputs[1]):
|
|
24
|
+
raise UnsupportedOpError("Dropout mask output is not supported")
|
|
25
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
26
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
27
|
+
if input_shape != output_shape:
|
|
28
|
+
raise ShapeInferenceError(
|
|
29
|
+
"Dropout output shape must match input shape, "
|
|
30
|
+
f"got {output_shape} for input {input_shape}"
|
|
31
|
+
)
|
|
32
|
+
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
33
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
34
|
+
if input_dtype != output_dtype:
|
|
35
|
+
raise UnsupportedOpError(
|
|
36
|
+
"Dropout expects matching input/output dtypes, "
|
|
37
|
+
f"got {input_dtype} and {output_dtype}"
|
|
38
|
+
)
|
|
39
|
+
return ReshapeOp(
|
|
40
|
+
input0=node.inputs[0],
|
|
41
|
+
output=node.outputs[0],
|
|
42
|
+
input_shape=input_shape,
|
|
43
|
+
output_shape=output_shape,
|
|
44
|
+
dtype=input_dtype,
|
|
45
|
+
input_dtype=input_dtype,
|
|
46
|
+
)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_functions import ScalarFunction
|
|
4
|
+
from shared.scalar_types import ScalarType
|
|
5
|
+
|
|
6
|
+
from ..codegen.c_emitter import ClipOp, UnaryOp
|
|
7
|
+
from ..errors import UnsupportedOpError
|
|
8
|
+
from ..ir.model import Graph, Node
|
|
9
|
+
from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
|
|
10
|
+
from ..lowering.registry import register_lowering
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_lowering("Clip")
|
|
14
|
+
def lower_clip(graph: Graph, node: Node) -> ClipOp:
|
|
15
|
+
if not node.inputs or len(node.outputs) != 1:
|
|
16
|
+
raise UnsupportedOpError("Clip must have 1 output")
|
|
17
|
+
input_name = node.inputs[0]
|
|
18
|
+
if not input_name:
|
|
19
|
+
raise UnsupportedOpError("Clip input must be provided")
|
|
20
|
+
min_name = optional_name(node.inputs, 1)
|
|
21
|
+
max_name = optional_name(node.inputs, 2)
|
|
22
|
+
input_dtype = value_dtype(graph, input_name, node)
|
|
23
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
24
|
+
if input_dtype != output_dtype:
|
|
25
|
+
raise UnsupportedOpError(
|
|
26
|
+
"Clip expects matching input/output dtypes, "
|
|
27
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
28
|
+
)
|
|
29
|
+
if min_name is not None:
|
|
30
|
+
min_dtype = value_dtype(graph, min_name, node)
|
|
31
|
+
if min_dtype != input_dtype:
|
|
32
|
+
raise UnsupportedOpError(
|
|
33
|
+
"Clip min dtype must match input dtype, "
|
|
34
|
+
f"got {min_dtype.onnx_name}"
|
|
35
|
+
)
|
|
36
|
+
if max_name is not None:
|
|
37
|
+
max_dtype = value_dtype(graph, max_name, node)
|
|
38
|
+
if max_dtype != input_dtype:
|
|
39
|
+
raise UnsupportedOpError(
|
|
40
|
+
"Clip max dtype must match input dtype, "
|
|
41
|
+
f"got {max_dtype.onnx_name}"
|
|
42
|
+
)
|
|
43
|
+
input_shape = value_shape(graph, input_name, node)
|
|
44
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
45
|
+
if input_shape != output_shape:
|
|
46
|
+
raise UnsupportedOpError("Clip input and output shapes must match")
|
|
47
|
+
min_shape = value_shape(graph, min_name, node) if min_name else None
|
|
48
|
+
max_shape = value_shape(graph, max_name, node) if max_name else None
|
|
49
|
+
return ClipOp(
|
|
50
|
+
input0=input_name,
|
|
51
|
+
input_min=min_name,
|
|
52
|
+
input_max=max_name,
|
|
53
|
+
output=node.outputs[0],
|
|
54
|
+
input_shape=input_shape,
|
|
55
|
+
min_shape=min_shape,
|
|
56
|
+
max_shape=max_shape,
|
|
57
|
+
output_shape=output_shape,
|
|
58
|
+
dtype=input_dtype,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@register_lowering("Celu")
|
|
63
|
+
def lower_celu(graph: Graph, node: Node) -> UnaryOp:
|
|
64
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
65
|
+
raise UnsupportedOpError("Celu must have 1 input and 1 output")
|
|
66
|
+
dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
67
|
+
if not dtype.is_float:
|
|
68
|
+
raise UnsupportedOpError("Celu only supports floating-point inputs")
|
|
69
|
+
alpha = float(node.attrs.get("alpha", 1.0))
|
|
70
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
71
|
+
return UnaryOp(
|
|
72
|
+
input0=node.inputs[0],
|
|
73
|
+
output=node.outputs[0],
|
|
74
|
+
function=ScalarFunction.CELU,
|
|
75
|
+
shape=output_shape,
|
|
76
|
+
dtype=dtype,
|
|
77
|
+
input_dtype=dtype,
|
|
78
|
+
params=(alpha,),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@register_lowering("Swish")
|
|
83
|
+
def lower_swish(graph: Graph, node: Node) -> UnaryOp:
|
|
84
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
85
|
+
raise UnsupportedOpError("Swish must have 1 input and 1 output")
|
|
86
|
+
dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
87
|
+
if not dtype.is_float:
|
|
88
|
+
raise UnsupportedOpError("Swish only supports floating-point inputs")
|
|
89
|
+
alpha = float(node.attrs.get("alpha", 1.0))
|
|
90
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
91
|
+
return UnaryOp(
|
|
92
|
+
input0=node.inputs[0],
|
|
93
|
+
output=node.outputs[0],
|
|
94
|
+
function=ScalarFunction.SWISH,
|
|
95
|
+
shape=output_shape,
|
|
96
|
+
dtype=dtype,
|
|
97
|
+
input_dtype=dtype,
|
|
98
|
+
params=(alpha,),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@register_lowering("Shrink")
|
|
103
|
+
def lower_shrink(graph: Graph, node: Node) -> UnaryOp:
|
|
104
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
105
|
+
raise UnsupportedOpError("Shrink must have 1 input and 1 output")
|
|
106
|
+
dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
107
|
+
if not dtype.is_float:
|
|
108
|
+
raise UnsupportedOpError("Shrink only supports floating-point inputs")
|
|
109
|
+
bias = float(node.attrs.get("bias", 0.0))
|
|
110
|
+
lambd = float(node.attrs.get("lambd", 0.5))
|
|
111
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
112
|
+
return UnaryOp(
|
|
113
|
+
input0=node.inputs[0],
|
|
114
|
+
output=node.outputs[0],
|
|
115
|
+
function=ScalarFunction.SHRINK,
|
|
116
|
+
shape=output_shape,
|
|
117
|
+
dtype=dtype,
|
|
118
|
+
input_dtype=dtype,
|
|
119
|
+
params=(bias, lambd),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@register_lowering("IsInf")
|
|
124
|
+
def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
|
|
125
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
126
|
+
raise UnsupportedOpError("IsInf must have 1 input and 1 output")
|
|
127
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
128
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
129
|
+
if not input_dtype.is_float:
|
|
130
|
+
raise UnsupportedOpError("IsInf only supports floating-point inputs")
|
|
131
|
+
if output_dtype != ScalarType.BOOL:
|
|
132
|
+
raise UnsupportedOpError("IsInf output must be bool")
|
|
133
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
134
|
+
return UnaryOp(
|
|
135
|
+
input0=node.inputs[0],
|
|
136
|
+
output=node.outputs[0],
|
|
137
|
+
function=ScalarFunction.ISINF,
|
|
138
|
+
shape=output_shape,
|
|
139
|
+
dtype=output_dtype,
|
|
140
|
+
input_dtype=input_dtype,
|
|
141
|
+
params=(),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@register_lowering("IsNaN")
|
|
146
|
+
def lower_isnan(graph: Graph, node: Node) -> UnaryOp:
|
|
147
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
148
|
+
raise UnsupportedOpError("IsNaN must have 1 input and 1 output")
|
|
149
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
150
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
151
|
+
if not input_dtype.is_float:
|
|
152
|
+
raise UnsupportedOpError("IsNaN only supports floating-point inputs")
|
|
153
|
+
if output_dtype != ScalarType.BOOL:
|
|
154
|
+
raise UnsupportedOpError("IsNaN output must be bool")
|
|
155
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
156
|
+
return UnaryOp(
|
|
157
|
+
input0=node.inputs[0],
|
|
158
|
+
output=node.outputs[0],
|
|
159
|
+
function=ScalarFunction.ISNAN,
|
|
160
|
+
shape=output_shape,
|
|
161
|
+
dtype=output_dtype,
|
|
162
|
+
input_dtype=input_dtype,
|
|
163
|
+
params=(),
|
|
164
|
+
)
|