emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..codegen.c_emitter import MeanVarianceNormalizationOp
|
|
4
|
+
from ..errors import UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from ..validation import ensure_output_shape_matches_input
|
|
7
|
+
from .common import node_dtype, shape_product, value_shape
|
|
8
|
+
from .reduce import normalize_reduce_axes
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@register_lowering("MeanVarianceNormalization")
|
|
13
|
+
def lower_mean_variance_normalization(
|
|
14
|
+
graph: Graph, node: Node
|
|
15
|
+
) -> MeanVarianceNormalizationOp:
|
|
16
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
17
|
+
raise UnsupportedOpError(
|
|
18
|
+
"MeanVarianceNormalization must have 1 input and 1 output"
|
|
19
|
+
)
|
|
20
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
21
|
+
if not op_dtype.is_float:
|
|
22
|
+
raise UnsupportedOpError(
|
|
23
|
+
"MeanVarianceNormalization supports float16, float, and double inputs only"
|
|
24
|
+
)
|
|
25
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
26
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
27
|
+
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
28
|
+
axes_attr = node.attrs.get("axes")
|
|
29
|
+
if axes_attr is None:
|
|
30
|
+
axes = (0, 2, 3)
|
|
31
|
+
else:
|
|
32
|
+
axes = tuple(int(axis) for axis in axes_attr)
|
|
33
|
+
axes = normalize_reduce_axes(axes, input_shape, node)
|
|
34
|
+
if not axes:
|
|
35
|
+
raise UnsupportedOpError(
|
|
36
|
+
"MeanVarianceNormalization requires non-empty reduction axes"
|
|
37
|
+
)
|
|
38
|
+
non_axes = tuple(i for i in range(len(input_shape)) if i not in axes)
|
|
39
|
+
reduce_count = shape_product(tuple(input_shape[axis] for axis in axes))
|
|
40
|
+
return MeanVarianceNormalizationOp(
|
|
41
|
+
input0=node.inputs[0],
|
|
42
|
+
output=node.outputs[0],
|
|
43
|
+
shape=input_shape,
|
|
44
|
+
axes=axes,
|
|
45
|
+
non_axes=non_axes,
|
|
46
|
+
reduce_count=reduce_count,
|
|
47
|
+
epsilon=1e-9,
|
|
48
|
+
dtype=op_dtype,
|
|
49
|
+
)
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import NegativeLogLikelihoodLossOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Initializer, Node
|
|
8
|
+
from .common import shape_product as _shape_product
|
|
9
|
+
from .common import value_dtype as _value_dtype
|
|
10
|
+
from .common import value_shape as _value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _find_node_by_output(graph: Graph, name: str) -> Node | None:
|
|
15
|
+
for node in graph.nodes:
|
|
16
|
+
if name in node.outputs:
|
|
17
|
+
return node
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
22
|
+
for initializer in graph.initializers:
|
|
23
|
+
if initializer.name == name:
|
|
24
|
+
return initializer
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _resolve_target_shape(
|
|
29
|
+
input_shape: tuple[int, ...],
|
|
30
|
+
shape_values: list[int],
|
|
31
|
+
*,
|
|
32
|
+
allowzero: int,
|
|
33
|
+
node: Node,
|
|
34
|
+
) -> tuple[int, ...]:
|
|
35
|
+
if allowzero not in (0, 1):
|
|
36
|
+
raise UnsupportedOpError("Reshape allowzero must be 0 or 1")
|
|
37
|
+
output_dims: list[int] = []
|
|
38
|
+
unknown_index: int | None = None
|
|
39
|
+
known_product = 1
|
|
40
|
+
for index, dim in enumerate(shape_values):
|
|
41
|
+
if dim == -1:
|
|
42
|
+
if unknown_index is not None:
|
|
43
|
+
raise ShapeInferenceError("Reshape allows only one -1 dimension")
|
|
44
|
+
unknown_index = index
|
|
45
|
+
output_dims.append(-1)
|
|
46
|
+
continue
|
|
47
|
+
if dim == 0:
|
|
48
|
+
if allowzero == 0:
|
|
49
|
+
if index >= len(input_shape):
|
|
50
|
+
raise ShapeInferenceError(
|
|
51
|
+
"Reshape zero dim must index into input shape"
|
|
52
|
+
)
|
|
53
|
+
dim = input_shape[index]
|
|
54
|
+
if dim < 0:
|
|
55
|
+
raise ShapeInferenceError("Reshape dims must be >= -1")
|
|
56
|
+
output_dims.append(dim)
|
|
57
|
+
known_product *= dim
|
|
58
|
+
input_product = _shape_product(input_shape)
|
|
59
|
+
if unknown_index is not None:
|
|
60
|
+
if known_product == 0 or input_product % known_product != 0:
|
|
61
|
+
raise ShapeInferenceError(
|
|
62
|
+
"Reshape cannot infer dimension from input shape"
|
|
63
|
+
)
|
|
64
|
+
output_dims[unknown_index] = input_product // known_product
|
|
65
|
+
output_shape = tuple(output_dims)
|
|
66
|
+
if _shape_product(output_shape) != input_product:
|
|
67
|
+
raise ShapeInferenceError(
|
|
68
|
+
"Reshape input and output element counts must match"
|
|
69
|
+
)
|
|
70
|
+
return output_shape
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _shape_values_from_shape_node(
|
|
74
|
+
graph: Graph, name: str, node: Node
|
|
75
|
+
) -> list[int] | None:
|
|
76
|
+
shape_node = _find_node_by_output(graph, name)
|
|
77
|
+
if shape_node is None or shape_node.op_type != "Shape":
|
|
78
|
+
return None
|
|
79
|
+
if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
|
|
80
|
+
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
81
|
+
source_shape = _value_shape(graph, shape_node.inputs[0], node)
|
|
82
|
+
return list(source_shape)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _resolve_shape_from_reshape(
|
|
86
|
+
graph: Graph, name: str, node: Node
|
|
87
|
+
) -> tuple[int, ...] | None:
|
|
88
|
+
reshape_node = _find_node_by_output(graph, name)
|
|
89
|
+
if reshape_node is None or reshape_node.op_type != "Reshape":
|
|
90
|
+
return None
|
|
91
|
+
if len(reshape_node.inputs) != 2 or len(reshape_node.outputs) != 1:
|
|
92
|
+
raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
|
|
93
|
+
input_shape = _value_shape(graph, reshape_node.inputs[0], node)
|
|
94
|
+
if not input_shape:
|
|
95
|
+
return None
|
|
96
|
+
allowzero = int(reshape_node.attrs.get("allowzero", 0))
|
|
97
|
+
shape_initializer = _find_initializer(graph, reshape_node.inputs[1])
|
|
98
|
+
if shape_initializer is not None:
|
|
99
|
+
if shape_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
100
|
+
raise UnsupportedOpError(
|
|
101
|
+
"Reshape expects int64 or int32 shape input, "
|
|
102
|
+
f"got {shape_initializer.type.dtype.onnx_name}"
|
|
103
|
+
)
|
|
104
|
+
if len(shape_initializer.type.shape) != 1:
|
|
105
|
+
raise UnsupportedOpError("Reshape expects a 1D shape input")
|
|
106
|
+
shape_values = [int(value) for value in shape_initializer.data.reshape(-1)]
|
|
107
|
+
return _resolve_target_shape(
|
|
108
|
+
input_shape,
|
|
109
|
+
shape_values,
|
|
110
|
+
allowzero=allowzero,
|
|
111
|
+
node=node,
|
|
112
|
+
)
|
|
113
|
+
shape_values = _shape_values_from_shape_node(
|
|
114
|
+
graph, reshape_node.inputs[1], node
|
|
115
|
+
)
|
|
116
|
+
if shape_values is None:
|
|
117
|
+
return None
|
|
118
|
+
return _resolve_target_shape(
|
|
119
|
+
input_shape,
|
|
120
|
+
shape_values,
|
|
121
|
+
allowzero=allowzero,
|
|
122
|
+
node=node,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _resolve_input_shape(
|
|
127
|
+
graph: Graph,
|
|
128
|
+
input_name: str,
|
|
129
|
+
target_shape: tuple[int, ...],
|
|
130
|
+
weight_name: str | None,
|
|
131
|
+
node: Node,
|
|
132
|
+
) -> tuple[int, ...]:
|
|
133
|
+
input_shape = _value_shape(graph, input_name, node)
|
|
134
|
+
if input_shape:
|
|
135
|
+
return input_shape
|
|
136
|
+
reshaped = _resolve_shape_from_reshape(graph, input_name, node)
|
|
137
|
+
if reshaped is not None:
|
|
138
|
+
return reshaped
|
|
139
|
+
if weight_name is not None and target_shape:
|
|
140
|
+
weight_shape = _value_shape(graph, weight_name, node)
|
|
141
|
+
if len(weight_shape) != 1:
|
|
142
|
+
return input_shape
|
|
143
|
+
return (target_shape[0], weight_shape[0], *target_shape[1:])
|
|
144
|
+
return input_shape
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@register_lowering("NegativeLogLikelihoodLoss")
|
|
148
|
+
def lower_negative_log_likelihood_loss(
|
|
149
|
+
graph: Graph, node: Node
|
|
150
|
+
) -> NegativeLogLikelihoodLossOp:
|
|
151
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
152
|
+
raise UnsupportedOpError(
|
|
153
|
+
"NegativeLogLikelihoodLoss must have 2 or 3 inputs and 1 output"
|
|
154
|
+
)
|
|
155
|
+
input_name = node.inputs[0]
|
|
156
|
+
target_name = node.inputs[1]
|
|
157
|
+
weight_name = node.inputs[2] if len(node.inputs) > 2 else None
|
|
158
|
+
input_dtype = _value_dtype(graph, input_name, node)
|
|
159
|
+
if not input_dtype.is_float:
|
|
160
|
+
raise UnsupportedOpError(
|
|
161
|
+
"NegativeLogLikelihoodLoss supports float16, float, and double inputs only"
|
|
162
|
+
)
|
|
163
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
164
|
+
if output_dtype != input_dtype:
|
|
165
|
+
raise UnsupportedOpError(
|
|
166
|
+
"NegativeLogLikelihoodLoss output dtype must match input dtype"
|
|
167
|
+
)
|
|
168
|
+
target_dtype = _value_dtype(graph, target_name, node)
|
|
169
|
+
if target_dtype not in {ScalarType.I32, ScalarType.I64}:
|
|
170
|
+
raise UnsupportedOpError(
|
|
171
|
+
"NegativeLogLikelihoodLoss target must be int32 or int64"
|
|
172
|
+
)
|
|
173
|
+
weight_dtype = None
|
|
174
|
+
weight_shape: tuple[int, ...] | None = None
|
|
175
|
+
if weight_name is not None:
|
|
176
|
+
weight_dtype = _value_dtype(graph, weight_name, node)
|
|
177
|
+
if weight_dtype != input_dtype:
|
|
178
|
+
raise UnsupportedOpError(
|
|
179
|
+
"NegativeLogLikelihoodLoss weight dtype must match input dtype"
|
|
180
|
+
)
|
|
181
|
+
target_shape = _value_shape(graph, target_name, node)
|
|
182
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
183
|
+
input_shape = _resolve_input_shape(
|
|
184
|
+
graph, input_name, target_shape, weight_name, node
|
|
185
|
+
)
|
|
186
|
+
if len(input_shape) < 2:
|
|
187
|
+
raise ShapeInferenceError(
|
|
188
|
+
"NegativeLogLikelihoodLoss input must be at least 2D"
|
|
189
|
+
)
|
|
190
|
+
if len(target_shape) != len(input_shape) - 1:
|
|
191
|
+
raise ShapeInferenceError(
|
|
192
|
+
"NegativeLogLikelihoodLoss target rank must be input rank - 1"
|
|
193
|
+
)
|
|
194
|
+
if input_shape[0] != target_shape[0]:
|
|
195
|
+
raise ShapeInferenceError(
|
|
196
|
+
"NegativeLogLikelihoodLoss target batch dimension must match input"
|
|
197
|
+
)
|
|
198
|
+
if input_shape[2:] != target_shape[1:]:
|
|
199
|
+
raise ShapeInferenceError(
|
|
200
|
+
"NegativeLogLikelihoodLoss target spatial dimensions must match input"
|
|
201
|
+
)
|
|
202
|
+
if weight_name is not None:
|
|
203
|
+
weight_shape = _value_shape(graph, weight_name, node)
|
|
204
|
+
if len(weight_shape) != 1 or weight_shape[0] != input_shape[1]:
|
|
205
|
+
raise ShapeInferenceError(
|
|
206
|
+
"NegativeLogLikelihoodLoss weight must have shape (C,)"
|
|
207
|
+
)
|
|
208
|
+
reduction = node.attrs.get("reduction", "mean")
|
|
209
|
+
if isinstance(reduction, bytes):
|
|
210
|
+
reduction = reduction.decode("utf-8")
|
|
211
|
+
if reduction not in {"none", "mean", "sum"}:
|
|
212
|
+
raise UnsupportedOpError(
|
|
213
|
+
"NegativeLogLikelihoodLoss reduction must be none, mean, or sum"
|
|
214
|
+
)
|
|
215
|
+
if reduction == "none":
|
|
216
|
+
if not output_shape:
|
|
217
|
+
output_shape = target_shape
|
|
218
|
+
if output_shape != target_shape:
|
|
219
|
+
raise ShapeInferenceError(
|
|
220
|
+
"NegativeLogLikelihoodLoss output must match target shape "
|
|
221
|
+
"when reduction is none"
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
if output_shape and output_shape not in {(), (1,)}:
|
|
225
|
+
raise ShapeInferenceError(
|
|
226
|
+
"NegativeLogLikelihoodLoss output must be scalar when reduced"
|
|
227
|
+
)
|
|
228
|
+
n = input_shape[0]
|
|
229
|
+
c = input_shape[1]
|
|
230
|
+
d = _shape_product(input_shape[2:]) if len(input_shape) > 2 else 1
|
|
231
|
+
ignore_index = int(node.attrs.get("ignore_index", -1))
|
|
232
|
+
return NegativeLogLikelihoodLossOp(
|
|
233
|
+
input0=input_name,
|
|
234
|
+
target=target_name,
|
|
235
|
+
weight=weight_name,
|
|
236
|
+
output=node.outputs[0],
|
|
237
|
+
input_shape=input_shape,
|
|
238
|
+
target_shape=target_shape,
|
|
239
|
+
output_shape=output_shape,
|
|
240
|
+
n=n,
|
|
241
|
+
c=c,
|
|
242
|
+
d=d,
|
|
243
|
+
reduction=reduction,
|
|
244
|
+
ignore_index=ignore_index,
|
|
245
|
+
input_dtype=input_dtype,
|
|
246
|
+
weight_dtype=weight_dtype,
|
|
247
|
+
weight_shape=weight_shape,
|
|
248
|
+
dtype=input_dtype,
|
|
249
|
+
target_dtype=target_dtype,
|
|
250
|
+
)
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..codegen.c_emitter import PadOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..lowering.common import optional_name, value_dtype, value_shape
|
|
11
|
+
from ..validation import normalize_axis
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
16
|
+
for initializer in graph.initializers:
|
|
17
|
+
if initializer.name == name:
|
|
18
|
+
return initializer
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _read_int_initializer(
|
|
23
|
+
graph: Graph,
|
|
24
|
+
name: str,
|
|
25
|
+
node: Node,
|
|
26
|
+
*,
|
|
27
|
+
label: str,
|
|
28
|
+
) -> tuple[int, ...] | None:
|
|
29
|
+
initializer = _find_initializer(graph, name)
|
|
30
|
+
if initializer is None:
|
|
31
|
+
return None
|
|
32
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
33
|
+
raise UnsupportedOpError(
|
|
34
|
+
f"Pad {label} input must be int64 or int32"
|
|
35
|
+
)
|
|
36
|
+
if len(initializer.type.shape) != 1:
|
|
37
|
+
raise UnsupportedOpError(f"Pad {label} input must be a 1D tensor")
|
|
38
|
+
values = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
39
|
+
return tuple(int(value) for value in values)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _read_scalar_initializer(
|
|
43
|
+
graph: Graph, name: str, node: Node, *, dtype: ScalarType
|
|
44
|
+
) -> float | int | bool | None:
|
|
45
|
+
initializer = _find_initializer(graph, name)
|
|
46
|
+
if initializer is None:
|
|
47
|
+
return None
|
|
48
|
+
if initializer.type.dtype != dtype:
|
|
49
|
+
raise UnsupportedOpError(
|
|
50
|
+
"Pad value input must match input dtype, "
|
|
51
|
+
f"got {initializer.type.dtype.onnx_name}"
|
|
52
|
+
)
|
|
53
|
+
values = np.array(initializer.data).reshape(-1)
|
|
54
|
+
if values.size != 1:
|
|
55
|
+
raise UnsupportedOpError("Pad value input must be a scalar")
|
|
56
|
+
return values.item()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _normalize_axes(
|
|
60
|
+
axes: tuple[int, ...], input_shape: tuple[int, ...], node: Node
|
|
61
|
+
) -> tuple[int, ...]:
|
|
62
|
+
normalized = [normalize_axis(axis, input_shape, node) for axis in axes]
|
|
63
|
+
if len(set(normalized)) != len(normalized):
|
|
64
|
+
raise UnsupportedOpError("Pad axes must be unique")
|
|
65
|
+
return tuple(normalized)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _default_pad_value(dtype: ScalarType) -> float | int | bool:
|
|
69
|
+
if dtype.is_bool:
|
|
70
|
+
return False
|
|
71
|
+
if dtype.is_float:
|
|
72
|
+
return 0.0
|
|
73
|
+
return 0
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
77
|
+
strides: list[int] = []
|
|
78
|
+
stride = 1
|
|
79
|
+
for dim in reversed(shape):
|
|
80
|
+
strides.append(stride)
|
|
81
|
+
stride *= dim
|
|
82
|
+
return tuple(reversed(strides))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@register_lowering("Pad")
|
|
86
|
+
def lower_pad(graph: Graph, node: Node) -> PadOp:
|
|
87
|
+
if not node.inputs or len(node.outputs) != 1:
|
|
88
|
+
raise UnsupportedOpError("Pad must have 1 output")
|
|
89
|
+
input_name = node.inputs[0]
|
|
90
|
+
if not input_name:
|
|
91
|
+
raise UnsupportedOpError("Pad input must be provided")
|
|
92
|
+
input_shape = value_shape(graph, input_name, node)
|
|
93
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
94
|
+
input_dtype = value_dtype(graph, input_name, node)
|
|
95
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
96
|
+
if input_dtype != output_dtype:
|
|
97
|
+
raise UnsupportedOpError(
|
|
98
|
+
"Pad expects matching input/output dtypes, "
|
|
99
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
100
|
+
)
|
|
101
|
+
mode = node.attrs.get("mode", "constant")
|
|
102
|
+
if isinstance(mode, bytes):
|
|
103
|
+
mode = mode.decode("utf-8")
|
|
104
|
+
if mode not in {"constant", "edge", "reflect", "wrap"}:
|
|
105
|
+
raise UnsupportedOpError(f"Pad mode '{mode}' is not supported")
|
|
106
|
+
pads_name = optional_name(node.inputs, 1)
|
|
107
|
+
pads_attr = node.attrs.get("pads")
|
|
108
|
+
if pads_name and pads_attr:
|
|
109
|
+
raise UnsupportedOpError("Pad pads must be provided via input or attribute")
|
|
110
|
+
pads = None
|
|
111
|
+
pads_input = None
|
|
112
|
+
pads_shape = None
|
|
113
|
+
pads_dtype = None
|
|
114
|
+
if pads_name:
|
|
115
|
+
pads = _read_int_initializer(graph, pads_name, node, label="pads")
|
|
116
|
+
if pads is None:
|
|
117
|
+
pads_shape = value_shape(graph, pads_name, node)
|
|
118
|
+
pads_dtype = value_dtype(graph, pads_name, node)
|
|
119
|
+
if pads_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
120
|
+
raise UnsupportedOpError(
|
|
121
|
+
"Pad pads input must be int64 or int32"
|
|
122
|
+
)
|
|
123
|
+
if len(pads_shape) != 1:
|
|
124
|
+
raise UnsupportedOpError("Pad pads input must be a 1D tensor")
|
|
125
|
+
pads_input = pads_name
|
|
126
|
+
elif pads_attr is not None:
|
|
127
|
+
pads = tuple(int(value) for value in pads_attr)
|
|
128
|
+
if pads is None and pads_input is None:
|
|
129
|
+
pads = tuple(0 for _ in range(2 * len(input_shape)))
|
|
130
|
+
|
|
131
|
+
axes_name = optional_name(node.inputs, 3)
|
|
132
|
+
axes = None
|
|
133
|
+
axes_input = None
|
|
134
|
+
axes_shape = None
|
|
135
|
+
axes_dtype = None
|
|
136
|
+
if axes_name:
|
|
137
|
+
axes = _read_int_initializer(graph, axes_name, node, label="axes")
|
|
138
|
+
if axes is None:
|
|
139
|
+
axes_shape = value_shape(graph, axes_name, node)
|
|
140
|
+
axes_dtype = value_dtype(graph, axes_name, node)
|
|
141
|
+
if axes_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
142
|
+
raise UnsupportedOpError(
|
|
143
|
+
"Pad axes input must be int64 or int32"
|
|
144
|
+
)
|
|
145
|
+
if len(axes_shape) != 1:
|
|
146
|
+
raise UnsupportedOpError("Pad axes input must be a 1D tensor")
|
|
147
|
+
if axes_shape[0] < 0:
|
|
148
|
+
raise ShapeInferenceError(
|
|
149
|
+
"Pad axes input must have a static length"
|
|
150
|
+
)
|
|
151
|
+
axes_input = axes_name
|
|
152
|
+
else:
|
|
153
|
+
axes = _normalize_axes(axes, input_shape, node)
|
|
154
|
+
|
|
155
|
+
pads_axis_map = None
|
|
156
|
+
pads_values = None
|
|
157
|
+
pads_begin = None
|
|
158
|
+
pads_end = None
|
|
159
|
+
|
|
160
|
+
if axes_input is None and axes is None:
|
|
161
|
+
if pads is None:
|
|
162
|
+
if pads_shape is None or pads_shape[0] != 2 * len(input_shape):
|
|
163
|
+
raise ShapeInferenceError(
|
|
164
|
+
"Pad pads must have length 2 * rank of input"
|
|
165
|
+
)
|
|
166
|
+
pads_begin = None
|
|
167
|
+
pads_end = None
|
|
168
|
+
else:
|
|
169
|
+
if len(pads) != 2 * len(input_shape):
|
|
170
|
+
raise ShapeInferenceError(
|
|
171
|
+
"Pad pads must have length 2 * rank of input"
|
|
172
|
+
)
|
|
173
|
+
pads_begin = list(pads[: len(input_shape)])
|
|
174
|
+
pads_end = list(pads[len(input_shape) :])
|
|
175
|
+
elif axes_input is None:
|
|
176
|
+
if pads_input is not None:
|
|
177
|
+
if pads_shape is None or pads_shape[0] != 2 * len(axes):
|
|
178
|
+
raise ShapeInferenceError(
|
|
179
|
+
"Pad pads must have length 2 * len(axes)"
|
|
180
|
+
)
|
|
181
|
+
pads_axis_map = [None] * len(input_shape)
|
|
182
|
+
for index, axis in enumerate(axes):
|
|
183
|
+
pads_axis_map[axis] = index
|
|
184
|
+
else:
|
|
185
|
+
if len(pads) != 2 * len(axes):
|
|
186
|
+
raise ShapeInferenceError(
|
|
187
|
+
"Pad pads must have length 2 * len(axes)"
|
|
188
|
+
)
|
|
189
|
+
pads_begin = [0] * len(input_shape)
|
|
190
|
+
pads_end = [0] * len(input_shape)
|
|
191
|
+
for index, axis in enumerate(axes):
|
|
192
|
+
pads_begin[axis] = pads[index]
|
|
193
|
+
pads_end[axis] = pads[index + len(axes)]
|
|
194
|
+
else:
|
|
195
|
+
axes_len = axes_shape[0] if axes_shape is not None else 0
|
|
196
|
+
if pads_input is not None:
|
|
197
|
+
if pads_shape is None or pads_shape[0] != 2 * axes_len:
|
|
198
|
+
raise ShapeInferenceError(
|
|
199
|
+
"Pad pads must have length 2 * len(axes)"
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
if len(pads) != 2 * axes_len:
|
|
203
|
+
raise ShapeInferenceError(
|
|
204
|
+
"Pad pads must have length 2 * len(axes)"
|
|
205
|
+
)
|
|
206
|
+
pads_values = pads
|
|
207
|
+
|
|
208
|
+
if pads_begin is not None and pads_end is not None:
|
|
209
|
+
if any(value < 0 for value in pads_begin + pads_end):
|
|
210
|
+
raise UnsupportedOpError("Pad pads must be non-negative")
|
|
211
|
+
|
|
212
|
+
expected_shape = tuple(
|
|
213
|
+
dim + pad_before + pad_after
|
|
214
|
+
for dim, pad_before, pad_after in zip(
|
|
215
|
+
input_shape, pads_begin, pads_end
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
if output_shape != expected_shape:
|
|
219
|
+
raise ShapeInferenceError(
|
|
220
|
+
"Pad output shape mismatch: "
|
|
221
|
+
f"expected {expected_shape}, got {output_shape}"
|
|
222
|
+
)
|
|
223
|
+
elif pads_values is not None:
|
|
224
|
+
if any(value < 0 for value in pads_values):
|
|
225
|
+
raise UnsupportedOpError("Pad pads must be non-negative")
|
|
226
|
+
|
|
227
|
+
value_name = optional_name(node.inputs, 2)
|
|
228
|
+
pad_value = None
|
|
229
|
+
value_input = None
|
|
230
|
+
value_input_shape = None
|
|
231
|
+
if value_name:
|
|
232
|
+
pad_value = _read_scalar_initializer(
|
|
233
|
+
graph, value_name, node, dtype=input_dtype
|
|
234
|
+
)
|
|
235
|
+
if pad_value is None:
|
|
236
|
+
value_input_shape = value_shape(graph, value_name, node)
|
|
237
|
+
input_value_dtype = value_dtype(graph, value_name, node)
|
|
238
|
+
if input_value_dtype != input_dtype:
|
|
239
|
+
raise UnsupportedOpError(
|
|
240
|
+
"Pad value input must match input dtype, "
|
|
241
|
+
f"got {input_value_dtype.onnx_name}"
|
|
242
|
+
)
|
|
243
|
+
if value_input_shape:
|
|
244
|
+
raise UnsupportedOpError("Pad value input must be a scalar")
|
|
245
|
+
value_input = value_name
|
|
246
|
+
elif "value" in node.attrs:
|
|
247
|
+
pad_value = node.attrs["value"]
|
|
248
|
+
if pad_value is None and value_input is None:
|
|
249
|
+
pad_value = _default_pad_value(input_dtype)
|
|
250
|
+
|
|
251
|
+
return PadOp(
|
|
252
|
+
input0=input_name,
|
|
253
|
+
output=node.outputs[0],
|
|
254
|
+
input_shape=input_shape,
|
|
255
|
+
output_shape=output_shape,
|
|
256
|
+
pads_begin=(
|
|
257
|
+
tuple(int(value) for value in pads_begin)
|
|
258
|
+
if pads_begin is not None
|
|
259
|
+
else None
|
|
260
|
+
),
|
|
261
|
+
pads_end=(
|
|
262
|
+
tuple(int(value) for value in pads_end)
|
|
263
|
+
if pads_end is not None
|
|
264
|
+
else None
|
|
265
|
+
),
|
|
266
|
+
pads_input=pads_input,
|
|
267
|
+
pads_shape=pads_shape,
|
|
268
|
+
pads_dtype=pads_dtype,
|
|
269
|
+
pads_axis_map=(
|
|
270
|
+
tuple(pads_axis_map) if pads_axis_map is not None else None
|
|
271
|
+
),
|
|
272
|
+
pads_values=(
|
|
273
|
+
tuple(int(value) for value in pads_values)
|
|
274
|
+
if pads_values is not None
|
|
275
|
+
else None
|
|
276
|
+
),
|
|
277
|
+
axes_input=axes_input,
|
|
278
|
+
axes_shape=axes_shape,
|
|
279
|
+
axes_dtype=axes_dtype,
|
|
280
|
+
mode=mode,
|
|
281
|
+
value=pad_value,
|
|
282
|
+
value_input=value_input,
|
|
283
|
+
value_shape=value_input_shape,
|
|
284
|
+
dtype=output_dtype,
|
|
285
|
+
input_dtype=input_dtype,
|
|
286
|
+
input_strides=_compute_strides(input_shape),
|
|
287
|
+
)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from shared.scalar_types import ScalarType
|
|
8
|
+
|
|
9
|
+
from ..codegen.c_emitter import RangeOp
|
|
10
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
11
|
+
from ..ir.model import Graph, Initializer, Node
|
|
12
|
+
from ..lowering.common import node_dtype, value_shape
|
|
13
|
+
from .registry import register_lowering
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_SUPPORTED_RANGE_DTYPES = {
|
|
17
|
+
ScalarType.F32,
|
|
18
|
+
ScalarType.F64,
|
|
19
|
+
ScalarType.I16,
|
|
20
|
+
ScalarType.I32,
|
|
21
|
+
ScalarType.I64,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
26
|
+
for initializer in graph.initializers:
|
|
27
|
+
if initializer.name == name:
|
|
28
|
+
return initializer
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _read_scalar_initializer(
|
|
33
|
+
graph: Graph, name: str, node: Node, label: str
|
|
34
|
+
) -> float | int | None:
|
|
35
|
+
initializer = _find_initializer(graph, name)
|
|
36
|
+
if initializer is None:
|
|
37
|
+
return None
|
|
38
|
+
data = np.array(initializer.data)
|
|
39
|
+
if data.size != 1:
|
|
40
|
+
raise UnsupportedOpError(
|
|
41
|
+
f"{node.op_type} {label} input must be a scalar"
|
|
42
|
+
)
|
|
43
|
+
return data.reshape(-1)[0].item()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
|
|
47
|
+
return shape == () or shape == (1,)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@register_lowering("Range")
|
|
51
|
+
def lower_range(graph: Graph, node: Node) -> RangeOp:
|
|
52
|
+
if len(node.inputs) != 3 or len(node.outputs) != 1:
|
|
53
|
+
raise UnsupportedOpError("Range must have 3 inputs and 1 output")
|
|
54
|
+
start_shape = value_shape(graph, node.inputs[0], node)
|
|
55
|
+
limit_shape = value_shape(graph, node.inputs[1], node)
|
|
56
|
+
delta_shape = value_shape(graph, node.inputs[2], node)
|
|
57
|
+
if not (
|
|
58
|
+
_is_scalar_shape(start_shape)
|
|
59
|
+
and _is_scalar_shape(limit_shape)
|
|
60
|
+
and _is_scalar_shape(delta_shape)
|
|
61
|
+
):
|
|
62
|
+
raise UnsupportedOpError("Range inputs must be scalars")
|
|
63
|
+
dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
64
|
+
if dtype not in _SUPPORTED_RANGE_DTYPES:
|
|
65
|
+
raise UnsupportedOpError(
|
|
66
|
+
f"Range does not support dtype {dtype.onnx_name}"
|
|
67
|
+
)
|
|
68
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
69
|
+
if len(output_shape) != 1:
|
|
70
|
+
raise ShapeInferenceError("Range output must be 1D")
|
|
71
|
+
start_value = _read_scalar_initializer(graph, node.inputs[0], node, "start")
|
|
72
|
+
limit_value = _read_scalar_initializer(graph, node.inputs[1], node, "limit")
|
|
73
|
+
delta_value = _read_scalar_initializer(graph, node.inputs[2], node, "delta")
|
|
74
|
+
if (
|
|
75
|
+
start_value is not None
|
|
76
|
+
and limit_value is not None
|
|
77
|
+
and delta_value is not None
|
|
78
|
+
):
|
|
79
|
+
if float(delta_value) == 0.0:
|
|
80
|
+
raise UnsupportedOpError("Range delta must be non-zero")
|
|
81
|
+
raw_count = (
|
|
82
|
+
float(limit_value) - float(start_value)
|
|
83
|
+
) / float(delta_value)
|
|
84
|
+
length = max(int(math.ceil(raw_count)), 0)
|
|
85
|
+
if length < 0:
|
|
86
|
+
raise ShapeInferenceError("Range output length must be non-negative")
|
|
87
|
+
if output_shape[0] != length:
|
|
88
|
+
raise ShapeInferenceError(
|
|
89
|
+
f"Range output length must be {length}, got {output_shape[0]}"
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
length = output_shape[0]
|
|
93
|
+
if length < 0:
|
|
94
|
+
raise ShapeInferenceError("Range output length must be non-negative")
|
|
95
|
+
return RangeOp(
|
|
96
|
+
start=node.inputs[0],
|
|
97
|
+
limit=node.inputs[1],
|
|
98
|
+
delta=node.inputs[2],
|
|
99
|
+
output=node.outputs[0],
|
|
100
|
+
output_shape=output_shape,
|
|
101
|
+
length=length,
|
|
102
|
+
dtype=dtype,
|
|
103
|
+
input_dtype=dtype,
|
|
104
|
+
)
|