emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import ShapeOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
12
|
+
try:
|
|
13
|
+
return graph.find_value(name).type.shape
|
|
14
|
+
except KeyError as exc:
|
|
15
|
+
raise ShapeInferenceError(
|
|
16
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
17
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
18
|
+
) from exc
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
22
|
+
try:
|
|
23
|
+
return graph.find_value(name).type.dtype
|
|
24
|
+
except KeyError as exc:
|
|
25
|
+
raise ShapeInferenceError(
|
|
26
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
27
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
28
|
+
) from exc
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _normalize_slice_bounds(
|
|
32
|
+
rank: int, *, start: int | None, end: int | None
|
|
33
|
+
) -> tuple[int, int]:
|
|
34
|
+
normalized_start = 0 if start is None else int(start)
|
|
35
|
+
normalized_end = rank if end is None else int(end)
|
|
36
|
+
if normalized_start < 0:
|
|
37
|
+
normalized_start += rank
|
|
38
|
+
if normalized_end < 0:
|
|
39
|
+
normalized_end += rank
|
|
40
|
+
normalized_start = max(0, min(normalized_start, rank))
|
|
41
|
+
normalized_end = max(0, min(normalized_end, rank))
|
|
42
|
+
return normalized_start, normalized_end
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@register_lowering("Shape")
|
|
46
|
+
def lower_shape(graph: Graph, node: Node) -> ShapeOp:
|
|
47
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
48
|
+
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
49
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
50
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
51
|
+
if len(output_shape) != 1:
|
|
52
|
+
raise ShapeInferenceError("Shape output must be 1D")
|
|
53
|
+
if output_shape[0] < 0:
|
|
54
|
+
raise ShapeInferenceError("Shape output length must be non-negative")
|
|
55
|
+
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
56
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
57
|
+
if output_dtype != ScalarType.I64:
|
|
58
|
+
raise UnsupportedOpError("Shape output dtype must be int64")
|
|
59
|
+
start = node.attrs.get("start")
|
|
60
|
+
end = node.attrs.get("end")
|
|
61
|
+
start_index, end_index = _normalize_slice_bounds(
|
|
62
|
+
len(input_shape), start=start, end=end
|
|
63
|
+
)
|
|
64
|
+
expected_shape = (max(0, end_index - start_index),)
|
|
65
|
+
if expected_shape != output_shape:
|
|
66
|
+
raise ShapeInferenceError(
|
|
67
|
+
"Shape output shape must be "
|
|
68
|
+
f"{expected_shape}, got {output_shape}"
|
|
69
|
+
)
|
|
70
|
+
return ShapeOp(
|
|
71
|
+
input0=node.inputs[0],
|
|
72
|
+
output=node.outputs[0],
|
|
73
|
+
input_shape=input_shape,
|
|
74
|
+
output_shape=output_shape,
|
|
75
|
+
values=input_shape[start_index:end_index],
|
|
76
|
+
dtype=output_dtype,
|
|
77
|
+
input_dtype=input_dtype,
|
|
78
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import SizeOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import shape_product, value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@register_lowering("Size")
|
|
13
|
+
def lower_size(graph: Graph, node: Node) -> SizeOp:
|
|
14
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
15
|
+
raise UnsupportedOpError("Size must have 1 input and 1 output")
|
|
16
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
17
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
18
|
+
if len(output_shape) != 0:
|
|
19
|
+
raise ShapeInferenceError("Size output must be a scalar")
|
|
20
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
21
|
+
if output_dtype != ScalarType.I64:
|
|
22
|
+
raise UnsupportedOpError("Size output dtype must be int64")
|
|
23
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
24
|
+
element_count = shape_product(input_shape)
|
|
25
|
+
return SizeOp(
|
|
26
|
+
input0=node.inputs[0],
|
|
27
|
+
output=node.outputs[0],
|
|
28
|
+
input_shape=input_shape,
|
|
29
|
+
output_shape=output_shape,
|
|
30
|
+
value=element_count,
|
|
31
|
+
dtype=output_dtype,
|
|
32
|
+
input_dtype=input_dtype,
|
|
33
|
+
)
|
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from shared.scalar_types import ScalarType
|
|
8
|
+
|
|
9
|
+
from ..codegen.c_emitter import SliceOp
|
|
10
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
11
|
+
from ..ir.model import Graph, Initializer, Node
|
|
12
|
+
from ..lowering.common import value_dtype, value_shape
|
|
13
|
+
from ..validation import normalize_axis
|
|
14
|
+
from .registry import register_lowering
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class SliceSpec:
|
|
19
|
+
input_shape: tuple[int, ...]
|
|
20
|
+
output_shape: tuple[int, ...]
|
|
21
|
+
starts: tuple[int, ...]
|
|
22
|
+
steps: tuple[int, ...]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class SliceInputs:
|
|
27
|
+
starts: list[int] | None
|
|
28
|
+
ends: list[int] | None
|
|
29
|
+
axes: list[int] | None
|
|
30
|
+
steps: list[int] | None
|
|
31
|
+
starts_input: str | None
|
|
32
|
+
ends_input: str | None
|
|
33
|
+
axes_input: str | None
|
|
34
|
+
steps_input: str | None
|
|
35
|
+
starts_shape: tuple[int, ...] | None
|
|
36
|
+
ends_shape: tuple[int, ...] | None
|
|
37
|
+
axes_shape: tuple[int, ...] | None
|
|
38
|
+
steps_shape: tuple[int, ...] | None
|
|
39
|
+
starts_dtype: ScalarType | None
|
|
40
|
+
ends_dtype: ScalarType | None
|
|
41
|
+
axes_dtype: ScalarType | None
|
|
42
|
+
steps_dtype: ScalarType | None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
46
|
+
for initializer in graph.initializers:
|
|
47
|
+
if initializer.name == name:
|
|
48
|
+
return initializer
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _read_int_list(
|
|
53
|
+
graph: Graph, name: str, node: Node, *, label: str
|
|
54
|
+
) -> list[int]:
|
|
55
|
+
initializer = _find_initializer(graph, name)
|
|
56
|
+
if initializer is None:
|
|
57
|
+
raise UnsupportedOpError(
|
|
58
|
+
f"{node.op_type} {label} input must be a constant initializer"
|
|
59
|
+
)
|
|
60
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
61
|
+
raise UnsupportedOpError(
|
|
62
|
+
f"{node.op_type} {label} input must be int64 or int32"
|
|
63
|
+
)
|
|
64
|
+
data = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
65
|
+
return [int(value) for value in data]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _maybe_read_int_list(
|
|
69
|
+
graph: Graph, name: str, node: Node, *, label: str
|
|
70
|
+
) -> list[int] | None:
|
|
71
|
+
initializer = _find_initializer(graph, name)
|
|
72
|
+
if initializer is None:
|
|
73
|
+
return None
|
|
74
|
+
return _read_int_list(graph, name, node, label=label)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _validate_int_input(
|
|
78
|
+
graph: Graph, name: str, node: Node, *, label: str
|
|
79
|
+
) -> tuple[tuple[int, ...], ScalarType]:
|
|
80
|
+
dtype = value_dtype(graph, name, node)
|
|
81
|
+
if dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
82
|
+
raise UnsupportedOpError(
|
|
83
|
+
f"{node.op_type} {label} input must be int64 or int32"
|
|
84
|
+
)
|
|
85
|
+
shape = value_shape(graph, name, node)
|
|
86
|
+
if len(shape) != 1:
|
|
87
|
+
raise UnsupportedOpError(
|
|
88
|
+
f"{node.op_type} {label} input must be a 1D tensor"
|
|
89
|
+
)
|
|
90
|
+
return shape, dtype
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _resolve_inputs(
|
|
94
|
+
graph: Graph, node: Node
|
|
95
|
+
) -> SliceInputs:
|
|
96
|
+
if "starts" in node.attrs or "ends" in node.attrs:
|
|
97
|
+
if len(node.inputs) != 1:
|
|
98
|
+
raise UnsupportedOpError(
|
|
99
|
+
f"{node.op_type} with starts/ends attributes expects 1 input"
|
|
100
|
+
)
|
|
101
|
+
if "starts" not in node.attrs or "ends" not in node.attrs:
|
|
102
|
+
raise UnsupportedOpError(
|
|
103
|
+
f"{node.op_type} must specify both starts and ends"
|
|
104
|
+
)
|
|
105
|
+
starts = [int(value) for value in node.attrs.get("starts", [])]
|
|
106
|
+
ends = [int(value) for value in node.attrs.get("ends", [])]
|
|
107
|
+
axes_attr = node.attrs.get("axes")
|
|
108
|
+
axes = [int(value) for value in axes_attr] if axes_attr else None
|
|
109
|
+
steps = None
|
|
110
|
+
return SliceInputs(
|
|
111
|
+
starts=starts,
|
|
112
|
+
ends=ends,
|
|
113
|
+
axes=axes,
|
|
114
|
+
steps=steps,
|
|
115
|
+
starts_input=None,
|
|
116
|
+
ends_input=None,
|
|
117
|
+
axes_input=None,
|
|
118
|
+
steps_input=None,
|
|
119
|
+
starts_shape=None,
|
|
120
|
+
ends_shape=None,
|
|
121
|
+
axes_shape=None,
|
|
122
|
+
steps_shape=None,
|
|
123
|
+
starts_dtype=None,
|
|
124
|
+
ends_dtype=None,
|
|
125
|
+
axes_dtype=None,
|
|
126
|
+
steps_dtype=None,
|
|
127
|
+
)
|
|
128
|
+
if len(node.inputs) < 3:
|
|
129
|
+
raise UnsupportedOpError(
|
|
130
|
+
f"{node.op_type} expects at least 3 inputs"
|
|
131
|
+
)
|
|
132
|
+
starts_name = node.inputs[1]
|
|
133
|
+
ends_name = node.inputs[2]
|
|
134
|
+
axes_name = node.inputs[3] if len(node.inputs) >= 4 else ""
|
|
135
|
+
steps_name = node.inputs[4] if len(node.inputs) >= 5 else ""
|
|
136
|
+
starts = _maybe_read_int_list(graph, starts_name, node, label="starts")
|
|
137
|
+
ends = _maybe_read_int_list(graph, ends_name, node, label="ends")
|
|
138
|
+
axes = (
|
|
139
|
+
_maybe_read_int_list(graph, axes_name, node, label="axes")
|
|
140
|
+
if axes_name
|
|
141
|
+
else None
|
|
142
|
+
)
|
|
143
|
+
steps = (
|
|
144
|
+
_maybe_read_int_list(graph, steps_name, node, label="steps")
|
|
145
|
+
if steps_name
|
|
146
|
+
else None
|
|
147
|
+
)
|
|
148
|
+
if starts is not None and ends is not None:
|
|
149
|
+
return SliceInputs(
|
|
150
|
+
starts=starts,
|
|
151
|
+
ends=ends,
|
|
152
|
+
axes=axes,
|
|
153
|
+
steps=steps,
|
|
154
|
+
starts_input=None,
|
|
155
|
+
ends_input=None,
|
|
156
|
+
axes_input=None,
|
|
157
|
+
steps_input=None,
|
|
158
|
+
starts_shape=None,
|
|
159
|
+
ends_shape=None,
|
|
160
|
+
axes_shape=None,
|
|
161
|
+
steps_shape=None,
|
|
162
|
+
starts_dtype=None,
|
|
163
|
+
ends_dtype=None,
|
|
164
|
+
axes_dtype=None,
|
|
165
|
+
steps_dtype=None,
|
|
166
|
+
)
|
|
167
|
+
if starts is None or ends is None:
|
|
168
|
+
starts_shape, starts_dtype = _validate_int_input(
|
|
169
|
+
graph, starts_name, node, label="starts"
|
|
170
|
+
)
|
|
171
|
+
ends_shape, ends_dtype = _validate_int_input(
|
|
172
|
+
graph, ends_name, node, label="ends"
|
|
173
|
+
)
|
|
174
|
+
if starts_shape != ends_shape:
|
|
175
|
+
raise ShapeInferenceError(
|
|
176
|
+
f"{node.op_type} starts and ends must have matching shapes"
|
|
177
|
+
)
|
|
178
|
+
axes_shape = None
|
|
179
|
+
axes_dtype = None
|
|
180
|
+
steps_shape = None
|
|
181
|
+
steps_dtype = None
|
|
182
|
+
axes_input = None
|
|
183
|
+
steps_input = None
|
|
184
|
+
if axes_name:
|
|
185
|
+
axes_shape, axes_dtype = _validate_int_input(
|
|
186
|
+
graph, axes_name, node, label="axes"
|
|
187
|
+
)
|
|
188
|
+
if axes_shape != starts_shape:
|
|
189
|
+
raise ShapeInferenceError(
|
|
190
|
+
f"{node.op_type} axes must match starts length"
|
|
191
|
+
)
|
|
192
|
+
axes_input = axes_name
|
|
193
|
+
if steps_name:
|
|
194
|
+
steps_shape, steps_dtype = _validate_int_input(
|
|
195
|
+
graph, steps_name, node, label="steps"
|
|
196
|
+
)
|
|
197
|
+
if steps_shape != starts_shape:
|
|
198
|
+
raise ShapeInferenceError(
|
|
199
|
+
f"{node.op_type} steps must match starts length"
|
|
200
|
+
)
|
|
201
|
+
steps_input = steps_name
|
|
202
|
+
return SliceInputs(
|
|
203
|
+
starts=None,
|
|
204
|
+
ends=None,
|
|
205
|
+
axes=None,
|
|
206
|
+
steps=None,
|
|
207
|
+
starts_input=starts_name,
|
|
208
|
+
ends_input=ends_name,
|
|
209
|
+
axes_input=axes_input,
|
|
210
|
+
steps_input=steps_input,
|
|
211
|
+
starts_shape=starts_shape,
|
|
212
|
+
ends_shape=ends_shape,
|
|
213
|
+
axes_shape=axes_shape,
|
|
214
|
+
steps_shape=steps_shape,
|
|
215
|
+
starts_dtype=starts_dtype,
|
|
216
|
+
ends_dtype=ends_dtype,
|
|
217
|
+
axes_dtype=axes_dtype,
|
|
218
|
+
steps_dtype=steps_dtype,
|
|
219
|
+
)
|
|
220
|
+
raise UnsupportedOpError(
|
|
221
|
+
f"{node.op_type} starts and ends inputs must both be constant initializers"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _normalize_slices(
|
|
226
|
+
input_shape: tuple[int, ...],
|
|
227
|
+
starts: list[int],
|
|
228
|
+
ends: list[int],
|
|
229
|
+
axes: list[int] | None,
|
|
230
|
+
steps: list[int] | None,
|
|
231
|
+
node: Node,
|
|
232
|
+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
233
|
+
rank = len(input_shape)
|
|
234
|
+
if rank == 0:
|
|
235
|
+
raise ShapeInferenceError(
|
|
236
|
+
f"{node.op_type} does not support scalar inputs"
|
|
237
|
+
)
|
|
238
|
+
if len(starts) != len(ends):
|
|
239
|
+
raise ShapeInferenceError(
|
|
240
|
+
f"{node.op_type} starts and ends must have matching lengths"
|
|
241
|
+
)
|
|
242
|
+
if axes is None:
|
|
243
|
+
axes = list(range(len(starts)))
|
|
244
|
+
if steps is None:
|
|
245
|
+
steps = [1] * len(starts)
|
|
246
|
+
if len(axes) != len(starts) or len(steps) != len(starts):
|
|
247
|
+
raise ShapeInferenceError(
|
|
248
|
+
f"{node.op_type} axes and steps must match starts length"
|
|
249
|
+
)
|
|
250
|
+
normalized_starts = [0] * rank
|
|
251
|
+
normalized_steps = [1] * rank
|
|
252
|
+
output_shape = list(input_shape)
|
|
253
|
+
seen_axes: set[int] = set()
|
|
254
|
+
for index, axis in enumerate(axes):
|
|
255
|
+
normalized_axis = normalize_axis(int(axis), input_shape, node)
|
|
256
|
+
if normalized_axis in seen_axes:
|
|
257
|
+
raise ShapeInferenceError(
|
|
258
|
+
f"{node.op_type} axes must be unique"
|
|
259
|
+
)
|
|
260
|
+
seen_axes.add(normalized_axis)
|
|
261
|
+
dim = input_shape[normalized_axis]
|
|
262
|
+
if dim < 0:
|
|
263
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
264
|
+
step = int(steps[index])
|
|
265
|
+
if step == 0:
|
|
266
|
+
raise UnsupportedOpError(
|
|
267
|
+
f"{node.op_type} steps must be non-zero"
|
|
268
|
+
)
|
|
269
|
+
if step < 0:
|
|
270
|
+
raise UnsupportedOpError(
|
|
271
|
+
f"{node.op_type} only supports positive steps"
|
|
272
|
+
)
|
|
273
|
+
start = int(starts[index])
|
|
274
|
+
end = int(ends[index])
|
|
275
|
+
if start < 0:
|
|
276
|
+
start += dim
|
|
277
|
+
if end < 0:
|
|
278
|
+
end += dim
|
|
279
|
+
start = max(0, min(start, dim))
|
|
280
|
+
end = max(0, min(end, dim))
|
|
281
|
+
length = max(0, (end - start + step - 1) // step)
|
|
282
|
+
normalized_starts[normalized_axis] = start
|
|
283
|
+
normalized_steps[normalized_axis] = step
|
|
284
|
+
output_shape[normalized_axis] = length
|
|
285
|
+
return (
|
|
286
|
+
tuple(normalized_starts),
|
|
287
|
+
tuple(normalized_steps),
|
|
288
|
+
tuple(output_shape),
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def resolve_slice_spec(graph: Graph, node: Node) -> SliceSpec:
|
|
293
|
+
if len(node.inputs) < 1 or len(node.outputs) != 1:
|
|
294
|
+
raise UnsupportedOpError("Slice must have 1 output")
|
|
295
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
296
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
297
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
298
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
299
|
+
if input_dtype != output_dtype:
|
|
300
|
+
raise UnsupportedOpError(
|
|
301
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
302
|
+
f"got {input_dtype} and {output_dtype}"
|
|
303
|
+
)
|
|
304
|
+
if any(dim < 0 for dim in input_shape):
|
|
305
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
306
|
+
if any(dim < 0 for dim in output_shape):
|
|
307
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
308
|
+
inputs = _resolve_inputs(graph, node)
|
|
309
|
+
if inputs.starts is None or inputs.ends is None:
|
|
310
|
+
raise UnsupportedOpError(
|
|
311
|
+
f"{node.op_type} starts/ends inputs must be constant for shape "
|
|
312
|
+
"inference"
|
|
313
|
+
)
|
|
314
|
+
starts = inputs.starts
|
|
315
|
+
ends = inputs.ends
|
|
316
|
+
axes = inputs.axes
|
|
317
|
+
steps = inputs.steps
|
|
318
|
+
normalized_starts, normalized_steps, computed_output_shape = _normalize_slices(
|
|
319
|
+
input_shape, starts, ends, axes, steps, node
|
|
320
|
+
)
|
|
321
|
+
if output_shape and computed_output_shape != output_shape:
|
|
322
|
+
raise ShapeInferenceError(
|
|
323
|
+
f"{node.op_type} output shape must be "
|
|
324
|
+
f"{computed_output_shape}, got {output_shape}"
|
|
325
|
+
)
|
|
326
|
+
return SliceSpec(
|
|
327
|
+
input_shape=input_shape,
|
|
328
|
+
output_shape=computed_output_shape,
|
|
329
|
+
starts=normalized_starts,
|
|
330
|
+
steps=normalized_steps,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@register_lowering("Slice")
|
|
335
|
+
def lower_slice(graph: Graph, node: Node) -> SliceOp:
|
|
336
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
337
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
338
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
339
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
340
|
+
if input_dtype != output_dtype:
|
|
341
|
+
raise UnsupportedOpError(
|
|
342
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
343
|
+
f"got {input_dtype} and {output_dtype}"
|
|
344
|
+
)
|
|
345
|
+
if any(dim < 0 for dim in input_shape):
|
|
346
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
347
|
+
if any(dim < 0 for dim in output_shape):
|
|
348
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
349
|
+
inputs = _resolve_inputs(graph, node)
|
|
350
|
+
if inputs.starts is not None and inputs.ends is not None:
|
|
351
|
+
normalized_starts, normalized_steps, computed_output_shape = _normalize_slices(
|
|
352
|
+
input_shape, inputs.starts, inputs.ends, inputs.axes, inputs.steps, node
|
|
353
|
+
)
|
|
354
|
+
if output_shape and computed_output_shape != output_shape:
|
|
355
|
+
raise ShapeInferenceError(
|
|
356
|
+
f"{node.op_type} output shape must be "
|
|
357
|
+
f"{computed_output_shape}, got {output_shape}"
|
|
358
|
+
)
|
|
359
|
+
return SliceOp(
|
|
360
|
+
input0=node.inputs[0],
|
|
361
|
+
output=node.outputs[0],
|
|
362
|
+
input_shape=input_shape,
|
|
363
|
+
output_shape=computed_output_shape,
|
|
364
|
+
starts=normalized_starts,
|
|
365
|
+
steps=normalized_steps,
|
|
366
|
+
axes=None,
|
|
367
|
+
starts_input=None,
|
|
368
|
+
ends_input=None,
|
|
369
|
+
axes_input=None,
|
|
370
|
+
steps_input=None,
|
|
371
|
+
starts_shape=None,
|
|
372
|
+
ends_shape=None,
|
|
373
|
+
axes_shape=None,
|
|
374
|
+
steps_shape=None,
|
|
375
|
+
starts_dtype=None,
|
|
376
|
+
ends_dtype=None,
|
|
377
|
+
axes_dtype=None,
|
|
378
|
+
steps_dtype=None,
|
|
379
|
+
dtype=input_dtype,
|
|
380
|
+
input_dtype=input_dtype,
|
|
381
|
+
)
|
|
382
|
+
if len(output_shape) != len(input_shape):
|
|
383
|
+
raise ShapeInferenceError(
|
|
384
|
+
f"{node.op_type} output rank must match input rank"
|
|
385
|
+
)
|
|
386
|
+
if inputs.starts_shape is None or inputs.ends_shape is None:
|
|
387
|
+
raise UnsupportedOpError(
|
|
388
|
+
f"{node.op_type} starts and ends inputs must be provided"
|
|
389
|
+
)
|
|
390
|
+
if inputs.starts_shape != inputs.ends_shape:
|
|
391
|
+
raise ShapeInferenceError(
|
|
392
|
+
f"{node.op_type} starts and ends must have matching shapes"
|
|
393
|
+
)
|
|
394
|
+
starts_len = inputs.starts_shape[0]
|
|
395
|
+
if starts_len > len(input_shape):
|
|
396
|
+
raise ShapeInferenceError(
|
|
397
|
+
f"{node.op_type} starts length exceeds input rank"
|
|
398
|
+
)
|
|
399
|
+
if starts_len == 0 and output_shape != input_shape:
|
|
400
|
+
raise ShapeInferenceError(
|
|
401
|
+
f"{node.op_type} empty starts expects output shape to match input"
|
|
402
|
+
)
|
|
403
|
+
return SliceOp(
|
|
404
|
+
input0=node.inputs[0],
|
|
405
|
+
output=node.outputs[0],
|
|
406
|
+
input_shape=input_shape,
|
|
407
|
+
output_shape=output_shape,
|
|
408
|
+
starts=None,
|
|
409
|
+
steps=None,
|
|
410
|
+
axes=None,
|
|
411
|
+
starts_input=inputs.starts_input,
|
|
412
|
+
ends_input=inputs.ends_input,
|
|
413
|
+
axes_input=inputs.axes_input,
|
|
414
|
+
steps_input=inputs.steps_input,
|
|
415
|
+
starts_shape=inputs.starts_shape,
|
|
416
|
+
ends_shape=inputs.ends_shape,
|
|
417
|
+
axes_shape=inputs.axes_shape,
|
|
418
|
+
steps_shape=inputs.steps_shape,
|
|
419
|
+
starts_dtype=inputs.starts_dtype,
|
|
420
|
+
ends_dtype=inputs.ends_dtype,
|
|
421
|
+
axes_dtype=inputs.axes_dtype,
|
|
422
|
+
steps_dtype=inputs.steps_dtype,
|
|
423
|
+
dtype=input_dtype,
|
|
424
|
+
input_dtype=input_dtype,
|
|
425
|
+
)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..codegen.c_emitter import SoftmaxOp
|
|
4
|
+
from ..errors import UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from .common import node_dtype as _node_dtype
|
|
7
|
+
from .common import shape_product as _shape_product
|
|
8
|
+
from .common import value_shape as _value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
from ..validation import ensure_output_shape_matches_input
|
|
11
|
+
from ..validation import normalize_axis as _normalize_axis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_lowering("Softmax")
|
|
15
|
+
def lower_softmax(graph: Graph, node: Node) -> SoftmaxOp:
|
|
16
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
17
|
+
raise UnsupportedOpError("Softmax must have 1 input and 1 output")
|
|
18
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
19
|
+
if not op_dtype.is_float:
|
|
20
|
+
raise UnsupportedOpError(
|
|
21
|
+
"Softmax supports float16, float, and double inputs only"
|
|
22
|
+
)
|
|
23
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
24
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
25
|
+
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
26
|
+
axis = _normalize_axis(
|
|
27
|
+
int(node.attrs.get("axis", -1)),
|
|
28
|
+
input_shape,
|
|
29
|
+
node,
|
|
30
|
+
)
|
|
31
|
+
outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
|
|
32
|
+
axis_size = input_shape[axis]
|
|
33
|
+
inner = (
|
|
34
|
+
_shape_product(input_shape[axis + 1 :])
|
|
35
|
+
if axis + 1 < len(input_shape)
|
|
36
|
+
else 1
|
|
37
|
+
)
|
|
38
|
+
return SoftmaxOp(
|
|
39
|
+
input0=node.inputs[0],
|
|
40
|
+
output=node.outputs[0],
|
|
41
|
+
outer=outer,
|
|
42
|
+
axis_size=axis_size,
|
|
43
|
+
inner=inner,
|
|
44
|
+
axis=axis,
|
|
45
|
+
shape=input_shape,
|
|
46
|
+
dtype=op_dtype,
|
|
47
|
+
)
|