emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +50 -23
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +30 -387
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +36 -18
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +1 -1
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +1 -1
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +1 -1
- emx_onnx_cgen/lowering/gather_nd.py +1 -1
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +1 -1
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +1 -1
- emx_onnx_cgen/lowering/identity.py +1 -1
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +1 -1
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +1 -1
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +1 -1
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +1 -1
- emx_onnx_cgen/lowering/one_hot.py +1 -1
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +1 -1
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +1 -1
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +1 -1
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +1 -1
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +1 -1
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +1 -1
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +25 -7
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +1 -1
- emx_onnx_cgen/lowering/unsqueeze.py +1 -1
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/runtime/evaluator.py +325 -1
- emx_onnx_cgen/verification.py +9 -39
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
- emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +11 -0
- shared/ulp.py +17 -0
- emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import GatherElementsOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from ..validation import normalize_axis
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import GatherNDOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import value_dtype as _value_dtype
|
emx_onnx_cgen/lowering/gemm.py
CHANGED
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import GemmOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Node
|
|
10
10
|
from .common import node_dtype as _node_dtype
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import ReduceOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import value_dtype as _value_dtype
|
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import GridSampleOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Node
|
|
10
10
|
from .common import value_dtype, value_shape
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import GroupNormalizationOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import HardmaxOp
|
|
6
6
|
from ..errors import UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import node_dtype as _node_dtype
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import InstanceNormalizationOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import LayerNormalizationOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import LpNormalizationOp
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import LpPoolOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .registry import register_lowering
|
emx_onnx_cgen/lowering/lrn.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import LrnOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .registry import register_lowering
|
emx_onnx_cgen/lowering/lstm.py
CHANGED
|
@@ -323,7 +323,7 @@ def resolve_lstm_spec(graph: Graph, node: Node) -> LstmSpec:
|
|
|
323
323
|
|
|
324
324
|
@register_lowering("LSTM")
|
|
325
325
|
def lower_lstm(graph: Graph, node: Node) -> "LstmOp":
|
|
326
|
-
from ..
|
|
326
|
+
from ..ir.ops import LstmOp
|
|
327
327
|
|
|
328
328
|
spec = resolve_lstm_spec(graph, node)
|
|
329
329
|
return LstmOp(
|
emx_onnx_cgen/lowering/matmul.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import MatMulOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import node_dtype as _node_dtype
|
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|
|
5
5
|
|
|
6
6
|
from shared.scalar_types import ScalarType
|
|
7
7
|
|
|
8
|
-
from ..
|
|
8
|
+
from ..ir.ops import MaxPoolOp
|
|
9
9
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
10
|
from ..ir.model import Graph, Node
|
|
11
11
|
from .common import node_dtype as _node_dtype
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import MeanVarianceNormalizationOp
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import NegativeLogLikelihoodLossOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Initializer, Node
|
|
8
8
|
from .common import shape_product as _shape_product
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import NonMaxSuppressionOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from ..lowering.common import optional_name, shape_product, value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _validate_scalar_input(
|
|
13
|
+
graph: Graph,
|
|
14
|
+
name: str,
|
|
15
|
+
node: Node,
|
|
16
|
+
*,
|
|
17
|
+
allowed_dtypes: set[ScalarType],
|
|
18
|
+
label: str,
|
|
19
|
+
) -> tuple[ScalarType, tuple[int, ...]]:
|
|
20
|
+
dtype = value_dtype(graph, name, node)
|
|
21
|
+
if dtype not in allowed_dtypes:
|
|
22
|
+
allowed = ", ".join(sorted(d.onnx_name for d in allowed_dtypes))
|
|
23
|
+
raise UnsupportedOpError(
|
|
24
|
+
f"{node.op_type} {label} must be {allowed}, got {dtype.onnx_name}"
|
|
25
|
+
)
|
|
26
|
+
shape = value_shape(graph, name, node)
|
|
27
|
+
if shape not in {(), (1,)}:
|
|
28
|
+
total = shape_product(shape)
|
|
29
|
+
if total != 1:
|
|
30
|
+
raise ShapeInferenceError(
|
|
31
|
+
f"{node.op_type} {label} must be a scalar tensor, got shape {shape}"
|
|
32
|
+
)
|
|
33
|
+
return dtype, shape
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@register_lowering("NonMaxSuppression")
|
|
37
|
+
def lower_non_max_suppression(graph: Graph, node: Node) -> NonMaxSuppressionOp:
|
|
38
|
+
if node.op_type != "NonMaxSuppression":
|
|
39
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
40
|
+
if len(node.outputs) != 1:
|
|
41
|
+
raise UnsupportedOpError(
|
|
42
|
+
f"{node.op_type} must have 1 output, got {len(node.outputs)}"
|
|
43
|
+
)
|
|
44
|
+
if len(node.inputs) < 2 or len(node.inputs) > 5:
|
|
45
|
+
raise UnsupportedOpError(
|
|
46
|
+
f"{node.op_type} must have 2 to 5 inputs, got {len(node.inputs)}"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
boxes = node.inputs[0]
|
|
50
|
+
scores = node.inputs[1]
|
|
51
|
+
max_output_boxes_per_class = optional_name(node.inputs, 2)
|
|
52
|
+
iou_threshold = optional_name(node.inputs, 3)
|
|
53
|
+
score_threshold = optional_name(node.inputs, 4)
|
|
54
|
+
output = node.outputs[0]
|
|
55
|
+
|
|
56
|
+
boxes_shape = value_shape(graph, boxes, node)
|
|
57
|
+
scores_shape = value_shape(graph, scores, node)
|
|
58
|
+
if len(boxes_shape) != 3 or boxes_shape[2] != 4:
|
|
59
|
+
raise ShapeInferenceError(
|
|
60
|
+
f"{node.op_type} boxes input must have shape "
|
|
61
|
+
f"[num_batches, num_boxes, 4], got {boxes_shape}"
|
|
62
|
+
)
|
|
63
|
+
if len(scores_shape) != 3:
|
|
64
|
+
raise ShapeInferenceError(
|
|
65
|
+
f"{node.op_type} scores input must have shape "
|
|
66
|
+
f"[num_batches, num_classes, num_boxes], got {scores_shape}"
|
|
67
|
+
)
|
|
68
|
+
if boxes_shape[0] != scores_shape[0]:
|
|
69
|
+
raise ShapeInferenceError(
|
|
70
|
+
f"{node.op_type} boxes/scores batch dims must match, "
|
|
71
|
+
f"got {boxes_shape[0]} and {scores_shape[0]}"
|
|
72
|
+
)
|
|
73
|
+
if boxes_shape[1] != scores_shape[2]:
|
|
74
|
+
raise ShapeInferenceError(
|
|
75
|
+
f"{node.op_type} boxes num_boxes dim {boxes_shape[1]} "
|
|
76
|
+
f"must match scores num_boxes dim {scores_shape[2]}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
boxes_dtype = value_dtype(graph, boxes, node)
|
|
80
|
+
scores_dtype = value_dtype(graph, scores, node)
|
|
81
|
+
if boxes_dtype != scores_dtype or not boxes_dtype.is_float:
|
|
82
|
+
raise UnsupportedOpError(
|
|
83
|
+
f"{node.op_type} boxes and scores must be the same float dtype, "
|
|
84
|
+
f"got {boxes_dtype.onnx_name} and {scores_dtype.onnx_name}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
max_output_dtype = None
|
|
88
|
+
max_output_shape = None
|
|
89
|
+
if max_output_boxes_per_class is not None:
|
|
90
|
+
max_output_dtype, max_output_shape = _validate_scalar_input(
|
|
91
|
+
graph,
|
|
92
|
+
max_output_boxes_per_class,
|
|
93
|
+
node,
|
|
94
|
+
allowed_dtypes={ScalarType.I32, ScalarType.I64},
|
|
95
|
+
label="max_output_boxes_per_class input",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
iou_threshold_dtype = None
|
|
99
|
+
iou_threshold_shape = None
|
|
100
|
+
if iou_threshold is not None:
|
|
101
|
+
iou_threshold_dtype, iou_threshold_shape = _validate_scalar_input(
|
|
102
|
+
graph,
|
|
103
|
+
iou_threshold,
|
|
104
|
+
node,
|
|
105
|
+
allowed_dtypes={ScalarType.F32, ScalarType.F64},
|
|
106
|
+
label="iou_threshold input",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
score_threshold_dtype = None
|
|
110
|
+
score_threshold_shape = None
|
|
111
|
+
if score_threshold is not None:
|
|
112
|
+
score_threshold_dtype, score_threshold_shape = _validate_scalar_input(
|
|
113
|
+
graph,
|
|
114
|
+
score_threshold,
|
|
115
|
+
node,
|
|
116
|
+
allowed_dtypes={ScalarType.F32, ScalarType.F64},
|
|
117
|
+
label="score_threshold input",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
output_shape = value_shape(graph, output, node)
|
|
121
|
+
if len(output_shape) != 2 or output_shape[1] != 3:
|
|
122
|
+
raise ShapeInferenceError(
|
|
123
|
+
f"{node.op_type} output must have shape [num_selected, 3], "
|
|
124
|
+
f"got {output_shape}"
|
|
125
|
+
)
|
|
126
|
+
output_dtype = value_dtype(graph, output, node)
|
|
127
|
+
if output_dtype != ScalarType.I64:
|
|
128
|
+
raise UnsupportedOpError(
|
|
129
|
+
f"{node.op_type} output dtype must be int64"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
center_point_box = int(node.attrs.get("center_point_box", 0))
|
|
133
|
+
if center_point_box not in {0, 1}:
|
|
134
|
+
raise UnsupportedOpError(
|
|
135
|
+
f"{node.op_type} center_point_box must be 0 or 1, got {center_point_box}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return NonMaxSuppressionOp(
|
|
139
|
+
boxes=boxes,
|
|
140
|
+
scores=scores,
|
|
141
|
+
max_output_boxes_per_class=max_output_boxes_per_class,
|
|
142
|
+
iou_threshold=iou_threshold,
|
|
143
|
+
score_threshold=score_threshold,
|
|
144
|
+
output=output,
|
|
145
|
+
boxes_shape=boxes_shape,
|
|
146
|
+
scores_shape=scores_shape,
|
|
147
|
+
output_shape=output_shape,
|
|
148
|
+
center_point_box=center_point_box,
|
|
149
|
+
boxes_dtype=boxes_dtype,
|
|
150
|
+
output_dtype=output_dtype,
|
|
151
|
+
max_output_dtype=max_output_dtype,
|
|
152
|
+
max_output_shape=max_output_shape,
|
|
153
|
+
iou_threshold_dtype=iou_threshold_dtype,
|
|
154
|
+
iou_threshold_shape=iou_threshold_shape,
|
|
155
|
+
score_threshold_dtype=score_threshold_dtype,
|
|
156
|
+
score_threshold_shape=score_threshold_shape,
|
|
157
|
+
)
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import NonZeroOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import value_dtype, value_shape
|
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import OneHotOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import value_dtype, value_shape
|
emx_onnx_cgen/lowering/pad.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import PadOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import optional_name, value_dtype, value_shape
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..ir.ops import QLinearMatMulOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Node
|
|
10
|
+
from .common import value_dtype as _value_dtype
|
|
11
|
+
from .common import value_shape as _value_shape
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class QLinearMatMulSpec:
|
|
17
|
+
input0_shape: tuple[int, ...]
|
|
18
|
+
input1_shape: tuple[int, ...]
|
|
19
|
+
output_shape: tuple[int, ...]
|
|
20
|
+
batch_shape: tuple[int, ...]
|
|
21
|
+
input0_batch_shape: tuple[int, ...]
|
|
22
|
+
input1_batch_shape: tuple[int, ...]
|
|
23
|
+
m: int
|
|
24
|
+
n: int
|
|
25
|
+
k: int
|
|
26
|
+
left_vector: bool
|
|
27
|
+
right_vector: bool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def resolve_qlinear_matmul_spec(graph: Graph, node: Node) -> QLinearMatMulSpec:
|
|
31
|
+
if len(node.inputs) != 8 or len(node.outputs) != 1:
|
|
32
|
+
raise UnsupportedOpError(
|
|
33
|
+
"QLinearMatMul must have 8 inputs and 1 output"
|
|
34
|
+
)
|
|
35
|
+
input0_shape = _value_shape(graph, node.inputs[0], node)
|
|
36
|
+
input1_shape = _value_shape(graph, node.inputs[3], node)
|
|
37
|
+
if len(input0_shape) < 1 or len(input1_shape) < 1:
|
|
38
|
+
raise UnsupportedOpError(
|
|
39
|
+
"QLinearMatMul inputs must be at least 1D, "
|
|
40
|
+
f"got {input0_shape} x {input1_shape}"
|
|
41
|
+
)
|
|
42
|
+
left_vector = len(input0_shape) == 1
|
|
43
|
+
right_vector = len(input1_shape) == 1
|
|
44
|
+
input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
|
|
45
|
+
input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
|
|
46
|
+
m, k_left = input0_effective[-2], input0_effective[-1]
|
|
47
|
+
k_right, n = input1_effective[-2], input1_effective[-1]
|
|
48
|
+
if k_left != k_right:
|
|
49
|
+
raise ShapeInferenceError(
|
|
50
|
+
"QLinearMatMul inner dimensions must match, "
|
|
51
|
+
f"got {k_left} and {k_right}"
|
|
52
|
+
)
|
|
53
|
+
batch_shape, input0_batch_shape, input1_batch_shape = (
|
|
54
|
+
_broadcast_batch_shapes(
|
|
55
|
+
input0_effective[:-2], input1_effective[:-2], node
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
if left_vector and right_vector:
|
|
59
|
+
output_shape = batch_shape
|
|
60
|
+
elif left_vector:
|
|
61
|
+
output_shape = batch_shape + (n,)
|
|
62
|
+
elif right_vector:
|
|
63
|
+
output_shape = batch_shape + (m,)
|
|
64
|
+
else:
|
|
65
|
+
output_shape = batch_shape + (m, n)
|
|
66
|
+
expected_output_shape = _value_shape(graph, node.outputs[0], node)
|
|
67
|
+
if expected_output_shape != output_shape:
|
|
68
|
+
raise ShapeInferenceError(
|
|
69
|
+
"QLinearMatMul output shape must be "
|
|
70
|
+
f"{output_shape}, got {expected_output_shape}"
|
|
71
|
+
)
|
|
72
|
+
return QLinearMatMulSpec(
|
|
73
|
+
input0_shape=input0_shape,
|
|
74
|
+
input1_shape=input1_shape,
|
|
75
|
+
output_shape=output_shape,
|
|
76
|
+
batch_shape=batch_shape,
|
|
77
|
+
input0_batch_shape=input0_batch_shape,
|
|
78
|
+
input1_batch_shape=input1_batch_shape,
|
|
79
|
+
m=m,
|
|
80
|
+
n=n,
|
|
81
|
+
k=k_left,
|
|
82
|
+
left_vector=left_vector,
|
|
83
|
+
right_vector=right_vector,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _broadcast_batch_shapes(
|
|
88
|
+
left: tuple[int, ...], right: tuple[int, ...], node: Node
|
|
89
|
+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
90
|
+
max_rank = max(len(left), len(right))
|
|
91
|
+
left_padded = (1,) * (max_rank - len(left)) + left
|
|
92
|
+
right_padded = (1,) * (max_rank - len(right)) + right
|
|
93
|
+
broadcast_shape = []
|
|
94
|
+
for left_dim, right_dim in zip(left_padded, right_padded):
|
|
95
|
+
if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
|
|
96
|
+
raise ShapeInferenceError(
|
|
97
|
+
"QLinearMatMul batch dimensions must be broadcastable, "
|
|
98
|
+
f"got {left} x {right}"
|
|
99
|
+
)
|
|
100
|
+
broadcast_shape.append(max(left_dim, right_dim))
|
|
101
|
+
return tuple(broadcast_shape), left_padded, right_padded
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _ensure_scalar_input(
|
|
105
|
+
graph: Graph, name: str, node: Node, label: str
|
|
106
|
+
) -> tuple[int, ...]:
|
|
107
|
+
shape = _value_shape(graph, name, node)
|
|
108
|
+
if shape not in {(), (1,)}:
|
|
109
|
+
raise UnsupportedOpError(
|
|
110
|
+
f"QLinearMatMul {label} must be scalar, got shape {shape}"
|
|
111
|
+
)
|
|
112
|
+
return shape
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _ensure_scale_dtype(dtype: ScalarType, label: str) -> None:
|
|
116
|
+
if not dtype.is_float:
|
|
117
|
+
raise UnsupportedOpError(
|
|
118
|
+
f"QLinearMatMul {label} must be float16/float/double"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@register_lowering("QLinearMatMul")
|
|
123
|
+
def lower_qlinear_matmul(graph: Graph, node: Node) -> QLinearMatMulOp:
|
|
124
|
+
spec = resolve_qlinear_matmul_spec(graph, node)
|
|
125
|
+
input0_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
126
|
+
input1_dtype = _value_dtype(graph, node.inputs[3], node)
|
|
127
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
128
|
+
if input0_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
129
|
+
raise UnsupportedOpError(
|
|
130
|
+
"QLinearMatMul supports uint8/int8 inputs only"
|
|
131
|
+
)
|
|
132
|
+
if input1_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
133
|
+
raise UnsupportedOpError(
|
|
134
|
+
"QLinearMatMul supports uint8/int8 inputs only"
|
|
135
|
+
)
|
|
136
|
+
if output_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
137
|
+
raise UnsupportedOpError(
|
|
138
|
+
"QLinearMatMul supports uint8/int8 outputs only"
|
|
139
|
+
)
|
|
140
|
+
input0_scale_dtype = _value_dtype(graph, node.inputs[1], node)
|
|
141
|
+
input1_scale_dtype = _value_dtype(graph, node.inputs[4], node)
|
|
142
|
+
output_scale_dtype = _value_dtype(graph, node.inputs[6], node)
|
|
143
|
+
_ensure_scale_dtype(input0_scale_dtype, "a_scale")
|
|
144
|
+
_ensure_scale_dtype(input1_scale_dtype, "b_scale")
|
|
145
|
+
_ensure_scale_dtype(output_scale_dtype, "y_scale")
|
|
146
|
+
input0_zero_dtype = _value_dtype(graph, node.inputs[2], node)
|
|
147
|
+
input1_zero_dtype = _value_dtype(graph, node.inputs[5], node)
|
|
148
|
+
output_zero_dtype = _value_dtype(graph, node.inputs[7], node)
|
|
149
|
+
if input0_zero_dtype != input0_dtype:
|
|
150
|
+
raise UnsupportedOpError(
|
|
151
|
+
"QLinearMatMul a_zero_point dtype must match a"
|
|
152
|
+
)
|
|
153
|
+
if input1_zero_dtype != input1_dtype:
|
|
154
|
+
raise UnsupportedOpError(
|
|
155
|
+
"QLinearMatMul b_zero_point dtype must match b"
|
|
156
|
+
)
|
|
157
|
+
if output_zero_dtype != output_dtype:
|
|
158
|
+
raise UnsupportedOpError(
|
|
159
|
+
"QLinearMatMul y_zero_point dtype must match y"
|
|
160
|
+
)
|
|
161
|
+
input0_scale_shape = _ensure_scalar_input(
|
|
162
|
+
graph, node.inputs[1], node, "a_scale"
|
|
163
|
+
)
|
|
164
|
+
input1_scale_shape = _ensure_scalar_input(
|
|
165
|
+
graph, node.inputs[4], node, "b_scale"
|
|
166
|
+
)
|
|
167
|
+
output_scale_shape = _ensure_scalar_input(
|
|
168
|
+
graph, node.inputs[6], node, "y_scale"
|
|
169
|
+
)
|
|
170
|
+
input0_zero_shape = _ensure_scalar_input(
|
|
171
|
+
graph, node.inputs[2], node, "a_zero_point"
|
|
172
|
+
)
|
|
173
|
+
input1_zero_shape = _ensure_scalar_input(
|
|
174
|
+
graph, node.inputs[5], node, "b_zero_point"
|
|
175
|
+
)
|
|
176
|
+
output_zero_shape = _ensure_scalar_input(
|
|
177
|
+
graph, node.inputs[7], node, "y_zero_point"
|
|
178
|
+
)
|
|
179
|
+
return QLinearMatMulOp(
|
|
180
|
+
input0=node.inputs[0],
|
|
181
|
+
input0_scale=node.inputs[1],
|
|
182
|
+
input0_zero_point=node.inputs[2],
|
|
183
|
+
input1=node.inputs[3],
|
|
184
|
+
input1_scale=node.inputs[4],
|
|
185
|
+
input1_zero_point=node.inputs[5],
|
|
186
|
+
output_scale=node.inputs[6],
|
|
187
|
+
output_zero_point=node.inputs[7],
|
|
188
|
+
output=node.outputs[0],
|
|
189
|
+
input0_shape=spec.input0_shape,
|
|
190
|
+
input1_shape=spec.input1_shape,
|
|
191
|
+
output_shape=spec.output_shape,
|
|
192
|
+
batch_shape=spec.batch_shape,
|
|
193
|
+
input0_batch_shape=spec.input0_batch_shape,
|
|
194
|
+
input1_batch_shape=spec.input1_batch_shape,
|
|
195
|
+
m=spec.m,
|
|
196
|
+
n=spec.n,
|
|
197
|
+
k=spec.k,
|
|
198
|
+
left_vector=spec.left_vector,
|
|
199
|
+
right_vector=spec.right_vector,
|
|
200
|
+
input0_dtype=input0_dtype,
|
|
201
|
+
input1_dtype=input1_dtype,
|
|
202
|
+
dtype=output_dtype,
|
|
203
|
+
input0_scale_dtype=input0_scale_dtype,
|
|
204
|
+
input1_scale_dtype=input1_scale_dtype,
|
|
205
|
+
output_scale_dtype=output_scale_dtype,
|
|
206
|
+
input0_scale_shape=input0_scale_shape,
|
|
207
|
+
input1_scale_shape=input1_scale_shape,
|
|
208
|
+
output_scale_shape=output_scale_shape,
|
|
209
|
+
input0_zero_shape=input0_zero_shape,
|
|
210
|
+
input1_zero_shape=input1_zero_shape,
|
|
211
|
+
output_zero_shape=output_zero_shape,
|
|
212
|
+
)
|
|
@@ -10,7 +10,7 @@ from ..ir.model import Graph, Node
|
|
|
10
10
|
from ..validation import normalize_axis
|
|
11
11
|
from .common import optional_name, value_dtype as _value_dtype, value_shape as _value_shape
|
|
12
12
|
from .registry import register_lowering
|
|
13
|
-
from ..
|
|
13
|
+
from ..ir.ops import QuantizeLinearOp
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
@dataclass(frozen=True)
|
emx_onnx_cgen/lowering/range.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
|
|
7
7
|
from shared.scalar_types import ScalarType
|
|
8
8
|
|
|
9
|
-
from ..
|
|
9
|
+
from ..ir.ops import RangeOp
|
|
10
10
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
11
11
|
from ..ir.model import Graph, Initializer, Node
|
|
12
12
|
from ..lowering.common import node_dtype, value_shape
|
emx_onnx_cgen/lowering/reduce.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
|
|
7
7
|
from shared.scalar_types import ScalarType
|
|
8
8
|
|
|
9
|
-
from ..
|
|
9
|
+
from ..ir.ops import ReduceOp, ReshapeOp
|
|
10
10
|
from ..dtypes import scalar_type_from_onnx
|
|
11
11
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
12
12
|
from ..ir.model import Graph, Initializer, Node
|
|
@@ -3,32 +3,51 @@ from __future__ import annotations
|
|
|
3
3
|
from collections.abc import Callable, Mapping
|
|
4
4
|
from typing import TypeVar
|
|
5
5
|
|
|
6
|
+
from ..ir.context import GraphContext
|
|
6
7
|
from ..ir.model import Graph, Node
|
|
8
|
+
from ..ir.op_base import OpBase
|
|
7
9
|
from ..errors import UnsupportedOpError
|
|
8
10
|
|
|
9
11
|
LoweredOp = TypeVar("LoweredOp")
|
|
10
12
|
Handler = TypeVar("Handler")
|
|
11
13
|
|
|
12
|
-
_LOWERING_REGISTRY: dict[str, Callable[[Graph, Node],
|
|
14
|
+
_LOWERING_REGISTRY: dict[str, Callable[[Graph | GraphContext, Node], OpBase]] = {}
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
def register_lowering(
|
|
16
18
|
op_type: str,
|
|
17
19
|
) -> Callable[[Callable[[Graph, Node], LoweredOp]], Callable[[Graph, Node], LoweredOp]]:
|
|
18
20
|
def decorator(
|
|
19
|
-
func: Callable[[Graph, Node], LoweredOp],
|
|
20
|
-
) -> Callable[[Graph, Node], LoweredOp]:
|
|
21
|
+
func: Callable[[Graph | GraphContext, Node], LoweredOp],
|
|
22
|
+
) -> Callable[[Graph | GraphContext, Node], LoweredOp]:
|
|
21
23
|
_LOWERING_REGISTRY[op_type] = func
|
|
22
24
|
return func
|
|
23
25
|
|
|
24
26
|
return decorator
|
|
25
27
|
|
|
26
28
|
|
|
27
|
-
def
|
|
29
|
+
def register_lowering_if_missing(
|
|
30
|
+
op_type: str,
|
|
31
|
+
) -> Callable[[Callable[[Graph | GraphContext, Node], LoweredOp]], Callable[[Graph | GraphContext, Node], LoweredOp]]:
|
|
32
|
+
def decorator(
|
|
33
|
+
func: Callable[[Graph | GraphContext, Node], LoweredOp],
|
|
34
|
+
) -> Callable[[Graph | GraphContext, Node], LoweredOp]:
|
|
35
|
+
if op_type not in _LOWERING_REGISTRY:
|
|
36
|
+
_LOWERING_REGISTRY[op_type] = func
|
|
37
|
+
return func
|
|
38
|
+
|
|
39
|
+
return decorator
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_lowering(
|
|
43
|
+
op_type: str,
|
|
44
|
+
) -> Callable[[Graph | GraphContext, Node], OpBase] | None:
|
|
28
45
|
return _LOWERING_REGISTRY.get(op_type)
|
|
29
46
|
|
|
30
47
|
|
|
31
|
-
def get_lowering_registry() -> Mapping[
|
|
48
|
+
def get_lowering_registry() -> Mapping[
|
|
49
|
+
str, Callable[[Graph | GraphContext, Node], OpBase]
|
|
50
|
+
]:
|
|
32
51
|
return _LOWERING_REGISTRY
|
|
33
52
|
|
|
34
53
|
|