emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import ReshapeOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Initializer, Node
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
12
|
+
try:
|
|
13
|
+
return graph.find_value(name).type.shape
|
|
14
|
+
except KeyError as exc:
|
|
15
|
+
raise ShapeInferenceError(
|
|
16
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
17
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
18
|
+
) from exc
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
22
|
+
try:
|
|
23
|
+
return graph.find_value(name).type.dtype
|
|
24
|
+
except KeyError as exc:
|
|
25
|
+
raise ShapeInferenceError(
|
|
26
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
27
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
28
|
+
) from exc
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _shape_product(shape: tuple[int, ...]) -> int:
|
|
32
|
+
product = 1
|
|
33
|
+
for dim in shape:
|
|
34
|
+
if dim < 0:
|
|
35
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
36
|
+
product *= dim
|
|
37
|
+
return product
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
41
|
+
for initializer in graph.initializers:
|
|
42
|
+
if initializer.name == name:
|
|
43
|
+
return initializer
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _find_node_by_output(graph: Graph, name: str) -> Node | None:
|
|
48
|
+
for node in graph.nodes:
|
|
49
|
+
if name in node.outputs:
|
|
50
|
+
return node
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _shape_values_from_shape_node(
|
|
55
|
+
graph: Graph, name: str, node: Node
|
|
56
|
+
) -> list[int] | None:
|
|
57
|
+
shape_node = _find_node_by_output(graph, name)
|
|
58
|
+
if shape_node is None or shape_node.op_type != "Shape":
|
|
59
|
+
return None
|
|
60
|
+
if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
|
|
61
|
+
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
62
|
+
source_shape = _value_shape(graph, shape_node.inputs[0], node)
|
|
63
|
+
return list(source_shape)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _resolve_target_shape(
|
|
67
|
+
input_shape: tuple[int, ...],
|
|
68
|
+
shape_values: list[int],
|
|
69
|
+
*,
|
|
70
|
+
allowzero: int,
|
|
71
|
+
node: Node,
|
|
72
|
+
) -> tuple[int, ...]:
|
|
73
|
+
if allowzero not in (0, 1):
|
|
74
|
+
raise UnsupportedOpError("Reshape allowzero must be 0 or 1")
|
|
75
|
+
output_dims: list[int] = []
|
|
76
|
+
unknown_index: int | None = None
|
|
77
|
+
known_product = 1
|
|
78
|
+
contains_zero = False
|
|
79
|
+
for index, dim in enumerate(shape_values):
|
|
80
|
+
if dim == -1:
|
|
81
|
+
if unknown_index is not None:
|
|
82
|
+
raise ShapeInferenceError("Reshape allows only one -1 dimension")
|
|
83
|
+
unknown_index = index
|
|
84
|
+
output_dims.append(-1)
|
|
85
|
+
continue
|
|
86
|
+
if dim == 0:
|
|
87
|
+
contains_zero = True
|
|
88
|
+
if allowzero == 0:
|
|
89
|
+
if index >= len(input_shape):
|
|
90
|
+
raise ShapeInferenceError(
|
|
91
|
+
"Reshape zero dim must index into input shape"
|
|
92
|
+
)
|
|
93
|
+
dim = input_shape[index]
|
|
94
|
+
if dim < 0:
|
|
95
|
+
raise ShapeInferenceError("Reshape dims must be >= -1")
|
|
96
|
+
output_dims.append(dim)
|
|
97
|
+
known_product *= dim
|
|
98
|
+
if allowzero == 1 and contains_zero and unknown_index is not None:
|
|
99
|
+
raise ShapeInferenceError(
|
|
100
|
+
"Reshape allowzero cannot combine zero and -1 dimensions"
|
|
101
|
+
)
|
|
102
|
+
input_product = _shape_product(input_shape)
|
|
103
|
+
if unknown_index is not None:
|
|
104
|
+
if known_product == 0:
|
|
105
|
+
if input_product != 0:
|
|
106
|
+
raise ShapeInferenceError(
|
|
107
|
+
"Reshape cannot infer dimension from input shape"
|
|
108
|
+
)
|
|
109
|
+
output_dims[unknown_index] = 0
|
|
110
|
+
else:
|
|
111
|
+
if input_product % known_product != 0:
|
|
112
|
+
raise ShapeInferenceError(
|
|
113
|
+
"Reshape cannot infer dimension from input shape"
|
|
114
|
+
)
|
|
115
|
+
output_dims[unknown_index] = input_product // known_product
|
|
116
|
+
output_shape = tuple(output_dims)
|
|
117
|
+
if _shape_product(output_shape) != input_product:
|
|
118
|
+
raise ShapeInferenceError(
|
|
119
|
+
"Reshape input and output element counts must match"
|
|
120
|
+
)
|
|
121
|
+
return output_shape
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@register_lowering("Reshape")
|
|
125
|
+
def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
|
|
126
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
127
|
+
raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
|
|
128
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
129
|
+
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
130
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
131
|
+
if input_dtype != output_dtype:
|
|
132
|
+
raise UnsupportedOpError(
|
|
133
|
+
"Reshape expects matching input/output dtypes, "
|
|
134
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
135
|
+
)
|
|
136
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
137
|
+
allowzero = int(node.attrs.get("allowzero", 0))
|
|
138
|
+
shape_initializer = _find_initializer(graph, node.inputs[1])
|
|
139
|
+
resolved_shape: tuple[int, ...] | None = None
|
|
140
|
+
if shape_initializer is None:
|
|
141
|
+
shape_values = _shape_values_from_shape_node(
|
|
142
|
+
graph, node.inputs[1], node
|
|
143
|
+
)
|
|
144
|
+
if shape_values is not None:
|
|
145
|
+
resolved_shape = _resolve_target_shape(
|
|
146
|
+
input_shape,
|
|
147
|
+
shape_values,
|
|
148
|
+
allowzero=allowzero,
|
|
149
|
+
node=node,
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
if _shape_product(output_shape) != _shape_product(input_shape):
|
|
153
|
+
raise ShapeInferenceError(
|
|
154
|
+
"Reshape input and output element counts must match"
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
if shape_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
158
|
+
raise UnsupportedOpError(
|
|
159
|
+
"Reshape expects int64 or int32 shape input, "
|
|
160
|
+
f"got {shape_initializer.type.dtype.onnx_name}"
|
|
161
|
+
)
|
|
162
|
+
if len(shape_initializer.type.shape) != 1:
|
|
163
|
+
raise UnsupportedOpError("Reshape expects a 1D shape input")
|
|
164
|
+
shape_values = [int(value) for value in shape_initializer.data.reshape(-1)]
|
|
165
|
+
resolved_shape = _resolve_target_shape(
|
|
166
|
+
input_shape,
|
|
167
|
+
shape_values,
|
|
168
|
+
allowzero=allowzero,
|
|
169
|
+
node=node,
|
|
170
|
+
)
|
|
171
|
+
if output_shape and resolved_shape != output_shape:
|
|
172
|
+
raise ShapeInferenceError(
|
|
173
|
+
"Reshape output shape must be "
|
|
174
|
+
f"{resolved_shape}, got {output_shape}"
|
|
175
|
+
)
|
|
176
|
+
if resolved_shape is not None:
|
|
177
|
+
output_shape = resolved_shape
|
|
178
|
+
for dim in output_shape:
|
|
179
|
+
if dim < 0:
|
|
180
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
181
|
+
return ReshapeOp(
|
|
182
|
+
input0=node.inputs[0],
|
|
183
|
+
output=node.outputs[0],
|
|
184
|
+
input_shape=input_shape,
|
|
185
|
+
output_shape=output_shape,
|
|
186
|
+
dtype=input_dtype,
|
|
187
|
+
input_dtype=input_dtype,
|
|
188
|
+
)
|
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..codegen.c_emitter import ResizeOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
_SUPPORTED_COORD_MODES = {
|
|
13
|
+
"half_pixel",
|
|
14
|
+
"half_pixel_symmetric",
|
|
15
|
+
"asymmetric",
|
|
16
|
+
"align_corners",
|
|
17
|
+
"pytorch_half_pixel",
|
|
18
|
+
"tf_crop_and_resize",
|
|
19
|
+
}
|
|
20
|
+
_SUPPORTED_MODES = {"nearest", "linear", "cubic"}
|
|
21
|
+
_SUPPORTED_NEAREST_MODES = {
|
|
22
|
+
"round_prefer_floor",
|
|
23
|
+
"round_prefer_ceil",
|
|
24
|
+
"floor",
|
|
25
|
+
"ceil",
|
|
26
|
+
}
|
|
27
|
+
_SUPPORTED_KEEP_ASPECT = {"stretch", "not_larger", "not_smaller"}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class _ResizeInputs:
|
|
32
|
+
roi: str | None
|
|
33
|
+
scales: str | None
|
|
34
|
+
sizes: str | None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class _ResolvedScales:
|
|
39
|
+
scales: tuple[float, ...]
|
|
40
|
+
output_shape: tuple[int, ...]
|
|
41
|
+
axes: tuple[int, ...]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass(frozen=True)
|
|
45
|
+
class _InputConfig:
|
|
46
|
+
input_shape: tuple[int, ...]
|
|
47
|
+
output_shape: tuple[int, ...]
|
|
48
|
+
input_dtype: ScalarType
|
|
49
|
+
output_dtype: ScalarType
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
53
|
+
try:
|
|
54
|
+
return graph.find_value(name).type.shape
|
|
55
|
+
except KeyError as exc:
|
|
56
|
+
raise ShapeInferenceError(
|
|
57
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
58
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
59
|
+
) from exc
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
63
|
+
try:
|
|
64
|
+
return graph.find_value(name).type.dtype
|
|
65
|
+
except KeyError as exc:
|
|
66
|
+
raise ShapeInferenceError(
|
|
67
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
68
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
69
|
+
) from exc
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
73
|
+
for initializer in graph.initializers:
|
|
74
|
+
if initializer.name == name:
|
|
75
|
+
return initializer
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _decode_attr(value: object, default: str) -> str:
|
|
80
|
+
if value is None:
|
|
81
|
+
return default
|
|
82
|
+
if isinstance(value, bytes):
|
|
83
|
+
return value.decode("utf-8", errors="ignore")
|
|
84
|
+
if isinstance(value, str):
|
|
85
|
+
return value
|
|
86
|
+
return str(value)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _normalize_axes(
|
|
90
|
+
axes: tuple[int, ...], rank: int, node: Node
|
|
91
|
+
) -> tuple[int, ...]:
|
|
92
|
+
normalized: list[int] = []
|
|
93
|
+
for axis in axes:
|
|
94
|
+
axis = int(axis)
|
|
95
|
+
if axis < 0:
|
|
96
|
+
axis += rank
|
|
97
|
+
if axis < 0 or axis >= rank:
|
|
98
|
+
raise ShapeInferenceError(
|
|
99
|
+
f"Resize axis {axis} is out of range for rank {rank}"
|
|
100
|
+
)
|
|
101
|
+
normalized.append(axis)
|
|
102
|
+
if len(set(normalized)) != len(normalized):
|
|
103
|
+
raise ShapeInferenceError("Resize axes must be unique")
|
|
104
|
+
return tuple(normalized)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _round_half_up(value: float) -> int:
|
|
108
|
+
return int(value + 0.5)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _parse_input_names(node: Node) -> _ResizeInputs:
|
|
112
|
+
inputs = list(node.inputs)
|
|
113
|
+
if len(inputs) > 4:
|
|
114
|
+
raise UnsupportedOpError("Resize expects at most 4 inputs")
|
|
115
|
+
while len(inputs) < 4:
|
|
116
|
+
inputs.append("")
|
|
117
|
+
_, roi, scales, sizes = inputs[:4]
|
|
118
|
+
return _ResizeInputs(
|
|
119
|
+
roi=roi or None,
|
|
120
|
+
scales=scales or None,
|
|
121
|
+
sizes=sizes or None,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _parse_axes(node: Node, rank: int) -> tuple[int, ...]:
|
|
126
|
+
axes_attr = node.attrs.get("axes")
|
|
127
|
+
if axes_attr is None:
|
|
128
|
+
return tuple(range(rank))
|
|
129
|
+
axes = tuple(int(value) for value in axes_attr)
|
|
130
|
+
return _normalize_axes(axes, rank, node)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _resolve_input_shapes(
|
|
134
|
+
graph: Graph, node: Node, input_name: str
|
|
135
|
+
) -> _InputConfig:
|
|
136
|
+
input_shape = _value_shape(graph, input_name, node)
|
|
137
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
138
|
+
input_dtype = _value_dtype(graph, input_name, node)
|
|
139
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
140
|
+
return _InputConfig(
|
|
141
|
+
input_shape=input_shape,
|
|
142
|
+
output_shape=output_shape,
|
|
143
|
+
input_dtype=input_dtype,
|
|
144
|
+
output_dtype=output_dtype,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _resolve_scales_from_sizes(
|
|
149
|
+
sizes: tuple[int, ...],
|
|
150
|
+
input_shape: tuple[int, ...],
|
|
151
|
+
axes: tuple[int, ...],
|
|
152
|
+
keep_aspect_ratio_policy: str,
|
|
153
|
+
) -> _ResolvedScales:
|
|
154
|
+
rank = len(input_shape)
|
|
155
|
+
full_sizes = list(input_shape)
|
|
156
|
+
for index, axis in enumerate(axes):
|
|
157
|
+
full_sizes[axis] = sizes[index]
|
|
158
|
+
if keep_aspect_ratio_policy != "stretch":
|
|
159
|
+
scales = [full_sizes[axis] / input_shape[axis] for axis in axes]
|
|
160
|
+
if keep_aspect_ratio_policy == "not_larger":
|
|
161
|
+
scale = min(scales)
|
|
162
|
+
else:
|
|
163
|
+
scale = max(scales)
|
|
164
|
+
for axis in axes:
|
|
165
|
+
full_sizes[axis] = _round_half_up(scale * input_shape[axis])
|
|
166
|
+
return _ResolvedScales(
|
|
167
|
+
scales=tuple(
|
|
168
|
+
scale if axis in axes else 1.0
|
|
169
|
+
for axis in range(rank)
|
|
170
|
+
),
|
|
171
|
+
output_shape=tuple(full_sizes),
|
|
172
|
+
axes=axes,
|
|
173
|
+
)
|
|
174
|
+
scales = tuple(
|
|
175
|
+
full_sizes[axis] / input_shape[axis] if axis in axes else 1.0
|
|
176
|
+
for axis in range(rank)
|
|
177
|
+
)
|
|
178
|
+
return _ResolvedScales(
|
|
179
|
+
scales=scales,
|
|
180
|
+
output_shape=tuple(full_sizes),
|
|
181
|
+
axes=axes,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _resolve_scales_from_values(
|
|
186
|
+
scales: tuple[float, ...],
|
|
187
|
+
input_shape: tuple[int, ...],
|
|
188
|
+
axes: tuple[int, ...],
|
|
189
|
+
) -> _ResolvedScales:
|
|
190
|
+
rank = len(input_shape)
|
|
191
|
+
full_scales = [1.0] * rank
|
|
192
|
+
for index, axis in enumerate(axes):
|
|
193
|
+
full_scales[axis] = scales[index]
|
|
194
|
+
output_shape = tuple(
|
|
195
|
+
int(input_shape[axis] * full_scales[axis])
|
|
196
|
+
if axis in axes
|
|
197
|
+
else input_shape[axis]
|
|
198
|
+
for axis in range(rank)
|
|
199
|
+
)
|
|
200
|
+
return _ResolvedScales(
|
|
201
|
+
scales=tuple(full_scales),
|
|
202
|
+
output_shape=output_shape,
|
|
203
|
+
axes=axes,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _load_initializer_values(
|
|
208
|
+
graph: Graph, name: str, node: Node
|
|
209
|
+
) -> tuple[float | int, ...] | None:
|
|
210
|
+
initializer = _find_initializer(graph, name)
|
|
211
|
+
if initializer is None:
|
|
212
|
+
return None
|
|
213
|
+
data = initializer.data.reshape(-1)
|
|
214
|
+
return tuple(data.tolist())
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _validate_tensor_1d(
|
|
218
|
+
graph: Graph,
|
|
219
|
+
name: str,
|
|
220
|
+
node: Node,
|
|
221
|
+
dtype_options: set[ScalarType],
|
|
222
|
+
) -> tuple[int, ScalarType]:
|
|
223
|
+
shape = _value_shape(graph, name, node)
|
|
224
|
+
if len(shape) != 1:
|
|
225
|
+
raise UnsupportedOpError("Resize expects 1D auxiliary inputs")
|
|
226
|
+
dtype = _value_dtype(graph, name, node)
|
|
227
|
+
if dtype not in dtype_options:
|
|
228
|
+
raise UnsupportedOpError(
|
|
229
|
+
"Resize expects "
|
|
230
|
+
f"{name} to have dtype in {[dtype.onnx_name for dtype in sorted(dtype_options, key=str)]}"
|
|
231
|
+
)
|
|
232
|
+
return shape[0], dtype
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _resolve_scales(
|
|
236
|
+
graph: Graph,
|
|
237
|
+
node: Node,
|
|
238
|
+
config: _InputConfig,
|
|
239
|
+
inputs: _ResizeInputs,
|
|
240
|
+
axes: tuple[int, ...],
|
|
241
|
+
keep_aspect_ratio_policy: str,
|
|
242
|
+
) -> tuple[tuple[float, ...], tuple[int, ...]]:
|
|
243
|
+
rank = len(config.input_shape)
|
|
244
|
+
if inputs.scales:
|
|
245
|
+
scale_len, _ = _validate_tensor_1d(
|
|
246
|
+
graph,
|
|
247
|
+
inputs.scales,
|
|
248
|
+
node,
|
|
249
|
+
{ScalarType.F16, ScalarType.F32, ScalarType.F64},
|
|
250
|
+
)
|
|
251
|
+
if scale_len not in {len(axes), rank}:
|
|
252
|
+
raise UnsupportedOpError("Resize scales length mismatch")
|
|
253
|
+
if scale_len == rank and axes != tuple(range(rank)):
|
|
254
|
+
raise UnsupportedOpError(
|
|
255
|
+
"Resize scales length conflicts with axes configuration"
|
|
256
|
+
)
|
|
257
|
+
scale_axes = axes if scale_len == len(axes) else tuple(range(rank))
|
|
258
|
+
values = _load_initializer_values(graph, inputs.scales, node)
|
|
259
|
+
if values is None:
|
|
260
|
+
scales = tuple(
|
|
261
|
+
config.output_shape[axis] / config.input_shape[axis]
|
|
262
|
+
if axis in scale_axes
|
|
263
|
+
else 1.0
|
|
264
|
+
for axis in range(rank)
|
|
265
|
+
)
|
|
266
|
+
return scales, config.output_shape
|
|
267
|
+
resolved = _resolve_scales_from_values(
|
|
268
|
+
tuple(float(value) for value in values),
|
|
269
|
+
config.input_shape,
|
|
270
|
+
scale_axes,
|
|
271
|
+
)
|
|
272
|
+
return resolved.scales, resolved.output_shape
|
|
273
|
+
if inputs.sizes:
|
|
274
|
+
size_len, _ = _validate_tensor_1d(
|
|
275
|
+
graph, inputs.sizes, node, {ScalarType.I64, ScalarType.I32}
|
|
276
|
+
)
|
|
277
|
+
if size_len not in {len(axes), rank}:
|
|
278
|
+
raise UnsupportedOpError("Resize sizes length mismatch")
|
|
279
|
+
if size_len == rank and axes != tuple(range(rank)):
|
|
280
|
+
raise UnsupportedOpError(
|
|
281
|
+
"Resize sizes length conflicts with axes configuration"
|
|
282
|
+
)
|
|
283
|
+
size_axes = axes if size_len == len(axes) else tuple(range(rank))
|
|
284
|
+
values = _load_initializer_values(graph, inputs.sizes, node)
|
|
285
|
+
if values is None:
|
|
286
|
+
scales = tuple(
|
|
287
|
+
config.output_shape[axis] / config.input_shape[axis]
|
|
288
|
+
if axis in size_axes
|
|
289
|
+
else 1.0
|
|
290
|
+
for axis in range(rank)
|
|
291
|
+
)
|
|
292
|
+
return scales, config.output_shape
|
|
293
|
+
resolved = _resolve_scales_from_sizes(
|
|
294
|
+
tuple(int(value) for value in values),
|
|
295
|
+
config.input_shape,
|
|
296
|
+
size_axes,
|
|
297
|
+
keep_aspect_ratio_policy,
|
|
298
|
+
)
|
|
299
|
+
return resolved.scales, resolved.output_shape
|
|
300
|
+
raise UnsupportedOpError("Resize expects scales or sizes input")
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _validate_output_shape(
|
|
304
|
+
expected: tuple[int, ...],
|
|
305
|
+
actual: tuple[int, ...],
|
|
306
|
+
) -> None:
|
|
307
|
+
if expected != actual:
|
|
308
|
+
raise ShapeInferenceError(
|
|
309
|
+
f"Resize output shape must be {expected}, got {actual}"
|
|
310
|
+
)
|
|
311
|
+
if any(dim < 0 for dim in actual):
|
|
312
|
+
raise ShapeInferenceError("Resize output shape must be non-negative")
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@register_lowering("Resize")
|
|
316
|
+
def lower_resize(graph: Graph, node: Node) -> ResizeOp:
|
|
317
|
+
if len(node.outputs) != 1:
|
|
318
|
+
raise UnsupportedOpError("Resize expects one output")
|
|
319
|
+
inputs = _parse_input_names(node)
|
|
320
|
+
if inputs.scales and inputs.sizes:
|
|
321
|
+
raise UnsupportedOpError("Resize cannot set both scales and sizes")
|
|
322
|
+
if not inputs.scales and not inputs.sizes:
|
|
323
|
+
raise UnsupportedOpError("Resize expects scales or sizes input")
|
|
324
|
+
mode = _decode_attr(node.attrs.get("mode"), "nearest")
|
|
325
|
+
coordinate_mode = _decode_attr(
|
|
326
|
+
node.attrs.get("coordinate_transformation_mode"), "half_pixel"
|
|
327
|
+
)
|
|
328
|
+
nearest_mode = _decode_attr(
|
|
329
|
+
node.attrs.get("nearest_mode"), "round_prefer_floor"
|
|
330
|
+
)
|
|
331
|
+
keep_aspect_ratio_policy = _decode_attr(
|
|
332
|
+
node.attrs.get("keep_aspect_ratio_policy"), "stretch"
|
|
333
|
+
)
|
|
334
|
+
antialias = bool(int(node.attrs.get("antialias", 0)))
|
|
335
|
+
cubic_coeff_a = float(node.attrs.get("cubic_coeff_a", -0.75))
|
|
336
|
+
exclude_outside = bool(int(node.attrs.get("exclude_outside", 0)))
|
|
337
|
+
extrapolation_value = float(node.attrs.get("extrapolation_value", 0.0))
|
|
338
|
+
if mode not in _SUPPORTED_MODES:
|
|
339
|
+
raise UnsupportedOpError(f"Resize mode {mode!r} is not supported")
|
|
340
|
+
if coordinate_mode not in _SUPPORTED_COORD_MODES:
|
|
341
|
+
raise UnsupportedOpError(
|
|
342
|
+
"Resize coordinate_transformation_mode "
|
|
343
|
+
f"{coordinate_mode!r} is not supported"
|
|
344
|
+
)
|
|
345
|
+
if nearest_mode not in _SUPPORTED_NEAREST_MODES:
|
|
346
|
+
raise UnsupportedOpError(
|
|
347
|
+
f"Resize nearest_mode {nearest_mode!r} is not supported"
|
|
348
|
+
)
|
|
349
|
+
if keep_aspect_ratio_policy not in _SUPPORTED_KEEP_ASPECT:
|
|
350
|
+
raise UnsupportedOpError(
|
|
351
|
+
"Resize keep_aspect_ratio_policy "
|
|
352
|
+
f"{keep_aspect_ratio_policy!r} is not supported"
|
|
353
|
+
)
|
|
354
|
+
if antialias and mode == "nearest":
|
|
355
|
+
raise UnsupportedOpError("Resize antialias is not supported for nearest")
|
|
356
|
+
config = _resolve_input_shapes(graph, node, node.inputs[0])
|
|
357
|
+
if config.input_dtype != config.output_dtype:
|
|
358
|
+
raise UnsupportedOpError(
|
|
359
|
+
"Resize expects matching input/output dtypes, "
|
|
360
|
+
f"got {config.input_dtype.onnx_name} and {config.output_dtype.onnx_name}"
|
|
361
|
+
)
|
|
362
|
+
rank = len(config.input_shape)
|
|
363
|
+
axes = _parse_axes(node, rank)
|
|
364
|
+
scales, expected_output = _resolve_scales(
|
|
365
|
+
graph,
|
|
366
|
+
node,
|
|
367
|
+
config,
|
|
368
|
+
inputs,
|
|
369
|
+
axes,
|
|
370
|
+
keep_aspect_ratio_policy,
|
|
371
|
+
)
|
|
372
|
+
_validate_output_shape(expected_output, config.output_shape)
|
|
373
|
+
roi_shape = None
|
|
374
|
+
roi_axes = None
|
|
375
|
+
roi_dtype = None
|
|
376
|
+
if inputs.roi:
|
|
377
|
+
roi_len, roi_dtype = _validate_tensor_1d(
|
|
378
|
+
graph,
|
|
379
|
+
inputs.roi,
|
|
380
|
+
node,
|
|
381
|
+
{ScalarType.F16, ScalarType.F32, ScalarType.F64},
|
|
382
|
+
)
|
|
383
|
+
if roi_len == 2 * rank:
|
|
384
|
+
roi_shape = (roi_len,)
|
|
385
|
+
elif roi_len == 2 * len(axes):
|
|
386
|
+
roi_shape = (roi_len,)
|
|
387
|
+
roi_axes = axes
|
|
388
|
+
else:
|
|
389
|
+
raise UnsupportedOpError("Resize roi length mismatch")
|
|
390
|
+
if coordinate_mode != "tf_crop_and_resize" and roi_len != 0:
|
|
391
|
+
roi_axes = roi_axes if roi_len == 2 * len(axes) else None
|
|
392
|
+
if coordinate_mode == "tf_crop_and_resize" and not inputs.roi:
|
|
393
|
+
raise UnsupportedOpError("Resize requires roi for tf_crop_and_resize")
|
|
394
|
+
scales_shape = None
|
|
395
|
+
sizes_shape = None
|
|
396
|
+
scales_dtype = None
|
|
397
|
+
sizes_dtype = None
|
|
398
|
+
scales_axes = None
|
|
399
|
+
sizes_axes = None
|
|
400
|
+
if inputs.scales:
|
|
401
|
+
scale_len, scales_dtype = _validate_tensor_1d(
|
|
402
|
+
graph,
|
|
403
|
+
inputs.scales,
|
|
404
|
+
node,
|
|
405
|
+
{ScalarType.F16, ScalarType.F32, ScalarType.F64},
|
|
406
|
+
)
|
|
407
|
+
scales_shape = (scale_len,)
|
|
408
|
+
if scale_len == len(axes) and len(axes) != rank:
|
|
409
|
+
scales_axes = axes
|
|
410
|
+
if inputs.sizes:
|
|
411
|
+
size_len, sizes_dtype = _validate_tensor_1d(
|
|
412
|
+
graph, inputs.sizes, node, {ScalarType.I64, ScalarType.I32}
|
|
413
|
+
)
|
|
414
|
+
sizes_shape = (size_len,)
|
|
415
|
+
if size_len == len(axes) and len(axes) != rank:
|
|
416
|
+
sizes_axes = axes
|
|
417
|
+
return ResizeOp(
|
|
418
|
+
input0=node.inputs[0],
|
|
419
|
+
output=node.outputs[0],
|
|
420
|
+
input_shape=config.input_shape,
|
|
421
|
+
output_shape=config.output_shape,
|
|
422
|
+
scales=scales,
|
|
423
|
+
scales_input=inputs.scales,
|
|
424
|
+
sizes_input=inputs.sizes,
|
|
425
|
+
roi_input=inputs.roi,
|
|
426
|
+
axes=axes,
|
|
427
|
+
scales_shape=scales_shape,
|
|
428
|
+
sizes_shape=sizes_shape,
|
|
429
|
+
roi_shape=roi_shape,
|
|
430
|
+
scales_dtype=scales_dtype,
|
|
431
|
+
sizes_dtype=sizes_dtype,
|
|
432
|
+
roi_dtype=roi_dtype,
|
|
433
|
+
scales_axes=scales_axes,
|
|
434
|
+
sizes_axes=sizes_axes,
|
|
435
|
+
roi_axes=roi_axes,
|
|
436
|
+
mode=mode,
|
|
437
|
+
coordinate_transformation_mode=coordinate_mode,
|
|
438
|
+
nearest_mode=nearest_mode,
|
|
439
|
+
cubic_coeff_a=cubic_coeff_a,
|
|
440
|
+
exclude_outside=exclude_outside,
|
|
441
|
+
extrapolation_value=extrapolation_value,
|
|
442
|
+
antialias=antialias,
|
|
443
|
+
keep_aspect_ratio_policy=keep_aspect_ratio_policy,
|
|
444
|
+
dtype=config.input_dtype,
|
|
445
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..codegen.c_emitter import RMSNormalizationOp
|
|
4
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from ..validation import ensure_output_shape_matches_input
|
|
7
|
+
from ..validation import normalize_axis
|
|
8
|
+
from .common import node_dtype, shape_product, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _ensure_broadcastable(
|
|
13
|
+
name: str,
|
|
14
|
+
shape: tuple[int, ...],
|
|
15
|
+
normalized_shape: tuple[int, ...],
|
|
16
|
+
) -> None:
|
|
17
|
+
if len(shape) != len(normalized_shape):
|
|
18
|
+
raise ShapeInferenceError(
|
|
19
|
+
f"RMSNormalization {name} rank must match normalized rank"
|
|
20
|
+
)
|
|
21
|
+
for dim, expected in zip(shape, normalized_shape):
|
|
22
|
+
if dim not in {1, expected}:
|
|
23
|
+
raise ShapeInferenceError(
|
|
24
|
+
f"RMSNormalization {name} shape {shape} must be broadcastable "
|
|
25
|
+
f"to {normalized_shape}"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@register_lowering("RMSNormalization")
|
|
30
|
+
def lower_rms_normalization(
|
|
31
|
+
graph: Graph, node: Node
|
|
32
|
+
) -> RMSNormalizationOp:
|
|
33
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
34
|
+
raise UnsupportedOpError("RMSNormalization must have 2 inputs and 1 output")
|
|
35
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
36
|
+
if not op_dtype.is_float:
|
|
37
|
+
raise UnsupportedOpError(
|
|
38
|
+
"RMSNormalization supports float16, float, and double inputs only"
|
|
39
|
+
)
|
|
40
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
41
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
42
|
+
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
43
|
+
axis = normalize_axis(int(node.attrs.get("axis", -1)), input_shape, node)
|
|
44
|
+
normalized_shape = input_shape[axis:]
|
|
45
|
+
scale_shape = value_shape(graph, node.inputs[1], node)
|
|
46
|
+
_ensure_broadcastable("scale", scale_shape, normalized_shape)
|
|
47
|
+
epsilon = float(node.attrs.get("epsilon", 1e-5))
|
|
48
|
+
stash_type = int(node.attrs.get("stash_type", 1))
|
|
49
|
+
if stash_type != 1:
|
|
50
|
+
raise UnsupportedOpError(
|
|
51
|
+
"RMSNormalization supports stash_type=1 only"
|
|
52
|
+
)
|
|
53
|
+
outer = shape_product(input_shape[:axis]) if axis > 0 else 1
|
|
54
|
+
inner = shape_product(normalized_shape)
|
|
55
|
+
return RMSNormalizationOp(
|
|
56
|
+
input0=node.inputs[0],
|
|
57
|
+
scale=node.inputs[1],
|
|
58
|
+
output=node.outputs[0],
|
|
59
|
+
shape=input_shape,
|
|
60
|
+
normalized_shape=normalized_shape,
|
|
61
|
+
scale_shape=scale_shape,
|
|
62
|
+
outer=outer,
|
|
63
|
+
inner=inner,
|
|
64
|
+
axis=axis,
|
|
65
|
+
epsilon=epsilon,
|
|
66
|
+
dtype=op_dtype,
|
|
67
|
+
)
|