emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +372 -64
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +169 -343
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +406 -11
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +2 -4
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +6 -2
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +7 -8
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +6 -7
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +224 -52
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +6 -2
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +6 -6
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +134 -0
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +6 -6
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +785 -43
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +31 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
- emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +60 -17
- shared/ulp.py +65 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/lowering/expand.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import ExpandOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import value_dtype, value_shape
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import ReshapeOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from .common import shape_product, value_dtype, value_shape
|
emx_onnx_cgen/lowering/gather.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import GatherOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from ..validation import normalize_axis
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import GatherElementsOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from ..validation import normalize_axis
|
|
@@ -33,9 +33,7 @@ def lower_gather_elements(graph: Graph, node: Node) -> GatherElementsOp:
|
|
|
33
33
|
for dim_index, (data_dim, index_dim) in enumerate(
|
|
34
34
|
zip(data_shape, indices_shape)
|
|
35
35
|
):
|
|
36
|
-
if dim_index
|
|
37
|
-
continue
|
|
38
|
-
if data_dim != index_dim:
|
|
36
|
+
if dim_index != axis and data_dim != index_dim:
|
|
39
37
|
raise ShapeInferenceError(
|
|
40
38
|
"GatherElements inputs must match on non-axis dimensions, "
|
|
41
39
|
f"got {data_shape} and {indices_shape}"
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import GatherNDOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype as _value_dtype
|
|
9
|
+
from .common import value_shape as _value_shape
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_lowering("GatherND")
|
|
14
|
+
def lower_gather_nd(graph: Graph, node: Node) -> GatherNDOp:
|
|
15
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
16
|
+
raise UnsupportedOpError("GatherND must have 2 inputs and 1 output")
|
|
17
|
+
data_name, indices_name = node.inputs
|
|
18
|
+
output_name = node.outputs[0]
|
|
19
|
+
data_shape = _value_shape(graph, data_name, node)
|
|
20
|
+
indices_shape = _value_shape(graph, indices_name, node)
|
|
21
|
+
output_shape = _value_shape(graph, output_name, node)
|
|
22
|
+
if len(indices_shape) < 1:
|
|
23
|
+
raise ShapeInferenceError("GatherND indices must have rank >= 1")
|
|
24
|
+
batch_dims = int(node.attrs.get("batch_dims", 0))
|
|
25
|
+
if batch_dims < 0:
|
|
26
|
+
raise ShapeInferenceError(
|
|
27
|
+
f"GatherND batch_dims must be >= 0, got {batch_dims}"
|
|
28
|
+
)
|
|
29
|
+
if batch_dims > len(indices_shape) - 1:
|
|
30
|
+
raise ShapeInferenceError(
|
|
31
|
+
"GatherND batch_dims must be <= indices rank - 1, "
|
|
32
|
+
f"got {batch_dims} vs {len(indices_shape) - 1}"
|
|
33
|
+
)
|
|
34
|
+
if batch_dims > len(data_shape):
|
|
35
|
+
raise ShapeInferenceError(
|
|
36
|
+
"GatherND batch_dims must be <= data rank, "
|
|
37
|
+
f"got {batch_dims} vs {len(data_shape)}"
|
|
38
|
+
)
|
|
39
|
+
if tuple(data_shape[:batch_dims]) != tuple(indices_shape[:batch_dims]):
|
|
40
|
+
raise ShapeInferenceError(
|
|
41
|
+
"GatherND batch_dims must match on data/indices, "
|
|
42
|
+
f"got {data_shape} vs {indices_shape}"
|
|
43
|
+
)
|
|
44
|
+
index_depth = indices_shape[-1]
|
|
45
|
+
if index_depth <= 0:
|
|
46
|
+
raise ShapeInferenceError(
|
|
47
|
+
"GatherND indices final dimension must be >= 1"
|
|
48
|
+
)
|
|
49
|
+
if index_depth > len(data_shape) - batch_dims:
|
|
50
|
+
raise ShapeInferenceError(
|
|
51
|
+
"GatherND indices final dimension must be <= data rank - "
|
|
52
|
+
f"batch_dims, got {index_depth} vs {len(data_shape) - batch_dims}"
|
|
53
|
+
)
|
|
54
|
+
expected_output_shape = indices_shape[:-1] + data_shape[
|
|
55
|
+
batch_dims + index_depth :
|
|
56
|
+
]
|
|
57
|
+
if output_shape != expected_output_shape:
|
|
58
|
+
raise ShapeInferenceError(
|
|
59
|
+
"GatherND output shape must be "
|
|
60
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
61
|
+
)
|
|
62
|
+
data_dtype = _value_dtype(graph, data_name, node)
|
|
63
|
+
indices_dtype = _value_dtype(graph, indices_name, node)
|
|
64
|
+
if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
65
|
+
raise UnsupportedOpError(
|
|
66
|
+
"GatherND indices must be int32 or int64, "
|
|
67
|
+
f"got {indices_dtype.onnx_name}"
|
|
68
|
+
)
|
|
69
|
+
return GatherNDOp(
|
|
70
|
+
data=data_name,
|
|
71
|
+
indices=indices_name,
|
|
72
|
+
output=output_name,
|
|
73
|
+
batch_dims=batch_dims,
|
|
74
|
+
data_shape=data_shape,
|
|
75
|
+
indices_shape=indices_shape,
|
|
76
|
+
output_shape=output_shape,
|
|
77
|
+
dtype=data_dtype,
|
|
78
|
+
indices_dtype=indices_dtype,
|
|
79
|
+
)
|
emx_onnx_cgen/lowering/gemm.py
CHANGED
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import GemmOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Node
|
|
10
10
|
from .common import node_dtype as _node_dtype
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import ReduceOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype as _value_dtype
|
|
9
|
+
from .common import value_shape as _value_shape
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_lowering("GlobalMaxPool")
|
|
14
|
+
def lower_global_max_pool(graph: Graph, node: Node) -> ReduceOp:
|
|
15
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
16
|
+
raise UnsupportedOpError("GlobalMaxPool must have 1 input and 1 output")
|
|
17
|
+
if node.attrs:
|
|
18
|
+
raise UnsupportedOpError("GlobalMaxPool has unsupported attributes")
|
|
19
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
20
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
21
|
+
if op_dtype != output_dtype:
|
|
22
|
+
raise UnsupportedOpError(
|
|
23
|
+
"GlobalMaxPool expects matching input/output dtypes, "
|
|
24
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
25
|
+
)
|
|
26
|
+
if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
|
|
27
|
+
raise UnsupportedOpError(
|
|
28
|
+
"GlobalMaxPool supports float16, float, and double inputs only"
|
|
29
|
+
)
|
|
30
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
31
|
+
if len(input_shape) < 3:
|
|
32
|
+
raise UnsupportedOpError(
|
|
33
|
+
"GlobalMaxPool expects input rank of at least 3"
|
|
34
|
+
)
|
|
35
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
36
|
+
expected_output_shape = (input_shape[0], input_shape[1]) + (
|
|
37
|
+
1,
|
|
38
|
+
) * (len(input_shape) - 2)
|
|
39
|
+
if output_shape != expected_output_shape:
|
|
40
|
+
raise ShapeInferenceError(
|
|
41
|
+
"GlobalMaxPool output shape must be "
|
|
42
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
43
|
+
)
|
|
44
|
+
axes = tuple(range(2, len(input_shape)))
|
|
45
|
+
return ReduceOp(
|
|
46
|
+
input0=node.inputs[0],
|
|
47
|
+
output=node.outputs[0],
|
|
48
|
+
input_shape=input_shape,
|
|
49
|
+
output_shape=output_shape,
|
|
50
|
+
axes=axes,
|
|
51
|
+
axes_input=None,
|
|
52
|
+
axes_input_shape=None,
|
|
53
|
+
axes_input_dtype=None,
|
|
54
|
+
keepdims=True,
|
|
55
|
+
noop_with_empty_axes=False,
|
|
56
|
+
reduce_kind="max",
|
|
57
|
+
reduce_count=None,
|
|
58
|
+
dtype=op_dtype,
|
|
59
|
+
)
|
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import GridSampleOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Node
|
|
10
10
|
from .common import value_dtype, value_shape
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import GroupNormalizationOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import HardmaxOp
|
|
6
|
+
from ..errors import UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import node_dtype as _node_dtype
|
|
9
|
+
from .common import onnx_opset_version as _onnx_opset_version
|
|
10
|
+
from .common import shape_product as _shape_product
|
|
11
|
+
from .common import value_shape as _value_shape
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
from ..validation import ensure_output_shape_matches_input
|
|
14
|
+
from ..validation import normalize_axis as _normalize_axis
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_lowering("Hardmax")
|
|
18
|
+
def lower_hardmax(graph: Graph, node: Node) -> HardmaxOp:
|
|
19
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
20
|
+
raise UnsupportedOpError("Hardmax must have 1 input and 1 output")
|
|
21
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
22
|
+
if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
|
|
23
|
+
raise UnsupportedOpError(
|
|
24
|
+
"Hardmax supports float16, float, and double inputs only"
|
|
25
|
+
)
|
|
26
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
27
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
28
|
+
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
29
|
+
opset_version = _onnx_opset_version(graph)
|
|
30
|
+
default_axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
31
|
+
axis_attr = node.attrs.get("axis", default_axis)
|
|
32
|
+
axis = _normalize_axis(
|
|
33
|
+
int(axis_attr),
|
|
34
|
+
input_shape,
|
|
35
|
+
node,
|
|
36
|
+
)
|
|
37
|
+
outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
|
|
38
|
+
axis_size = input_shape[axis]
|
|
39
|
+
inner = (
|
|
40
|
+
_shape_product(input_shape[axis + 1 :])
|
|
41
|
+
if axis + 1 < len(input_shape)
|
|
42
|
+
else 1
|
|
43
|
+
)
|
|
44
|
+
return HardmaxOp(
|
|
45
|
+
input0=node.inputs[0],
|
|
46
|
+
output=node.outputs[0],
|
|
47
|
+
outer=outer,
|
|
48
|
+
axis_size=axis_size,
|
|
49
|
+
inner=inner,
|
|
50
|
+
axis=axis,
|
|
51
|
+
shape=input_shape,
|
|
52
|
+
dtype=op_dtype,
|
|
53
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import IdentityOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from .common import value_dtype, value_shape
|
|
@@ -22,11 +22,12 @@ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
|
|
|
22
22
|
for index, (input_dim, output_dim) in enumerate(
|
|
23
23
|
zip(input_shape, output_shape)
|
|
24
24
|
):
|
|
25
|
-
if input_dim
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
25
|
+
if input_dim != output_dim and not (
|
|
26
|
+
input_dim_params[index] or output_dim_params[index]
|
|
27
|
+
):
|
|
28
|
+
raise ShapeInferenceError(
|
|
29
|
+
"Identity input and output shapes must match"
|
|
30
|
+
)
|
|
30
31
|
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
31
32
|
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
32
33
|
if input_dtype != output_dtype:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import InstanceNormalizationOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import LayerNormalizationOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import LogSoftmaxOp
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from .common import node_dtype as _node_dtype
|
|
7
|
+
from .common import onnx_opset_version as _onnx_opset_version
|
|
7
8
|
from .common import shape_product as _shape_product
|
|
8
9
|
from .common import value_shape as _value_shape
|
|
9
10
|
from .registry import register_lowering
|
|
@@ -23,8 +24,11 @@ def lower_logsoftmax(graph: Graph, node: Node) -> LogSoftmaxOp:
|
|
|
23
24
|
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
24
25
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
25
26
|
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
27
|
+
opset_version = _onnx_opset_version(graph)
|
|
28
|
+
default_axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
29
|
+
axis_attr = node.attrs.get("axis", default_axis)
|
|
26
30
|
axis = _normalize_axis(
|
|
27
|
-
int(
|
|
31
|
+
int(axis_attr),
|
|
28
32
|
input_shape,
|
|
29
33
|
node,
|
|
30
34
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import LpNormalizationOp
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import LpPoolOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
from .common import value_dtype as _value_dtype, value_shape as _value_shape
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class LpPoolSpec:
|
|
14
|
+
batch: int
|
|
15
|
+
channels: int
|
|
16
|
+
in_h: int
|
|
17
|
+
in_w: int
|
|
18
|
+
out_h: int
|
|
19
|
+
out_w: int
|
|
20
|
+
kernel_h: int
|
|
21
|
+
kernel_w: int
|
|
22
|
+
stride_h: int
|
|
23
|
+
stride_w: int
|
|
24
|
+
pad_top: int
|
|
25
|
+
pad_left: int
|
|
26
|
+
pad_bottom: int
|
|
27
|
+
pad_right: int
|
|
28
|
+
p: int
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
|
|
32
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
33
|
+
raise UnsupportedOpError("LpPool must have 1 input and 1 output")
|
|
34
|
+
supported_attrs = {
|
|
35
|
+
"auto_pad",
|
|
36
|
+
"ceil_mode",
|
|
37
|
+
"dilations",
|
|
38
|
+
"kernel_shape",
|
|
39
|
+
"pads",
|
|
40
|
+
"p",
|
|
41
|
+
"strides",
|
|
42
|
+
}
|
|
43
|
+
if set(node.attrs) - supported_attrs:
|
|
44
|
+
raise UnsupportedOpError("LpPool has unsupported attributes")
|
|
45
|
+
auto_pad = node.attrs.get("auto_pad", b"NOTSET")
|
|
46
|
+
if isinstance(auto_pad, bytes):
|
|
47
|
+
auto_pad = auto_pad.decode("utf-8", errors="ignore")
|
|
48
|
+
if auto_pad not in ("", "NOTSET"):
|
|
49
|
+
raise UnsupportedOpError("LpPool supports auto_pad=NOTSET only")
|
|
50
|
+
ceil_mode = int(node.attrs.get("ceil_mode", 0))
|
|
51
|
+
if ceil_mode != 0:
|
|
52
|
+
raise UnsupportedOpError("LpPool supports ceil_mode=0 only")
|
|
53
|
+
dilations = tuple(int(value) for value in node.attrs.get("dilations", (1, 1)))
|
|
54
|
+
if any(value != 1 for value in dilations):
|
|
55
|
+
raise UnsupportedOpError("LpPool supports dilations=1 only")
|
|
56
|
+
kernel_shape = node.attrs.get("kernel_shape")
|
|
57
|
+
if kernel_shape is None:
|
|
58
|
+
raise UnsupportedOpError("LpPool requires kernel_shape")
|
|
59
|
+
kernel_shape = tuple(int(value) for value in kernel_shape)
|
|
60
|
+
if len(kernel_shape) != 2:
|
|
61
|
+
raise UnsupportedOpError("LpPool expects 2D kernel_shape")
|
|
62
|
+
kernel_h, kernel_w = kernel_shape
|
|
63
|
+
strides = tuple(int(value) for value in node.attrs.get("strides", (1, 1)))
|
|
64
|
+
if len(strides) != 2:
|
|
65
|
+
raise UnsupportedOpError("LpPool expects 2D strides")
|
|
66
|
+
pads = tuple(int(value) for value in node.attrs.get("pads", (0, 0, 0, 0)))
|
|
67
|
+
if len(pads) != 4:
|
|
68
|
+
raise UnsupportedOpError("LpPool expects 4D pads")
|
|
69
|
+
pad_top, pad_left, pad_bottom, pad_right = pads
|
|
70
|
+
p = int(node.attrs.get("p", 2))
|
|
71
|
+
if p < 1:
|
|
72
|
+
raise UnsupportedOpError("LpPool p must be >= 1")
|
|
73
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
74
|
+
if len(input_shape) != 4:
|
|
75
|
+
raise UnsupportedOpError("LpPool supports NCHW 2D inputs only")
|
|
76
|
+
batch, channels, in_h, in_w = input_shape
|
|
77
|
+
stride_h, stride_w = strides
|
|
78
|
+
out_h = (in_h + pad_top + pad_bottom - kernel_h) // stride_h + 1
|
|
79
|
+
out_w = (in_w + pad_left + pad_right - kernel_w) // stride_w + 1
|
|
80
|
+
if out_h < 0 or out_w < 0:
|
|
81
|
+
raise ShapeInferenceError("LpPool output shape must be non-negative")
|
|
82
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
83
|
+
expected_output_shape = (batch, channels, out_h, out_w)
|
|
84
|
+
if output_shape != expected_output_shape:
|
|
85
|
+
raise ShapeInferenceError(
|
|
86
|
+
"LpPool output shape must be "
|
|
87
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
88
|
+
)
|
|
89
|
+
return LpPoolSpec(
|
|
90
|
+
batch=batch,
|
|
91
|
+
channels=channels,
|
|
92
|
+
in_h=in_h,
|
|
93
|
+
in_w=in_w,
|
|
94
|
+
out_h=out_h,
|
|
95
|
+
out_w=out_w,
|
|
96
|
+
kernel_h=kernel_h,
|
|
97
|
+
kernel_w=kernel_w,
|
|
98
|
+
stride_h=stride_h,
|
|
99
|
+
stride_w=stride_w,
|
|
100
|
+
pad_top=pad_top,
|
|
101
|
+
pad_left=pad_left,
|
|
102
|
+
pad_bottom=pad_bottom,
|
|
103
|
+
pad_right=pad_right,
|
|
104
|
+
p=p,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@register_lowering("LpPool")
|
|
109
|
+
def lower_lp_pool(graph: Graph, node: Node) -> LpPoolOp:
|
|
110
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
111
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
112
|
+
if op_dtype != output_dtype:
|
|
113
|
+
raise UnsupportedOpError(
|
|
114
|
+
"LpPool expects matching input/output dtypes, "
|
|
115
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
116
|
+
)
|
|
117
|
+
if not op_dtype.is_float:
|
|
118
|
+
raise UnsupportedOpError(
|
|
119
|
+
"LpPool supports float16, float, and double inputs only"
|
|
120
|
+
)
|
|
121
|
+
spec = _resolve_lp_pool_spec(graph, node)
|
|
122
|
+
return LpPoolOp(
|
|
123
|
+
input0=node.inputs[0],
|
|
124
|
+
output=node.outputs[0],
|
|
125
|
+
batch=spec.batch,
|
|
126
|
+
channels=spec.channels,
|
|
127
|
+
in_h=spec.in_h,
|
|
128
|
+
in_w=spec.in_w,
|
|
129
|
+
out_h=spec.out_h,
|
|
130
|
+
out_w=spec.out_w,
|
|
131
|
+
kernel_h=spec.kernel_h,
|
|
132
|
+
kernel_w=spec.kernel_w,
|
|
133
|
+
stride_h=spec.stride_h,
|
|
134
|
+
stride_w=spec.stride_w,
|
|
135
|
+
pad_top=spec.pad_top,
|
|
136
|
+
pad_left=spec.pad_left,
|
|
137
|
+
pad_bottom=spec.pad_bottom,
|
|
138
|
+
pad_right=spec.pad_right,
|
|
139
|
+
p=spec.p,
|
|
140
|
+
dtype=op_dtype,
|
|
141
|
+
)
|
emx_onnx_cgen/lowering/lrn.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import LrnOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .registry import register_lowering
|
emx_onnx_cgen/lowering/lstm.py
CHANGED
|
@@ -323,7 +323,7 @@ def resolve_lstm_spec(graph: Graph, node: Node) -> LstmSpec:
|
|
|
323
323
|
|
|
324
324
|
@register_lowering("LSTM")
|
|
325
325
|
def lower_lstm(graph: Graph, node: Node) -> "LstmOp":
|
|
326
|
-
from ..
|
|
326
|
+
from ..ir.ops import LstmOp
|
|
327
327
|
|
|
328
328
|
spec = resolve_lstm_spec(graph, node)
|
|
329
329
|
return LstmOp(
|
emx_onnx_cgen/lowering/matmul.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import MatMulOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import node_dtype as _node_dtype
|
|
@@ -87,13 +87,12 @@ def _broadcast_batch_shapes(
|
|
|
87
87
|
right_padded = (1,) * (max_rank - len(right)) + right
|
|
88
88
|
broadcast_shape = []
|
|
89
89
|
for left_dim, right_dim in zip(left_padded, right_padded):
|
|
90
|
-
if left_dim == right_dim or left_dim == 1 or right_dim == 1:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
)
|
|
90
|
+
if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
|
|
91
|
+
raise ShapeInferenceError(
|
|
92
|
+
"MatMul batch dimensions must be broadcastable, "
|
|
93
|
+
f"got {left} x {right}"
|
|
94
|
+
)
|
|
95
|
+
broadcast_shape.append(max(left_dim, right_dim))
|
|
97
96
|
return tuple(broadcast_shape), left_padded, right_padded
|
|
98
97
|
|
|
99
98
|
|
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|
|
5
5
|
|
|
6
6
|
from shared.scalar_types import ScalarType
|
|
7
7
|
|
|
8
|
-
from ..
|
|
8
|
+
from ..ir.ops import MaxPoolOp
|
|
9
9
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
10
|
from ..ir.model import Graph, Node
|
|
11
11
|
from .common import node_dtype as _node_dtype
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import MeanVarianceNormalizationOp
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import NegativeLogLikelihoodLossOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Initializer, Node
|
|
8
8
|
from .common import shape_product as _shape_product
|
|
@@ -43,18 +43,18 @@ def _resolve_target_shape(
|
|
|
43
43
|
raise ShapeInferenceError("Reshape allows only one -1 dimension")
|
|
44
44
|
unknown_index = index
|
|
45
45
|
output_dims.append(-1)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
46
|
+
else:
|
|
47
|
+
if dim == 0:
|
|
48
|
+
if allowzero == 0:
|
|
49
|
+
if index >= len(input_shape):
|
|
50
|
+
raise ShapeInferenceError(
|
|
51
|
+
"Reshape zero dim must index into input shape"
|
|
52
|
+
)
|
|
53
|
+
dim = input_shape[index]
|
|
54
|
+
if dim < 0:
|
|
55
|
+
raise ShapeInferenceError("Reshape dims must be >= -1")
|
|
56
|
+
output_dims.append(dim)
|
|
57
|
+
known_product *= dim
|
|
58
58
|
input_product = _shape_product(input_shape)
|
|
59
59
|
if unknown_index is not None:
|
|
60
60
|
if known_product == 0 or input_product % known_product != 0:
|