emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/lowering/matmul.py
CHANGED
|
@@ -1,119 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
|
|
5
3
|
from ..ir.ops import MatMulOp
|
|
6
|
-
from ..errors import
|
|
4
|
+
from ..errors import UnsupportedOpError
|
|
7
5
|
from ..ir.model import Graph, Node
|
|
8
|
-
from .common import node_dtype as _node_dtype
|
|
9
|
-
from .common import value_shape as _value_shape
|
|
10
6
|
from .registry import register_lowering
|
|
11
7
|
|
|
12
8
|
|
|
13
|
-
@dataclass(frozen=True)
|
|
14
|
-
class MatMulSpec:
|
|
15
|
-
input0_shape: tuple[int, ...]
|
|
16
|
-
input1_shape: tuple[int, ...]
|
|
17
|
-
output_shape: tuple[int, ...]
|
|
18
|
-
batch_shape: tuple[int, ...]
|
|
19
|
-
input0_batch_shape: tuple[int, ...]
|
|
20
|
-
input1_batch_shape: tuple[int, ...]
|
|
21
|
-
m: int
|
|
22
|
-
n: int
|
|
23
|
-
k: int
|
|
24
|
-
left_vector: bool
|
|
25
|
-
right_vector: bool
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def resolve_matmul_spec(graph: Graph, node: Node) -> MatMulSpec:
|
|
29
|
-
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
30
|
-
raise UnsupportedOpError("MatMul must have 2 inputs and 1 output")
|
|
31
|
-
input0_shape = _value_shape(graph, node.inputs[0], node)
|
|
32
|
-
input1_shape = _value_shape(graph, node.inputs[1], node)
|
|
33
|
-
if len(input0_shape) < 1 or len(input1_shape) < 1:
|
|
34
|
-
raise UnsupportedOpError(
|
|
35
|
-
"MatMul inputs must be at least 1D, "
|
|
36
|
-
f"got {input0_shape} x {input1_shape}"
|
|
37
|
-
)
|
|
38
|
-
left_vector = len(input0_shape) == 1
|
|
39
|
-
right_vector = len(input1_shape) == 1
|
|
40
|
-
input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
|
|
41
|
-
input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
|
|
42
|
-
m, k_left = input0_effective[-2], input0_effective[-1]
|
|
43
|
-
k_right, n = input1_effective[-2], input1_effective[-1]
|
|
44
|
-
if k_left != k_right:
|
|
45
|
-
raise ShapeInferenceError(
|
|
46
|
-
f"MatMul inner dimensions must match, got {k_left} and {k_right}"
|
|
47
|
-
)
|
|
48
|
-
batch_shape, input0_batch_shape, input1_batch_shape = (
|
|
49
|
-
_broadcast_batch_shapes(
|
|
50
|
-
input0_effective[:-2], input1_effective[:-2], node
|
|
51
|
-
)
|
|
52
|
-
)
|
|
53
|
-
if left_vector and right_vector:
|
|
54
|
-
output_shape = batch_shape
|
|
55
|
-
elif left_vector:
|
|
56
|
-
output_shape = batch_shape + (n,)
|
|
57
|
-
elif right_vector:
|
|
58
|
-
output_shape = batch_shape + (m,)
|
|
59
|
-
else:
|
|
60
|
-
output_shape = batch_shape + (m, n)
|
|
61
|
-
expected_output_shape = _value_shape(graph, node.outputs[0], node)
|
|
62
|
-
if expected_output_shape != output_shape:
|
|
63
|
-
raise ShapeInferenceError(
|
|
64
|
-
"MatMul output shape must be "
|
|
65
|
-
f"{output_shape}, got {expected_output_shape}"
|
|
66
|
-
)
|
|
67
|
-
return MatMulSpec(
|
|
68
|
-
input0_shape=input0_shape,
|
|
69
|
-
input1_shape=input1_shape,
|
|
70
|
-
output_shape=output_shape,
|
|
71
|
-
batch_shape=batch_shape,
|
|
72
|
-
input0_batch_shape=input0_batch_shape,
|
|
73
|
-
input1_batch_shape=input1_batch_shape,
|
|
74
|
-
m=m,
|
|
75
|
-
n=n,
|
|
76
|
-
k=k_left,
|
|
77
|
-
left_vector=left_vector,
|
|
78
|
-
right_vector=right_vector,
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def _broadcast_batch_shapes(
|
|
83
|
-
left: tuple[int, ...], right: tuple[int, ...], node: Node
|
|
84
|
-
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
85
|
-
max_rank = max(len(left), len(right))
|
|
86
|
-
left_padded = (1,) * (max_rank - len(left)) + left
|
|
87
|
-
right_padded = (1,) * (max_rank - len(right)) + right
|
|
88
|
-
broadcast_shape = []
|
|
89
|
-
for left_dim, right_dim in zip(left_padded, right_padded):
|
|
90
|
-
if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
|
|
91
|
-
raise ShapeInferenceError(
|
|
92
|
-
"MatMul batch dimensions must be broadcastable, "
|
|
93
|
-
f"got {left} x {right}"
|
|
94
|
-
)
|
|
95
|
-
broadcast_shape.append(max(left_dim, right_dim))
|
|
96
|
-
return tuple(broadcast_shape), left_padded, right_padded
|
|
97
|
-
|
|
98
|
-
|
|
99
9
|
@register_lowering("MatMul")
|
|
100
10
|
def lower_matmul(graph: Graph, node: Node) -> MatMulOp:
|
|
101
|
-
|
|
102
|
-
|
|
11
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
12
|
+
raise UnsupportedOpError("MatMul must have 2 inputs and 1 output")
|
|
103
13
|
return MatMulOp(
|
|
104
14
|
input0=node.inputs[0],
|
|
105
15
|
input1=node.inputs[1],
|
|
106
16
|
output=node.outputs[0],
|
|
107
|
-
input0_shape=spec.input0_shape,
|
|
108
|
-
input1_shape=spec.input1_shape,
|
|
109
|
-
output_shape=spec.output_shape,
|
|
110
|
-
batch_shape=spec.batch_shape,
|
|
111
|
-
input0_batch_shape=spec.input0_batch_shape,
|
|
112
|
-
input1_batch_shape=spec.input1_batch_shape,
|
|
113
|
-
m=spec.m,
|
|
114
|
-
n=spec.n,
|
|
115
|
-
k=spec.k,
|
|
116
|
-
left_vector=spec.left_vector,
|
|
117
|
-
right_vector=spec.right_vector,
|
|
118
|
-
dtype=op_dtype,
|
|
119
17
|
)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..errors import UnsupportedOpError
|
|
4
|
+
from ..ir.context import GraphContext
|
|
5
|
+
from ..ir.model import Node
|
|
6
|
+
from ..ir.ops import OptionalHasElementOp
|
|
7
|
+
from .registry import register_lowering
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@register_lowering("OptionalHasElement")
|
|
11
|
+
def lower_optional_has_element(
|
|
12
|
+
ctx: GraphContext, node: Node
|
|
13
|
+
) -> OptionalHasElementOp:
|
|
14
|
+
if len(node.inputs) != 1 or not node.inputs[0]:
|
|
15
|
+
raise UnsupportedOpError(
|
|
16
|
+
"OptionalHasElement expects exactly one non-empty input."
|
|
17
|
+
)
|
|
18
|
+
if len(node.outputs) != 1 or not node.outputs[0]:
|
|
19
|
+
raise UnsupportedOpError(
|
|
20
|
+
"OptionalHasElement expects exactly one output."
|
|
21
|
+
)
|
|
22
|
+
input_name = node.inputs[0]
|
|
23
|
+
value = ctx.find_value(input_name)
|
|
24
|
+
if not value.type.is_optional:
|
|
25
|
+
raise UnsupportedOpError(
|
|
26
|
+
"OptionalHasElement expects an optional input."
|
|
27
|
+
)
|
|
28
|
+
return OptionalHasElementOp(input0=input_name, output=node.outputs[0])
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
6
|
+
from ..ir.model import Graph, Node
|
|
7
|
+
from ..ir.op_base import BroadcastingOpBase
|
|
8
|
+
from ..ir.ops import QLinearMulOp
|
|
9
|
+
from .common import value_dtype as _value_dtype
|
|
10
|
+
from .common import value_shape as _value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _ensure_scalar_input(
|
|
15
|
+
graph: Graph, name: str, node: Node, label: str
|
|
16
|
+
) -> tuple[int, ...]:
|
|
17
|
+
shape = _value_shape(graph, name, node)
|
|
18
|
+
if shape not in {(), (1,)}:
|
|
19
|
+
raise UnsupportedOpError(
|
|
20
|
+
f"QLinearMul {label} must be scalar, got shape {shape}"
|
|
21
|
+
)
|
|
22
|
+
return shape
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _ensure_scale_dtype(dtype: ScalarType, label: str) -> None:
|
|
26
|
+
if not dtype.is_float:
|
|
27
|
+
raise UnsupportedOpError(
|
|
28
|
+
f"QLinearMul {label} must be float16/float/double"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@register_lowering("QLinearMul")
|
|
33
|
+
def lower_qlinear_mul(graph: Graph, node: Node) -> QLinearMulOp:
|
|
34
|
+
if len(node.inputs) != 8 or len(node.outputs) != 1:
|
|
35
|
+
raise UnsupportedOpError("QLinearMul must have 8 inputs and 1 output")
|
|
36
|
+
input0_shape = _value_shape(graph, node.inputs[0], node)
|
|
37
|
+
input1_shape = _value_shape(graph, node.inputs[3], node)
|
|
38
|
+
output_shape = BroadcastingOpBase.broadcast_shapes(
|
|
39
|
+
input0_shape, input1_shape
|
|
40
|
+
)
|
|
41
|
+
expected_output_shape = _value_shape(graph, node.outputs[0], node)
|
|
42
|
+
if expected_output_shape != output_shape:
|
|
43
|
+
raise ShapeInferenceError(
|
|
44
|
+
"QLinearMul output shape must be "
|
|
45
|
+
f"{output_shape}, got {expected_output_shape}"
|
|
46
|
+
)
|
|
47
|
+
input0_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
48
|
+
input1_dtype = _value_dtype(graph, node.inputs[3], node)
|
|
49
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
50
|
+
if input0_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
51
|
+
raise UnsupportedOpError("QLinearMul supports uint8/int8 inputs only")
|
|
52
|
+
if input1_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
53
|
+
raise UnsupportedOpError("QLinearMul supports uint8/int8 inputs only")
|
|
54
|
+
if output_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
55
|
+
raise UnsupportedOpError(
|
|
56
|
+
"QLinearMul supports uint8/int8 outputs only"
|
|
57
|
+
)
|
|
58
|
+
input0_scale_dtype = _value_dtype(graph, node.inputs[1], node)
|
|
59
|
+
input1_scale_dtype = _value_dtype(graph, node.inputs[4], node)
|
|
60
|
+
output_scale_dtype = _value_dtype(graph, node.inputs[6], node)
|
|
61
|
+
_ensure_scale_dtype(input0_scale_dtype, "a_scale")
|
|
62
|
+
_ensure_scale_dtype(input1_scale_dtype, "b_scale")
|
|
63
|
+
_ensure_scale_dtype(output_scale_dtype, "y_scale")
|
|
64
|
+
input0_zero_dtype = _value_dtype(graph, node.inputs[2], node)
|
|
65
|
+
input1_zero_dtype = _value_dtype(graph, node.inputs[5], node)
|
|
66
|
+
output_zero_dtype = _value_dtype(graph, node.inputs[7], node)
|
|
67
|
+
if input0_zero_dtype != input0_dtype:
|
|
68
|
+
raise UnsupportedOpError("QLinearMul a_zero_point dtype must match a")
|
|
69
|
+
if input1_zero_dtype != input1_dtype:
|
|
70
|
+
raise UnsupportedOpError("QLinearMul b_zero_point dtype must match b")
|
|
71
|
+
if output_zero_dtype != output_dtype:
|
|
72
|
+
raise UnsupportedOpError("QLinearMul y_zero_point dtype must match y")
|
|
73
|
+
input0_scale_shape = _ensure_scalar_input(
|
|
74
|
+
graph, node.inputs[1], node, "a_scale"
|
|
75
|
+
)
|
|
76
|
+
input1_scale_shape = _ensure_scalar_input(
|
|
77
|
+
graph, node.inputs[4], node, "b_scale"
|
|
78
|
+
)
|
|
79
|
+
output_scale_shape = _ensure_scalar_input(
|
|
80
|
+
graph, node.inputs[6], node, "y_scale"
|
|
81
|
+
)
|
|
82
|
+
input0_zero_shape = _ensure_scalar_input(
|
|
83
|
+
graph, node.inputs[2], node, "a_zero_point"
|
|
84
|
+
)
|
|
85
|
+
input1_zero_shape = _ensure_scalar_input(
|
|
86
|
+
graph, node.inputs[5], node, "b_zero_point"
|
|
87
|
+
)
|
|
88
|
+
output_zero_shape = _ensure_scalar_input(
|
|
89
|
+
graph, node.inputs[7], node, "y_zero_point"
|
|
90
|
+
)
|
|
91
|
+
return QLinearMulOp(
|
|
92
|
+
input0=node.inputs[0],
|
|
93
|
+
input0_scale=node.inputs[1],
|
|
94
|
+
input0_zero_point=node.inputs[2],
|
|
95
|
+
input1=node.inputs[3],
|
|
96
|
+
input1_scale=node.inputs[4],
|
|
97
|
+
input1_zero_point=node.inputs[5],
|
|
98
|
+
output_scale=node.inputs[6],
|
|
99
|
+
output_zero_point=node.inputs[7],
|
|
100
|
+
output=node.outputs[0],
|
|
101
|
+
input0_shape=input0_shape,
|
|
102
|
+
input1_shape=input1_shape,
|
|
103
|
+
output_shape=output_shape,
|
|
104
|
+
input0_dtype=input0_dtype,
|
|
105
|
+
input1_dtype=input1_dtype,
|
|
106
|
+
dtype=output_dtype,
|
|
107
|
+
input0_scale_dtype=input0_scale_dtype,
|
|
108
|
+
input1_scale_dtype=input1_scale_dtype,
|
|
109
|
+
output_scale_dtype=output_scale_dtype,
|
|
110
|
+
input0_scale_shape=input0_scale_shape,
|
|
111
|
+
input1_scale_shape=input1_scale_shape,
|
|
112
|
+
output_scale_shape=output_scale_shape,
|
|
113
|
+
input0_zero_shape=input0_zero_shape,
|
|
114
|
+
input1_zero_shape=input1_zero_shape,
|
|
115
|
+
output_zero_shape=output_zero_shape,
|
|
116
|
+
)
|
emx_onnx_cgen/lowering/reduce.py
CHANGED
|
@@ -525,17 +525,12 @@ def lower_reduce(graph: Graph, node: Node) -> ReduceOp | ReshapeOp:
|
|
|
525
525
|
return ReduceOp(
|
|
526
526
|
input0=node.inputs[0],
|
|
527
527
|
output=node.outputs[0],
|
|
528
|
-
input_shape=input_shape,
|
|
529
|
-
output_shape=spec.output_shape,
|
|
530
528
|
axes=spec.axes or (),
|
|
531
529
|
axes_input=spec.axes_input,
|
|
532
|
-
axes_input_shape=spec.axes_input_shape,
|
|
533
|
-
axes_input_dtype=spec.axes_input_dtype,
|
|
534
530
|
keepdims=spec.keepdims,
|
|
535
531
|
noop_with_empty_axes=bool(int(node.attrs.get("noop_with_empty_axes", 0))),
|
|
536
532
|
reduce_kind=REDUCE_KIND_BY_OP[node.op_type],
|
|
537
533
|
reduce_count=spec.reduce_count,
|
|
538
|
-
dtype=op_dtype,
|
|
539
534
|
)
|
|
540
535
|
|
|
541
536
|
|
|
@@ -2,31 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..ir.ops import ReshapeOp
|
|
6
5
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
6
|
+
from ..ir.context import GraphContext
|
|
7
7
|
from ..ir.model import Graph, Initializer, Node
|
|
8
|
-
from .
|
|
8
|
+
from ..ir.ops import ReshapeOp
|
|
9
|
+
from .common import value_dtype, value_shape as resolved_value_shape
|
|
9
10
|
from .registry import register_lowering
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
13
|
-
|
|
14
|
-
return graph.find_value(name).type.shape
|
|
15
|
-
except KeyError as exc:
|
|
16
|
-
raise ShapeInferenceError(
|
|
17
|
-
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
18
|
-
"Hint: run ONNX shape inference or export with static shapes."
|
|
19
|
-
) from exc
|
|
14
|
+
return resolved_value_shape(graph, name, node)
|
|
20
15
|
|
|
21
16
|
|
|
22
17
|
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
23
|
-
|
|
24
|
-
return graph.find_value(name).type.dtype
|
|
25
|
-
except KeyError as exc:
|
|
26
|
-
raise ShapeInferenceError(
|
|
27
|
-
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
28
|
-
"Hint: run ONNX shape inference or export with static shapes."
|
|
29
|
-
) from exc
|
|
18
|
+
return value_dtype(graph, name, node)
|
|
30
19
|
|
|
31
20
|
|
|
32
21
|
def _shape_product(shape: tuple[int, ...]) -> int:
|
|
@@ -350,6 +339,8 @@ def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
|
|
|
350
339
|
for dim in output_shape:
|
|
351
340
|
if dim < 0:
|
|
352
341
|
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
342
|
+
if isinstance(graph, GraphContext):
|
|
343
|
+
graph.set_shape(node.outputs[0], output_shape)
|
|
353
344
|
return ReshapeOp(
|
|
354
345
|
input0=node.inputs[0],
|
|
355
346
|
output=node.outputs[0],
|
emx_onnx_cgen/lowering/shape.py
CHANGED
|
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..ir.ops import ShapeOp
|
|
6
5
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
6
|
+
from ..ir.context import GraphContext
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
|
-
from .
|
|
8
|
+
from ..ir.ops import ShapeOp
|
|
9
|
+
from .common import value_dtype, value_has_dim_params, value_shape
|
|
9
10
|
from .registry import register_lowering
|
|
10
11
|
|
|
11
12
|
|
|
@@ -29,10 +30,13 @@ def lower_shape(graph: Graph, node: Node) -> ShapeOp:
|
|
|
29
30
|
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
30
31
|
input_shape = value_shape(graph, node.inputs[0], node)
|
|
31
32
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
32
|
-
if
|
|
33
|
-
|
|
34
|
-
if output_shape
|
|
35
|
-
|
|
33
|
+
if value_has_dim_params(graph, node.outputs[0]) or not output_shape:
|
|
34
|
+
output_shape = ()
|
|
35
|
+
if output_shape:
|
|
36
|
+
if len(output_shape) != 1:
|
|
37
|
+
raise ShapeInferenceError("Shape output must be 1D")
|
|
38
|
+
if output_shape[0] < 0:
|
|
39
|
+
raise ShapeInferenceError("Shape output length must be non-negative")
|
|
36
40
|
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
37
41
|
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
38
42
|
if output_dtype != ScalarType.I64:
|
|
@@ -43,16 +47,18 @@ def lower_shape(graph: Graph, node: Node) -> ShapeOp:
|
|
|
43
47
|
len(input_shape), start=start, end=end
|
|
44
48
|
)
|
|
45
49
|
expected_shape = (max(0, end_index - start_index),)
|
|
46
|
-
if expected_shape != output_shape:
|
|
50
|
+
if output_shape and expected_shape != output_shape:
|
|
47
51
|
raise ShapeInferenceError(
|
|
48
52
|
"Shape output shape must be "
|
|
49
53
|
f"{expected_shape}, got {output_shape}"
|
|
50
54
|
)
|
|
55
|
+
if isinstance(graph, GraphContext):
|
|
56
|
+
graph.set_shape(node.outputs[0], expected_shape)
|
|
51
57
|
return ShapeOp(
|
|
52
58
|
input0=node.inputs[0],
|
|
53
59
|
output=node.outputs[0],
|
|
54
60
|
input_shape=input_shape,
|
|
55
|
-
output_shape=
|
|
61
|
+
output_shape=expected_shape,
|
|
56
62
|
values=input_shape[start_index:end_index],
|
|
57
63
|
dtype=output_dtype,
|
|
58
64
|
input_dtype=input_dtype,
|
emx_onnx_cgen/lowering/slice.py
CHANGED
|
@@ -6,10 +6,16 @@ import numpy as np
|
|
|
6
6
|
|
|
7
7
|
from shared.scalar_types import ScalarType
|
|
8
8
|
|
|
9
|
-
from ..ir.ops import SliceOp
|
|
10
9
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
|
+
from ..ir.context import GraphContext
|
|
11
11
|
from ..ir.model import Graph, Initializer, Node
|
|
12
|
-
from ..
|
|
12
|
+
from ..ir.ops import SliceOp
|
|
13
|
+
from ..lowering.common import (
|
|
14
|
+
resolve_int_list_from_value,
|
|
15
|
+
value_has_dim_params,
|
|
16
|
+
value_dtype,
|
|
17
|
+
value_shape,
|
|
18
|
+
)
|
|
13
19
|
from ..validation import normalize_axis
|
|
14
20
|
from .registry import register_lowering
|
|
15
21
|
|
|
@@ -70,7 +76,7 @@ def _maybe_read_int_list(
|
|
|
70
76
|
) -> list[int] | None:
|
|
71
77
|
initializer = _find_initializer(graph, name)
|
|
72
78
|
if initializer is None:
|
|
73
|
-
return
|
|
79
|
+
return resolve_int_list_from_value(graph, name, node)
|
|
74
80
|
return _read_int_list(graph, name, node, label=label)
|
|
75
81
|
|
|
76
82
|
|
|
@@ -335,6 +341,8 @@ def resolve_slice_spec(graph: Graph, node: Node) -> SliceSpec:
|
|
|
335
341
|
def lower_slice(graph: Graph, node: Node) -> SliceOp:
|
|
336
342
|
input_shape = value_shape(graph, node.inputs[0], node)
|
|
337
343
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
344
|
+
if value_has_dim_params(graph, node.outputs[0]):
|
|
345
|
+
output_shape = ()
|
|
338
346
|
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
339
347
|
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
340
348
|
if input_dtype != output_dtype:
|
|
@@ -356,6 +364,8 @@ def lower_slice(graph: Graph, node: Node) -> SliceOp:
|
|
|
356
364
|
f"{node.op_type} output shape must be "
|
|
357
365
|
f"{computed_output_shape}, got {output_shape}"
|
|
358
366
|
)
|
|
367
|
+
if isinstance(graph, GraphContext):
|
|
368
|
+
graph.set_shape(node.outputs[0], computed_output_shape)
|
|
359
369
|
return SliceOp(
|
|
360
370
|
input0=node.inputs[0],
|
|
361
371
|
output=node.outputs[0],
|
|
@@ -379,7 +389,7 @@ def lower_slice(graph: Graph, node: Node) -> SliceOp:
|
|
|
379
389
|
dtype=input_dtype,
|
|
380
390
|
input_dtype=input_dtype,
|
|
381
391
|
)
|
|
382
|
-
if len(output_shape) != len(input_shape):
|
|
392
|
+
if output_shape and len(output_shape) != len(input_shape):
|
|
383
393
|
raise ShapeInferenceError(
|
|
384
394
|
f"{node.op_type} output rank must match input rank"
|
|
385
395
|
)
|
|
@@ -3,49 +3,15 @@ from __future__ import annotations
|
|
|
3
3
|
from ..ir.ops import SoftmaxOp
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
|
-
from .common import node_dtype as _node_dtype
|
|
7
|
-
from .common import onnx_opset_version as _onnx_opset_version
|
|
8
|
-
from .common import shape_product as _shape_product
|
|
9
|
-
from .common import value_shape as _value_shape
|
|
10
6
|
from .registry import register_lowering
|
|
11
|
-
from ..validation import ensure_output_shape_matches_input
|
|
12
|
-
from ..validation import normalize_axis as _normalize_axis
|
|
13
7
|
|
|
14
8
|
|
|
15
9
|
@register_lowering("Softmax")
|
|
16
10
|
def lower_softmax(graph: Graph, node: Node) -> SoftmaxOp:
|
|
17
11
|
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
18
12
|
raise UnsupportedOpError("Softmax must have 1 input and 1 output")
|
|
19
|
-
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
20
|
-
if not op_dtype.is_float:
|
|
21
|
-
raise UnsupportedOpError(
|
|
22
|
-
"Softmax supports float16, float, and double inputs only"
|
|
23
|
-
)
|
|
24
|
-
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
25
|
-
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
26
|
-
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
27
|
-
opset_version = _onnx_opset_version(graph)
|
|
28
|
-
default_axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
29
|
-
axis_attr = node.attrs.get("axis", default_axis)
|
|
30
|
-
axis = _normalize_axis(
|
|
31
|
-
int(axis_attr),
|
|
32
|
-
input_shape,
|
|
33
|
-
node,
|
|
34
|
-
)
|
|
35
|
-
outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
|
|
36
|
-
axis_size = input_shape[axis]
|
|
37
|
-
inner = (
|
|
38
|
-
_shape_product(input_shape[axis + 1 :])
|
|
39
|
-
if axis + 1 < len(input_shape)
|
|
40
|
-
else 1
|
|
41
|
-
)
|
|
42
13
|
return SoftmaxOp(
|
|
43
14
|
input0=node.inputs[0],
|
|
44
15
|
output=node.outputs[0],
|
|
45
|
-
|
|
46
|
-
axis_size=axis_size,
|
|
47
|
-
inner=inner,
|
|
48
|
-
axis=axis,
|
|
49
|
-
shape=input_shape,
|
|
50
|
-
dtype=op_dtype,
|
|
16
|
+
axis=int(node.attrs["axis"]) if "axis" in node.attrs else None,
|
|
51
17
|
)
|
emx_onnx_cgen/lowering/split.py
CHANGED
|
@@ -6,8 +6,14 @@ from shared.scalar_types import ScalarType
|
|
|
6
6
|
|
|
7
7
|
from ..ir.ops import SplitOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.context import GraphContext
|
|
9
10
|
from ..ir.model import Graph, Initializer, Node
|
|
10
|
-
from ..lowering.common import
|
|
11
|
+
from ..lowering.common import (
|
|
12
|
+
optional_name,
|
|
13
|
+
resolve_int_list_from_value,
|
|
14
|
+
value_dtype,
|
|
15
|
+
value_shape,
|
|
16
|
+
)
|
|
11
17
|
from ..validation import normalize_axis
|
|
12
18
|
from .registry import register_lowering
|
|
13
19
|
|
|
@@ -46,6 +52,22 @@ def _validate_static_dims(shape: tuple[int, ...], node: Node) -> None:
|
|
|
46
52
|
)
|
|
47
53
|
|
|
48
54
|
|
|
55
|
+
def _validate_output_ranks(
|
|
56
|
+
output_shapes: list[tuple[int, ...]],
|
|
57
|
+
input_shape: tuple[int, ...],
|
|
58
|
+
node: Node,
|
|
59
|
+
) -> None:
|
|
60
|
+
expected_rank = len(input_shape)
|
|
61
|
+
for output_shape in output_shapes:
|
|
62
|
+
if not output_shape:
|
|
63
|
+
continue
|
|
64
|
+
if len(output_shape) != expected_rank:
|
|
65
|
+
raise ShapeInferenceError(
|
|
66
|
+
f"{node.op_type} output rank must match input rank "
|
|
67
|
+
f"{expected_rank}, got {len(output_shape)}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
49
71
|
def _normalize_num_outputs(node: Node, output_count: int) -> int:
|
|
50
72
|
num_outputs_attr = node.attrs.get("num_outputs")
|
|
51
73
|
if num_outputs_attr is None:
|
|
@@ -75,6 +97,7 @@ def lower_split(graph: Graph, node: Node) -> SplitOp:
|
|
|
75
97
|
output_shapes = [
|
|
76
98
|
value_shape(graph, output, node) for output in node.outputs
|
|
77
99
|
]
|
|
100
|
+
_validate_output_ranks(output_shapes, input_shape, node)
|
|
78
101
|
input_dtype = value_dtype(graph, input_name, node)
|
|
79
102
|
output_dtypes = {value_dtype(graph, output, node) for output in node.outputs}
|
|
80
103
|
if output_dtypes != {input_dtype}:
|
|
@@ -107,7 +130,15 @@ def lower_split(graph: Graph, node: Node) -> SplitOp:
|
|
|
107
130
|
raise ShapeInferenceError(
|
|
108
131
|
f"Split expects {len(node.outputs)} outputs, got {split_shape[0]}"
|
|
109
132
|
)
|
|
110
|
-
split_sizes =
|
|
133
|
+
split_sizes = resolve_int_list_from_value(graph, split_name, node)
|
|
134
|
+
if split_sizes is None:
|
|
135
|
+
if all(output_shape for output_shape in output_shapes):
|
|
136
|
+
split_sizes = [shape[axis] for shape in output_shapes]
|
|
137
|
+
else:
|
|
138
|
+
raise ShapeInferenceError(
|
|
139
|
+
"Split sizes must be constant when output shapes "
|
|
140
|
+
"are unavailable"
|
|
141
|
+
)
|
|
111
142
|
if len(split_sizes) != len(node.outputs):
|
|
112
143
|
raise ShapeInferenceError(
|
|
113
144
|
f"Split expects {len(split_sizes)} outputs, got {len(node.outputs)}"
|
|
@@ -133,11 +164,14 @@ def lower_split(graph: Graph, node: Node) -> SplitOp:
|
|
|
133
164
|
shape = list(input_shape)
|
|
134
165
|
shape[axis] = size
|
|
135
166
|
computed_shape = tuple(shape)
|
|
136
|
-
if output_shape != computed_shape:
|
|
167
|
+
if output_shape and output_shape != computed_shape:
|
|
137
168
|
raise ShapeInferenceError(
|
|
138
169
|
f"Split output shape must be {computed_shape}, got {output_shape}"
|
|
139
170
|
)
|
|
140
171
|
computed_shapes.append(computed_shape)
|
|
172
|
+
if isinstance(graph, GraphContext):
|
|
173
|
+
for output_name, shape in zip(node.outputs, computed_shapes):
|
|
174
|
+
graph.set_shape(output_name, shape)
|
|
141
175
|
return SplitOp(
|
|
142
176
|
input0=input_name,
|
|
143
177
|
outputs=tuple(node.outputs),
|