emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +50 -23
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +30 -387
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +36 -18
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +1 -1
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +1 -1
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +1 -1
- emx_onnx_cgen/lowering/gather_nd.py +1 -1
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +1 -1
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +1 -1
- emx_onnx_cgen/lowering/identity.py +1 -1
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +1 -1
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +1 -1
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +1 -1
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +1 -1
- emx_onnx_cgen/lowering/one_hot.py +1 -1
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +1 -1
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +1 -1
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +1 -1
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +1 -1
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +1 -1
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +1 -1
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +25 -7
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +1 -1
- emx_onnx_cgen/lowering/unsqueeze.py +1 -1
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/runtime/evaluator.py +325 -1
- emx_onnx_cgen/verification.py +9 -39
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
- emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +11 -0
- shared/ulp.py +17 -0
- emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
|
|
1
5
|
from .registry import get_lowering, register_lowering
|
|
2
6
|
|
|
3
|
-
|
|
7
|
+
_LOWERING_MODULES = [
|
|
8
|
+
"adagrad",
|
|
9
|
+
"arg_reduce",
|
|
10
|
+
"attention",
|
|
11
|
+
"average_pool",
|
|
12
|
+
"batch_normalization",
|
|
13
|
+
"cast",
|
|
14
|
+
"concat",
|
|
15
|
+
"constant_of_shape",
|
|
16
|
+
"conv",
|
|
17
|
+
"conv_transpose",
|
|
18
|
+
"cumsum",
|
|
19
|
+
"depth_space",
|
|
20
|
+
"dropout",
|
|
21
|
+
"einsum",
|
|
22
|
+
"elementwise",
|
|
23
|
+
"expand",
|
|
24
|
+
"eye_like",
|
|
25
|
+
"flatten",
|
|
26
|
+
"gather",
|
|
27
|
+
"gather_elements",
|
|
28
|
+
"gather_nd",
|
|
29
|
+
"gemm",
|
|
30
|
+
"global_max_pool",
|
|
31
|
+
"grid_sample",
|
|
32
|
+
"group_normalization",
|
|
33
|
+
"hardmax",
|
|
34
|
+
"identity",
|
|
35
|
+
"instance_normalization",
|
|
36
|
+
"layer_normalization",
|
|
37
|
+
"logsoftmax",
|
|
38
|
+
"lp_normalization",
|
|
39
|
+
"lp_pool",
|
|
40
|
+
"lrn",
|
|
41
|
+
"lstm",
|
|
42
|
+
"matmul",
|
|
43
|
+
"maxpool",
|
|
44
|
+
"mean_variance_normalization",
|
|
45
|
+
"negative_log_likelihood_loss",
|
|
46
|
+
"non_max_suppression",
|
|
47
|
+
"nonzero",
|
|
48
|
+
"one_hot",
|
|
49
|
+
"pad",
|
|
50
|
+
"qlinear_matmul",
|
|
51
|
+
"quantize_linear",
|
|
52
|
+
"range",
|
|
53
|
+
"reduce",
|
|
54
|
+
"reshape",
|
|
55
|
+
"resize",
|
|
56
|
+
"rms_normalization",
|
|
57
|
+
"rotary_embedding",
|
|
58
|
+
"scatter_nd",
|
|
59
|
+
"shape",
|
|
60
|
+
"size",
|
|
61
|
+
"slice",
|
|
62
|
+
"softmax",
|
|
63
|
+
"softmax_cross_entropy_loss",
|
|
64
|
+
"split",
|
|
65
|
+
"squeeze",
|
|
66
|
+
"tensor_scatter",
|
|
67
|
+
"tile",
|
|
68
|
+
"topk",
|
|
69
|
+
"transpose",
|
|
70
|
+
"trilu",
|
|
71
|
+
"unsqueeze",
|
|
72
|
+
"variadic",
|
|
73
|
+
"where",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def load_lowering_registry() -> None:
|
|
78
|
+
for module_name in _LOWERING_MODULES:
|
|
79
|
+
importlib.import_module(f"{__name__}.{module_name}")
|
|
80
|
+
|
|
81
|
+
__all__ = ["get_lowering", "register_lowering", "load_lowering_registry"]
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import AdagradOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
|
|
13
|
+
return shape == () or shape == (1,)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@register_lowering("Adagrad")
|
|
17
|
+
def lower_adagrad(graph: Graph, node: Node) -> AdagradOp:
|
|
18
|
+
if len(node.inputs) < 5:
|
|
19
|
+
raise UnsupportedOpError("Adagrad must have at least 5 inputs")
|
|
20
|
+
if len(node.outputs) < 2:
|
|
21
|
+
raise UnsupportedOpError("Adagrad must have at least 2 outputs")
|
|
22
|
+
if (len(node.inputs) - 2) % 3 != 0:
|
|
23
|
+
raise UnsupportedOpError(
|
|
24
|
+
"Adagrad inputs must be R, T, Xs, Gs, Hs with matching counts"
|
|
25
|
+
)
|
|
26
|
+
tensor_count = (len(node.inputs) - 2) // 3
|
|
27
|
+
if len(node.outputs) != tensor_count * 2:
|
|
28
|
+
raise UnsupportedOpError(
|
|
29
|
+
"Adagrad outputs must be X_news followed by H_news"
|
|
30
|
+
)
|
|
31
|
+
rate_name = node.inputs[0]
|
|
32
|
+
timestep_name = node.inputs[1]
|
|
33
|
+
rate_shape = value_shape(graph, rate_name, node)
|
|
34
|
+
timestep_shape = value_shape(graph, timestep_name, node)
|
|
35
|
+
if not _is_scalar_shape(rate_shape):
|
|
36
|
+
raise UnsupportedOpError("Adagrad R input must be a scalar")
|
|
37
|
+
if not _is_scalar_shape(timestep_shape):
|
|
38
|
+
raise UnsupportedOpError("Adagrad T input must be a scalar")
|
|
39
|
+
rate_dtype = value_dtype(graph, rate_name, node)
|
|
40
|
+
if rate_dtype not in {ScalarType.F32, ScalarType.F64}:
|
|
41
|
+
raise UnsupportedOpError(
|
|
42
|
+
"Adagrad R input must be float or double"
|
|
43
|
+
)
|
|
44
|
+
timestep_dtype = value_dtype(graph, timestep_name, node)
|
|
45
|
+
if timestep_dtype != ScalarType.I64:
|
|
46
|
+
raise UnsupportedOpError("Adagrad T input must be int64")
|
|
47
|
+
|
|
48
|
+
inputs = node.inputs[2 : 2 + tensor_count]
|
|
49
|
+
gradients = node.inputs[2 + tensor_count : 2 + tensor_count * 2]
|
|
50
|
+
accumulators = node.inputs[2 + tensor_count * 2 : 2 + tensor_count * 3]
|
|
51
|
+
outputs = node.outputs[:tensor_count]
|
|
52
|
+
accumulator_outputs = node.outputs[tensor_count:]
|
|
53
|
+
if not inputs or not gradients or not accumulators:
|
|
54
|
+
raise UnsupportedOpError("Adagrad requires X, G, H inputs")
|
|
55
|
+
dtype = value_dtype(graph, inputs[0], node)
|
|
56
|
+
if dtype not in {ScalarType.F32, ScalarType.F64}:
|
|
57
|
+
raise UnsupportedOpError("Adagrad supports float and double tensors only")
|
|
58
|
+
if rate_dtype != dtype:
|
|
59
|
+
raise UnsupportedOpError(
|
|
60
|
+
"Adagrad R input dtype must match tensor dtype"
|
|
61
|
+
)
|
|
62
|
+
input_shapes: list[tuple[int, ...]] = []
|
|
63
|
+
output_shapes: list[tuple[int, ...]] = []
|
|
64
|
+
for index, (x_name, g_name, h_name, out_name, h_out_name) in enumerate(
|
|
65
|
+
zip(inputs, gradients, accumulators, outputs, accumulator_outputs)
|
|
66
|
+
):
|
|
67
|
+
x_dtype = value_dtype(graph, x_name, node)
|
|
68
|
+
g_dtype = value_dtype(graph, g_name, node)
|
|
69
|
+
h_dtype = value_dtype(graph, h_name, node)
|
|
70
|
+
out_dtype = value_dtype(graph, out_name, node)
|
|
71
|
+
h_out_dtype = value_dtype(graph, h_out_name, node)
|
|
72
|
+
if {x_dtype, g_dtype, h_dtype, out_dtype, h_out_dtype} != {dtype}:
|
|
73
|
+
raise UnsupportedOpError(
|
|
74
|
+
"Adagrad inputs and outputs must share the same dtype"
|
|
75
|
+
)
|
|
76
|
+
x_shape = value_shape(graph, x_name, node)
|
|
77
|
+
g_shape = value_shape(graph, g_name, node)
|
|
78
|
+
h_shape = value_shape(graph, h_name, node)
|
|
79
|
+
out_shape = value_shape(graph, out_name, node)
|
|
80
|
+
h_out_shape = value_shape(graph, h_out_name, node)
|
|
81
|
+
if x_shape != g_shape or x_shape != h_shape:
|
|
82
|
+
raise ShapeInferenceError(
|
|
83
|
+
f"Adagrad inputs X/G/H shapes must match for tensor {index}"
|
|
84
|
+
)
|
|
85
|
+
if out_shape != x_shape or h_out_shape != x_shape:
|
|
86
|
+
raise ShapeInferenceError(
|
|
87
|
+
f"Adagrad outputs must match X shape for tensor {index}"
|
|
88
|
+
)
|
|
89
|
+
input_shapes.append(x_shape)
|
|
90
|
+
output_shapes.append(out_shape)
|
|
91
|
+
|
|
92
|
+
norm_coefficient = float(node.attrs.get("norm_coefficient", 0.0))
|
|
93
|
+
epsilon = float(node.attrs.get("epsilon", 0.0))
|
|
94
|
+
decay_factor = float(node.attrs.get("decay_factor", 0.0))
|
|
95
|
+
|
|
96
|
+
return AdagradOp(
|
|
97
|
+
rate=rate_name,
|
|
98
|
+
timestep=timestep_name,
|
|
99
|
+
inputs=tuple(inputs),
|
|
100
|
+
gradients=tuple(gradients),
|
|
101
|
+
accumulators=tuple(accumulators),
|
|
102
|
+
outputs=tuple(outputs),
|
|
103
|
+
accumulator_outputs=tuple(accumulator_outputs),
|
|
104
|
+
rate_shape=rate_shape,
|
|
105
|
+
timestep_shape=timestep_shape,
|
|
106
|
+
tensor_shapes=tuple(input_shapes),
|
|
107
|
+
output_shapes=tuple(output_shapes),
|
|
108
|
+
dtype=dtype,
|
|
109
|
+
rate_dtype=rate_dtype,
|
|
110
|
+
timestep_dtype=timestep_dtype,
|
|
111
|
+
norm_coefficient=norm_coefficient,
|
|
112
|
+
epsilon=epsilon,
|
|
113
|
+
decay_factor=decay_factor,
|
|
114
|
+
)
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import ArgReduceOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import shape_product, value_dtype, value_shape
|
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|
|
5
5
|
|
|
6
6
|
from shared.scalar_types import ScalarType
|
|
7
7
|
|
|
8
|
-
from ..
|
|
8
|
+
from ..ir.ops import AttentionOp
|
|
9
9
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
10
|
from ..ir.model import Graph, Node
|
|
11
11
|
from .common import node_dtype as _node_dtype
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import AveragePoolOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .registry import register_lowering
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import BatchNormOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .registry import register_lowering
|
emx_onnx_cgen/lowering/cast.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import onnx
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import CastOp
|
|
6
6
|
from ..dtypes import scalar_type_from_onnx
|
|
7
7
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
8
|
from ..ir.model import Graph, Node
|
emx_onnx_cgen/lowering/common.py
CHANGED
|
@@ -5,6 +5,7 @@ from collections.abc import Sequence
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
7
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
+
from ..ir.context import GraphContext
|
|
8
9
|
from ..ir.model import Graph, Initializer, Node
|
|
9
10
|
|
|
10
11
|
|
|
@@ -14,7 +15,9 @@ def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
|
|
|
14
15
|
return dtype
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
def onnx_opset_version(graph: Graph, domain: str = "") -> int | None:
|
|
18
|
+
def onnx_opset_version(graph: Graph | GraphContext, domain: str = "") -> int | None:
|
|
19
|
+
if isinstance(graph, GraphContext):
|
|
20
|
+
return graph.opset_version(domain)
|
|
18
21
|
if domain in {"", "ai.onnx"}:
|
|
19
22
|
domains = {"", "ai.onnx"}
|
|
20
23
|
else:
|
|
@@ -25,7 +28,11 @@ def onnx_opset_version(graph: Graph, domain: str = "") -> int | None:
|
|
|
25
28
|
return None
|
|
26
29
|
|
|
27
30
|
|
|
28
|
-
def value_dtype(
|
|
31
|
+
def value_dtype(
|
|
32
|
+
graph: Graph | GraphContext, name: str, node: Node | None = None
|
|
33
|
+
) -> ScalarType:
|
|
34
|
+
if isinstance(graph, GraphContext):
|
|
35
|
+
return graph.dtype(name, node)
|
|
29
36
|
try:
|
|
30
37
|
value = graph.find_value(name)
|
|
31
38
|
except KeyError as exc:
|
|
@@ -37,31 +44,42 @@ def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType
|
|
|
37
44
|
return ensure_supported_dtype(value.type.dtype)
|
|
38
45
|
|
|
39
46
|
|
|
40
|
-
def value_shape(
|
|
41
|
-
|
|
47
|
+
def value_shape(
|
|
48
|
+
graph: Graph | GraphContext, name: str, node: Node | None = None
|
|
49
|
+
) -> tuple[int, ...]:
|
|
50
|
+
if isinstance(graph, GraphContext):
|
|
51
|
+
shape = graph.shape(name, node)
|
|
42
52
|
value = graph.find_value(name)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
53
|
+
else:
|
|
54
|
+
try:
|
|
55
|
+
value = graph.find_value(name)
|
|
56
|
+
except KeyError as exc:
|
|
57
|
+
op_type = node.op_type if node is not None else "unknown"
|
|
58
|
+
raise ShapeInferenceError(
|
|
59
|
+
f"Missing shape for value '{name}' in op {op_type}. "
|
|
60
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
61
|
+
) from exc
|
|
62
|
+
shape = value.type.shape
|
|
49
63
|
if any(value.type.dim_params):
|
|
50
64
|
resolved = _resolve_value_shape(graph, name, node)
|
|
51
65
|
if resolved is not None:
|
|
52
66
|
return resolved
|
|
53
67
|
return value.type.shape
|
|
54
|
-
return
|
|
68
|
+
return shape
|
|
55
69
|
|
|
56
70
|
|
|
57
|
-
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
71
|
+
def _find_initializer(graph: Graph | GraphContext, name: str) -> Initializer | None:
|
|
72
|
+
if isinstance(graph, GraphContext):
|
|
73
|
+
return graph.initializer(name)
|
|
58
74
|
for initializer in graph.initializers:
|
|
59
75
|
if initializer.name == name:
|
|
60
76
|
return initializer
|
|
61
77
|
return None
|
|
62
78
|
|
|
63
79
|
|
|
64
|
-
def _find_node_by_output(graph: Graph, name: str) -> Node | None:
|
|
80
|
+
def _find_node_by_output(graph: Graph | GraphContext, name: str) -> Node | None:
|
|
81
|
+
if isinstance(graph, GraphContext):
|
|
82
|
+
return graph.producer(name)
|
|
65
83
|
for node in graph.nodes:
|
|
66
84
|
if name in node.outputs:
|
|
67
85
|
return node
|
|
@@ -69,7 +87,7 @@ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
|
|
|
69
87
|
|
|
70
88
|
|
|
71
89
|
def _shape_values_from_shape_node(
|
|
72
|
-
graph: Graph, shape_node: Node, node: Node | None
|
|
90
|
+
graph: Graph | GraphContext, shape_node: Node, node: Node | None
|
|
73
91
|
) -> list[int]:
|
|
74
92
|
if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
|
|
75
93
|
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
@@ -88,7 +106,7 @@ def _shape_values_from_shape_node(
|
|
|
88
106
|
|
|
89
107
|
|
|
90
108
|
def _shape_values_from_initializer(
|
|
91
|
-
graph: Graph,
|
|
109
|
+
graph: Graph | GraphContext,
|
|
92
110
|
name: str,
|
|
93
111
|
) -> list[int] | None:
|
|
94
112
|
initializer = _find_initializer(graph, name)
|
|
@@ -103,7 +121,7 @@ def _shape_values_from_initializer(
|
|
|
103
121
|
|
|
104
122
|
|
|
105
123
|
def _shape_values_from_input(
|
|
106
|
-
graph: Graph,
|
|
124
|
+
graph: Graph | GraphContext,
|
|
107
125
|
name: str,
|
|
108
126
|
node: Node | None,
|
|
109
127
|
*,
|
|
@@ -277,7 +295,7 @@ def _broadcast_shapes(
|
|
|
277
295
|
|
|
278
296
|
|
|
279
297
|
def _resolve_value_shape(
|
|
280
|
-
graph: Graph,
|
|
298
|
+
graph: Graph | GraphContext,
|
|
281
299
|
name: str,
|
|
282
300
|
node: Node | None,
|
|
283
301
|
*,
|
|
@@ -414,7 +432,7 @@ def _resolve_value_shape(
|
|
|
414
432
|
_visited.remove(name)
|
|
415
433
|
|
|
416
434
|
|
|
417
|
-
def node_dtype(graph: Graph, node: Node, *names: str) -> ScalarType:
|
|
435
|
+
def node_dtype(graph: Graph | GraphContext, node: Node, *names: str) -> ScalarType:
|
|
418
436
|
filtered = [name for name in names if name]
|
|
419
437
|
if not filtered:
|
|
420
438
|
raise UnsupportedOpError(
|
emx_onnx_cgen/lowering/concat.py
CHANGED
|
@@ -4,7 +4,7 @@ from onnx import numpy_helper
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import ConstantOfShapeOp
|
|
8
8
|
from ..dtypes import scalar_type_from_onnx
|
|
9
9
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
10
|
from ..ir.model import Graph, Node
|
emx_onnx_cgen/lowering/conv.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import math
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
|
|
6
|
-
from ..
|
|
6
|
+
from ..ir.ops import ConvOp
|
|
7
7
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
8
|
from ..ir.model import Graph, Node
|
|
9
9
|
from .common import node_dtype as _node_dtype
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import math
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
|
|
6
|
-
from ..
|
|
6
|
+
from ..ir.ops import ConvTransposeOp
|
|
7
7
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
8
|
from ..ir.model import Graph, Node
|
|
9
9
|
from .common import node_dtype as _node_dtype
|
emx_onnx_cgen/lowering/cumsum.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import CumSumOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import value_dtype, value_shape
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import DepthToSpaceOp, SpaceToDepthOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..lowering.common import value_dtype, value_shape
|
emx_onnx_cgen/lowering/einsum.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import EinsumKind, EinsumOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from .common import node_dtype as _node_dtype
|
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from shared.scalar_functions import ScalarFunction
|
|
3
|
+
from shared.scalar_functions import ScalarFunction, ScalarFunctionError
|
|
4
4
|
from shared.scalar_types import ScalarType
|
|
5
5
|
|
|
6
|
-
from ..
|
|
6
|
+
from ..ir.ops import BinaryOp, ClipOp, UnaryOp
|
|
7
7
|
from ..errors import UnsupportedOpError
|
|
8
|
+
from ..ir.context import GraphContext
|
|
8
9
|
from ..ir.model import Graph, Node
|
|
9
10
|
from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
|
|
10
|
-
from ..lowering.registry import register_lowering
|
|
11
|
+
from ..lowering.registry import register_lowering, register_lowering_if_missing
|
|
12
|
+
from ..ops import (
|
|
13
|
+
BINARY_OP_TYPES,
|
|
14
|
+
COMPARE_FUNCTIONS,
|
|
15
|
+
UNARY_OP_TYPES,
|
|
16
|
+
binary_op_symbol,
|
|
17
|
+
unary_op_symbol,
|
|
18
|
+
validate_unary_attrs,
|
|
19
|
+
)
|
|
20
|
+
from ..lowering.variadic import VARIADIC_OP_FUNCTIONS
|
|
11
21
|
|
|
12
22
|
|
|
13
23
|
@register_lowering("Clip")
|
|
@@ -120,6 +130,138 @@ def lower_shrink(graph: Graph, node: Node) -> UnaryOp:
|
|
|
120
130
|
)
|
|
121
131
|
|
|
122
132
|
|
|
133
|
+
def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | UnaryOp:
|
|
134
|
+
if node.op_type == "BitShift":
|
|
135
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
136
|
+
raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
|
|
137
|
+
direction_attr = node.attrs.get("direction", "LEFT")
|
|
138
|
+
if isinstance(direction_attr, bytes):
|
|
139
|
+
direction = direction_attr.decode()
|
|
140
|
+
else:
|
|
141
|
+
direction = str(direction_attr)
|
|
142
|
+
if direction not in {"LEFT", "RIGHT"}:
|
|
143
|
+
raise UnsupportedOpError(
|
|
144
|
+
"BitShift direction must be LEFT or RIGHT"
|
|
145
|
+
)
|
|
146
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
147
|
+
if not op_dtype.is_integer:
|
|
148
|
+
raise UnsupportedOpError("BitShift expects integer inputs")
|
|
149
|
+
function = (
|
|
150
|
+
ScalarFunction.BITWISE_LEFT_SHIFT
|
|
151
|
+
if direction == "LEFT"
|
|
152
|
+
else ScalarFunction.BITWISE_RIGHT_SHIFT
|
|
153
|
+
)
|
|
154
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
155
|
+
if op_spec is None:
|
|
156
|
+
raise UnsupportedOpError("Unsupported op BitShift")
|
|
157
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
158
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
159
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
160
|
+
return BinaryOp(
|
|
161
|
+
input0=node.inputs[0],
|
|
162
|
+
input1=node.inputs[1],
|
|
163
|
+
output=node.outputs[0],
|
|
164
|
+
function=function,
|
|
165
|
+
operator_kind=op_spec.kind,
|
|
166
|
+
input0_shape=input0_shape,
|
|
167
|
+
input1_shape=input1_shape,
|
|
168
|
+
shape=output_shape,
|
|
169
|
+
dtype=op_dtype,
|
|
170
|
+
input_dtype=op_dtype,
|
|
171
|
+
)
|
|
172
|
+
if node.op_type == "Mod":
|
|
173
|
+
fmod = int(node.attrs.get("fmod", 0))
|
|
174
|
+
if fmod not in {0, 1}:
|
|
175
|
+
raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
|
|
176
|
+
function = (
|
|
177
|
+
ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
try:
|
|
181
|
+
function = ScalarFunction.from_onnx_op(node.op_type)
|
|
182
|
+
except ScalarFunctionError as exc:
|
|
183
|
+
raise UnsupportedOpError(
|
|
184
|
+
f"Unsupported op {node.op_type}"
|
|
185
|
+
) from exc
|
|
186
|
+
validate_unary_attrs(node.op_type, node.attrs)
|
|
187
|
+
if function in COMPARE_FUNCTIONS:
|
|
188
|
+
input_dtype = node_dtype(graph, node, *node.inputs)
|
|
189
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
190
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
|
|
191
|
+
if op_spec is None:
|
|
192
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
193
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
194
|
+
raise UnsupportedOpError(
|
|
195
|
+
f"{node.op_type} must have 2 inputs and 1 output"
|
|
196
|
+
)
|
|
197
|
+
if output_dtype != ScalarType.BOOL:
|
|
198
|
+
raise UnsupportedOpError(
|
|
199
|
+
f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
|
|
200
|
+
)
|
|
201
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
202
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
203
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
204
|
+
return BinaryOp(
|
|
205
|
+
input0=node.inputs[0],
|
|
206
|
+
input1=node.inputs[1],
|
|
207
|
+
output=node.outputs[0],
|
|
208
|
+
function=function,
|
|
209
|
+
operator_kind=op_spec.kind,
|
|
210
|
+
input0_shape=input0_shape,
|
|
211
|
+
input1_shape=input1_shape,
|
|
212
|
+
shape=output_shape,
|
|
213
|
+
dtype=output_dtype,
|
|
214
|
+
input_dtype=input_dtype,
|
|
215
|
+
)
|
|
216
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
217
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
218
|
+
unary_symbol = unary_op_symbol(function, dtype=op_dtype)
|
|
219
|
+
if op_spec is None and unary_symbol is None:
|
|
220
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
221
|
+
if op_spec is not None:
|
|
222
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
223
|
+
raise UnsupportedOpError(
|
|
224
|
+
f"{node.op_type} must have 2 inputs and 1 output"
|
|
225
|
+
)
|
|
226
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
227
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
228
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
229
|
+
return BinaryOp(
|
|
230
|
+
input0=node.inputs[0],
|
|
231
|
+
input1=node.inputs[1],
|
|
232
|
+
output=node.outputs[0],
|
|
233
|
+
function=function,
|
|
234
|
+
operator_kind=op_spec.kind,
|
|
235
|
+
input0_shape=input0_shape,
|
|
236
|
+
input1_shape=input1_shape,
|
|
237
|
+
shape=output_shape,
|
|
238
|
+
dtype=op_dtype,
|
|
239
|
+
input_dtype=op_dtype,
|
|
240
|
+
)
|
|
241
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
242
|
+
raise UnsupportedOpError(
|
|
243
|
+
f"{node.op_type} must have 1 input and 1 output"
|
|
244
|
+
)
|
|
245
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
246
|
+
return UnaryOp(
|
|
247
|
+
input0=node.inputs[0],
|
|
248
|
+
output=node.outputs[0],
|
|
249
|
+
function=function,
|
|
250
|
+
shape=output_shape,
|
|
251
|
+
dtype=op_dtype,
|
|
252
|
+
input_dtype=op_dtype,
|
|
253
|
+
params=(),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
_DEFAULT_ELEMENTWISE_TYPES = (
|
|
258
|
+
BINARY_OP_TYPES.union(UNARY_OP_TYPES) - set(VARIADIC_OP_FUNCTIONS.keys())
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
for _op_type in _DEFAULT_ELEMENTWISE_TYPES:
|
|
262
|
+
register_lowering_if_missing(_op_type)(_lower_binary_unary)
|
|
263
|
+
|
|
264
|
+
|
|
123
265
|
@register_lowering("IsInf")
|
|
124
266
|
def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
|
|
125
267
|
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
@@ -130,6 +272,12 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
|
|
|
130
272
|
raise UnsupportedOpError("IsInf only supports floating-point inputs")
|
|
131
273
|
if output_dtype != ScalarType.BOOL:
|
|
132
274
|
raise UnsupportedOpError("IsInf output must be bool")
|
|
275
|
+
detect_negative = int(node.attrs.get("detect_negative", 1))
|
|
276
|
+
detect_positive = int(node.attrs.get("detect_positive", 1))
|
|
277
|
+
if detect_negative not in {0, 1} or detect_positive not in {0, 1}:
|
|
278
|
+
raise UnsupportedOpError(
|
|
279
|
+
"IsInf detect_negative and detect_positive must be 0 or 1"
|
|
280
|
+
)
|
|
133
281
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
134
282
|
return UnaryOp(
|
|
135
283
|
input0=node.inputs[0],
|
|
@@ -138,7 +286,7 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
|
|
|
138
286
|
shape=output_shape,
|
|
139
287
|
dtype=output_dtype,
|
|
140
288
|
input_dtype=input_dtype,
|
|
141
|
-
params=(),
|
|
289
|
+
params=(float(detect_negative), float(detect_positive)),
|
|
142
290
|
)
|
|
143
291
|
|
|
144
292
|
|
emx_onnx_cgen/lowering/expand.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import ExpandOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import value_dtype, value_shape
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import ReshapeOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from .common import shape_product, value_dtype, value_shape
|
emx_onnx_cgen/lowering/gather.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import GatherOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from ..validation import normalize_axis
|