emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import ReshapeOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Initializer, Node
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
12
|
+
try:
|
|
13
|
+
return graph.find_value(name).type.shape
|
|
14
|
+
except KeyError as exc:
|
|
15
|
+
raise ShapeInferenceError(
|
|
16
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
17
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
18
|
+
) from exc
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
22
|
+
try:
|
|
23
|
+
return graph.find_value(name).type.dtype
|
|
24
|
+
except KeyError as exc:
|
|
25
|
+
raise ShapeInferenceError(
|
|
26
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
27
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
28
|
+
) from exc
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
32
|
+
for initializer in graph.initializers:
|
|
33
|
+
if initializer.name == name:
|
|
34
|
+
return initializer
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _validate_shape(shape: tuple[int, ...], node: Node, label: str) -> None:
|
|
39
|
+
for dim in shape:
|
|
40
|
+
if dim < 0:
|
|
41
|
+
raise ShapeInferenceError(
|
|
42
|
+
f"{node.op_type} does not support dynamic dims in {label}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _normalize_axes(
|
|
47
|
+
axes: list[int], output_rank: int, node: Node
|
|
48
|
+
) -> tuple[int, ...]:
|
|
49
|
+
normalized: list[int] = []
|
|
50
|
+
for axis in axes:
|
|
51
|
+
if axis < 0:
|
|
52
|
+
axis += output_rank
|
|
53
|
+
if axis < 0 or axis >= output_rank:
|
|
54
|
+
raise ShapeInferenceError(
|
|
55
|
+
f"{node.op_type} axis {axis} is out of range for rank {output_rank}"
|
|
56
|
+
)
|
|
57
|
+
normalized.append(axis)
|
|
58
|
+
if len(set(normalized)) != len(normalized):
|
|
59
|
+
raise ShapeInferenceError(f"{node.op_type} axes must be unique")
|
|
60
|
+
return tuple(sorted(normalized))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _resolve_axes(graph: Graph, node: Node) -> tuple[int, ...] | None:
|
|
64
|
+
axes_attr = node.attrs.get("axes")
|
|
65
|
+
axes_values: list[int] | None = None
|
|
66
|
+
if len(node.inputs) == 2:
|
|
67
|
+
axes_initializer = _find_initializer(graph, node.inputs[1])
|
|
68
|
+
if axes_initializer is not None:
|
|
69
|
+
if axes_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
70
|
+
raise UnsupportedOpError(
|
|
71
|
+
"Unsqueeze axes input must be int64 or int32, "
|
|
72
|
+
f"got {axes_initializer.type.dtype.onnx_name}"
|
|
73
|
+
)
|
|
74
|
+
axes_values = [int(value) for value in axes_initializer.data.reshape(-1)]
|
|
75
|
+
elif axes_attr is not None:
|
|
76
|
+
axes_values = [int(value) for value in axes_attr]
|
|
77
|
+
if axes_values is None and axes_attr is None and len(node.inputs) != 2:
|
|
78
|
+
raise UnsupportedOpError("Unsqueeze requires axes")
|
|
79
|
+
if axes_values is None:
|
|
80
|
+
return None
|
|
81
|
+
if not axes_values:
|
|
82
|
+
raise UnsupportedOpError("Unsqueeze requires non-empty axes")
|
|
83
|
+
return tuple(axes_values)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _expected_output_shape(
|
|
87
|
+
input_shape: tuple[int, ...], axes: tuple[int, ...], node: Node
|
|
88
|
+
) -> tuple[int, ...]:
|
|
89
|
+
output_rank = len(input_shape) + len(axes)
|
|
90
|
+
normalized_axes = _normalize_axes(list(axes), output_rank, node)
|
|
91
|
+
output_dims: list[int] = []
|
|
92
|
+
input_index = 0
|
|
93
|
+
for axis in range(output_rank):
|
|
94
|
+
if axis in normalized_axes:
|
|
95
|
+
output_dims.append(1)
|
|
96
|
+
else:
|
|
97
|
+
output_dims.append(input_shape[input_index])
|
|
98
|
+
input_index += 1
|
|
99
|
+
return tuple(output_dims)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@register_lowering("Unsqueeze")
|
|
103
|
+
def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
|
|
104
|
+
if len(node.outputs) != 1 or len(node.inputs) not in {1, 2}:
|
|
105
|
+
raise UnsupportedOpError("Unsqueeze must have 1 or 2 inputs and 1 output")
|
|
106
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
107
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
108
|
+
_validate_shape(input_shape, node, "input")
|
|
109
|
+
_validate_shape(output_shape, node, "output")
|
|
110
|
+
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
111
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
112
|
+
if input_dtype != output_dtype:
|
|
113
|
+
raise UnsupportedOpError(
|
|
114
|
+
"Unsqueeze expects matching input/output dtypes, "
|
|
115
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
116
|
+
)
|
|
117
|
+
axes = _resolve_axes(graph, node)
|
|
118
|
+
if axes is None:
|
|
119
|
+
if len(node.inputs) == 2:
|
|
120
|
+
axes_dtype = _value_dtype(graph, node.inputs[1], node)
|
|
121
|
+
if axes_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
122
|
+
raise UnsupportedOpError(
|
|
123
|
+
"Unsqueeze axes input must be int64 or int32, "
|
|
124
|
+
f"got {axes_dtype.onnx_name}"
|
|
125
|
+
)
|
|
126
|
+
if len(output_shape) <= len(input_shape):
|
|
127
|
+
raise ShapeInferenceError(
|
|
128
|
+
"Unsqueeze output rank must exceed input rank"
|
|
129
|
+
)
|
|
130
|
+
input_index = 0
|
|
131
|
+
for dim in output_shape:
|
|
132
|
+
if input_index < len(input_shape) and dim == input_shape[input_index]:
|
|
133
|
+
input_index += 1
|
|
134
|
+
continue
|
|
135
|
+
if dim != 1:
|
|
136
|
+
raise ShapeInferenceError(
|
|
137
|
+
"Unsqueeze output shape must insert ones only"
|
|
138
|
+
)
|
|
139
|
+
if input_index != len(input_shape):
|
|
140
|
+
raise ShapeInferenceError(
|
|
141
|
+
"Unsqueeze output shape must contain input shape in order"
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
expected_shape = _expected_output_shape(input_shape, axes, node)
|
|
145
|
+
if expected_shape != output_shape:
|
|
146
|
+
raise ShapeInferenceError(
|
|
147
|
+
"Unsqueeze output shape must be "
|
|
148
|
+
f"{expected_shape}, got {output_shape}"
|
|
149
|
+
)
|
|
150
|
+
return ReshapeOp(
|
|
151
|
+
input0=node.inputs[0],
|
|
152
|
+
output=node.outputs[0],
|
|
153
|
+
input_shape=input_shape,
|
|
154
|
+
output_shape=output_shape,
|
|
155
|
+
dtype=input_dtype,
|
|
156
|
+
input_dtype=input_dtype,
|
|
157
|
+
)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_functions import ScalarFunction
|
|
4
|
+
from shared.scalar_types import ScalarType
|
|
5
|
+
|
|
6
|
+
from ..codegen.c_emitter import MultiInputBinaryOp
|
|
7
|
+
from ..errors import UnsupportedOpError
|
|
8
|
+
from ..ir.model import Graph, Node
|
|
9
|
+
from ..lowering.common import node_dtype, value_dtype, value_shape
|
|
10
|
+
from ..lowering.registry import register_lowering
|
|
11
|
+
from ..ops import binary_op_symbol
|
|
12
|
+
|
|
13
|
+
VARIADIC_OP_FUNCTIONS: dict[str, ScalarFunction] = {
|
|
14
|
+
"Sum": ScalarFunction.ADD,
|
|
15
|
+
"Mean": ScalarFunction.MEAN,
|
|
16
|
+
"Max": ScalarFunction.MAXIMUM,
|
|
17
|
+
"Min": ScalarFunction.MINIMUM,
|
|
18
|
+
"And": ScalarFunction.LOGICAL_AND,
|
|
19
|
+
"Or": ScalarFunction.LOGICAL_OR,
|
|
20
|
+
"Xor": ScalarFunction.LOGICAL_XOR,
|
|
21
|
+
"BitwiseAnd": ScalarFunction.BITWISE_AND,
|
|
22
|
+
"BitwiseOr": ScalarFunction.BITWISE_OR,
|
|
23
|
+
"BitwiseXor": ScalarFunction.BITWISE_XOR,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
BINARY_ONLY_OPS = {
|
|
27
|
+
"And",
|
|
28
|
+
"Or",
|
|
29
|
+
"Xor",
|
|
30
|
+
"BitwiseAnd",
|
|
31
|
+
"BitwiseOr",
|
|
32
|
+
"BitwiseXor",
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _validate_inputs(
|
|
37
|
+
graph: Graph, node: Node, *, function: ScalarFunction
|
|
38
|
+
) -> tuple[ScalarType, tuple[int, ...]]:
|
|
39
|
+
if len(node.outputs) != 1:
|
|
40
|
+
raise UnsupportedOpError(f"{node.op_type} must have 1 output")
|
|
41
|
+
if node.op_type in BINARY_ONLY_OPS:
|
|
42
|
+
if len(node.inputs) != 2:
|
|
43
|
+
raise UnsupportedOpError(
|
|
44
|
+
f"{node.op_type} must have exactly 2 inputs"
|
|
45
|
+
)
|
|
46
|
+
elif len(node.inputs) < 2:
|
|
47
|
+
raise UnsupportedOpError(
|
|
48
|
+
f"{node.op_type} must have at least 2 inputs"
|
|
49
|
+
)
|
|
50
|
+
for name in node.inputs:
|
|
51
|
+
if not name:
|
|
52
|
+
raise UnsupportedOpError(f"{node.op_type} input must be provided")
|
|
53
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
54
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
55
|
+
if op_dtype != output_dtype:
|
|
56
|
+
raise UnsupportedOpError(
|
|
57
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
58
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
59
|
+
)
|
|
60
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
61
|
+
for name in node.inputs:
|
|
62
|
+
input_shape = value_shape(graph, name, node)
|
|
63
|
+
if input_shape != output_shape:
|
|
64
|
+
raise UnsupportedOpError(
|
|
65
|
+
f"{node.op_type} expects identical input/output shapes"
|
|
66
|
+
)
|
|
67
|
+
op_spec = binary_op_symbol(function, dtype=op_dtype, validate_attrs=False)
|
|
68
|
+
if op_spec is None:
|
|
69
|
+
raise UnsupportedOpError(
|
|
70
|
+
f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
|
|
71
|
+
)
|
|
72
|
+
return op_dtype, output_shape
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _lower_variadic(graph: Graph, node: Node) -> MultiInputBinaryOp:
|
|
76
|
+
function = VARIADIC_OP_FUNCTIONS[node.op_type]
|
|
77
|
+
op_dtype, output_shape = _validate_inputs(graph, node, function=function)
|
|
78
|
+
op_spec = binary_op_symbol(function, dtype=op_dtype, validate_attrs=False)
|
|
79
|
+
if op_spec is None:
|
|
80
|
+
raise UnsupportedOpError(
|
|
81
|
+
f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
|
|
82
|
+
)
|
|
83
|
+
return MultiInputBinaryOp(
|
|
84
|
+
inputs=tuple(node.inputs),
|
|
85
|
+
output=node.outputs[0],
|
|
86
|
+
function=function,
|
|
87
|
+
operator_kind=op_spec.kind,
|
|
88
|
+
shape=output_shape,
|
|
89
|
+
dtype=op_dtype,
|
|
90
|
+
input_dtype=op_dtype,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
for _op_type in VARIADIC_OP_FUNCTIONS:
|
|
95
|
+
register_lowering(_op_type)(_lower_variadic)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import WhereOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype as _value_dtype
|
|
9
|
+
from .common import value_shape as _value_shape
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _broadcast_shape(shapes: tuple[tuple[int, ...], ...], node: Node) -> tuple[int, ...]:
|
|
14
|
+
if not shapes:
|
|
15
|
+
return ()
|
|
16
|
+
max_rank = max(len(shape) for shape in shapes)
|
|
17
|
+
padded = [
|
|
18
|
+
(1,) * (max_rank - len(shape)) + shape
|
|
19
|
+
for shape in shapes
|
|
20
|
+
]
|
|
21
|
+
broadcast: list[int] = []
|
|
22
|
+
for dims in zip(*padded):
|
|
23
|
+
dim = max(dims)
|
|
24
|
+
if any(item not in (1, dim) for item in dims):
|
|
25
|
+
raise ShapeInferenceError(
|
|
26
|
+
f"{node.op_type} inputs must be broadcastable, got {shapes}"
|
|
27
|
+
)
|
|
28
|
+
broadcast.append(dim)
|
|
29
|
+
return tuple(broadcast)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@register_lowering("Where")
|
|
33
|
+
def lower_where(graph: Graph, node: Node) -> WhereOp:
|
|
34
|
+
if len(node.inputs) != 3 or len(node.outputs) != 1:
|
|
35
|
+
raise UnsupportedOpError("Where must have 3 inputs and 1 output")
|
|
36
|
+
condition_name, x_name, y_name = node.inputs
|
|
37
|
+
output_name = node.outputs[0]
|
|
38
|
+
condition_dtype = _value_dtype(graph, condition_name, node)
|
|
39
|
+
if condition_dtype != ScalarType.BOOL:
|
|
40
|
+
raise UnsupportedOpError(
|
|
41
|
+
f"Where expects bool condition, got {condition_dtype.onnx_name}"
|
|
42
|
+
)
|
|
43
|
+
x_dtype = _value_dtype(graph, x_name, node)
|
|
44
|
+
y_dtype = _value_dtype(graph, y_name, node)
|
|
45
|
+
output_dtype = _value_dtype(graph, output_name, node)
|
|
46
|
+
if x_dtype != y_dtype or output_dtype != x_dtype:
|
|
47
|
+
raise UnsupportedOpError(
|
|
48
|
+
"Where expects matching input/output dtypes, "
|
|
49
|
+
f"got {x_dtype.onnx_name}, {y_dtype.onnx_name}, {output_dtype.onnx_name}"
|
|
50
|
+
)
|
|
51
|
+
condition_shape = _value_shape(graph, condition_name, node)
|
|
52
|
+
x_shape = _value_shape(graph, x_name, node)
|
|
53
|
+
y_shape = _value_shape(graph, y_name, node)
|
|
54
|
+
output_shape = _value_shape(graph, output_name, node)
|
|
55
|
+
broadcast_shape = _broadcast_shape(
|
|
56
|
+
(condition_shape, x_shape, y_shape),
|
|
57
|
+
node,
|
|
58
|
+
)
|
|
59
|
+
if output_shape != broadcast_shape:
|
|
60
|
+
raise ShapeInferenceError(
|
|
61
|
+
f"Where output shape must be {broadcast_shape}, got {output_shape}"
|
|
62
|
+
)
|
|
63
|
+
return WhereOp(
|
|
64
|
+
condition=condition_name,
|
|
65
|
+
input_x=x_name,
|
|
66
|
+
input_y=y_name,
|
|
67
|
+
output=output_name,
|
|
68
|
+
condition_shape=condition_shape,
|
|
69
|
+
x_shape=x_shape,
|
|
70
|
+
y_shape=y_shape,
|
|
71
|
+
output_shape=output_shape,
|
|
72
|
+
dtype=output_dtype,
|
|
73
|
+
)
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Iterable, Mapping
|
|
4
|
+
|
|
5
|
+
import onnx
|
|
6
|
+
import numpy as np
|
|
7
|
+
from onnx import helper, numpy_helper, shape_inference
|
|
8
|
+
|
|
9
|
+
from shared.scalar_types import ScalarType
|
|
10
|
+
|
|
11
|
+
from .dtypes import scalar_type_from_onnx
|
|
12
|
+
from .errors import ShapeInferenceError, UnsupportedOpError
|
|
13
|
+
from .ir.model import Graph, Initializer, Node, TensorType, Value
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _normalize_initializer_data(dtype: ScalarType, data: object) -> np.ndarray:
|
|
17
|
+
if isinstance(data, (onnx.TensorProto, onnx.SparseTensorProto)):
|
|
18
|
+
array = numpy_helper.to_array(data)
|
|
19
|
+
elif isinstance(data, np.ndarray):
|
|
20
|
+
array = data
|
|
21
|
+
else:
|
|
22
|
+
array = np.array(data)
|
|
23
|
+
return array.astype(dtype.np_dtype, copy=False)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _format_elem_type(elem_type: int) -> str:
|
|
27
|
+
try:
|
|
28
|
+
name = onnx.TensorProto.DataType.Name(elem_type)
|
|
29
|
+
except ValueError:
|
|
30
|
+
name = "UNKNOWN"
|
|
31
|
+
return f"{elem_type} ({name})"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _unsupported_value_type(value_info: onnx.ValueInfoProto) -> UnsupportedOpError:
|
|
35
|
+
value_kind = value_info.type.WhichOneof("value")
|
|
36
|
+
if value_kind is None:
|
|
37
|
+
value_kind = "unknown"
|
|
38
|
+
return UnsupportedOpError(
|
|
39
|
+
f"Unsupported value type '{value_kind}' for '{value_info.name}'. "
|
|
40
|
+
"Hint: export the model with tensor inputs/outputs."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _tensor_type(
|
|
45
|
+
value_info: onnx.ValueInfoProto,
|
|
46
|
+
*,
|
|
47
|
+
dim_param_override: tuple[str | None, ...] | None = None,
|
|
48
|
+
) -> TensorType:
|
|
49
|
+
if value_info.type.WhichOneof("value") != "tensor_type":
|
|
50
|
+
raise _unsupported_value_type(value_info)
|
|
51
|
+
tensor_type = value_info.type.tensor_type
|
|
52
|
+
if not tensor_type.HasField("elem_type"):
|
|
53
|
+
raise ShapeInferenceError(f"Missing elem_type for tensor '{value_info.name}'")
|
|
54
|
+
dtype = scalar_type_from_onnx(tensor_type.elem_type)
|
|
55
|
+
if dtype is None:
|
|
56
|
+
raise UnsupportedOpError(
|
|
57
|
+
"Unsupported elem_type "
|
|
58
|
+
f"{_format_elem_type(tensor_type.elem_type)} for tensor '{value_info.name}'."
|
|
59
|
+
)
|
|
60
|
+
shape = []
|
|
61
|
+
dim_params = []
|
|
62
|
+
for dim_index, dim in enumerate(tensor_type.shape.dim):
|
|
63
|
+
dim_param = dim.dim_param if dim.HasField("dim_param") else ""
|
|
64
|
+
if (
|
|
65
|
+
dim_param_override is not None
|
|
66
|
+
and dim_index < len(dim_param_override)
|
|
67
|
+
and dim_param_override[dim_index]
|
|
68
|
+
):
|
|
69
|
+
dim_param = dim_param_override[dim_index] or ""
|
|
70
|
+
dim_params.append(dim_param or None)
|
|
71
|
+
if not dim.HasField("dim_value"):
|
|
72
|
+
if dim_param:
|
|
73
|
+
shape.append(1)
|
|
74
|
+
continue
|
|
75
|
+
raise ShapeInferenceError(f"Dynamic dim for tensor '{value_info.name}'")
|
|
76
|
+
shape.append(dim.dim_value)
|
|
77
|
+
return TensorType(
|
|
78
|
+
dtype=dtype,
|
|
79
|
+
shape=tuple(shape),
|
|
80
|
+
dim_params=tuple(dim_params),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _values(
|
|
85
|
+
value_infos: Iterable[onnx.ValueInfoProto],
|
|
86
|
+
*,
|
|
87
|
+
dim_param_by_name: Mapping[str, tuple[str | None, ...]] | None = None,
|
|
88
|
+
) -> tuple[Value, ...]:
|
|
89
|
+
dim_param_by_name = dim_param_by_name or {}
|
|
90
|
+
return tuple(
|
|
91
|
+
Value(
|
|
92
|
+
name=vi.name,
|
|
93
|
+
type=_tensor_type(
|
|
94
|
+
vi, dim_param_override=dim_param_by_name.get(vi.name)
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
for vi in value_infos
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _collect_dim_params(
|
|
102
|
+
value_infos: Iterable[onnx.ValueInfoProto],
|
|
103
|
+
) -> dict[str, tuple[str | None, ...]]:
|
|
104
|
+
dim_params: dict[str, tuple[str | None, ...]] = {}
|
|
105
|
+
for value_info in value_infos:
|
|
106
|
+
dims = []
|
|
107
|
+
for dim in value_info.type.tensor_type.shape.dim:
|
|
108
|
+
dim_param = dim.dim_param if dim.HasField("dim_param") else ""
|
|
109
|
+
dims.append(dim_param or None)
|
|
110
|
+
if any(dims):
|
|
111
|
+
dim_params[value_info.name] = tuple(dims)
|
|
112
|
+
return dim_params
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _initializer(value: onnx.TensorProto) -> Initializer:
|
|
116
|
+
dtype = scalar_type_from_onnx(value.data_type)
|
|
117
|
+
if dtype is None:
|
|
118
|
+
raise UnsupportedOpError(
|
|
119
|
+
"Unsupported elem_type "
|
|
120
|
+
f"{_format_elem_type(value.data_type)} for initializer '{value.name}'. "
|
|
121
|
+
"Hint: export the model with float32 initializers."
|
|
122
|
+
)
|
|
123
|
+
data = _normalize_initializer_data(dtype, value)
|
|
124
|
+
return Initializer(
|
|
125
|
+
name=value.name,
|
|
126
|
+
type=TensorType(
|
|
127
|
+
dtype=dtype,
|
|
128
|
+
shape=tuple(data.shape),
|
|
129
|
+
dim_params=(None,) * len(data.shape),
|
|
130
|
+
),
|
|
131
|
+
data=data,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _node_attrs(node: onnx.NodeProto) -> dict[str, object]:
|
|
136
|
+
return {attr.name: helper.get_attribute_value(attr) for attr in node.attribute}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _constant_initializer(node: onnx.NodeProto) -> Initializer:
|
|
140
|
+
if len(node.output) != 1:
|
|
141
|
+
raise UnsupportedOpError("Constant must have exactly one output")
|
|
142
|
+
attrs = _node_attrs(node)
|
|
143
|
+
output_name = node.output[0]
|
|
144
|
+
if "value" in attrs:
|
|
145
|
+
tensor = attrs["value"]
|
|
146
|
+
dtype = scalar_type_from_onnx(tensor.data_type)
|
|
147
|
+
if dtype is None:
|
|
148
|
+
raise UnsupportedOpError(
|
|
149
|
+
"Unsupported elem_type "
|
|
150
|
+
f"{_format_elem_type(tensor.data_type)} for Constant '{output_name}'."
|
|
151
|
+
)
|
|
152
|
+
data = _normalize_initializer_data(dtype, tensor)
|
|
153
|
+
return Initializer(
|
|
154
|
+
name=output_name,
|
|
155
|
+
type=TensorType(
|
|
156
|
+
dtype=dtype,
|
|
157
|
+
shape=tuple(data.shape),
|
|
158
|
+
dim_params=(None,) * len(data.shape),
|
|
159
|
+
),
|
|
160
|
+
data=data,
|
|
161
|
+
)
|
|
162
|
+
if "sparse_value" in attrs:
|
|
163
|
+
tensor = attrs["sparse_value"]
|
|
164
|
+
dtype = scalar_type_from_onnx(tensor.values.data_type)
|
|
165
|
+
if dtype is None:
|
|
166
|
+
raise UnsupportedOpError(
|
|
167
|
+
"Unsupported elem_type "
|
|
168
|
+
f"{_format_elem_type(tensor.values.data_type)} for Constant '{output_name}'."
|
|
169
|
+
)
|
|
170
|
+
data = _normalize_initializer_data(dtype, tensor)
|
|
171
|
+
return Initializer(
|
|
172
|
+
name=output_name,
|
|
173
|
+
type=TensorType(
|
|
174
|
+
dtype=dtype,
|
|
175
|
+
shape=tuple(data.shape),
|
|
176
|
+
dim_params=(None,) * len(data.shape),
|
|
177
|
+
),
|
|
178
|
+
data=data,
|
|
179
|
+
)
|
|
180
|
+
if "value_float" in attrs or "value_floats" in attrs:
|
|
181
|
+
values = attrs.get("value_floats", attrs.get("value_float"))
|
|
182
|
+
data = _normalize_initializer_data(ScalarType.F32, values)
|
|
183
|
+
return Initializer(
|
|
184
|
+
name=output_name,
|
|
185
|
+
type=TensorType(
|
|
186
|
+
dtype=ScalarType.F32,
|
|
187
|
+
shape=tuple(data.shape),
|
|
188
|
+
dim_params=(None,) * len(data.shape),
|
|
189
|
+
),
|
|
190
|
+
data=data,
|
|
191
|
+
)
|
|
192
|
+
if "value_int" in attrs or "value_ints" in attrs:
|
|
193
|
+
values = attrs.get("value_ints", attrs.get("value_int"))
|
|
194
|
+
data = _normalize_initializer_data(ScalarType.I64, values)
|
|
195
|
+
return Initializer(
|
|
196
|
+
name=output_name,
|
|
197
|
+
type=TensorType(
|
|
198
|
+
dtype=ScalarType.I64,
|
|
199
|
+
shape=tuple(data.shape),
|
|
200
|
+
dim_params=(None,) * len(data.shape),
|
|
201
|
+
),
|
|
202
|
+
data=data,
|
|
203
|
+
)
|
|
204
|
+
if "value_string" in attrs or "value_strings" in attrs:
|
|
205
|
+
raise UnsupportedOpError(
|
|
206
|
+
f"Constant '{output_name}' has unsupported string values"
|
|
207
|
+
)
|
|
208
|
+
raise UnsupportedOpError(f"Constant '{output_name}' requires a value attribute")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def import_onnx(model: onnx.ModelProto) -> Graph:
|
|
212
|
+
dim_param_by_name = _collect_dim_params(
|
|
213
|
+
tuple(model.graph.input) + tuple(model.graph.output)
|
|
214
|
+
)
|
|
215
|
+
try:
|
|
216
|
+
model = shape_inference.infer_shapes(model, data_prop=True)
|
|
217
|
+
except Exception as exc: # pragma: no cover - onnx inference errors
|
|
218
|
+
raise ShapeInferenceError("ONNX shape inference failed") from exc
|
|
219
|
+
graph = model.graph
|
|
220
|
+
base_initializers = [_initializer(value) for value in graph.initializer]
|
|
221
|
+
constant_initializers: list[Initializer] = []
|
|
222
|
+
input_names = {value_info.name for value_info in graph.input}
|
|
223
|
+
output_names = {value_info.name for value_info in graph.output}
|
|
224
|
+
nodes: list[Node] = []
|
|
225
|
+
for node in graph.node:
|
|
226
|
+
if node.op_type == "Constant":
|
|
227
|
+
constant_initializers.append(_constant_initializer(node))
|
|
228
|
+
continue
|
|
229
|
+
nodes.append(
|
|
230
|
+
Node(
|
|
231
|
+
op_type=node.op_type,
|
|
232
|
+
name=node.name or None,
|
|
233
|
+
inputs=tuple(node.input),
|
|
234
|
+
outputs=tuple(node.output),
|
|
235
|
+
attrs=_node_attrs(node),
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
initializers = tuple(base_initializers + constant_initializers)
|
|
239
|
+
initializer_names = {initializer.name for initializer in initializers}
|
|
240
|
+
inputs = _values(
|
|
241
|
+
(
|
|
242
|
+
value_info
|
|
243
|
+
for value_info in graph.input
|
|
244
|
+
if value_info.name not in initializer_names
|
|
245
|
+
),
|
|
246
|
+
dim_param_by_name=dim_param_by_name,
|
|
247
|
+
)
|
|
248
|
+
outputs = _values(graph.output, dim_param_by_name=dim_param_by_name)
|
|
249
|
+
values = _values(
|
|
250
|
+
value_info
|
|
251
|
+
for value_info in graph.value_info
|
|
252
|
+
if value_info.name
|
|
253
|
+
not in initializer_names | input_names | output_names
|
|
254
|
+
)
|
|
255
|
+
return Graph(
|
|
256
|
+
inputs=inputs,
|
|
257
|
+
outputs=outputs,
|
|
258
|
+
nodes=nodes,
|
|
259
|
+
initializers=initializers,
|
|
260
|
+
values=values,
|
|
261
|
+
)
|