emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
+
from ..ir.model import Graph, Node
|
|
9
|
+
from ..validation import normalize_axis
|
|
10
|
+
from .common import (
|
|
11
|
+
optional_name,
|
|
12
|
+
value_dtype as _value_dtype,
|
|
13
|
+
value_shape as _value_shape,
|
|
14
|
+
)
|
|
15
|
+
from .registry import register_lowering
|
|
16
|
+
from ..ir.ops import DequantizeLinearOp
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class DequantizeSpec:
|
|
21
|
+
input_shape: tuple[int, ...]
|
|
22
|
+
scale_shape: tuple[int, ...]
|
|
23
|
+
axis: int | None
|
|
24
|
+
block_size: int | None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def resolve_dequantize_spec(graph: Graph, node: Node) -> DequantizeSpec:
|
|
28
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
29
|
+
raise UnsupportedOpError(
|
|
30
|
+
"DequantizeLinear must have 2 or 3 inputs and 1 output"
|
|
31
|
+
)
|
|
32
|
+
supported_attrs = {"axis", "block_size"}
|
|
33
|
+
if set(node.attrs) - supported_attrs:
|
|
34
|
+
raise UnsupportedOpError("DequantizeLinear has unsupported attributes")
|
|
35
|
+
block_size = int(node.attrs.get("block_size", 0))
|
|
36
|
+
if block_size < 0:
|
|
37
|
+
raise UnsupportedOpError("DequantizeLinear block_size must be >= 0")
|
|
38
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
39
|
+
scale_shape = _value_shape(graph, node.inputs[1], node)
|
|
40
|
+
zero_point_name = optional_name(node.inputs, 2)
|
|
41
|
+
if zero_point_name is not None:
|
|
42
|
+
zero_point_shape = _value_shape(graph, zero_point_name, node)
|
|
43
|
+
if zero_point_shape != scale_shape:
|
|
44
|
+
raise ShapeInferenceError(
|
|
45
|
+
"DequantizeLinear zero_point shape must match scale shape"
|
|
46
|
+
)
|
|
47
|
+
if scale_shape not in {(), (1,)}:
|
|
48
|
+
axis = int(node.attrs.get("axis", 1))
|
|
49
|
+
axis = normalize_axis(axis, input_shape, node)
|
|
50
|
+
if block_size > 0:
|
|
51
|
+
if len(scale_shape) != len(input_shape):
|
|
52
|
+
raise UnsupportedOpError(
|
|
53
|
+
"DequantizeLinear blocked scales must match input rank"
|
|
54
|
+
)
|
|
55
|
+
if input_shape[axis] % block_size != 0:
|
|
56
|
+
raise ShapeInferenceError(
|
|
57
|
+
"DequantizeLinear block_size must evenly divide axis length"
|
|
58
|
+
)
|
|
59
|
+
expected = list(input_shape)
|
|
60
|
+
expected[axis] = input_shape[axis] // block_size
|
|
61
|
+
if scale_shape != tuple(expected):
|
|
62
|
+
raise ShapeInferenceError(
|
|
63
|
+
"DequantizeLinear blocked scale shape must match "
|
|
64
|
+
"input shape with a reduced axis"
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
if len(scale_shape) != 1:
|
|
68
|
+
raise UnsupportedOpError(
|
|
69
|
+
"DequantizeLinear supports per-tensor, per-axis, "
|
|
70
|
+
"and blocked scales only"
|
|
71
|
+
)
|
|
72
|
+
if scale_shape[0] != input_shape[axis]:
|
|
73
|
+
raise ShapeInferenceError(
|
|
74
|
+
"DequantizeLinear scale length must match input axis size"
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
axis = None
|
|
78
|
+
block_size = 0
|
|
79
|
+
return DequantizeSpec(
|
|
80
|
+
input_shape=input_shape,
|
|
81
|
+
scale_shape=scale_shape,
|
|
82
|
+
axis=axis,
|
|
83
|
+
block_size=block_size or None,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@register_lowering("DequantizeLinear")
|
|
88
|
+
def lower_dequantize_linear(graph: Graph, node: Node) -> DequantizeLinearOp:
|
|
89
|
+
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
90
|
+
scale_dtype = _value_dtype(graph, node.inputs[1], node)
|
|
91
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
92
|
+
if input_dtype not in {
|
|
93
|
+
ScalarType.U8,
|
|
94
|
+
ScalarType.I8,
|
|
95
|
+
ScalarType.U16,
|
|
96
|
+
ScalarType.I16,
|
|
97
|
+
}:
|
|
98
|
+
raise UnsupportedOpError(
|
|
99
|
+
"DequantizeLinear supports int8/uint8/int16/uint16 inputs only"
|
|
100
|
+
)
|
|
101
|
+
if not scale_dtype.is_float or not output_dtype.is_float:
|
|
102
|
+
raise UnsupportedOpError(
|
|
103
|
+
"DequantizeLinear supports float16/float/double scales and outputs only"
|
|
104
|
+
)
|
|
105
|
+
if output_dtype != scale_dtype:
|
|
106
|
+
raise UnsupportedOpError(
|
|
107
|
+
"DequantizeLinear output dtype must match scale dtype"
|
|
108
|
+
)
|
|
109
|
+
zero_point_name = optional_name(node.inputs, 2)
|
|
110
|
+
if zero_point_name is not None:
|
|
111
|
+
zero_point_dtype = _value_dtype(graph, zero_point_name, node)
|
|
112
|
+
if zero_point_dtype != input_dtype:
|
|
113
|
+
raise UnsupportedOpError(
|
|
114
|
+
"DequantizeLinear zero_point dtype must match input dtype"
|
|
115
|
+
)
|
|
116
|
+
spec = resolve_dequantize_spec(graph, node)
|
|
117
|
+
return DequantizeLinearOp(
|
|
118
|
+
input0=node.inputs[0],
|
|
119
|
+
scale=node.inputs[1],
|
|
120
|
+
zero_point=zero_point_name,
|
|
121
|
+
output=node.outputs[0],
|
|
122
|
+
input_shape=spec.input_shape,
|
|
123
|
+
axis=spec.axis,
|
|
124
|
+
block_size=spec.block_size,
|
|
125
|
+
dtype=output_dtype,
|
|
126
|
+
input_dtype=input_dtype,
|
|
127
|
+
scale_dtype=scale_dtype,
|
|
128
|
+
)
|
|
@@ -3,11 +3,18 @@ from __future__ import annotations
|
|
|
3
3
|
from shared.scalar_functions import ScalarFunction, ScalarFunctionError
|
|
4
4
|
from shared.scalar_types import ScalarType
|
|
5
5
|
|
|
6
|
-
from ..ir.
|
|
6
|
+
from ..ir.op_base import BroadcastingOpBase
|
|
7
|
+
from ..ir.ops import BinaryOp, ClipOp, PowOp, UnaryOp
|
|
7
8
|
from ..errors import UnsupportedOpError
|
|
8
9
|
from ..ir.context import GraphContext
|
|
9
10
|
from ..ir.model import Graph, Node
|
|
10
|
-
from ..lowering.common import
|
|
11
|
+
from ..lowering.common import (
|
|
12
|
+
node_dtype,
|
|
13
|
+
onnx_opset_version,
|
|
14
|
+
optional_name,
|
|
15
|
+
value_dtype,
|
|
16
|
+
value_shape,
|
|
17
|
+
)
|
|
11
18
|
from ..lowering.registry import register_lowering, register_lowering_if_missing
|
|
12
19
|
from ..ops import (
|
|
13
20
|
BINARY_OP_TYPES,
|
|
@@ -29,6 +36,24 @@ def lower_clip(graph: Graph, node: Node) -> ClipOp:
|
|
|
29
36
|
raise UnsupportedOpError("Clip input must be provided")
|
|
30
37
|
min_name = optional_name(node.inputs, 1)
|
|
31
38
|
max_name = optional_name(node.inputs, 2)
|
|
39
|
+
min_value = None
|
|
40
|
+
max_value = None
|
|
41
|
+
opset_version = onnx_opset_version(graph)
|
|
42
|
+
if opset_version is None or opset_version < 11:
|
|
43
|
+
if min_name is None and "min" in node.attrs:
|
|
44
|
+
try:
|
|
45
|
+
min_value = float(node.attrs["min"])
|
|
46
|
+
except (TypeError, ValueError) as exc:
|
|
47
|
+
raise UnsupportedOpError(
|
|
48
|
+
"Clip min attribute must be numeric"
|
|
49
|
+
) from exc
|
|
50
|
+
if max_name is None and "max" in node.attrs:
|
|
51
|
+
try:
|
|
52
|
+
max_value = float(node.attrs["max"])
|
|
53
|
+
except (TypeError, ValueError) as exc:
|
|
54
|
+
raise UnsupportedOpError(
|
|
55
|
+
"Clip max attribute must be numeric"
|
|
56
|
+
) from exc
|
|
32
57
|
input_dtype = value_dtype(graph, input_name, node)
|
|
33
58
|
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
34
59
|
if input_dtype != output_dtype:
|
|
@@ -61,11 +86,8 @@ def lower_clip(graph: Graph, node: Node) -> ClipOp:
|
|
|
61
86
|
input_min=min_name,
|
|
62
87
|
input_max=max_name,
|
|
63
88
|
output=node.outputs[0],
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
max_shape=max_shape,
|
|
67
|
-
output_shape=output_shape,
|
|
68
|
-
dtype=input_dtype,
|
|
89
|
+
min_value=min_value,
|
|
90
|
+
max_value=max_value,
|
|
69
91
|
)
|
|
70
92
|
|
|
71
93
|
|
|
@@ -82,9 +104,54 @@ def lower_celu(graph: Graph, node: Node) -> UnaryOp:
|
|
|
82
104
|
input0=node.inputs[0],
|
|
83
105
|
output=node.outputs[0],
|
|
84
106
|
function=ScalarFunction.CELU,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
107
|
+
params=(alpha,),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@register_lowering("Elu")
|
|
112
|
+
def lower_elu(graph: Graph, node: Node) -> UnaryOp:
|
|
113
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
114
|
+
raise UnsupportedOpError("Elu must have 1 input and 1 output")
|
|
115
|
+
dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
116
|
+
if not dtype.is_float:
|
|
117
|
+
raise UnsupportedOpError("Elu only supports floating-point inputs")
|
|
118
|
+
for key in node.attrs:
|
|
119
|
+
if key != "alpha":
|
|
120
|
+
raise UnsupportedOpError(f"Elu does not support attribute {key}")
|
|
121
|
+
try:
|
|
122
|
+
alpha = float(node.attrs.get("alpha", 1.0))
|
|
123
|
+
except (TypeError, ValueError) as exc:
|
|
124
|
+
raise UnsupportedOpError("Elu alpha must be numeric") from exc
|
|
125
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
126
|
+
return UnaryOp(
|
|
127
|
+
input0=node.inputs[0],
|
|
128
|
+
output=node.outputs[0],
|
|
129
|
+
function=ScalarFunction.ELU,
|
|
130
|
+
params=(alpha,),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@register_lowering("LeakyRelu")
|
|
135
|
+
def lower_leaky_relu(graph: Graph, node: Node) -> UnaryOp:
|
|
136
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
137
|
+
raise UnsupportedOpError("LeakyRelu must have 1 input and 1 output")
|
|
138
|
+
dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
139
|
+
if not dtype.is_float:
|
|
140
|
+
raise UnsupportedOpError("LeakyRelu only supports floating-point inputs")
|
|
141
|
+
for key in node.attrs:
|
|
142
|
+
if key != "alpha":
|
|
143
|
+
raise UnsupportedOpError(
|
|
144
|
+
f"LeakyRelu does not support attribute {key}"
|
|
145
|
+
)
|
|
146
|
+
try:
|
|
147
|
+
alpha = float(node.attrs.get("alpha", 0.01))
|
|
148
|
+
except (TypeError, ValueError) as exc:
|
|
149
|
+
raise UnsupportedOpError("LeakyRelu alpha must be numeric") from exc
|
|
150
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
151
|
+
return UnaryOp(
|
|
152
|
+
input0=node.inputs[0],
|
|
153
|
+
output=node.outputs[0],
|
|
154
|
+
function=ScalarFunction.LEAKY_RELU,
|
|
88
155
|
params=(alpha,),
|
|
89
156
|
)
|
|
90
157
|
|
|
@@ -102,9 +169,6 @@ def lower_swish(graph: Graph, node: Node) -> UnaryOp:
|
|
|
102
169
|
input0=node.inputs[0],
|
|
103
170
|
output=node.outputs[0],
|
|
104
171
|
function=ScalarFunction.SWISH,
|
|
105
|
-
shape=output_shape,
|
|
106
|
-
dtype=dtype,
|
|
107
|
-
input_dtype=dtype,
|
|
108
172
|
params=(alpha,),
|
|
109
173
|
)
|
|
110
174
|
|
|
@@ -123,13 +187,50 @@ def lower_shrink(graph: Graph, node: Node) -> UnaryOp:
|
|
|
123
187
|
input0=node.inputs[0],
|
|
124
188
|
output=node.outputs[0],
|
|
125
189
|
function=ScalarFunction.SHRINK,
|
|
126
|
-
shape=output_shape,
|
|
127
|
-
dtype=dtype,
|
|
128
|
-
input_dtype=dtype,
|
|
129
190
|
params=(bias, lambd),
|
|
130
191
|
)
|
|
131
192
|
|
|
132
193
|
|
|
194
|
+
@register_lowering("Pow")
|
|
195
|
+
def lower_pow(graph: Graph, node: Node) -> PowOp:
|
|
196
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
197
|
+
raise UnsupportedOpError("Pow must have 2 inputs and 1 output")
|
|
198
|
+
op_dtype = value_dtype(graph, node.inputs[0], node)
|
|
199
|
+
op_spec = binary_op_symbol(ScalarFunction.POW, node.attrs, dtype=op_dtype)
|
|
200
|
+
if op_spec is None:
|
|
201
|
+
raise UnsupportedOpError("Unsupported op Pow")
|
|
202
|
+
return PowOp(
|
|
203
|
+
input0=node.inputs[0],
|
|
204
|
+
input1=node.inputs[1],
|
|
205
|
+
output=node.outputs[0],
|
|
206
|
+
function=ScalarFunction.POW,
|
|
207
|
+
operator_kind=op_spec.kind,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _infer_binary_output_shape(
|
|
212
|
+
*,
|
|
213
|
+
function: ScalarFunction,
|
|
214
|
+
input0_shape: tuple[int, ...],
|
|
215
|
+
input1_shape: tuple[int, ...],
|
|
216
|
+
) -> tuple[int, ...]:
|
|
217
|
+
if function != ScalarFunction.PRELU:
|
|
218
|
+
return BroadcastingOpBase.broadcast_shapes(input0_shape, input1_shape)
|
|
219
|
+
if BroadcastingOpBase.unidirectional_broadcastable(
|
|
220
|
+
input1_shape, input0_shape
|
|
221
|
+
):
|
|
222
|
+
return input0_shape
|
|
223
|
+
channel_axis = BroadcastingOpBase.prelu_channel_axis(
|
|
224
|
+
input0_shape, input1_shape
|
|
225
|
+
)
|
|
226
|
+
if channel_axis is None:
|
|
227
|
+
raise ShapeInferenceError(
|
|
228
|
+
"Broadcasting mismatch for shapes: "
|
|
229
|
+
+ ", ".join(str(shape) for shape in (input0_shape, input1_shape))
|
|
230
|
+
)
|
|
231
|
+
return input0_shape
|
|
232
|
+
|
|
233
|
+
|
|
133
234
|
def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | UnaryOp:
|
|
134
235
|
if node.op_type == "BitShift":
|
|
135
236
|
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
@@ -163,11 +264,6 @@ def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | U
|
|
|
163
264
|
output=node.outputs[0],
|
|
164
265
|
function=function,
|
|
165
266
|
operator_kind=op_spec.kind,
|
|
166
|
-
input0_shape=input0_shape,
|
|
167
|
-
input1_shape=input1_shape,
|
|
168
|
-
shape=output_shape,
|
|
169
|
-
dtype=op_dtype,
|
|
170
|
-
input_dtype=op_dtype,
|
|
171
267
|
)
|
|
172
268
|
if node.op_type == "Mod":
|
|
173
269
|
fmod = int(node.attrs.get("fmod", 0))
|
|
@@ -201,18 +297,21 @@ def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | U
|
|
|
201
297
|
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
202
298
|
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
203
299
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
204
|
-
|
|
300
|
+
op = BinaryOp(
|
|
205
301
|
input0=node.inputs[0],
|
|
206
302
|
input1=node.inputs[1],
|
|
207
303
|
output=node.outputs[0],
|
|
208
304
|
function=function,
|
|
209
305
|
operator_kind=op_spec.kind,
|
|
210
|
-
input0_shape=input0_shape,
|
|
211
|
-
input1_shape=input1_shape,
|
|
212
|
-
shape=output_shape,
|
|
213
|
-
dtype=output_dtype,
|
|
214
|
-
input_dtype=input_dtype,
|
|
215
306
|
)
|
|
307
|
+
if isinstance(graph, GraphContext):
|
|
308
|
+
inferred_shape = _infer_binary_output_shape(
|
|
309
|
+
function=function,
|
|
310
|
+
input0_shape=input0_shape,
|
|
311
|
+
input1_shape=input1_shape,
|
|
312
|
+
)
|
|
313
|
+
graph.set_shape(node.outputs[0], inferred_shape)
|
|
314
|
+
return op
|
|
216
315
|
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
217
316
|
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
218
317
|
unary_symbol = unary_op_symbol(function, dtype=op_dtype)
|
|
@@ -226,32 +325,36 @@ def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | U
|
|
|
226
325
|
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
227
326
|
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
228
327
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
229
|
-
|
|
328
|
+
op = BinaryOp(
|
|
230
329
|
input0=node.inputs[0],
|
|
231
330
|
input1=node.inputs[1],
|
|
232
331
|
output=node.outputs[0],
|
|
233
332
|
function=function,
|
|
234
333
|
operator_kind=op_spec.kind,
|
|
235
|
-
input0_shape=input0_shape,
|
|
236
|
-
input1_shape=input1_shape,
|
|
237
|
-
shape=output_shape,
|
|
238
|
-
dtype=op_dtype,
|
|
239
|
-
input_dtype=op_dtype,
|
|
240
334
|
)
|
|
335
|
+
if isinstance(graph, GraphContext):
|
|
336
|
+
inferred_shape = _infer_binary_output_shape(
|
|
337
|
+
function=function,
|
|
338
|
+
input0_shape=input0_shape,
|
|
339
|
+
input1_shape=input1_shape,
|
|
340
|
+
)
|
|
341
|
+
graph.set_shape(node.outputs[0], inferred_shape)
|
|
342
|
+
return op
|
|
241
343
|
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
242
344
|
raise UnsupportedOpError(
|
|
243
345
|
f"{node.op_type} must have 1 input and 1 output"
|
|
244
346
|
)
|
|
245
347
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
246
|
-
|
|
348
|
+
op = UnaryOp(
|
|
247
349
|
input0=node.inputs[0],
|
|
248
350
|
output=node.outputs[0],
|
|
249
351
|
function=function,
|
|
250
|
-
shape=output_shape,
|
|
251
|
-
dtype=op_dtype,
|
|
252
|
-
input_dtype=op_dtype,
|
|
253
352
|
params=(),
|
|
254
353
|
)
|
|
354
|
+
if isinstance(graph, GraphContext):
|
|
355
|
+
inferred_shape = value_shape(graph, node.inputs[0], node)
|
|
356
|
+
graph.set_shape(node.outputs[0], inferred_shape)
|
|
357
|
+
return op
|
|
255
358
|
|
|
256
359
|
|
|
257
360
|
_DEFAULT_ELEMENTWISE_TYPES = (
|
|
@@ -283,9 +386,6 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
|
|
|
283
386
|
input0=node.inputs[0],
|
|
284
387
|
output=node.outputs[0],
|
|
285
388
|
function=ScalarFunction.ISINF,
|
|
286
|
-
shape=output_shape,
|
|
287
|
-
dtype=output_dtype,
|
|
288
|
-
input_dtype=input_dtype,
|
|
289
389
|
params=(float(detect_negative), float(detect_positive)),
|
|
290
390
|
)
|
|
291
391
|
|
|
@@ -305,8 +405,5 @@ def lower_isnan(graph: Graph, node: Node) -> UnaryOp:
|
|
|
305
405
|
input0=node.inputs[0],
|
|
306
406
|
output=node.outputs[0],
|
|
307
407
|
function=ScalarFunction.ISNAN,
|
|
308
|
-
shape=output_shape,
|
|
309
|
-
dtype=output_dtype,
|
|
310
|
-
input_dtype=input_dtype,
|
|
311
408
|
params=(),
|
|
312
409
|
)
|
emx_onnx_cgen/lowering/gather.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..ir.ops import GatherOp
|
|
4
3
|
from ..errors import UnsupportedOpError
|
|
4
|
+
from ..ir.context import GraphContext
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
|
+
from ..ir.ops import GatherOp
|
|
7
|
+
from ..lowering.common import value_shape
|
|
8
|
+
from ..validation import normalize_axis
|
|
6
9
|
from .registry import register_lowering
|
|
7
10
|
|
|
8
11
|
|
|
@@ -11,9 +14,15 @@ def lower_gather(graph: Graph, node: Node) -> GatherOp:
|
|
|
11
14
|
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
12
15
|
raise UnsupportedOpError("Gather must have 2 inputs and 1 output")
|
|
13
16
|
data_name, indices_name = node.inputs
|
|
17
|
+
data_shape = value_shape(graph, data_name, node)
|
|
18
|
+
indices_shape = value_shape(graph, indices_name, node)
|
|
19
|
+
axis = normalize_axis(int(node.attrs.get("axis", 0)), data_shape, node)
|
|
20
|
+
output_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1 :]
|
|
21
|
+
if isinstance(graph, GraphContext):
|
|
22
|
+
graph.set_shape(node.outputs[0], output_shape)
|
|
14
23
|
return GatherOp(
|
|
15
24
|
data=data_name,
|
|
16
25
|
indices=indices_name,
|
|
17
26
|
output=node.outputs[0],
|
|
18
|
-
axis=
|
|
27
|
+
axis=axis,
|
|
19
28
|
)
|
emx_onnx_cgen/lowering/gemm.py
CHANGED
|
@@ -1,139 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
|
|
5
|
-
from shared.scalar_types import ScalarType
|
|
6
|
-
|
|
7
3
|
from ..ir.ops import GemmOp
|
|
8
|
-
from ..errors import
|
|
4
|
+
from ..errors import UnsupportedOpError
|
|
9
5
|
from ..ir.model import Graph, Node
|
|
10
|
-
from .common import node_dtype as _node_dtype
|
|
11
|
-
from .common import value_shape as _value_shape
|
|
12
6
|
from .registry import register_lowering
|
|
13
7
|
|
|
14
8
|
|
|
15
|
-
@dataclass(frozen=True)
|
|
16
|
-
class GemmSpec:
|
|
17
|
-
m: int
|
|
18
|
-
n: int
|
|
19
|
-
k: int
|
|
20
|
-
alpha: float | int
|
|
21
|
-
beta: float | int
|
|
22
|
-
trans_a: bool
|
|
23
|
-
trans_b: bool
|
|
24
|
-
c_shape: tuple[int, ...] | None
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def resolve_gemm_spec(graph: Graph, node: Node, dtype: ScalarType) -> GemmSpec:
|
|
28
|
-
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
29
|
-
raise UnsupportedOpError("Gemm must have 2 or 3 inputs and 1 output")
|
|
30
|
-
alpha, beta, trans_a, trans_b = _resolve_gemm_attrs(node, dtype)
|
|
31
|
-
input0_shape = _value_shape(graph, node.inputs[0], node)
|
|
32
|
-
input1_shape = _value_shape(graph, node.inputs[1], node)
|
|
33
|
-
if len(input0_shape) != 2 or len(input1_shape) != 2:
|
|
34
|
-
raise UnsupportedOpError(
|
|
35
|
-
"Gemm supports 2D inputs only, "
|
|
36
|
-
f"got {input0_shape} x {input1_shape}"
|
|
37
|
-
)
|
|
38
|
-
if trans_a:
|
|
39
|
-
m, k_left = input0_shape[1], input0_shape[0]
|
|
40
|
-
else:
|
|
41
|
-
m, k_left = input0_shape
|
|
42
|
-
if trans_b:
|
|
43
|
-
n, k_right = input1_shape[0], input1_shape[1]
|
|
44
|
-
else:
|
|
45
|
-
k_right, n = input1_shape
|
|
46
|
-
if k_left != k_right:
|
|
47
|
-
raise ShapeInferenceError(
|
|
48
|
-
f"Gemm inner dimensions must match, got {k_left} and {k_right}"
|
|
49
|
-
)
|
|
50
|
-
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
51
|
-
if output_shape != (m, n):
|
|
52
|
-
raise ShapeInferenceError(
|
|
53
|
-
f"Gemm output shape must be {(m, n)}, got {output_shape}"
|
|
54
|
-
)
|
|
55
|
-
c_shape = None
|
|
56
|
-
if len(node.inputs) == 3:
|
|
57
|
-
bias_shape = _value_shape(graph, node.inputs[2], node)
|
|
58
|
-
c_shape = validate_gemm_bias_shape((m, n), bias_shape, node)
|
|
59
|
-
return GemmSpec(
|
|
60
|
-
m=m,
|
|
61
|
-
n=n,
|
|
62
|
-
k=k_left,
|
|
63
|
-
alpha=alpha,
|
|
64
|
-
beta=beta,
|
|
65
|
-
trans_a=trans_a,
|
|
66
|
-
trans_b=trans_b,
|
|
67
|
-
c_shape=c_shape,
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _resolve_gemm_attrs(
|
|
72
|
-
node: Node, dtype: ScalarType
|
|
73
|
-
) -> tuple[float | int, float | int, bool, bool]:
|
|
74
|
-
alpha = float(node.attrs.get("alpha", 1.0))
|
|
75
|
-
beta = float(node.attrs.get("beta", 1.0))
|
|
76
|
-
trans_a = int(node.attrs.get("transA", 0))
|
|
77
|
-
trans_b = int(node.attrs.get("transB", 0))
|
|
78
|
-
if trans_a not in {0, 1} or trans_b not in {0, 1}:
|
|
79
|
-
raise UnsupportedOpError(
|
|
80
|
-
"Gemm only supports transA/transB values of 0 or 1"
|
|
81
|
-
)
|
|
82
|
-
if dtype == ScalarType.BOOL:
|
|
83
|
-
raise UnsupportedOpError("Gemm supports numeric inputs only")
|
|
84
|
-
if not dtype.is_float:
|
|
85
|
-
alpha_int = int(alpha)
|
|
86
|
-
beta_int = int(beta)
|
|
87
|
-
if alpha != alpha_int or beta != beta_int:
|
|
88
|
-
raise UnsupportedOpError(
|
|
89
|
-
"Gemm alpha and beta must be integers for non-float inputs"
|
|
90
|
-
)
|
|
91
|
-
alpha = alpha_int
|
|
92
|
-
beta = beta_int
|
|
93
|
-
return alpha, beta, bool(trans_a), bool(trans_b)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def validate_gemm_bias_shape(
|
|
97
|
-
output_shape: tuple[int, int], bias_shape: tuple[int, ...], node: Node
|
|
98
|
-
) -> tuple[int, ...]:
|
|
99
|
-
if len(bias_shape) == 0:
|
|
100
|
-
return bias_shape
|
|
101
|
-
if len(bias_shape) == 1:
|
|
102
|
-
if bias_shape[0] not in {1, output_shape[1]}:
|
|
103
|
-
raise ShapeInferenceError(
|
|
104
|
-
"Gemm bias input must be broadcastable to output shape, "
|
|
105
|
-
f"got {bias_shape} vs {output_shape}"
|
|
106
|
-
)
|
|
107
|
-
return bias_shape
|
|
108
|
-
if len(bias_shape) == 2:
|
|
109
|
-
m, n = output_shape
|
|
110
|
-
if bias_shape[0] not in {1, m} or bias_shape[1] not in {1, n}:
|
|
111
|
-
raise ShapeInferenceError(
|
|
112
|
-
"Gemm bias input must be broadcastable to output shape, "
|
|
113
|
-
f"got {bias_shape} vs {output_shape}"
|
|
114
|
-
)
|
|
115
|
-
return bias_shape
|
|
116
|
-
raise ShapeInferenceError(
|
|
117
|
-
f"Gemm bias input must be rank 1 or 2, got {bias_shape}"
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
|
|
121
9
|
@register_lowering("Gemm")
|
|
122
10
|
def lower_gemm(graph: Graph, node: Node) -> GemmOp:
|
|
123
|
-
|
|
124
|
-
|
|
11
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
12
|
+
raise UnsupportedOpError("Gemm must have 2 or 3 inputs and 1 output")
|
|
125
13
|
return GemmOp(
|
|
126
14
|
input_a=node.inputs[0],
|
|
127
15
|
input_b=node.inputs[1],
|
|
128
16
|
input_c=node.inputs[2] if len(node.inputs) == 3 else None,
|
|
129
17
|
output=node.outputs[0],
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
trans_b=spec.trans_b,
|
|
135
|
-
alpha=spec.alpha,
|
|
136
|
-
beta=spec.beta,
|
|
137
|
-
c_shape=spec.c_shape,
|
|
138
|
-
dtype=op_dtype,
|
|
18
|
+
alpha=float(node.attrs.get("alpha", 1.0)),
|
|
19
|
+
beta=float(node.attrs.get("beta", 1.0)),
|
|
20
|
+
trans_a=int(node.attrs.get("transA", 0)),
|
|
21
|
+
trans_b=int(node.attrs.get("transB", 0)),
|
|
139
22
|
)
|
|
@@ -45,15 +45,10 @@ def lower_global_max_pool(graph: Graph, node: Node) -> ReduceOp:
|
|
|
45
45
|
return ReduceOp(
|
|
46
46
|
input0=node.inputs[0],
|
|
47
47
|
output=node.outputs[0],
|
|
48
|
-
input_shape=input_shape,
|
|
49
|
-
output_shape=output_shape,
|
|
50
48
|
axes=axes,
|
|
51
49
|
axes_input=None,
|
|
52
|
-
axes_input_shape=None,
|
|
53
|
-
axes_input_dtype=None,
|
|
54
50
|
keepdims=True,
|
|
55
51
|
noop_with_empty_axes=False,
|
|
56
52
|
reduce_kind="max",
|
|
57
53
|
reduce_count=None,
|
|
58
|
-
dtype=op_dtype,
|
|
59
54
|
)
|