emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import AveragePoolOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class _AveragePoolSpec:
|
|
13
|
+
batch: int
|
|
14
|
+
channels: int
|
|
15
|
+
in_h: int
|
|
16
|
+
in_w: int
|
|
17
|
+
out_h: int
|
|
18
|
+
out_w: int
|
|
19
|
+
kernel_h: int
|
|
20
|
+
kernel_w: int
|
|
21
|
+
stride_h: int
|
|
22
|
+
stride_w: int
|
|
23
|
+
pad_top: int
|
|
24
|
+
pad_left: int
|
|
25
|
+
pad_bottom: int
|
|
26
|
+
pad_right: int
|
|
27
|
+
count_include_pad: bool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
31
|
+
try:
|
|
32
|
+
return graph.find_value(name).type.shape
|
|
33
|
+
except KeyError as exc:
|
|
34
|
+
raise ShapeInferenceError(
|
|
35
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
36
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
37
|
+
) from exc
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> str:
|
|
41
|
+
try:
|
|
42
|
+
return graph.find_value(name).type.dtype
|
|
43
|
+
except KeyError as exc:
|
|
44
|
+
raise ShapeInferenceError(
|
|
45
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
46
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
47
|
+
) from exc
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _resolve_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
|
|
51
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
52
|
+
raise UnsupportedOpError("AveragePool must have 1 input and 1 output")
|
|
53
|
+
supported_attrs = {
|
|
54
|
+
"auto_pad",
|
|
55
|
+
"ceil_mode",
|
|
56
|
+
"count_include_pad",
|
|
57
|
+
"kernel_shape",
|
|
58
|
+
"pads",
|
|
59
|
+
"strides",
|
|
60
|
+
}
|
|
61
|
+
if set(node.attrs) - supported_attrs:
|
|
62
|
+
raise UnsupportedOpError("AveragePool has unsupported attributes")
|
|
63
|
+
auto_pad = node.attrs.get("auto_pad", b"NOTSET")
|
|
64
|
+
if isinstance(auto_pad, bytes):
|
|
65
|
+
auto_pad = auto_pad.decode("utf-8", errors="ignore")
|
|
66
|
+
if auto_pad not in ("", "NOTSET"):
|
|
67
|
+
raise UnsupportedOpError("AveragePool supports auto_pad=NOTSET only")
|
|
68
|
+
ceil_mode = int(node.attrs.get("ceil_mode", 0))
|
|
69
|
+
if ceil_mode != 0:
|
|
70
|
+
raise UnsupportedOpError("AveragePool supports ceil_mode=0 only")
|
|
71
|
+
count_include_pad = int(node.attrs.get("count_include_pad", 0))
|
|
72
|
+
if count_include_pad not in (0, 1):
|
|
73
|
+
raise UnsupportedOpError("AveragePool supports count_include_pad 0 or 1")
|
|
74
|
+
kernel_shape = node.attrs.get("kernel_shape")
|
|
75
|
+
if kernel_shape is None:
|
|
76
|
+
raise UnsupportedOpError("AveragePool requires kernel_shape")
|
|
77
|
+
kernel_shape = tuple(int(value) for value in kernel_shape)
|
|
78
|
+
if len(kernel_shape) != 2:
|
|
79
|
+
raise UnsupportedOpError("AveragePool expects 2D kernel_shape")
|
|
80
|
+
kernel_h, kernel_w = kernel_shape
|
|
81
|
+
strides = tuple(int(value) for value in node.attrs.get("strides", (1, 1)))
|
|
82
|
+
if len(strides) != 2:
|
|
83
|
+
raise UnsupportedOpError("AveragePool expects 2D strides")
|
|
84
|
+
pads = tuple(int(value) for value in node.attrs.get("pads", (0, 0, 0, 0)))
|
|
85
|
+
if len(pads) != 4:
|
|
86
|
+
raise UnsupportedOpError("AveragePool expects 4D pads")
|
|
87
|
+
pad_top, pad_left, pad_bottom, pad_right = pads
|
|
88
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
89
|
+
if len(input_shape) != 4:
|
|
90
|
+
raise UnsupportedOpError("AveragePool supports NCHW 2D inputs only")
|
|
91
|
+
batch, channels, in_h, in_w = input_shape
|
|
92
|
+
stride_h, stride_w = strides
|
|
93
|
+
out_h = (in_h + pad_top + pad_bottom - kernel_h) // stride_h + 1
|
|
94
|
+
out_w = (in_w + pad_left + pad_right - kernel_w) // stride_w + 1
|
|
95
|
+
if out_h < 0 or out_w < 0:
|
|
96
|
+
raise ShapeInferenceError(
|
|
97
|
+
"AveragePool output shape must be non-negative"
|
|
98
|
+
)
|
|
99
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
100
|
+
expected_output_shape = (batch, channels, out_h, out_w)
|
|
101
|
+
if output_shape != expected_output_shape:
|
|
102
|
+
raise ShapeInferenceError(
|
|
103
|
+
"AveragePool output shape must be "
|
|
104
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
105
|
+
)
|
|
106
|
+
return _AveragePoolSpec(
|
|
107
|
+
batch=batch,
|
|
108
|
+
channels=channels,
|
|
109
|
+
in_h=in_h,
|
|
110
|
+
in_w=in_w,
|
|
111
|
+
out_h=out_h,
|
|
112
|
+
out_w=out_w,
|
|
113
|
+
kernel_h=kernel_h,
|
|
114
|
+
kernel_w=kernel_w,
|
|
115
|
+
stride_h=stride_h,
|
|
116
|
+
stride_w=stride_w,
|
|
117
|
+
pad_top=pad_top,
|
|
118
|
+
pad_left=pad_left,
|
|
119
|
+
pad_bottom=pad_bottom,
|
|
120
|
+
pad_right=pad_right,
|
|
121
|
+
count_include_pad=bool(count_include_pad),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _resolve_global_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
|
|
126
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
127
|
+
raise UnsupportedOpError("GlobalAveragePool must have 1 input and 1 output")
|
|
128
|
+
if node.attrs:
|
|
129
|
+
raise UnsupportedOpError("GlobalAveragePool has unsupported attributes")
|
|
130
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
131
|
+
if len(input_shape) != 4:
|
|
132
|
+
raise UnsupportedOpError("GlobalAveragePool supports NCHW 2D inputs only")
|
|
133
|
+
batch, channels, in_h, in_w = input_shape
|
|
134
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
135
|
+
expected_output_shape = (batch, channels, 1, 1)
|
|
136
|
+
if output_shape != expected_output_shape:
|
|
137
|
+
raise ShapeInferenceError(
|
|
138
|
+
"GlobalAveragePool output shape must be "
|
|
139
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
140
|
+
)
|
|
141
|
+
return _AveragePoolSpec(
|
|
142
|
+
batch=batch,
|
|
143
|
+
channels=channels,
|
|
144
|
+
in_h=in_h,
|
|
145
|
+
in_w=in_w,
|
|
146
|
+
out_h=1,
|
|
147
|
+
out_w=1,
|
|
148
|
+
kernel_h=in_h,
|
|
149
|
+
kernel_w=in_w,
|
|
150
|
+
stride_h=1,
|
|
151
|
+
stride_w=1,
|
|
152
|
+
pad_top=0,
|
|
153
|
+
pad_left=0,
|
|
154
|
+
pad_bottom=0,
|
|
155
|
+
pad_right=0,
|
|
156
|
+
count_include_pad=False,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@register_lowering("AveragePool")
|
|
161
|
+
def lower_average_pool(graph: Graph, node: Node) -> AveragePoolOp:
|
|
162
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
163
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
164
|
+
if op_dtype != output_dtype:
|
|
165
|
+
raise UnsupportedOpError(
|
|
166
|
+
"AveragePool expects matching input/output dtypes, "
|
|
167
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
168
|
+
)
|
|
169
|
+
if not op_dtype.is_float:
|
|
170
|
+
raise UnsupportedOpError(
|
|
171
|
+
"AveragePool supports float16, float, and double inputs only"
|
|
172
|
+
)
|
|
173
|
+
spec = _resolve_average_pool_spec(graph, node)
|
|
174
|
+
return AveragePoolOp(
|
|
175
|
+
input0=node.inputs[0],
|
|
176
|
+
output=node.outputs[0],
|
|
177
|
+
batch=spec.batch,
|
|
178
|
+
channels=spec.channels,
|
|
179
|
+
in_h=spec.in_h,
|
|
180
|
+
in_w=spec.in_w,
|
|
181
|
+
out_h=spec.out_h,
|
|
182
|
+
out_w=spec.out_w,
|
|
183
|
+
kernel_h=spec.kernel_h,
|
|
184
|
+
kernel_w=spec.kernel_w,
|
|
185
|
+
stride_h=spec.stride_h,
|
|
186
|
+
stride_w=spec.stride_w,
|
|
187
|
+
pad_top=spec.pad_top,
|
|
188
|
+
pad_left=spec.pad_left,
|
|
189
|
+
pad_bottom=spec.pad_bottom,
|
|
190
|
+
pad_right=spec.pad_right,
|
|
191
|
+
count_include_pad=spec.count_include_pad,
|
|
192
|
+
dtype=op_dtype,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@register_lowering("GlobalAveragePool")
|
|
197
|
+
def lower_global_average_pool(graph: Graph, node: Node) -> AveragePoolOp:
|
|
198
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
199
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
200
|
+
if op_dtype != output_dtype:
|
|
201
|
+
raise UnsupportedOpError(
|
|
202
|
+
"GlobalAveragePool expects matching input/output dtypes, "
|
|
203
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
204
|
+
)
|
|
205
|
+
if not op_dtype.is_float:
|
|
206
|
+
raise UnsupportedOpError(
|
|
207
|
+
"GlobalAveragePool supports float16, float, and double inputs only"
|
|
208
|
+
)
|
|
209
|
+
spec = _resolve_global_average_pool_spec(graph, node)
|
|
210
|
+
return AveragePoolOp(
|
|
211
|
+
input0=node.inputs[0],
|
|
212
|
+
output=node.outputs[0],
|
|
213
|
+
batch=spec.batch,
|
|
214
|
+
channels=spec.channels,
|
|
215
|
+
in_h=spec.in_h,
|
|
216
|
+
in_w=spec.in_w,
|
|
217
|
+
out_h=spec.out_h,
|
|
218
|
+
out_w=spec.out_w,
|
|
219
|
+
kernel_h=spec.kernel_h,
|
|
220
|
+
kernel_w=spec.kernel_w,
|
|
221
|
+
stride_h=spec.stride_h,
|
|
222
|
+
stride_w=spec.stride_w,
|
|
223
|
+
pad_top=spec.pad_top,
|
|
224
|
+
pad_left=spec.pad_left,
|
|
225
|
+
pad_bottom=spec.pad_bottom,
|
|
226
|
+
pad_right=spec.pad_right,
|
|
227
|
+
count_include_pad=spec.count_include_pad,
|
|
228
|
+
dtype=op_dtype,
|
|
229
|
+
)
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import BatchNormOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class _BatchNormSpec:
|
|
13
|
+
shape: tuple[int, ...]
|
|
14
|
+
channels: int
|
|
15
|
+
epsilon: float
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
19
|
+
try:
|
|
20
|
+
return graph.find_value(name).type.shape
|
|
21
|
+
except KeyError as exc:
|
|
22
|
+
raise ShapeInferenceError(
|
|
23
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
24
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
25
|
+
) from exc
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> str:
|
|
29
|
+
try:
|
|
30
|
+
return graph.find_value(name).type.dtype
|
|
31
|
+
except KeyError as exc:
|
|
32
|
+
raise ShapeInferenceError(
|
|
33
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
34
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
35
|
+
) from exc
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _node_dtype(graph: Graph, node: Node, *names: str) -> str:
|
|
39
|
+
dtypes = {_value_dtype(graph, name, node) for name in names}
|
|
40
|
+
if len(dtypes) != 1:
|
|
41
|
+
raise UnsupportedOpError(
|
|
42
|
+
f"{node.op_type} expects matching dtypes, got {', '.join(sorted(dtypes))}"
|
|
43
|
+
)
|
|
44
|
+
return next(iter(dtypes))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _resolve_batch_norm_spec(graph: Graph, node: Node) -> _BatchNormSpec:
|
|
48
|
+
if len(node.inputs) != 5 or len(node.outputs) != 1:
|
|
49
|
+
raise UnsupportedOpError(
|
|
50
|
+
"BatchNormalization must have 5 inputs and 1 output"
|
|
51
|
+
)
|
|
52
|
+
supported_attrs = {
|
|
53
|
+
"epsilon",
|
|
54
|
+
"is_test",
|
|
55
|
+
"momentum",
|
|
56
|
+
"spatial",
|
|
57
|
+
"training_mode",
|
|
58
|
+
}
|
|
59
|
+
if set(node.attrs) - supported_attrs:
|
|
60
|
+
raise UnsupportedOpError("BatchNormalization has unsupported attributes")
|
|
61
|
+
is_test = int(node.attrs.get("is_test", 1))
|
|
62
|
+
if is_test != 1:
|
|
63
|
+
raise UnsupportedOpError("BatchNormalization supports is_test=1 only")
|
|
64
|
+
training_mode = int(node.attrs.get("training_mode", 0))
|
|
65
|
+
if training_mode != 0:
|
|
66
|
+
raise UnsupportedOpError("BatchNormalization supports training_mode=0 only")
|
|
67
|
+
spatial = int(node.attrs.get("spatial", 1))
|
|
68
|
+
if spatial != 1:
|
|
69
|
+
raise UnsupportedOpError("BatchNormalization supports spatial=1 only")
|
|
70
|
+
epsilon = float(node.attrs.get("epsilon", 1e-5))
|
|
71
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
72
|
+
if len(input_shape) < 2:
|
|
73
|
+
raise UnsupportedOpError(
|
|
74
|
+
"BatchNormalization expects input rank of at least 2"
|
|
75
|
+
)
|
|
76
|
+
channels = input_shape[1]
|
|
77
|
+
for name in node.inputs[1:]:
|
|
78
|
+
shape = _value_shape(graph, name, node)
|
|
79
|
+
if shape != (channels,):
|
|
80
|
+
raise ShapeInferenceError(
|
|
81
|
+
"BatchNormalization parameter shape must be "
|
|
82
|
+
f"({channels},), got {shape}"
|
|
83
|
+
)
|
|
84
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
85
|
+
if output_shape != input_shape:
|
|
86
|
+
raise ShapeInferenceError(
|
|
87
|
+
"BatchNormalization output shape must match input shape, "
|
|
88
|
+
f"got {output_shape}"
|
|
89
|
+
)
|
|
90
|
+
return _BatchNormSpec(
|
|
91
|
+
shape=input_shape,
|
|
92
|
+
channels=channels,
|
|
93
|
+
epsilon=epsilon,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@register_lowering("BatchNormalization")
|
|
98
|
+
def lower_batch_normalization(graph: Graph, node: Node) -> BatchNormOp:
|
|
99
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
100
|
+
if not op_dtype.is_float:
|
|
101
|
+
raise UnsupportedOpError(
|
|
102
|
+
"BatchNormalization supports float16, float, and double inputs only"
|
|
103
|
+
)
|
|
104
|
+
spec = _resolve_batch_norm_spec(graph, node)
|
|
105
|
+
return BatchNormOp(
|
|
106
|
+
input0=node.inputs[0],
|
|
107
|
+
scale=node.inputs[1],
|
|
108
|
+
bias=node.inputs[2],
|
|
109
|
+
mean=node.inputs[3],
|
|
110
|
+
variance=node.inputs[4],
|
|
111
|
+
output=node.outputs[0],
|
|
112
|
+
shape=spec.shape,
|
|
113
|
+
channels=spec.channels,
|
|
114
|
+
epsilon=spec.epsilon,
|
|
115
|
+
dtype=op_dtype,
|
|
116
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import onnx
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import CastOp
|
|
6
|
+
from ..dtypes import scalar_type_from_onnx
|
|
7
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
+
from ..ir.model import Graph, Node
|
|
9
|
+
from .common import ensure_supported_dtype, value_dtype, value_shape
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_lowering("Cast")
|
|
14
|
+
def lower_cast(graph: Graph, node: Node) -> CastOp:
|
|
15
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
16
|
+
raise UnsupportedOpError("Cast must have 1 input and 1 output")
|
|
17
|
+
if "to" not in node.attrs:
|
|
18
|
+
raise UnsupportedOpError("Cast requires a 'to' attribute")
|
|
19
|
+
target_onnx_dtype = int(node.attrs["to"])
|
|
20
|
+
target_dtype = scalar_type_from_onnx(target_onnx_dtype)
|
|
21
|
+
if target_dtype is None:
|
|
22
|
+
name = onnx.TensorProto.DataType.Name(target_onnx_dtype)
|
|
23
|
+
raise UnsupportedOpError(
|
|
24
|
+
f"Cast 'to' dtype {target_onnx_dtype} ({name}) is not supported"
|
|
25
|
+
)
|
|
26
|
+
target_dtype = ensure_supported_dtype(target_dtype)
|
|
27
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
28
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
29
|
+
if output_dtype != target_dtype:
|
|
30
|
+
raise UnsupportedOpError(
|
|
31
|
+
"Cast output dtype must match 'to' attribute, "
|
|
32
|
+
f"got {output_dtype.onnx_name} and {target_dtype.onnx_name}"
|
|
33
|
+
)
|
|
34
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
35
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
36
|
+
if input_shape != output_shape:
|
|
37
|
+
raise ShapeInferenceError("Cast input and output shapes must match")
|
|
38
|
+
return CastOp(
|
|
39
|
+
input0=node.inputs[0],
|
|
40
|
+
output=node.outputs[0],
|
|
41
|
+
shape=output_shape,
|
|
42
|
+
input_dtype=input_dtype,
|
|
43
|
+
dtype=output_dtype,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@register_lowering("CastLike")
|
|
48
|
+
def lower_castlike(graph: Graph, node: Node) -> CastOp:
|
|
49
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
50
|
+
raise UnsupportedOpError("CastLike must have 2 inputs and 1 output")
|
|
51
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
52
|
+
like_dtype = value_dtype(graph, node.inputs[1], node)
|
|
53
|
+
target_dtype = ensure_supported_dtype(like_dtype)
|
|
54
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
55
|
+
if output_dtype != target_dtype:
|
|
56
|
+
raise UnsupportedOpError(
|
|
57
|
+
"CastLike output dtype must match like input dtype, "
|
|
58
|
+
f"got {output_dtype.onnx_name} and {target_dtype.onnx_name}"
|
|
59
|
+
)
|
|
60
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
61
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
62
|
+
if input_shape != output_shape:
|
|
63
|
+
raise ShapeInferenceError("CastLike input and output shapes must match")
|
|
64
|
+
return CastOp(
|
|
65
|
+
input0=node.inputs[0],
|
|
66
|
+
output=node.outputs[0],
|
|
67
|
+
shape=output_shape,
|
|
68
|
+
input_dtype=input_dtype,
|
|
69
|
+
dtype=output_dtype,
|
|
70
|
+
)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
+
from ..ir.model import Graph, Node
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
|
|
12
|
+
if not isinstance(dtype, ScalarType):
|
|
13
|
+
raise UnsupportedOpError(f"Unsupported dtype {dtype}")
|
|
14
|
+
return dtype
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType:
|
|
18
|
+
try:
|
|
19
|
+
value = graph.find_value(name)
|
|
20
|
+
except KeyError as exc:
|
|
21
|
+
op_type = node.op_type if node is not None else "unknown"
|
|
22
|
+
raise ShapeInferenceError(
|
|
23
|
+
f"Missing dtype for value '{name}' in op {op_type}. "
|
|
24
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
25
|
+
) from exc
|
|
26
|
+
return ensure_supported_dtype(value.type.dtype)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def value_shape(graph: Graph, name: str, node: Node | None = None) -> tuple[int, ...]:
|
|
30
|
+
try:
|
|
31
|
+
return graph.find_value(name).type.shape
|
|
32
|
+
except KeyError as exc:
|
|
33
|
+
op_type = node.op_type if node is not None else "unknown"
|
|
34
|
+
raise ShapeInferenceError(
|
|
35
|
+
f"Missing shape for value '{name}' in op {op_type}. "
|
|
36
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
37
|
+
) from exc
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def node_dtype(graph: Graph, node: Node, *names: str) -> ScalarType:
|
|
41
|
+
filtered = [name for name in names if name]
|
|
42
|
+
if not filtered:
|
|
43
|
+
raise UnsupportedOpError(
|
|
44
|
+
f"{node.op_type} expects at least one typed input or output"
|
|
45
|
+
)
|
|
46
|
+
dtypes = {value_dtype(graph, name, node) for name in filtered}
|
|
47
|
+
if len(dtypes) != 1:
|
|
48
|
+
dtype_names = ", ".join(dtype.onnx_name for dtype in sorted(dtypes, key=str))
|
|
49
|
+
raise UnsupportedOpError(
|
|
50
|
+
f"{node.op_type} expects matching dtypes, got {dtype_names}"
|
|
51
|
+
)
|
|
52
|
+
return next(iter(dtypes))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def shape_product(shape: tuple[int, ...]) -> int:
|
|
56
|
+
if not shape:
|
|
57
|
+
return 1
|
|
58
|
+
product = 1
|
|
59
|
+
for dim in shape:
|
|
60
|
+
if dim < 0:
|
|
61
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
62
|
+
if dim == 0:
|
|
63
|
+
return 0
|
|
64
|
+
product *= dim
|
|
65
|
+
return product
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def optional_name(names: Sequence[str], index: int) -> str | None:
|
|
69
|
+
if index >= len(names):
|
|
70
|
+
return None
|
|
71
|
+
name = names[index]
|
|
72
|
+
return name or None
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..codegen.c_emitter import ConcatOp
|
|
4
|
+
from ..errors import UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from .common import node_dtype as _node_dtype
|
|
7
|
+
from .common import value_shape as _value_shape
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
from ..validation import validate_concat_shapes
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@register_lowering("Concat")
|
|
13
|
+
def lower_concat(graph: Graph, node: Node) -> ConcatOp:
|
|
14
|
+
if len(node.inputs) < 1 or len(node.outputs) != 1:
|
|
15
|
+
raise UnsupportedOpError("Concat must have at least 1 input and 1 output")
|
|
16
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
17
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
18
|
+
input_shapes = tuple(_value_shape(graph, name, node) for name in node.inputs)
|
|
19
|
+
axis = validate_concat_shapes(
|
|
20
|
+
input_shapes,
|
|
21
|
+
output_shape,
|
|
22
|
+
int(node.attrs.get("axis", 0)),
|
|
23
|
+
)
|
|
24
|
+
return ConcatOp(
|
|
25
|
+
inputs=node.inputs,
|
|
26
|
+
output=node.outputs[0],
|
|
27
|
+
axis=axis,
|
|
28
|
+
input_shapes=input_shapes,
|
|
29
|
+
output_shape=output_shape,
|
|
30
|
+
dtype=op_dtype,
|
|
31
|
+
)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from onnx import numpy_helper
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..codegen.c_emitter import ConstantOfShapeOp
|
|
8
|
+
from ..dtypes import scalar_type_from_onnx
|
|
9
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
|
+
from ..ir.model import Graph, Node
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
15
|
+
try:
|
|
16
|
+
return graph.find_value(name).type.shape
|
|
17
|
+
except KeyError as exc:
|
|
18
|
+
raise ShapeInferenceError(
|
|
19
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
20
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
21
|
+
) from exc
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
25
|
+
try:
|
|
26
|
+
return graph.find_value(name).type.dtype
|
|
27
|
+
except KeyError as exc:
|
|
28
|
+
raise ShapeInferenceError(
|
|
29
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
30
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
31
|
+
) from exc
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _parse_value_attr(node: Node) -> tuple[ScalarType, float | int | bool]:
|
|
35
|
+
value_attr = node.attrs.get("value")
|
|
36
|
+
if value_attr is None:
|
|
37
|
+
return ScalarType.F32, 0.0
|
|
38
|
+
dtype = scalar_type_from_onnx(value_attr.data_type)
|
|
39
|
+
if dtype is None:
|
|
40
|
+
raise UnsupportedOpError(
|
|
41
|
+
f"ConstantOfShape has unsupported value dtype {value_attr.data_type}"
|
|
42
|
+
)
|
|
43
|
+
data = numpy_helper.to_array(value_attr)
|
|
44
|
+
if data.size != 1:
|
|
45
|
+
raise UnsupportedOpError("ConstantOfShape value must be a scalar")
|
|
46
|
+
return dtype, data.reshape(-1)[0].item()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@register_lowering("ConstantOfShape")
|
|
50
|
+
def lower_constant_of_shape(graph: Graph, node: Node) -> ConstantOfShapeOp:
|
|
51
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
52
|
+
raise UnsupportedOpError("ConstantOfShape must have 1 input and 1 output")
|
|
53
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
54
|
+
if len(input_shape) != 1:
|
|
55
|
+
raise UnsupportedOpError("ConstantOfShape expects a 1D shape input")
|
|
56
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
57
|
+
if input_shape[0] != len(output_shape):
|
|
58
|
+
raise ShapeInferenceError(
|
|
59
|
+
"ConstantOfShape input length must match output rank"
|
|
60
|
+
)
|
|
61
|
+
for dim in output_shape:
|
|
62
|
+
if dim < 0:
|
|
63
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
64
|
+
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
65
|
+
if input_dtype != ScalarType.I64:
|
|
66
|
+
raise UnsupportedOpError(
|
|
67
|
+
"ConstantOfShape expects int64 shape input, "
|
|
68
|
+
f"got {input_dtype.onnx_name}"
|
|
69
|
+
)
|
|
70
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
71
|
+
value_dtype, value = _parse_value_attr(node)
|
|
72
|
+
if output_dtype != value_dtype:
|
|
73
|
+
raise UnsupportedOpError(
|
|
74
|
+
"ConstantOfShape output dtype must match value dtype, "
|
|
75
|
+
f"got {output_dtype.onnx_name} and {value_dtype.onnx_name}"
|
|
76
|
+
)
|
|
77
|
+
return ConstantOfShapeOp(
|
|
78
|
+
input0=node.inputs[0],
|
|
79
|
+
output=node.outputs[0],
|
|
80
|
+
input_shape=input_shape,
|
|
81
|
+
shape=output_shape,
|
|
82
|
+
value=value,
|
|
83
|
+
dtype=output_dtype,
|
|
84
|
+
input_dtype=input_dtype,
|
|
85
|
+
)
|