emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +340 -59
- emx_onnx_cgen/codegen/c_emitter.py +2369 -111
- emx_onnx_cgen/compiler.py +188 -5
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/lowering/common.py +379 -2
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/gather_elements.py +1 -3
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +6 -5
- emx_onnx_cgen/lowering/logsoftmax.py +5 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/matmul.py +6 -7
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/reduce.py +5 -6
- emx_onnx_cgen/lowering/reshape.py +223 -51
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/softmax.py +5 -1
- emx_onnx_cgen/lowering/squeeze.py +5 -5
- emx_onnx_cgen/lowering/topk.py +116 -0
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +5 -5
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +460 -42
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +61 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
- shared/scalar_functions.py +49 -17
- shared/ulp.py +48 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..codegen.c_emitter import EinsumKind, EinsumOp
|
|
4
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from .common import node_dtype as _node_dtype
|
|
7
|
+
from .common import value_shape as _value_shape
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _normalize_equation(equation: str) -> str:
|
|
12
|
+
return equation.replace(" ", "")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_lowering("Einsum")
|
|
16
|
+
def lower_einsum(graph: Graph, node: Node) -> EinsumOp:
|
|
17
|
+
if not node.inputs or len(node.outputs) != 1:
|
|
18
|
+
raise UnsupportedOpError("Einsum must have 1 output and at least 1 input")
|
|
19
|
+
equation_value = node.attrs.get("equation")
|
|
20
|
+
if equation_value is None:
|
|
21
|
+
raise UnsupportedOpError("Einsum equation attribute is required")
|
|
22
|
+
equation = (
|
|
23
|
+
equation_value.decode()
|
|
24
|
+
if isinstance(equation_value, (bytes, bytearray))
|
|
25
|
+
else str(equation_value)
|
|
26
|
+
)
|
|
27
|
+
normalized = _normalize_equation(equation)
|
|
28
|
+
input_shapes = tuple(
|
|
29
|
+
_value_shape(graph, name, node) for name in node.inputs
|
|
30
|
+
)
|
|
31
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
32
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
33
|
+
if normalized == "->":
|
|
34
|
+
if len(node.inputs) != 1:
|
|
35
|
+
raise UnsupportedOpError("Einsum '->' must have 1 input")
|
|
36
|
+
if output_shape:
|
|
37
|
+
raise ShapeInferenceError(
|
|
38
|
+
"Einsum '->' output must be scalar, "
|
|
39
|
+
f"got shape {output_shape}"
|
|
40
|
+
)
|
|
41
|
+
kind = EinsumKind.REDUCE_ALL
|
|
42
|
+
elif normalized == "ij->i":
|
|
43
|
+
if len(node.inputs) != 1:
|
|
44
|
+
raise UnsupportedOpError("Einsum 'ij->i' must have 1 input")
|
|
45
|
+
input_shape = input_shapes[0]
|
|
46
|
+
if len(input_shape) != 2:
|
|
47
|
+
raise ShapeInferenceError(
|
|
48
|
+
"Einsum 'ij->i' input must be 2D, "
|
|
49
|
+
f"got shape {input_shape}"
|
|
50
|
+
)
|
|
51
|
+
expected = (input_shape[0],)
|
|
52
|
+
if output_shape != expected:
|
|
53
|
+
raise ShapeInferenceError(
|
|
54
|
+
f"Einsum 'ij->i' output must match shape {expected}, "
|
|
55
|
+
f"got {output_shape}"
|
|
56
|
+
)
|
|
57
|
+
kind = EinsumKind.SUM_J
|
|
58
|
+
elif normalized == "ij->ji":
|
|
59
|
+
if len(node.inputs) != 1:
|
|
60
|
+
raise UnsupportedOpError("Einsum 'ij->ji' must have 1 input")
|
|
61
|
+
input_shape = input_shapes[0]
|
|
62
|
+
if len(input_shape) != 2:
|
|
63
|
+
raise ShapeInferenceError(
|
|
64
|
+
"Einsum 'ij->ji' input must be 2D, "
|
|
65
|
+
f"got shape {input_shape}"
|
|
66
|
+
)
|
|
67
|
+
expected = (input_shape[1], input_shape[0])
|
|
68
|
+
if output_shape != expected:
|
|
69
|
+
raise ShapeInferenceError(
|
|
70
|
+
f"Einsum 'ij->ji' output must match shape {expected}, "
|
|
71
|
+
f"got {output_shape}"
|
|
72
|
+
)
|
|
73
|
+
kind = EinsumKind.TRANSPOSE
|
|
74
|
+
elif normalized in {"i,i", "i,i->"}:
|
|
75
|
+
if len(node.inputs) != 2:
|
|
76
|
+
raise UnsupportedOpError("Einsum 'i,i' must have 2 inputs")
|
|
77
|
+
left_shape, right_shape = input_shapes
|
|
78
|
+
if len(left_shape) != 1 or len(right_shape) != 1:
|
|
79
|
+
raise ShapeInferenceError(
|
|
80
|
+
"Einsum 'i,i' inputs must be vectors, "
|
|
81
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
82
|
+
)
|
|
83
|
+
if left_shape[0] != right_shape[0]:
|
|
84
|
+
raise ShapeInferenceError(
|
|
85
|
+
"Einsum 'i,i' inputs must have the same length, "
|
|
86
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
87
|
+
)
|
|
88
|
+
if output_shape:
|
|
89
|
+
raise ShapeInferenceError(
|
|
90
|
+
"Einsum 'i,i' output must be scalar, "
|
|
91
|
+
f"got shape {output_shape}"
|
|
92
|
+
)
|
|
93
|
+
kind = EinsumKind.DOT
|
|
94
|
+
elif normalized == "bij,bjk->bik":
|
|
95
|
+
if len(node.inputs) != 2:
|
|
96
|
+
raise UnsupportedOpError("Einsum 'bij,bjk->bik' must have 2 inputs")
|
|
97
|
+
left_shape, right_shape = input_shapes
|
|
98
|
+
if len(left_shape) != 3 or len(right_shape) != 3:
|
|
99
|
+
raise ShapeInferenceError(
|
|
100
|
+
"Einsum 'bij,bjk->bik' inputs must be 3D, "
|
|
101
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
102
|
+
)
|
|
103
|
+
if left_shape[0] != right_shape[0]:
|
|
104
|
+
raise ShapeInferenceError(
|
|
105
|
+
"Einsum 'bij,bjk->bik' batch dimensions must match, "
|
|
106
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
107
|
+
)
|
|
108
|
+
if left_shape[2] != right_shape[1]:
|
|
109
|
+
raise ShapeInferenceError(
|
|
110
|
+
"Einsum 'bij,bjk->bik' contraction dimensions must match, "
|
|
111
|
+
f"got shapes {left_shape} and {right_shape}"
|
|
112
|
+
)
|
|
113
|
+
expected = (left_shape[0], left_shape[1], right_shape[2])
|
|
114
|
+
if output_shape != expected:
|
|
115
|
+
raise ShapeInferenceError(
|
|
116
|
+
f"Einsum 'bij,bjk->bik' output must match shape {expected}, "
|
|
117
|
+
f"got {output_shape}"
|
|
118
|
+
)
|
|
119
|
+
kind = EinsumKind.BATCH_MATMUL
|
|
120
|
+
elif normalized == "...ii->...i":
|
|
121
|
+
if len(node.inputs) != 1:
|
|
122
|
+
raise UnsupportedOpError("Einsum '...ii->...i' must have 1 input")
|
|
123
|
+
input_shape = input_shapes[0]
|
|
124
|
+
if len(input_shape) < 2:
|
|
125
|
+
raise ShapeInferenceError(
|
|
126
|
+
"Einsum '...ii->...i' input must be at least 2D, "
|
|
127
|
+
f"got shape {input_shape}"
|
|
128
|
+
)
|
|
129
|
+
if input_shape[-1] != input_shape[-2]:
|
|
130
|
+
raise ShapeInferenceError(
|
|
131
|
+
"Einsum '...ii->...i' requires last two dims to match, "
|
|
132
|
+
f"got shape {input_shape}"
|
|
133
|
+
)
|
|
134
|
+
expected = (*input_shape[:-2], input_shape[-1])
|
|
135
|
+
if output_shape != expected:
|
|
136
|
+
raise ShapeInferenceError(
|
|
137
|
+
f"Einsum '...ii->...i' output must match shape {expected}, "
|
|
138
|
+
f"got {output_shape}"
|
|
139
|
+
)
|
|
140
|
+
kind = EinsumKind.BATCH_DIAGONAL
|
|
141
|
+
else:
|
|
142
|
+
raise UnsupportedOpError(
|
|
143
|
+
f"Unsupported Einsum equation '{equation}'"
|
|
144
|
+
)
|
|
145
|
+
return EinsumOp(
|
|
146
|
+
inputs=tuple(node.inputs),
|
|
147
|
+
output=node.outputs[0],
|
|
148
|
+
kind=kind,
|
|
149
|
+
input_shapes=input_shapes,
|
|
150
|
+
output_shape=output_shape,
|
|
151
|
+
dtype=op_dtype,
|
|
152
|
+
input_dtype=op_dtype,
|
|
153
|
+
)
|
|
@@ -33,9 +33,7 @@ def lower_gather_elements(graph: Graph, node: Node) -> GatherElementsOp:
|
|
|
33
33
|
for dim_index, (data_dim, index_dim) in enumerate(
|
|
34
34
|
zip(data_shape, indices_shape)
|
|
35
35
|
):
|
|
36
|
-
if dim_index
|
|
37
|
-
continue
|
|
38
|
-
if data_dim != index_dim:
|
|
36
|
+
if dim_index != axis and data_dim != index_dim:
|
|
39
37
|
raise ShapeInferenceError(
|
|
40
38
|
"GatherElements inputs must match on non-axis dimensions, "
|
|
41
39
|
f"got {data_shape} and {indices_shape}"
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import GatherNDOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype as _value_dtype
|
|
9
|
+
from .common import value_shape as _value_shape
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_lowering("GatherND")
|
|
14
|
+
def lower_gather_nd(graph: Graph, node: Node) -> GatherNDOp:
|
|
15
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
16
|
+
raise UnsupportedOpError("GatherND must have 2 inputs and 1 output")
|
|
17
|
+
data_name, indices_name = node.inputs
|
|
18
|
+
output_name = node.outputs[0]
|
|
19
|
+
data_shape = _value_shape(graph, data_name, node)
|
|
20
|
+
indices_shape = _value_shape(graph, indices_name, node)
|
|
21
|
+
output_shape = _value_shape(graph, output_name, node)
|
|
22
|
+
if len(indices_shape) < 1:
|
|
23
|
+
raise ShapeInferenceError("GatherND indices must have rank >= 1")
|
|
24
|
+
batch_dims = int(node.attrs.get("batch_dims", 0))
|
|
25
|
+
if batch_dims < 0:
|
|
26
|
+
raise ShapeInferenceError(
|
|
27
|
+
f"GatherND batch_dims must be >= 0, got {batch_dims}"
|
|
28
|
+
)
|
|
29
|
+
if batch_dims > len(indices_shape) - 1:
|
|
30
|
+
raise ShapeInferenceError(
|
|
31
|
+
"GatherND batch_dims must be <= indices rank - 1, "
|
|
32
|
+
f"got {batch_dims} vs {len(indices_shape) - 1}"
|
|
33
|
+
)
|
|
34
|
+
if batch_dims > len(data_shape):
|
|
35
|
+
raise ShapeInferenceError(
|
|
36
|
+
"GatherND batch_dims must be <= data rank, "
|
|
37
|
+
f"got {batch_dims} vs {len(data_shape)}"
|
|
38
|
+
)
|
|
39
|
+
if tuple(data_shape[:batch_dims]) != tuple(indices_shape[:batch_dims]):
|
|
40
|
+
raise ShapeInferenceError(
|
|
41
|
+
"GatherND batch_dims must match on data/indices, "
|
|
42
|
+
f"got {data_shape} vs {indices_shape}"
|
|
43
|
+
)
|
|
44
|
+
index_depth = indices_shape[-1]
|
|
45
|
+
if index_depth <= 0:
|
|
46
|
+
raise ShapeInferenceError(
|
|
47
|
+
"GatherND indices final dimension must be >= 1"
|
|
48
|
+
)
|
|
49
|
+
if index_depth > len(data_shape) - batch_dims:
|
|
50
|
+
raise ShapeInferenceError(
|
|
51
|
+
"GatherND indices final dimension must be <= data rank - "
|
|
52
|
+
f"batch_dims, got {index_depth} vs {len(data_shape) - batch_dims}"
|
|
53
|
+
)
|
|
54
|
+
expected_output_shape = indices_shape[:-1] + data_shape[
|
|
55
|
+
batch_dims + index_depth :
|
|
56
|
+
]
|
|
57
|
+
if output_shape != expected_output_shape:
|
|
58
|
+
raise ShapeInferenceError(
|
|
59
|
+
"GatherND output shape must be "
|
|
60
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
61
|
+
)
|
|
62
|
+
data_dtype = _value_dtype(graph, data_name, node)
|
|
63
|
+
indices_dtype = _value_dtype(graph, indices_name, node)
|
|
64
|
+
if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
65
|
+
raise UnsupportedOpError(
|
|
66
|
+
"GatherND indices must be int32 or int64, "
|
|
67
|
+
f"got {indices_dtype.onnx_name}"
|
|
68
|
+
)
|
|
69
|
+
return GatherNDOp(
|
|
70
|
+
data=data_name,
|
|
71
|
+
indices=indices_name,
|
|
72
|
+
output=output_name,
|
|
73
|
+
batch_dims=batch_dims,
|
|
74
|
+
data_shape=data_shape,
|
|
75
|
+
indices_shape=indices_shape,
|
|
76
|
+
output_shape=output_shape,
|
|
77
|
+
dtype=data_dtype,
|
|
78
|
+
indices_dtype=indices_dtype,
|
|
79
|
+
)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import ReduceOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype as _value_dtype
|
|
9
|
+
from .common import value_shape as _value_shape
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_lowering("GlobalMaxPool")
|
|
14
|
+
def lower_global_max_pool(graph: Graph, node: Node) -> ReduceOp:
|
|
15
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
16
|
+
raise UnsupportedOpError("GlobalMaxPool must have 1 input and 1 output")
|
|
17
|
+
if node.attrs:
|
|
18
|
+
raise UnsupportedOpError("GlobalMaxPool has unsupported attributes")
|
|
19
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
20
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
21
|
+
if op_dtype != output_dtype:
|
|
22
|
+
raise UnsupportedOpError(
|
|
23
|
+
"GlobalMaxPool expects matching input/output dtypes, "
|
|
24
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
25
|
+
)
|
|
26
|
+
if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
|
|
27
|
+
raise UnsupportedOpError(
|
|
28
|
+
"GlobalMaxPool supports float16, float, and double inputs only"
|
|
29
|
+
)
|
|
30
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
31
|
+
if len(input_shape) < 3:
|
|
32
|
+
raise UnsupportedOpError(
|
|
33
|
+
"GlobalMaxPool expects input rank of at least 3"
|
|
34
|
+
)
|
|
35
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
36
|
+
expected_output_shape = (input_shape[0], input_shape[1]) + (
|
|
37
|
+
1,
|
|
38
|
+
) * (len(input_shape) - 2)
|
|
39
|
+
if output_shape != expected_output_shape:
|
|
40
|
+
raise ShapeInferenceError(
|
|
41
|
+
"GlobalMaxPool output shape must be "
|
|
42
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
43
|
+
)
|
|
44
|
+
axes = tuple(range(2, len(input_shape)))
|
|
45
|
+
return ReduceOp(
|
|
46
|
+
input0=node.inputs[0],
|
|
47
|
+
output=node.outputs[0],
|
|
48
|
+
input_shape=input_shape,
|
|
49
|
+
output_shape=output_shape,
|
|
50
|
+
axes=axes,
|
|
51
|
+
axes_input=None,
|
|
52
|
+
axes_input_shape=None,
|
|
53
|
+
axes_input_dtype=None,
|
|
54
|
+
keepdims=True,
|
|
55
|
+
noop_with_empty_axes=False,
|
|
56
|
+
reduce_kind="max",
|
|
57
|
+
reduce_count=None,
|
|
58
|
+
dtype=op_dtype,
|
|
59
|
+
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import HardmaxOp
|
|
6
|
+
from ..errors import UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import node_dtype as _node_dtype
|
|
9
|
+
from .common import onnx_opset_version as _onnx_opset_version
|
|
10
|
+
from .common import shape_product as _shape_product
|
|
11
|
+
from .common import value_shape as _value_shape
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
from ..validation import ensure_output_shape_matches_input
|
|
14
|
+
from ..validation import normalize_axis as _normalize_axis
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_lowering("Hardmax")
|
|
18
|
+
def lower_hardmax(graph: Graph, node: Node) -> HardmaxOp:
|
|
19
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
20
|
+
raise UnsupportedOpError("Hardmax must have 1 input and 1 output")
|
|
21
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
22
|
+
if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
|
|
23
|
+
raise UnsupportedOpError(
|
|
24
|
+
"Hardmax supports float16, float, and double inputs only"
|
|
25
|
+
)
|
|
26
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
27
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
28
|
+
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
29
|
+
opset_version = _onnx_opset_version(graph)
|
|
30
|
+
default_axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
31
|
+
axis_attr = node.attrs.get("axis", default_axis)
|
|
32
|
+
axis = _normalize_axis(
|
|
33
|
+
int(axis_attr),
|
|
34
|
+
input_shape,
|
|
35
|
+
node,
|
|
36
|
+
)
|
|
37
|
+
outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
|
|
38
|
+
axis_size = input_shape[axis]
|
|
39
|
+
inner = (
|
|
40
|
+
_shape_product(input_shape[axis + 1 :])
|
|
41
|
+
if axis + 1 < len(input_shape)
|
|
42
|
+
else 1
|
|
43
|
+
)
|
|
44
|
+
return HardmaxOp(
|
|
45
|
+
input0=node.inputs[0],
|
|
46
|
+
output=node.outputs[0],
|
|
47
|
+
outer=outer,
|
|
48
|
+
axis_size=axis_size,
|
|
49
|
+
inner=inner,
|
|
50
|
+
axis=axis,
|
|
51
|
+
shape=input_shape,
|
|
52
|
+
dtype=op_dtype,
|
|
53
|
+
)
|
|
@@ -22,11 +22,12 @@ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
|
|
|
22
22
|
for index, (input_dim, output_dim) in enumerate(
|
|
23
23
|
zip(input_shape, output_shape)
|
|
24
24
|
):
|
|
25
|
-
if input_dim
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
25
|
+
if input_dim != output_dim and not (
|
|
26
|
+
input_dim_params[index] or output_dim_params[index]
|
|
27
|
+
):
|
|
28
|
+
raise ShapeInferenceError(
|
|
29
|
+
"Identity input and output shapes must match"
|
|
30
|
+
)
|
|
30
31
|
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
31
32
|
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
32
33
|
if input_dtype != output_dtype:
|
|
@@ -4,6 +4,7 @@ from ..codegen.c_emitter import LogSoftmaxOp
|
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from .common import node_dtype as _node_dtype
|
|
7
|
+
from .common import onnx_opset_version as _onnx_opset_version
|
|
7
8
|
from .common import shape_product as _shape_product
|
|
8
9
|
from .common import value_shape as _value_shape
|
|
9
10
|
from .registry import register_lowering
|
|
@@ -23,8 +24,11 @@ def lower_logsoftmax(graph: Graph, node: Node) -> LogSoftmaxOp:
|
|
|
23
24
|
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
24
25
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
25
26
|
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
27
|
+
opset_version = _onnx_opset_version(graph)
|
|
28
|
+
default_axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
29
|
+
axis_attr = node.attrs.get("axis", default_axis)
|
|
26
30
|
axis = _normalize_axis(
|
|
27
|
-
int(
|
|
31
|
+
int(axis_attr),
|
|
28
32
|
input_shape,
|
|
29
33
|
node,
|
|
30
34
|
)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import LpPoolOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
from .common import value_dtype as _value_dtype, value_shape as _value_shape
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class LpPoolSpec:
|
|
14
|
+
batch: int
|
|
15
|
+
channels: int
|
|
16
|
+
in_h: int
|
|
17
|
+
in_w: int
|
|
18
|
+
out_h: int
|
|
19
|
+
out_w: int
|
|
20
|
+
kernel_h: int
|
|
21
|
+
kernel_w: int
|
|
22
|
+
stride_h: int
|
|
23
|
+
stride_w: int
|
|
24
|
+
pad_top: int
|
|
25
|
+
pad_left: int
|
|
26
|
+
pad_bottom: int
|
|
27
|
+
pad_right: int
|
|
28
|
+
p: int
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
|
|
32
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
33
|
+
raise UnsupportedOpError("LpPool must have 1 input and 1 output")
|
|
34
|
+
supported_attrs = {
|
|
35
|
+
"auto_pad",
|
|
36
|
+
"ceil_mode",
|
|
37
|
+
"dilations",
|
|
38
|
+
"kernel_shape",
|
|
39
|
+
"pads",
|
|
40
|
+
"p",
|
|
41
|
+
"strides",
|
|
42
|
+
}
|
|
43
|
+
if set(node.attrs) - supported_attrs:
|
|
44
|
+
raise UnsupportedOpError("LpPool has unsupported attributes")
|
|
45
|
+
auto_pad = node.attrs.get("auto_pad", b"NOTSET")
|
|
46
|
+
if isinstance(auto_pad, bytes):
|
|
47
|
+
auto_pad = auto_pad.decode("utf-8", errors="ignore")
|
|
48
|
+
if auto_pad not in ("", "NOTSET"):
|
|
49
|
+
raise UnsupportedOpError("LpPool supports auto_pad=NOTSET only")
|
|
50
|
+
ceil_mode = int(node.attrs.get("ceil_mode", 0))
|
|
51
|
+
if ceil_mode != 0:
|
|
52
|
+
raise UnsupportedOpError("LpPool supports ceil_mode=0 only")
|
|
53
|
+
dilations = tuple(int(value) for value in node.attrs.get("dilations", (1, 1)))
|
|
54
|
+
if any(value != 1 for value in dilations):
|
|
55
|
+
raise UnsupportedOpError("LpPool supports dilations=1 only")
|
|
56
|
+
kernel_shape = node.attrs.get("kernel_shape")
|
|
57
|
+
if kernel_shape is None:
|
|
58
|
+
raise UnsupportedOpError("LpPool requires kernel_shape")
|
|
59
|
+
kernel_shape = tuple(int(value) for value in kernel_shape)
|
|
60
|
+
if len(kernel_shape) != 2:
|
|
61
|
+
raise UnsupportedOpError("LpPool expects 2D kernel_shape")
|
|
62
|
+
kernel_h, kernel_w = kernel_shape
|
|
63
|
+
strides = tuple(int(value) for value in node.attrs.get("strides", (1, 1)))
|
|
64
|
+
if len(strides) != 2:
|
|
65
|
+
raise UnsupportedOpError("LpPool expects 2D strides")
|
|
66
|
+
pads = tuple(int(value) for value in node.attrs.get("pads", (0, 0, 0, 0)))
|
|
67
|
+
if len(pads) != 4:
|
|
68
|
+
raise UnsupportedOpError("LpPool expects 4D pads")
|
|
69
|
+
pad_top, pad_left, pad_bottom, pad_right = pads
|
|
70
|
+
p = int(node.attrs.get("p", 2))
|
|
71
|
+
if p < 1:
|
|
72
|
+
raise UnsupportedOpError("LpPool p must be >= 1")
|
|
73
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
74
|
+
if len(input_shape) != 4:
|
|
75
|
+
raise UnsupportedOpError("LpPool supports NCHW 2D inputs only")
|
|
76
|
+
batch, channels, in_h, in_w = input_shape
|
|
77
|
+
stride_h, stride_w = strides
|
|
78
|
+
out_h = (in_h + pad_top + pad_bottom - kernel_h) // stride_h + 1
|
|
79
|
+
out_w = (in_w + pad_left + pad_right - kernel_w) // stride_w + 1
|
|
80
|
+
if out_h < 0 or out_w < 0:
|
|
81
|
+
raise ShapeInferenceError("LpPool output shape must be non-negative")
|
|
82
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
83
|
+
expected_output_shape = (batch, channels, out_h, out_w)
|
|
84
|
+
if output_shape != expected_output_shape:
|
|
85
|
+
raise ShapeInferenceError(
|
|
86
|
+
"LpPool output shape must be "
|
|
87
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
88
|
+
)
|
|
89
|
+
return LpPoolSpec(
|
|
90
|
+
batch=batch,
|
|
91
|
+
channels=channels,
|
|
92
|
+
in_h=in_h,
|
|
93
|
+
in_w=in_w,
|
|
94
|
+
out_h=out_h,
|
|
95
|
+
out_w=out_w,
|
|
96
|
+
kernel_h=kernel_h,
|
|
97
|
+
kernel_w=kernel_w,
|
|
98
|
+
stride_h=stride_h,
|
|
99
|
+
stride_w=stride_w,
|
|
100
|
+
pad_top=pad_top,
|
|
101
|
+
pad_left=pad_left,
|
|
102
|
+
pad_bottom=pad_bottom,
|
|
103
|
+
pad_right=pad_right,
|
|
104
|
+
p=p,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@register_lowering("LpPool")
|
|
109
|
+
def lower_lp_pool(graph: Graph, node: Node) -> LpPoolOp:
|
|
110
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
111
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
112
|
+
if op_dtype != output_dtype:
|
|
113
|
+
raise UnsupportedOpError(
|
|
114
|
+
"LpPool expects matching input/output dtypes, "
|
|
115
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
116
|
+
)
|
|
117
|
+
if not op_dtype.is_float:
|
|
118
|
+
raise UnsupportedOpError(
|
|
119
|
+
"LpPool supports float16, float, and double inputs only"
|
|
120
|
+
)
|
|
121
|
+
spec = _resolve_lp_pool_spec(graph, node)
|
|
122
|
+
return LpPoolOp(
|
|
123
|
+
input0=node.inputs[0],
|
|
124
|
+
output=node.outputs[0],
|
|
125
|
+
batch=spec.batch,
|
|
126
|
+
channels=spec.channels,
|
|
127
|
+
in_h=spec.in_h,
|
|
128
|
+
in_w=spec.in_w,
|
|
129
|
+
out_h=spec.out_h,
|
|
130
|
+
out_w=spec.out_w,
|
|
131
|
+
kernel_h=spec.kernel_h,
|
|
132
|
+
kernel_w=spec.kernel_w,
|
|
133
|
+
stride_h=spec.stride_h,
|
|
134
|
+
stride_w=spec.stride_w,
|
|
135
|
+
pad_top=spec.pad_top,
|
|
136
|
+
pad_left=spec.pad_left,
|
|
137
|
+
pad_bottom=spec.pad_bottom,
|
|
138
|
+
pad_right=spec.pad_right,
|
|
139
|
+
p=spec.p,
|
|
140
|
+
dtype=op_dtype,
|
|
141
|
+
)
|
emx_onnx_cgen/lowering/matmul.py
CHANGED
|
@@ -87,13 +87,12 @@ def _broadcast_batch_shapes(
|
|
|
87
87
|
right_padded = (1,) * (max_rank - len(right)) + right
|
|
88
88
|
broadcast_shape = []
|
|
89
89
|
for left_dim, right_dim in zip(left_padded, right_padded):
|
|
90
|
-
if left_dim == right_dim or left_dim == 1 or right_dim == 1:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
)
|
|
90
|
+
if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
|
|
91
|
+
raise ShapeInferenceError(
|
|
92
|
+
"MatMul batch dimensions must be broadcastable, "
|
|
93
|
+
f"got {left} x {right}"
|
|
94
|
+
)
|
|
95
|
+
broadcast_shape.append(max(left_dim, right_dim))
|
|
97
96
|
return tuple(broadcast_shape), left_padded, right_padded
|
|
98
97
|
|
|
99
98
|
|
|
@@ -43,18 +43,18 @@ def _resolve_target_shape(
|
|
|
43
43
|
raise ShapeInferenceError("Reshape allows only one -1 dimension")
|
|
44
44
|
unknown_index = index
|
|
45
45
|
output_dims.append(-1)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
46
|
+
else:
|
|
47
|
+
if dim == 0:
|
|
48
|
+
if allowzero == 0:
|
|
49
|
+
if index >= len(input_shape):
|
|
50
|
+
raise ShapeInferenceError(
|
|
51
|
+
"Reshape zero dim must index into input shape"
|
|
52
|
+
)
|
|
53
|
+
dim = input_shape[index]
|
|
54
|
+
if dim < 0:
|
|
55
|
+
raise ShapeInferenceError("Reshape dims must be >= -1")
|
|
56
|
+
output_dims.append(dim)
|
|
57
|
+
known_product *= dim
|
|
58
58
|
input_product = _shape_product(input_shape)
|
|
59
59
|
if unknown_index is not None:
|
|
60
60
|
if known_product == 0 or input_product % known_product != 0:
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import NonZeroOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@register_lowering("NonZero")
|
|
13
|
+
def lower_nonzero(graph: Graph, node: Node) -> NonZeroOp:
|
|
14
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
15
|
+
raise UnsupportedOpError("NonZero must have 1 input and 1 output")
|
|
16
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
17
|
+
if len(input_shape) == 0:
|
|
18
|
+
raise UnsupportedOpError("NonZero does not support scalar inputs")
|
|
19
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
20
|
+
if len(output_shape) != 2:
|
|
21
|
+
raise ShapeInferenceError("NonZero output must be 2D")
|
|
22
|
+
if output_shape[0] != len(input_shape):
|
|
23
|
+
raise ShapeInferenceError(
|
|
24
|
+
"NonZero output shape must be "
|
|
25
|
+
f"({len(input_shape)}, N), got {output_shape}"
|
|
26
|
+
)
|
|
27
|
+
if output_shape[0] < 0 or output_shape[1] < 0:
|
|
28
|
+
raise ShapeInferenceError(
|
|
29
|
+
"NonZero output shape must be non-negative"
|
|
30
|
+
)
|
|
31
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
32
|
+
if output_dtype != ScalarType.I64:
|
|
33
|
+
raise UnsupportedOpError("NonZero output dtype must be int64")
|
|
34
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
35
|
+
return NonZeroOp(
|
|
36
|
+
input0=node.inputs[0],
|
|
37
|
+
output=node.outputs[0],
|
|
38
|
+
input_shape=input_shape,
|
|
39
|
+
output_shape=output_shape,
|
|
40
|
+
dtype=output_dtype,
|
|
41
|
+
input_dtype=input_dtype,
|
|
42
|
+
)
|