emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +372 -64
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +169 -343
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +406 -11
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +2 -4
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +6 -2
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +7 -8
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +6 -7
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +224 -52
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +6 -2
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +6 -6
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +134 -0
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +6 -6
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +785 -43
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +31 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
- emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +60 -17
- shared/ulp.py +65 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
|
|
1
5
|
from .registry import get_lowering, register_lowering
|
|
2
6
|
|
|
3
|
-
|
|
7
|
+
_LOWERING_MODULES = [
|
|
8
|
+
"adagrad",
|
|
9
|
+
"arg_reduce",
|
|
10
|
+
"attention",
|
|
11
|
+
"average_pool",
|
|
12
|
+
"batch_normalization",
|
|
13
|
+
"cast",
|
|
14
|
+
"concat",
|
|
15
|
+
"constant_of_shape",
|
|
16
|
+
"conv",
|
|
17
|
+
"conv_transpose",
|
|
18
|
+
"cumsum",
|
|
19
|
+
"depth_space",
|
|
20
|
+
"dropout",
|
|
21
|
+
"einsum",
|
|
22
|
+
"elementwise",
|
|
23
|
+
"expand",
|
|
24
|
+
"eye_like",
|
|
25
|
+
"flatten",
|
|
26
|
+
"gather",
|
|
27
|
+
"gather_elements",
|
|
28
|
+
"gather_nd",
|
|
29
|
+
"gemm",
|
|
30
|
+
"global_max_pool",
|
|
31
|
+
"grid_sample",
|
|
32
|
+
"group_normalization",
|
|
33
|
+
"hardmax",
|
|
34
|
+
"identity",
|
|
35
|
+
"instance_normalization",
|
|
36
|
+
"layer_normalization",
|
|
37
|
+
"logsoftmax",
|
|
38
|
+
"lp_normalization",
|
|
39
|
+
"lp_pool",
|
|
40
|
+
"lrn",
|
|
41
|
+
"lstm",
|
|
42
|
+
"matmul",
|
|
43
|
+
"maxpool",
|
|
44
|
+
"mean_variance_normalization",
|
|
45
|
+
"negative_log_likelihood_loss",
|
|
46
|
+
"non_max_suppression",
|
|
47
|
+
"nonzero",
|
|
48
|
+
"one_hot",
|
|
49
|
+
"pad",
|
|
50
|
+
"qlinear_matmul",
|
|
51
|
+
"quantize_linear",
|
|
52
|
+
"range",
|
|
53
|
+
"reduce",
|
|
54
|
+
"reshape",
|
|
55
|
+
"resize",
|
|
56
|
+
"rms_normalization",
|
|
57
|
+
"rotary_embedding",
|
|
58
|
+
"scatter_nd",
|
|
59
|
+
"shape",
|
|
60
|
+
"size",
|
|
61
|
+
"slice",
|
|
62
|
+
"softmax",
|
|
63
|
+
"softmax_cross_entropy_loss",
|
|
64
|
+
"split",
|
|
65
|
+
"squeeze",
|
|
66
|
+
"tensor_scatter",
|
|
67
|
+
"tile",
|
|
68
|
+
"topk",
|
|
69
|
+
"transpose",
|
|
70
|
+
"trilu",
|
|
71
|
+
"unsqueeze",
|
|
72
|
+
"variadic",
|
|
73
|
+
"where",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def load_lowering_registry() -> None:
|
|
78
|
+
for module_name in _LOWERING_MODULES:
|
|
79
|
+
importlib.import_module(f"{__name__}.{module_name}")
|
|
80
|
+
|
|
81
|
+
__all__ = ["get_lowering", "register_lowering", "load_lowering_registry"]
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import AdagradOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
|
|
13
|
+
return shape == () or shape == (1,)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@register_lowering("Adagrad")
|
|
17
|
+
def lower_adagrad(graph: Graph, node: Node) -> AdagradOp:
|
|
18
|
+
if len(node.inputs) < 5:
|
|
19
|
+
raise UnsupportedOpError("Adagrad must have at least 5 inputs")
|
|
20
|
+
if len(node.outputs) < 2:
|
|
21
|
+
raise UnsupportedOpError("Adagrad must have at least 2 outputs")
|
|
22
|
+
if (len(node.inputs) - 2) % 3 != 0:
|
|
23
|
+
raise UnsupportedOpError(
|
|
24
|
+
"Adagrad inputs must be R, T, Xs, Gs, Hs with matching counts"
|
|
25
|
+
)
|
|
26
|
+
tensor_count = (len(node.inputs) - 2) // 3
|
|
27
|
+
if len(node.outputs) != tensor_count * 2:
|
|
28
|
+
raise UnsupportedOpError(
|
|
29
|
+
"Adagrad outputs must be X_news followed by H_news"
|
|
30
|
+
)
|
|
31
|
+
rate_name = node.inputs[0]
|
|
32
|
+
timestep_name = node.inputs[1]
|
|
33
|
+
rate_shape = value_shape(graph, rate_name, node)
|
|
34
|
+
timestep_shape = value_shape(graph, timestep_name, node)
|
|
35
|
+
if not _is_scalar_shape(rate_shape):
|
|
36
|
+
raise UnsupportedOpError("Adagrad R input must be a scalar")
|
|
37
|
+
if not _is_scalar_shape(timestep_shape):
|
|
38
|
+
raise UnsupportedOpError("Adagrad T input must be a scalar")
|
|
39
|
+
rate_dtype = value_dtype(graph, rate_name, node)
|
|
40
|
+
if rate_dtype not in {ScalarType.F32, ScalarType.F64}:
|
|
41
|
+
raise UnsupportedOpError(
|
|
42
|
+
"Adagrad R input must be float or double"
|
|
43
|
+
)
|
|
44
|
+
timestep_dtype = value_dtype(graph, timestep_name, node)
|
|
45
|
+
if timestep_dtype != ScalarType.I64:
|
|
46
|
+
raise UnsupportedOpError("Adagrad T input must be int64")
|
|
47
|
+
|
|
48
|
+
inputs = node.inputs[2 : 2 + tensor_count]
|
|
49
|
+
gradients = node.inputs[2 + tensor_count : 2 + tensor_count * 2]
|
|
50
|
+
accumulators = node.inputs[2 + tensor_count * 2 : 2 + tensor_count * 3]
|
|
51
|
+
outputs = node.outputs[:tensor_count]
|
|
52
|
+
accumulator_outputs = node.outputs[tensor_count:]
|
|
53
|
+
if not inputs or not gradients or not accumulators:
|
|
54
|
+
raise UnsupportedOpError("Adagrad requires X, G, H inputs")
|
|
55
|
+
dtype = value_dtype(graph, inputs[0], node)
|
|
56
|
+
if dtype not in {ScalarType.F32, ScalarType.F64}:
|
|
57
|
+
raise UnsupportedOpError("Adagrad supports float and double tensors only")
|
|
58
|
+
if rate_dtype != dtype:
|
|
59
|
+
raise UnsupportedOpError(
|
|
60
|
+
"Adagrad R input dtype must match tensor dtype"
|
|
61
|
+
)
|
|
62
|
+
input_shapes: list[tuple[int, ...]] = []
|
|
63
|
+
output_shapes: list[tuple[int, ...]] = []
|
|
64
|
+
for index, (x_name, g_name, h_name, out_name, h_out_name) in enumerate(
|
|
65
|
+
zip(inputs, gradients, accumulators, outputs, accumulator_outputs)
|
|
66
|
+
):
|
|
67
|
+
x_dtype = value_dtype(graph, x_name, node)
|
|
68
|
+
g_dtype = value_dtype(graph, g_name, node)
|
|
69
|
+
h_dtype = value_dtype(graph, h_name, node)
|
|
70
|
+
out_dtype = value_dtype(graph, out_name, node)
|
|
71
|
+
h_out_dtype = value_dtype(graph, h_out_name, node)
|
|
72
|
+
if {x_dtype, g_dtype, h_dtype, out_dtype, h_out_dtype} != {dtype}:
|
|
73
|
+
raise UnsupportedOpError(
|
|
74
|
+
"Adagrad inputs and outputs must share the same dtype"
|
|
75
|
+
)
|
|
76
|
+
x_shape = value_shape(graph, x_name, node)
|
|
77
|
+
g_shape = value_shape(graph, g_name, node)
|
|
78
|
+
h_shape = value_shape(graph, h_name, node)
|
|
79
|
+
out_shape = value_shape(graph, out_name, node)
|
|
80
|
+
h_out_shape = value_shape(graph, h_out_name, node)
|
|
81
|
+
if x_shape != g_shape or x_shape != h_shape:
|
|
82
|
+
raise ShapeInferenceError(
|
|
83
|
+
f"Adagrad inputs X/G/H shapes must match for tensor {index}"
|
|
84
|
+
)
|
|
85
|
+
if out_shape != x_shape or h_out_shape != x_shape:
|
|
86
|
+
raise ShapeInferenceError(
|
|
87
|
+
f"Adagrad outputs must match X shape for tensor {index}"
|
|
88
|
+
)
|
|
89
|
+
input_shapes.append(x_shape)
|
|
90
|
+
output_shapes.append(out_shape)
|
|
91
|
+
|
|
92
|
+
norm_coefficient = float(node.attrs.get("norm_coefficient", 0.0))
|
|
93
|
+
epsilon = float(node.attrs.get("epsilon", 0.0))
|
|
94
|
+
decay_factor = float(node.attrs.get("decay_factor", 0.0))
|
|
95
|
+
|
|
96
|
+
return AdagradOp(
|
|
97
|
+
rate=rate_name,
|
|
98
|
+
timestep=timestep_name,
|
|
99
|
+
inputs=tuple(inputs),
|
|
100
|
+
gradients=tuple(gradients),
|
|
101
|
+
accumulators=tuple(accumulators),
|
|
102
|
+
outputs=tuple(outputs),
|
|
103
|
+
accumulator_outputs=tuple(accumulator_outputs),
|
|
104
|
+
rate_shape=rate_shape,
|
|
105
|
+
timestep_shape=timestep_shape,
|
|
106
|
+
tensor_shapes=tuple(input_shapes),
|
|
107
|
+
output_shapes=tuple(output_shapes),
|
|
108
|
+
dtype=dtype,
|
|
109
|
+
rate_dtype=rate_dtype,
|
|
110
|
+
timestep_dtype=timestep_dtype,
|
|
111
|
+
norm_coefficient=norm_coefficient,
|
|
112
|
+
epsilon=epsilon,
|
|
113
|
+
decay_factor=decay_factor,
|
|
114
|
+
)
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import ArgReduceOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import shape_product, value_dtype, value_shape
|
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|
|
5
5
|
|
|
6
6
|
from shared.scalar_types import ScalarType
|
|
7
7
|
|
|
8
|
-
from ..
|
|
8
|
+
from ..ir.ops import AttentionOp
|
|
9
9
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
10
|
from ..ir.model import Graph, Node
|
|
11
11
|
from .common import node_dtype as _node_dtype
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import AveragePoolOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .registry import register_lowering
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import BatchNormOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .registry import register_lowering
|
emx_onnx_cgen/lowering/cast.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import onnx
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import CastOp
|
|
6
6
|
from ..dtypes import scalar_type_from_onnx
|
|
7
7
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
8
|
from ..ir.model import Graph, Node
|
emx_onnx_cgen/lowering/common.py
CHANGED
|
@@ -5,7 +5,8 @@ from collections.abc import Sequence
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
7
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
-
from ..ir.
|
|
8
|
+
from ..ir.context import GraphContext
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
|
|
@@ -14,7 +15,24 @@ def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
|
|
|
14
15
|
return dtype
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
def
|
|
18
|
+
def onnx_opset_version(graph: Graph | GraphContext, domain: str = "") -> int | None:
|
|
19
|
+
if isinstance(graph, GraphContext):
|
|
20
|
+
return graph.opset_version(domain)
|
|
21
|
+
if domain in {"", "ai.onnx"}:
|
|
22
|
+
domains = {"", "ai.onnx"}
|
|
23
|
+
else:
|
|
24
|
+
domains = {domain}
|
|
25
|
+
for opset_domain, version in graph.opset_imports:
|
|
26
|
+
if opset_domain in domains:
|
|
27
|
+
return int(version)
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def value_dtype(
|
|
32
|
+
graph: Graph | GraphContext, name: str, node: Node | None = None
|
|
33
|
+
) -> ScalarType:
|
|
34
|
+
if isinstance(graph, GraphContext):
|
|
35
|
+
return graph.dtype(name, node)
|
|
18
36
|
try:
|
|
19
37
|
value = graph.find_value(name)
|
|
20
38
|
except KeyError as exc:
|
|
@@ -26,18 +44,395 @@ def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType
|
|
|
26
44
|
return ensure_supported_dtype(value.type.dtype)
|
|
27
45
|
|
|
28
46
|
|
|
29
|
-
def value_shape(
|
|
47
|
+
def value_shape(
|
|
48
|
+
graph: Graph | GraphContext, name: str, node: Node | None = None
|
|
49
|
+
) -> tuple[int, ...]:
|
|
50
|
+
if isinstance(graph, GraphContext):
|
|
51
|
+
shape = graph.shape(name, node)
|
|
52
|
+
value = graph.find_value(name)
|
|
53
|
+
else:
|
|
54
|
+
try:
|
|
55
|
+
value = graph.find_value(name)
|
|
56
|
+
except KeyError as exc:
|
|
57
|
+
op_type = node.op_type if node is not None else "unknown"
|
|
58
|
+
raise ShapeInferenceError(
|
|
59
|
+
f"Missing shape for value '{name}' in op {op_type}. "
|
|
60
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
61
|
+
) from exc
|
|
62
|
+
shape = value.type.shape
|
|
63
|
+
if any(value.type.dim_params):
|
|
64
|
+
resolved = _resolve_value_shape(graph, name, node)
|
|
65
|
+
if resolved is not None:
|
|
66
|
+
return resolved
|
|
67
|
+
return value.type.shape
|
|
68
|
+
return shape
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _find_initializer(graph: Graph | GraphContext, name: str) -> Initializer | None:
|
|
72
|
+
if isinstance(graph, GraphContext):
|
|
73
|
+
return graph.initializer(name)
|
|
74
|
+
for initializer in graph.initializers:
|
|
75
|
+
if initializer.name == name:
|
|
76
|
+
return initializer
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _find_node_by_output(graph: Graph | GraphContext, name: str) -> Node | None:
|
|
81
|
+
if isinstance(graph, GraphContext):
|
|
82
|
+
return graph.producer(name)
|
|
83
|
+
for node in graph.nodes:
|
|
84
|
+
if name in node.outputs:
|
|
85
|
+
return node
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _shape_values_from_shape_node(
|
|
90
|
+
graph: Graph | GraphContext, shape_node: Node, node: Node | None
|
|
91
|
+
) -> list[int]:
|
|
92
|
+
if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
|
|
93
|
+
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
94
|
+
source_shape = value_shape(graph, shape_node.inputs[0], node)
|
|
95
|
+
start = int(shape_node.attrs.get("start", 0))
|
|
96
|
+
end = int(shape_node.attrs.get("end", len(source_shape)))
|
|
97
|
+
if start < 0:
|
|
98
|
+
start += len(source_shape)
|
|
99
|
+
if end < 0:
|
|
100
|
+
end += len(source_shape)
|
|
101
|
+
start = max(start, 0)
|
|
102
|
+
end = min(end, len(source_shape))
|
|
103
|
+
if start > end:
|
|
104
|
+
return []
|
|
105
|
+
return list(source_shape[start:end])
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _shape_values_from_initializer(
|
|
109
|
+
graph: Graph | GraphContext,
|
|
110
|
+
name: str,
|
|
111
|
+
) -> list[int] | None:
|
|
112
|
+
initializer = _find_initializer(graph, name)
|
|
113
|
+
if initializer is None:
|
|
114
|
+
return None
|
|
115
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
116
|
+
raise UnsupportedOpError(
|
|
117
|
+
"Reshape expects int64 or int32 shape input, "
|
|
118
|
+
f"got {initializer.type.dtype.onnx_name}"
|
|
119
|
+
)
|
|
120
|
+
return [int(value) for value in initializer.data.reshape(-1)]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _shape_values_from_input(
|
|
124
|
+
graph: Graph | GraphContext,
|
|
125
|
+
name: str,
|
|
126
|
+
node: Node | None,
|
|
127
|
+
*,
|
|
128
|
+
_visited: set[str] | None = None,
|
|
129
|
+
) -> list[int] | None:
|
|
130
|
+
if _visited is None:
|
|
131
|
+
_visited = set()
|
|
132
|
+
if name in _visited:
|
|
133
|
+
return None
|
|
134
|
+
_visited.add(name)
|
|
30
135
|
try:
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
136
|
+
shape_values = _shape_values_from_initializer(graph, name)
|
|
137
|
+
if shape_values is not None:
|
|
138
|
+
return shape_values
|
|
139
|
+
source_node = _find_node_by_output(graph, name)
|
|
140
|
+
if source_node is None:
|
|
141
|
+
return None
|
|
142
|
+
if source_node.op_type == "Shape":
|
|
143
|
+
return _shape_values_from_shape_node(graph, source_node, node)
|
|
144
|
+
if source_node.op_type == "Concat":
|
|
145
|
+
axis = int(source_node.attrs.get("axis", 0))
|
|
146
|
+
if axis != 0:
|
|
147
|
+
raise UnsupportedOpError("Reshape shape concat must use axis 0")
|
|
148
|
+
values: list[int] = []
|
|
149
|
+
for input_name in source_node.inputs:
|
|
150
|
+
input_values = _shape_values_from_input(
|
|
151
|
+
graph,
|
|
152
|
+
input_name,
|
|
153
|
+
node,
|
|
154
|
+
_visited=_visited,
|
|
155
|
+
)
|
|
156
|
+
if input_values is None:
|
|
157
|
+
return None
|
|
158
|
+
values.extend(input_values)
|
|
159
|
+
return values
|
|
160
|
+
if source_node.op_type == "Cast":
|
|
161
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
162
|
+
raise UnsupportedOpError("Cast must have 1 input and 1 output")
|
|
163
|
+
return _shape_values_from_input(
|
|
164
|
+
graph,
|
|
165
|
+
source_node.inputs[0],
|
|
166
|
+
node,
|
|
167
|
+
_visited=_visited,
|
|
168
|
+
)
|
|
169
|
+
if source_node.op_type == "Unsqueeze":
|
|
170
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
171
|
+
raise UnsupportedOpError("Unsqueeze must have 1 input and 1 output")
|
|
172
|
+
return _shape_values_from_input(
|
|
173
|
+
graph,
|
|
174
|
+
source_node.inputs[0],
|
|
175
|
+
node,
|
|
176
|
+
_visited=_visited,
|
|
177
|
+
)
|
|
178
|
+
if source_node.op_type == "Identity":
|
|
179
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
180
|
+
raise UnsupportedOpError("Identity must have 1 input and 1 output")
|
|
181
|
+
return _shape_values_from_input(
|
|
182
|
+
graph,
|
|
183
|
+
source_node.inputs[0],
|
|
184
|
+
node,
|
|
185
|
+
_visited=_visited,
|
|
186
|
+
)
|
|
187
|
+
if source_node.op_type in {"Equal", "And", "Or", "Div", "Mod"}:
|
|
188
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
189
|
+
raise UnsupportedOpError(
|
|
190
|
+
f"{source_node.op_type} must have 2 inputs and 1 output"
|
|
191
|
+
)
|
|
192
|
+
left = _shape_values_from_input(
|
|
193
|
+
graph,
|
|
194
|
+
source_node.inputs[0],
|
|
195
|
+
node,
|
|
196
|
+
_visited=_visited,
|
|
197
|
+
)
|
|
198
|
+
right = _shape_values_from_input(
|
|
199
|
+
graph,
|
|
200
|
+
source_node.inputs[1],
|
|
201
|
+
node,
|
|
202
|
+
_visited=_visited,
|
|
203
|
+
)
|
|
204
|
+
if left is None or right is None:
|
|
205
|
+
return None
|
|
206
|
+
if len(left) == 1 and len(right) != 1:
|
|
207
|
+
left = left * len(right)
|
|
208
|
+
if len(right) == 1 and len(left) != 1:
|
|
209
|
+
right = right * len(left)
|
|
210
|
+
if len(left) != len(right):
|
|
211
|
+
return None
|
|
212
|
+
if source_node.op_type == "Equal":
|
|
213
|
+
return [1 if l == r else 0 for l, r in zip(left, right)]
|
|
214
|
+
if source_node.op_type == "And":
|
|
215
|
+
return [1 if (l and r) else 0 for l, r in zip(left, right)]
|
|
216
|
+
if source_node.op_type == "Or":
|
|
217
|
+
return [1 if (l or r) else 0 for l, r in zip(left, right)]
|
|
218
|
+
if source_node.op_type == "Div":
|
|
219
|
+
return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
|
|
220
|
+
if source_node.op_type == "Mod":
|
|
221
|
+
return [l % r if r != 0 else 0 for l, r in zip(left, right)]
|
|
222
|
+
if source_node.op_type == "Not":
|
|
223
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
224
|
+
raise UnsupportedOpError("Not must have 1 input and 1 output")
|
|
225
|
+
values = _shape_values_from_input(
|
|
226
|
+
graph,
|
|
227
|
+
source_node.inputs[0],
|
|
228
|
+
node,
|
|
229
|
+
_visited=_visited,
|
|
230
|
+
)
|
|
231
|
+
if values is None:
|
|
232
|
+
return None
|
|
233
|
+
return [0 if value else 1 for value in values]
|
|
234
|
+
if source_node.op_type == "Where":
|
|
235
|
+
if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
|
|
236
|
+
raise UnsupportedOpError("Where must have 3 inputs and 1 output")
|
|
237
|
+
condition = _shape_values_from_input(
|
|
238
|
+
graph,
|
|
239
|
+
source_node.inputs[0],
|
|
240
|
+
node,
|
|
241
|
+
_visited=_visited,
|
|
242
|
+
)
|
|
243
|
+
if condition is None:
|
|
244
|
+
return None
|
|
245
|
+
on_true = _shape_values_from_input(
|
|
246
|
+
graph,
|
|
247
|
+
source_node.inputs[1],
|
|
248
|
+
node,
|
|
249
|
+
_visited=_visited,
|
|
250
|
+
)
|
|
251
|
+
on_false = _shape_values_from_input(
|
|
252
|
+
graph,
|
|
253
|
+
source_node.inputs[2],
|
|
254
|
+
node,
|
|
255
|
+
_visited=_visited,
|
|
256
|
+
)
|
|
257
|
+
if on_true is None or on_false is None:
|
|
258
|
+
return None
|
|
259
|
+
if len(condition) == 1:
|
|
260
|
+
condition = condition * max(len(on_true), len(on_false))
|
|
261
|
+
if len(on_true) == 1 and len(condition) != 1:
|
|
262
|
+
on_true = on_true * len(condition)
|
|
263
|
+
if len(on_false) == 1 and len(condition) != 1:
|
|
264
|
+
on_false = on_false * len(condition)
|
|
265
|
+
if not (len(condition) == len(on_true) == len(on_false)):
|
|
266
|
+
return None
|
|
267
|
+
return [
|
|
268
|
+
t if cond else f
|
|
269
|
+
for cond, t, f in zip(condition, on_true, on_false)
|
|
270
|
+
]
|
|
271
|
+
return None
|
|
272
|
+
finally:
|
|
273
|
+
_visited.remove(name)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _broadcast_shapes(
|
|
277
|
+
left: tuple[int, ...],
|
|
278
|
+
right: tuple[int, ...],
|
|
279
|
+
) -> tuple[int, ...] | None:
|
|
280
|
+
result = []
|
|
281
|
+
left_rev = list(reversed(left))
|
|
282
|
+
right_rev = list(reversed(right))
|
|
283
|
+
for index in range(max(len(left_rev), len(right_rev))):
|
|
284
|
+
left_dim = left_rev[index] if index < len(left_rev) else 1
|
|
285
|
+
right_dim = right_rev[index] if index < len(right_rev) else 1
|
|
286
|
+
if left_dim == right_dim:
|
|
287
|
+
result.append(left_dim)
|
|
288
|
+
elif left_dim == 1:
|
|
289
|
+
result.append(right_dim)
|
|
290
|
+
elif right_dim == 1:
|
|
291
|
+
result.append(left_dim)
|
|
292
|
+
else:
|
|
293
|
+
return None
|
|
294
|
+
return tuple(reversed(result))
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _resolve_value_shape(
|
|
298
|
+
graph: Graph | GraphContext,
|
|
299
|
+
name: str,
|
|
300
|
+
node: Node | None,
|
|
301
|
+
*,
|
|
302
|
+
_visited: set[str] | None = None,
|
|
303
|
+
) -> tuple[int, ...] | None:
|
|
304
|
+
if _visited is None:
|
|
305
|
+
_visited = set()
|
|
306
|
+
if name in _visited:
|
|
307
|
+
return None
|
|
308
|
+
_visited.add(name)
|
|
309
|
+
try:
|
|
310
|
+
value = graph.find_value(name)
|
|
311
|
+
shape = value.type.shape
|
|
312
|
+
if not any(value.type.dim_params):
|
|
313
|
+
return shape
|
|
314
|
+
source_node = _find_node_by_output(graph, name)
|
|
315
|
+
if source_node is None:
|
|
316
|
+
return None
|
|
317
|
+
if source_node.op_type == "Expand":
|
|
318
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
319
|
+
raise UnsupportedOpError("Expand must have 2 inputs and 1 output")
|
|
320
|
+
shape_values = _shape_values_from_input(
|
|
321
|
+
graph, source_node.inputs[1], node
|
|
322
|
+
)
|
|
323
|
+
if shape_values is not None and all(dim >= 0 for dim in shape_values):
|
|
324
|
+
return tuple(shape_values)
|
|
325
|
+
return None
|
|
326
|
+
if source_node.op_type == "Reshape":
|
|
327
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
328
|
+
raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
|
|
329
|
+
shape_values = _shape_values_from_input(
|
|
330
|
+
graph, source_node.inputs[1], node
|
|
331
|
+
)
|
|
332
|
+
if shape_values is None:
|
|
333
|
+
return None
|
|
334
|
+
allowzero = int(source_node.attrs.get("allowzero", 0))
|
|
335
|
+
input_shape = _resolve_value_shape(
|
|
336
|
+
graph,
|
|
337
|
+
source_node.inputs[0],
|
|
338
|
+
node,
|
|
339
|
+
_visited=_visited,
|
|
340
|
+
)
|
|
341
|
+
if input_shape is None:
|
|
342
|
+
return None
|
|
343
|
+
output_dims: list[int] = []
|
|
344
|
+
unknown_index: int | None = None
|
|
345
|
+
known_product = 1
|
|
346
|
+
contains_zero = False
|
|
347
|
+
for index, dim in enumerate(shape_values):
|
|
348
|
+
if dim == -1:
|
|
349
|
+
if unknown_index is not None:
|
|
350
|
+
return None
|
|
351
|
+
unknown_index = len(output_dims)
|
|
352
|
+
output_dims.append(-1)
|
|
353
|
+
else:
|
|
354
|
+
if dim == 0:
|
|
355
|
+
contains_zero = True
|
|
356
|
+
if allowzero == 0:
|
|
357
|
+
if index >= len(input_shape):
|
|
358
|
+
return None
|
|
359
|
+
dim = input_shape[index]
|
|
360
|
+
if dim < 0:
|
|
361
|
+
return None
|
|
362
|
+
output_dims.append(dim)
|
|
363
|
+
known_product *= dim
|
|
364
|
+
if allowzero == 1 and contains_zero and unknown_index is not None:
|
|
365
|
+
return None
|
|
366
|
+
input_product = shape_product(input_shape)
|
|
367
|
+
if unknown_index is not None:
|
|
368
|
+
if known_product == 0:
|
|
369
|
+
if input_product != 0:
|
|
370
|
+
return None
|
|
371
|
+
output_dims[unknown_index] = 0
|
|
372
|
+
else:
|
|
373
|
+
if input_product % known_product != 0:
|
|
374
|
+
return None
|
|
375
|
+
output_dims[unknown_index] = input_product // known_product
|
|
376
|
+
return tuple(output_dims)
|
|
377
|
+
if source_node.op_type in {
|
|
378
|
+
"Add",
|
|
379
|
+
"Sub",
|
|
380
|
+
"Mul",
|
|
381
|
+
"Div",
|
|
382
|
+
"Pow",
|
|
383
|
+
"Mod",
|
|
384
|
+
"And",
|
|
385
|
+
"Or",
|
|
386
|
+
"Xor",
|
|
387
|
+
"Equal",
|
|
388
|
+
"Greater",
|
|
389
|
+
"Less",
|
|
390
|
+
"GreaterOrEqual",
|
|
391
|
+
"LessOrEqual",
|
|
392
|
+
}:
|
|
393
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
394
|
+
raise UnsupportedOpError(
|
|
395
|
+
f"{source_node.op_type} must have 2 inputs and 1 output"
|
|
396
|
+
)
|
|
397
|
+
left = _resolve_value_shape(
|
|
398
|
+
graph,
|
|
399
|
+
source_node.inputs[0],
|
|
400
|
+
node,
|
|
401
|
+
_visited=_visited,
|
|
402
|
+
)
|
|
403
|
+
right = _resolve_value_shape(
|
|
404
|
+
graph,
|
|
405
|
+
source_node.inputs[1],
|
|
406
|
+
node,
|
|
407
|
+
_visited=_visited,
|
|
408
|
+
)
|
|
409
|
+
if left is None or right is None:
|
|
410
|
+
return None
|
|
411
|
+
return _broadcast_shapes(left, right)
|
|
412
|
+
if source_node.op_type == "Where":
|
|
413
|
+
if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
|
|
414
|
+
raise UnsupportedOpError("Where must have 3 inputs and 1 output")
|
|
415
|
+
on_true = _resolve_value_shape(
|
|
416
|
+
graph,
|
|
417
|
+
source_node.inputs[1],
|
|
418
|
+
node,
|
|
419
|
+
_visited=_visited,
|
|
420
|
+
)
|
|
421
|
+
on_false = _resolve_value_shape(
|
|
422
|
+
graph,
|
|
423
|
+
source_node.inputs[2],
|
|
424
|
+
node,
|
|
425
|
+
_visited=_visited,
|
|
426
|
+
)
|
|
427
|
+
if on_true is None or on_false is None:
|
|
428
|
+
return None
|
|
429
|
+
return _broadcast_shapes(on_true, on_false)
|
|
430
|
+
return None
|
|
431
|
+
finally:
|
|
432
|
+
_visited.remove(name)
|
|
38
433
|
|
|
39
434
|
|
|
40
|
-
def node_dtype(graph: Graph, node: Node, *names: str) -> ScalarType:
|
|
435
|
+
def node_dtype(graph: Graph | GraphContext, node: Node, *names: str) -> ScalarType:
|
|
41
436
|
filtered = [name for name in names if name]
|
|
42
437
|
if not filtered:
|
|
43
438
|
raise UnsupportedOpError(
|
emx_onnx_cgen/lowering/concat.py
CHANGED