emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +372 -64
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +169 -343
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +406 -11
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +2 -4
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +6 -2
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +7 -8
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +6 -7
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +224 -52
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +6 -2
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +6 -6
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +134 -0
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +6 -6
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +785 -43
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +31 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
- emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +60 -17
- shared/ulp.py +65 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import NonMaxSuppressionOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from ..lowering.common import optional_name, shape_product, value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _validate_scalar_input(
|
|
13
|
+
graph: Graph,
|
|
14
|
+
name: str,
|
|
15
|
+
node: Node,
|
|
16
|
+
*,
|
|
17
|
+
allowed_dtypes: set[ScalarType],
|
|
18
|
+
label: str,
|
|
19
|
+
) -> tuple[ScalarType, tuple[int, ...]]:
|
|
20
|
+
dtype = value_dtype(graph, name, node)
|
|
21
|
+
if dtype not in allowed_dtypes:
|
|
22
|
+
allowed = ", ".join(sorted(d.onnx_name for d in allowed_dtypes))
|
|
23
|
+
raise UnsupportedOpError(
|
|
24
|
+
f"{node.op_type} {label} must be {allowed}, got {dtype.onnx_name}"
|
|
25
|
+
)
|
|
26
|
+
shape = value_shape(graph, name, node)
|
|
27
|
+
if shape not in {(), (1,)}:
|
|
28
|
+
total = shape_product(shape)
|
|
29
|
+
if total != 1:
|
|
30
|
+
raise ShapeInferenceError(
|
|
31
|
+
f"{node.op_type} {label} must be a scalar tensor, got shape {shape}"
|
|
32
|
+
)
|
|
33
|
+
return dtype, shape
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@register_lowering("NonMaxSuppression")
|
|
37
|
+
def lower_non_max_suppression(graph: Graph, node: Node) -> NonMaxSuppressionOp:
|
|
38
|
+
if node.op_type != "NonMaxSuppression":
|
|
39
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
40
|
+
if len(node.outputs) != 1:
|
|
41
|
+
raise UnsupportedOpError(
|
|
42
|
+
f"{node.op_type} must have 1 output, got {len(node.outputs)}"
|
|
43
|
+
)
|
|
44
|
+
if len(node.inputs) < 2 or len(node.inputs) > 5:
|
|
45
|
+
raise UnsupportedOpError(
|
|
46
|
+
f"{node.op_type} must have 2 to 5 inputs, got {len(node.inputs)}"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
boxes = node.inputs[0]
|
|
50
|
+
scores = node.inputs[1]
|
|
51
|
+
max_output_boxes_per_class = optional_name(node.inputs, 2)
|
|
52
|
+
iou_threshold = optional_name(node.inputs, 3)
|
|
53
|
+
score_threshold = optional_name(node.inputs, 4)
|
|
54
|
+
output = node.outputs[0]
|
|
55
|
+
|
|
56
|
+
boxes_shape = value_shape(graph, boxes, node)
|
|
57
|
+
scores_shape = value_shape(graph, scores, node)
|
|
58
|
+
if len(boxes_shape) != 3 or boxes_shape[2] != 4:
|
|
59
|
+
raise ShapeInferenceError(
|
|
60
|
+
f"{node.op_type} boxes input must have shape "
|
|
61
|
+
f"[num_batches, num_boxes, 4], got {boxes_shape}"
|
|
62
|
+
)
|
|
63
|
+
if len(scores_shape) != 3:
|
|
64
|
+
raise ShapeInferenceError(
|
|
65
|
+
f"{node.op_type} scores input must have shape "
|
|
66
|
+
f"[num_batches, num_classes, num_boxes], got {scores_shape}"
|
|
67
|
+
)
|
|
68
|
+
if boxes_shape[0] != scores_shape[0]:
|
|
69
|
+
raise ShapeInferenceError(
|
|
70
|
+
f"{node.op_type} boxes/scores batch dims must match, "
|
|
71
|
+
f"got {boxes_shape[0]} and {scores_shape[0]}"
|
|
72
|
+
)
|
|
73
|
+
if boxes_shape[1] != scores_shape[2]:
|
|
74
|
+
raise ShapeInferenceError(
|
|
75
|
+
f"{node.op_type} boxes num_boxes dim {boxes_shape[1]} "
|
|
76
|
+
f"must match scores num_boxes dim {scores_shape[2]}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
boxes_dtype = value_dtype(graph, boxes, node)
|
|
80
|
+
scores_dtype = value_dtype(graph, scores, node)
|
|
81
|
+
if boxes_dtype != scores_dtype or not boxes_dtype.is_float:
|
|
82
|
+
raise UnsupportedOpError(
|
|
83
|
+
f"{node.op_type} boxes and scores must be the same float dtype, "
|
|
84
|
+
f"got {boxes_dtype.onnx_name} and {scores_dtype.onnx_name}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
max_output_dtype = None
|
|
88
|
+
max_output_shape = None
|
|
89
|
+
if max_output_boxes_per_class is not None:
|
|
90
|
+
max_output_dtype, max_output_shape = _validate_scalar_input(
|
|
91
|
+
graph,
|
|
92
|
+
max_output_boxes_per_class,
|
|
93
|
+
node,
|
|
94
|
+
allowed_dtypes={ScalarType.I32, ScalarType.I64},
|
|
95
|
+
label="max_output_boxes_per_class input",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
iou_threshold_dtype = None
|
|
99
|
+
iou_threshold_shape = None
|
|
100
|
+
if iou_threshold is not None:
|
|
101
|
+
iou_threshold_dtype, iou_threshold_shape = _validate_scalar_input(
|
|
102
|
+
graph,
|
|
103
|
+
iou_threshold,
|
|
104
|
+
node,
|
|
105
|
+
allowed_dtypes={ScalarType.F32, ScalarType.F64},
|
|
106
|
+
label="iou_threshold input",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
score_threshold_dtype = None
|
|
110
|
+
score_threshold_shape = None
|
|
111
|
+
if score_threshold is not None:
|
|
112
|
+
score_threshold_dtype, score_threshold_shape = _validate_scalar_input(
|
|
113
|
+
graph,
|
|
114
|
+
score_threshold,
|
|
115
|
+
node,
|
|
116
|
+
allowed_dtypes={ScalarType.F32, ScalarType.F64},
|
|
117
|
+
label="score_threshold input",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
output_shape = value_shape(graph, output, node)
|
|
121
|
+
if len(output_shape) != 2 or output_shape[1] != 3:
|
|
122
|
+
raise ShapeInferenceError(
|
|
123
|
+
f"{node.op_type} output must have shape [num_selected, 3], "
|
|
124
|
+
f"got {output_shape}"
|
|
125
|
+
)
|
|
126
|
+
output_dtype = value_dtype(graph, output, node)
|
|
127
|
+
if output_dtype != ScalarType.I64:
|
|
128
|
+
raise UnsupportedOpError(
|
|
129
|
+
f"{node.op_type} output dtype must be int64"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
center_point_box = int(node.attrs.get("center_point_box", 0))
|
|
133
|
+
if center_point_box not in {0, 1}:
|
|
134
|
+
raise UnsupportedOpError(
|
|
135
|
+
f"{node.op_type} center_point_box must be 0 or 1, got {center_point_box}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return NonMaxSuppressionOp(
|
|
139
|
+
boxes=boxes,
|
|
140
|
+
scores=scores,
|
|
141
|
+
max_output_boxes_per_class=max_output_boxes_per_class,
|
|
142
|
+
iou_threshold=iou_threshold,
|
|
143
|
+
score_threshold=score_threshold,
|
|
144
|
+
output=output,
|
|
145
|
+
boxes_shape=boxes_shape,
|
|
146
|
+
scores_shape=scores_shape,
|
|
147
|
+
output_shape=output_shape,
|
|
148
|
+
center_point_box=center_point_box,
|
|
149
|
+
boxes_dtype=boxes_dtype,
|
|
150
|
+
output_dtype=output_dtype,
|
|
151
|
+
max_output_dtype=max_output_dtype,
|
|
152
|
+
max_output_shape=max_output_shape,
|
|
153
|
+
iou_threshold_dtype=iou_threshold_dtype,
|
|
154
|
+
iou_threshold_shape=iou_threshold_shape,
|
|
155
|
+
score_threshold_dtype=score_threshold_dtype,
|
|
156
|
+
score_threshold_shape=score_threshold_shape,
|
|
157
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import NonZeroOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@register_lowering("NonZero")
|
|
13
|
+
def lower_nonzero(graph: Graph, node: Node) -> NonZeroOp:
|
|
14
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
15
|
+
raise UnsupportedOpError("NonZero must have 1 input and 1 output")
|
|
16
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
17
|
+
if len(input_shape) == 0:
|
|
18
|
+
raise UnsupportedOpError("NonZero does not support scalar inputs")
|
|
19
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
20
|
+
if len(output_shape) != 2:
|
|
21
|
+
raise ShapeInferenceError("NonZero output must be 2D")
|
|
22
|
+
if output_shape[0] != len(input_shape):
|
|
23
|
+
raise ShapeInferenceError(
|
|
24
|
+
"NonZero output shape must be "
|
|
25
|
+
f"({len(input_shape)}, N), got {output_shape}"
|
|
26
|
+
)
|
|
27
|
+
if output_shape[0] < 0 or output_shape[1] < 0:
|
|
28
|
+
raise ShapeInferenceError(
|
|
29
|
+
"NonZero output shape must be non-negative"
|
|
30
|
+
)
|
|
31
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
32
|
+
if output_dtype != ScalarType.I64:
|
|
33
|
+
raise UnsupportedOpError("NonZero output dtype must be int64")
|
|
34
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
35
|
+
return NonZeroOp(
|
|
36
|
+
input0=node.inputs[0],
|
|
37
|
+
output=node.outputs[0],
|
|
38
|
+
input_shape=input_shape,
|
|
39
|
+
output_shape=output_shape,
|
|
40
|
+
dtype=output_dtype,
|
|
41
|
+
input_dtype=input_dtype,
|
|
42
|
+
)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..ir.ops import OneHotOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..lowering.common import value_dtype, value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
15
|
+
for initializer in graph.initializers:
|
|
16
|
+
if initializer.name == name:
|
|
17
|
+
return initializer
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _read_scalar_initializer(
|
|
22
|
+
graph: Graph, name: str, node: Node, label: str
|
|
23
|
+
) -> int | None:
|
|
24
|
+
initializer = _find_initializer(graph, name)
|
|
25
|
+
if initializer is None:
|
|
26
|
+
return None
|
|
27
|
+
data = np.array(initializer.data)
|
|
28
|
+
if data.size != 1:
|
|
29
|
+
raise UnsupportedOpError(
|
|
30
|
+
f"{node.op_type} {label} input must be a scalar"
|
|
31
|
+
)
|
|
32
|
+
return int(data.reshape(-1)[0])
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
|
|
36
|
+
return shape == () or shape == (1,)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _normalize_onehot_axis(axis: int, rank: int, node: Node) -> int:
|
|
40
|
+
if axis < 0:
|
|
41
|
+
axis += rank + 1
|
|
42
|
+
if axis < 0 or axis > rank:
|
|
43
|
+
raise ShapeInferenceError(
|
|
44
|
+
f"{node.op_type} axis {axis} is out of range for rank {rank}"
|
|
45
|
+
)
|
|
46
|
+
return axis
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@register_lowering("OneHot")
|
|
50
|
+
def lower_onehot(graph: Graph, node: Node) -> OneHotOp:
|
|
51
|
+
if len(node.inputs) != 3 or len(node.outputs) != 1:
|
|
52
|
+
raise UnsupportedOpError("OneHot must have 3 inputs and 1 output")
|
|
53
|
+
indices_name, depth_name, values_name = node.inputs
|
|
54
|
+
indices_shape = value_shape(graph, indices_name, node)
|
|
55
|
+
depth_shape = value_shape(graph, depth_name, node)
|
|
56
|
+
values_shape = value_shape(graph, values_name, node)
|
|
57
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
58
|
+
if not _is_scalar_shape(depth_shape):
|
|
59
|
+
raise UnsupportedOpError("OneHot depth input must be a scalar")
|
|
60
|
+
if len(values_shape) != 1 or values_shape[0] != 2:
|
|
61
|
+
raise UnsupportedOpError(
|
|
62
|
+
"OneHot values input must be a 1D tensor of size 2"
|
|
63
|
+
)
|
|
64
|
+
output_rank = len(indices_shape) + 1
|
|
65
|
+
if len(output_shape) != output_rank:
|
|
66
|
+
raise ShapeInferenceError(
|
|
67
|
+
f"OneHot output rank must be {output_rank}, got {len(output_shape)}"
|
|
68
|
+
)
|
|
69
|
+
axis = _normalize_onehot_axis(
|
|
70
|
+
int(node.attrs.get("axis", -1)), len(indices_shape), node
|
|
71
|
+
)
|
|
72
|
+
depth_value = _read_scalar_initializer(graph, depth_name, node, "depth")
|
|
73
|
+
if depth_value is not None:
|
|
74
|
+
if depth_value < 0:
|
|
75
|
+
raise ShapeInferenceError("OneHot depth must be non-negative")
|
|
76
|
+
if output_shape[axis] != depth_value:
|
|
77
|
+
raise ShapeInferenceError(
|
|
78
|
+
"OneHot output depth must be "
|
|
79
|
+
f"{depth_value}, got {output_shape[axis]}"
|
|
80
|
+
)
|
|
81
|
+
depth_dim = depth_value
|
|
82
|
+
else:
|
|
83
|
+
depth_dim = output_shape[axis]
|
|
84
|
+
if depth_dim < 0:
|
|
85
|
+
raise ShapeInferenceError("OneHot output depth must be non-negative")
|
|
86
|
+
expected_output_shape = (
|
|
87
|
+
indices_shape[:axis] + (depth_dim,) + indices_shape[axis:]
|
|
88
|
+
)
|
|
89
|
+
if output_shape != expected_output_shape:
|
|
90
|
+
raise ShapeInferenceError(
|
|
91
|
+
"OneHot output shape must be "
|
|
92
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
93
|
+
)
|
|
94
|
+
indices_dtype = value_dtype(graph, indices_name, node)
|
|
95
|
+
depth_dtype = value_dtype(graph, depth_name, node)
|
|
96
|
+
values_dtype = value_dtype(graph, values_name, node)
|
|
97
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
98
|
+
if indices_dtype.is_bool:
|
|
99
|
+
raise UnsupportedOpError("OneHot indices must be numeric")
|
|
100
|
+
if depth_dtype.is_bool:
|
|
101
|
+
raise UnsupportedOpError("OneHot depth must be numeric")
|
|
102
|
+
if values_dtype != output_dtype:
|
|
103
|
+
raise UnsupportedOpError(
|
|
104
|
+
"OneHot values dtype must match output dtype, "
|
|
105
|
+
f"got {values_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
106
|
+
)
|
|
107
|
+
return OneHotOp(
|
|
108
|
+
indices=indices_name,
|
|
109
|
+
depth=depth_name,
|
|
110
|
+
values=values_name,
|
|
111
|
+
output=node.outputs[0],
|
|
112
|
+
axis=axis,
|
|
113
|
+
indices_shape=indices_shape,
|
|
114
|
+
values_shape=values_shape,
|
|
115
|
+
output_shape=output_shape,
|
|
116
|
+
depth_dim=depth_dim,
|
|
117
|
+
dtype=values_dtype,
|
|
118
|
+
indices_dtype=indices_dtype,
|
|
119
|
+
depth_dtype=depth_dtype,
|
|
120
|
+
)
|
emx_onnx_cgen/lowering/pad.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import PadOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import optional_name, value_dtype, value_shape
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..ir.ops import QLinearMatMulOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Node
|
|
10
|
+
from .common import value_dtype as _value_dtype
|
|
11
|
+
from .common import value_shape as _value_shape
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class QLinearMatMulSpec:
|
|
17
|
+
input0_shape: tuple[int, ...]
|
|
18
|
+
input1_shape: tuple[int, ...]
|
|
19
|
+
output_shape: tuple[int, ...]
|
|
20
|
+
batch_shape: tuple[int, ...]
|
|
21
|
+
input0_batch_shape: tuple[int, ...]
|
|
22
|
+
input1_batch_shape: tuple[int, ...]
|
|
23
|
+
m: int
|
|
24
|
+
n: int
|
|
25
|
+
k: int
|
|
26
|
+
left_vector: bool
|
|
27
|
+
right_vector: bool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def resolve_qlinear_matmul_spec(graph: Graph, node: Node) -> QLinearMatMulSpec:
|
|
31
|
+
if len(node.inputs) != 8 or len(node.outputs) != 1:
|
|
32
|
+
raise UnsupportedOpError(
|
|
33
|
+
"QLinearMatMul must have 8 inputs and 1 output"
|
|
34
|
+
)
|
|
35
|
+
input0_shape = _value_shape(graph, node.inputs[0], node)
|
|
36
|
+
input1_shape = _value_shape(graph, node.inputs[3], node)
|
|
37
|
+
if len(input0_shape) < 1 or len(input1_shape) < 1:
|
|
38
|
+
raise UnsupportedOpError(
|
|
39
|
+
"QLinearMatMul inputs must be at least 1D, "
|
|
40
|
+
f"got {input0_shape} x {input1_shape}"
|
|
41
|
+
)
|
|
42
|
+
left_vector = len(input0_shape) == 1
|
|
43
|
+
right_vector = len(input1_shape) == 1
|
|
44
|
+
input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
|
|
45
|
+
input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
|
|
46
|
+
m, k_left = input0_effective[-2], input0_effective[-1]
|
|
47
|
+
k_right, n = input1_effective[-2], input1_effective[-1]
|
|
48
|
+
if k_left != k_right:
|
|
49
|
+
raise ShapeInferenceError(
|
|
50
|
+
"QLinearMatMul inner dimensions must match, "
|
|
51
|
+
f"got {k_left} and {k_right}"
|
|
52
|
+
)
|
|
53
|
+
batch_shape, input0_batch_shape, input1_batch_shape = (
|
|
54
|
+
_broadcast_batch_shapes(
|
|
55
|
+
input0_effective[:-2], input1_effective[:-2], node
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
if left_vector and right_vector:
|
|
59
|
+
output_shape = batch_shape
|
|
60
|
+
elif left_vector:
|
|
61
|
+
output_shape = batch_shape + (n,)
|
|
62
|
+
elif right_vector:
|
|
63
|
+
output_shape = batch_shape + (m,)
|
|
64
|
+
else:
|
|
65
|
+
output_shape = batch_shape + (m, n)
|
|
66
|
+
expected_output_shape = _value_shape(graph, node.outputs[0], node)
|
|
67
|
+
if expected_output_shape != output_shape:
|
|
68
|
+
raise ShapeInferenceError(
|
|
69
|
+
"QLinearMatMul output shape must be "
|
|
70
|
+
f"{output_shape}, got {expected_output_shape}"
|
|
71
|
+
)
|
|
72
|
+
return QLinearMatMulSpec(
|
|
73
|
+
input0_shape=input0_shape,
|
|
74
|
+
input1_shape=input1_shape,
|
|
75
|
+
output_shape=output_shape,
|
|
76
|
+
batch_shape=batch_shape,
|
|
77
|
+
input0_batch_shape=input0_batch_shape,
|
|
78
|
+
input1_batch_shape=input1_batch_shape,
|
|
79
|
+
m=m,
|
|
80
|
+
n=n,
|
|
81
|
+
k=k_left,
|
|
82
|
+
left_vector=left_vector,
|
|
83
|
+
right_vector=right_vector,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _broadcast_batch_shapes(
|
|
88
|
+
left: tuple[int, ...], right: tuple[int, ...], node: Node
|
|
89
|
+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
90
|
+
max_rank = max(len(left), len(right))
|
|
91
|
+
left_padded = (1,) * (max_rank - len(left)) + left
|
|
92
|
+
right_padded = (1,) * (max_rank - len(right)) + right
|
|
93
|
+
broadcast_shape = []
|
|
94
|
+
for left_dim, right_dim in zip(left_padded, right_padded):
|
|
95
|
+
if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
|
|
96
|
+
raise ShapeInferenceError(
|
|
97
|
+
"QLinearMatMul batch dimensions must be broadcastable, "
|
|
98
|
+
f"got {left} x {right}"
|
|
99
|
+
)
|
|
100
|
+
broadcast_shape.append(max(left_dim, right_dim))
|
|
101
|
+
return tuple(broadcast_shape), left_padded, right_padded
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _ensure_scalar_input(
|
|
105
|
+
graph: Graph, name: str, node: Node, label: str
|
|
106
|
+
) -> tuple[int, ...]:
|
|
107
|
+
shape = _value_shape(graph, name, node)
|
|
108
|
+
if shape not in {(), (1,)}:
|
|
109
|
+
raise UnsupportedOpError(
|
|
110
|
+
f"QLinearMatMul {label} must be scalar, got shape {shape}"
|
|
111
|
+
)
|
|
112
|
+
return shape
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _ensure_scale_dtype(dtype: ScalarType, label: str) -> None:
|
|
116
|
+
if not dtype.is_float:
|
|
117
|
+
raise UnsupportedOpError(
|
|
118
|
+
f"QLinearMatMul {label} must be float16/float/double"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@register_lowering("QLinearMatMul")
|
|
123
|
+
def lower_qlinear_matmul(graph: Graph, node: Node) -> QLinearMatMulOp:
|
|
124
|
+
spec = resolve_qlinear_matmul_spec(graph, node)
|
|
125
|
+
input0_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
126
|
+
input1_dtype = _value_dtype(graph, node.inputs[3], node)
|
|
127
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
128
|
+
if input0_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
129
|
+
raise UnsupportedOpError(
|
|
130
|
+
"QLinearMatMul supports uint8/int8 inputs only"
|
|
131
|
+
)
|
|
132
|
+
if input1_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
133
|
+
raise UnsupportedOpError(
|
|
134
|
+
"QLinearMatMul supports uint8/int8 inputs only"
|
|
135
|
+
)
|
|
136
|
+
if output_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
137
|
+
raise UnsupportedOpError(
|
|
138
|
+
"QLinearMatMul supports uint8/int8 outputs only"
|
|
139
|
+
)
|
|
140
|
+
input0_scale_dtype = _value_dtype(graph, node.inputs[1], node)
|
|
141
|
+
input1_scale_dtype = _value_dtype(graph, node.inputs[4], node)
|
|
142
|
+
output_scale_dtype = _value_dtype(graph, node.inputs[6], node)
|
|
143
|
+
_ensure_scale_dtype(input0_scale_dtype, "a_scale")
|
|
144
|
+
_ensure_scale_dtype(input1_scale_dtype, "b_scale")
|
|
145
|
+
_ensure_scale_dtype(output_scale_dtype, "y_scale")
|
|
146
|
+
input0_zero_dtype = _value_dtype(graph, node.inputs[2], node)
|
|
147
|
+
input1_zero_dtype = _value_dtype(graph, node.inputs[5], node)
|
|
148
|
+
output_zero_dtype = _value_dtype(graph, node.inputs[7], node)
|
|
149
|
+
if input0_zero_dtype != input0_dtype:
|
|
150
|
+
raise UnsupportedOpError(
|
|
151
|
+
"QLinearMatMul a_zero_point dtype must match a"
|
|
152
|
+
)
|
|
153
|
+
if input1_zero_dtype != input1_dtype:
|
|
154
|
+
raise UnsupportedOpError(
|
|
155
|
+
"QLinearMatMul b_zero_point dtype must match b"
|
|
156
|
+
)
|
|
157
|
+
if output_zero_dtype != output_dtype:
|
|
158
|
+
raise UnsupportedOpError(
|
|
159
|
+
"QLinearMatMul y_zero_point dtype must match y"
|
|
160
|
+
)
|
|
161
|
+
input0_scale_shape = _ensure_scalar_input(
|
|
162
|
+
graph, node.inputs[1], node, "a_scale"
|
|
163
|
+
)
|
|
164
|
+
input1_scale_shape = _ensure_scalar_input(
|
|
165
|
+
graph, node.inputs[4], node, "b_scale"
|
|
166
|
+
)
|
|
167
|
+
output_scale_shape = _ensure_scalar_input(
|
|
168
|
+
graph, node.inputs[6], node, "y_scale"
|
|
169
|
+
)
|
|
170
|
+
input0_zero_shape = _ensure_scalar_input(
|
|
171
|
+
graph, node.inputs[2], node, "a_zero_point"
|
|
172
|
+
)
|
|
173
|
+
input1_zero_shape = _ensure_scalar_input(
|
|
174
|
+
graph, node.inputs[5], node, "b_zero_point"
|
|
175
|
+
)
|
|
176
|
+
output_zero_shape = _ensure_scalar_input(
|
|
177
|
+
graph, node.inputs[7], node, "y_zero_point"
|
|
178
|
+
)
|
|
179
|
+
return QLinearMatMulOp(
|
|
180
|
+
input0=node.inputs[0],
|
|
181
|
+
input0_scale=node.inputs[1],
|
|
182
|
+
input0_zero_point=node.inputs[2],
|
|
183
|
+
input1=node.inputs[3],
|
|
184
|
+
input1_scale=node.inputs[4],
|
|
185
|
+
input1_zero_point=node.inputs[5],
|
|
186
|
+
output_scale=node.inputs[6],
|
|
187
|
+
output_zero_point=node.inputs[7],
|
|
188
|
+
output=node.outputs[0],
|
|
189
|
+
input0_shape=spec.input0_shape,
|
|
190
|
+
input1_shape=spec.input1_shape,
|
|
191
|
+
output_shape=spec.output_shape,
|
|
192
|
+
batch_shape=spec.batch_shape,
|
|
193
|
+
input0_batch_shape=spec.input0_batch_shape,
|
|
194
|
+
input1_batch_shape=spec.input1_batch_shape,
|
|
195
|
+
m=spec.m,
|
|
196
|
+
n=spec.n,
|
|
197
|
+
k=spec.k,
|
|
198
|
+
left_vector=spec.left_vector,
|
|
199
|
+
right_vector=spec.right_vector,
|
|
200
|
+
input0_dtype=input0_dtype,
|
|
201
|
+
input1_dtype=input1_dtype,
|
|
202
|
+
dtype=output_dtype,
|
|
203
|
+
input0_scale_dtype=input0_scale_dtype,
|
|
204
|
+
input1_scale_dtype=input1_scale_dtype,
|
|
205
|
+
output_scale_dtype=output_scale_dtype,
|
|
206
|
+
input0_scale_shape=input0_scale_shape,
|
|
207
|
+
input1_scale_shape=input1_scale_shape,
|
|
208
|
+
output_scale_shape=output_scale_shape,
|
|
209
|
+
input0_zero_shape=input0_zero_shape,
|
|
210
|
+
input1_zero_shape=input1_zero_shape,
|
|
211
|
+
output_zero_shape=output_zero_shape,
|
|
212
|
+
)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..dtypes import scalar_type_from_onnx
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Node
|
|
10
|
+
from ..validation import normalize_axis
|
|
11
|
+
from .common import optional_name, value_dtype as _value_dtype, value_shape as _value_shape
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
from ..ir.ops import QuantizeLinearOp
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class QuantizeSpec:
|
|
18
|
+
input_shape: tuple[int, ...]
|
|
19
|
+
scale_shape: tuple[int, ...]
|
|
20
|
+
axis: int | None
|
|
21
|
+
output_dtype: ScalarType
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _resolve_output_dtype(
|
|
25
|
+
graph: Graph, node: Node, zero_point_name: str | None
|
|
26
|
+
) -> ScalarType:
|
|
27
|
+
output_attr = int(node.attrs.get("output_dtype", 0))
|
|
28
|
+
if output_attr:
|
|
29
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
30
|
+
attr_dtype = scalar_type_from_onnx(output_attr)
|
|
31
|
+
if attr_dtype is None:
|
|
32
|
+
raise UnsupportedOpError(
|
|
33
|
+
"QuantizeLinear output_dtype must map to a supported scalar type"
|
|
34
|
+
)
|
|
35
|
+
if output_dtype != attr_dtype:
|
|
36
|
+
raise UnsupportedOpError(
|
|
37
|
+
"QuantizeLinear output_dtype must match output tensor dtype"
|
|
38
|
+
)
|
|
39
|
+
return output_dtype
|
|
40
|
+
if zero_point_name is None:
|
|
41
|
+
return ScalarType.U8
|
|
42
|
+
return _value_dtype(graph, zero_point_name, node)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def resolve_quantize_spec(graph: Graph, node: Node) -> QuantizeSpec:
|
|
46
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
47
|
+
raise UnsupportedOpError(
|
|
48
|
+
"QuantizeLinear must have 2 or 3 inputs and 1 output"
|
|
49
|
+
)
|
|
50
|
+
supported_attrs = {"axis", "block_size", "output_dtype", "precision", "saturate"}
|
|
51
|
+
if set(node.attrs) - supported_attrs:
|
|
52
|
+
raise UnsupportedOpError("QuantizeLinear has unsupported attributes")
|
|
53
|
+
block_size = int(node.attrs.get("block_size", 0))
|
|
54
|
+
if block_size != 0:
|
|
55
|
+
raise UnsupportedOpError("QuantizeLinear block_size is not supported")
|
|
56
|
+
precision = int(node.attrs.get("precision", 0))
|
|
57
|
+
if precision != 0:
|
|
58
|
+
raise UnsupportedOpError("QuantizeLinear precision is not supported")
|
|
59
|
+
saturate = int(node.attrs.get("saturate", 1))
|
|
60
|
+
if saturate != 1:
|
|
61
|
+
raise UnsupportedOpError("QuantizeLinear saturate must be 1")
|
|
62
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
63
|
+
scale_shape = _value_shape(graph, node.inputs[1], node)
|
|
64
|
+
zero_point_name = optional_name(node.inputs, 2)
|
|
65
|
+
output_dtype = _resolve_output_dtype(graph, node, zero_point_name)
|
|
66
|
+
if output_dtype not in {
|
|
67
|
+
ScalarType.U8,
|
|
68
|
+
ScalarType.I8,
|
|
69
|
+
ScalarType.U16,
|
|
70
|
+
ScalarType.I16,
|
|
71
|
+
}:
|
|
72
|
+
raise UnsupportedOpError(
|
|
73
|
+
"QuantizeLinear supports int8/uint8/int16/uint16 outputs only"
|
|
74
|
+
)
|
|
75
|
+
if zero_point_name is not None:
|
|
76
|
+
zero_point_dtype = _value_dtype(graph, zero_point_name, node)
|
|
77
|
+
if zero_point_dtype != output_dtype:
|
|
78
|
+
raise UnsupportedOpError(
|
|
79
|
+
"QuantizeLinear zero_point dtype must match output dtype"
|
|
80
|
+
)
|
|
81
|
+
zero_point_shape = _value_shape(graph, zero_point_name, node)
|
|
82
|
+
if zero_point_shape != scale_shape:
|
|
83
|
+
raise ShapeInferenceError(
|
|
84
|
+
"QuantizeLinear zero_point shape must match scale shape"
|
|
85
|
+
)
|
|
86
|
+
if scale_shape not in {(), (1,)}:
|
|
87
|
+
if len(scale_shape) != 1:
|
|
88
|
+
raise UnsupportedOpError(
|
|
89
|
+
"QuantizeLinear supports per-tensor and per-axis scales only"
|
|
90
|
+
)
|
|
91
|
+
axis = int(node.attrs.get("axis", 1))
|
|
92
|
+
axis = normalize_axis(axis, input_shape, node)
|
|
93
|
+
if scale_shape[0] != input_shape[axis]:
|
|
94
|
+
raise ShapeInferenceError(
|
|
95
|
+
"QuantizeLinear scale length must match input axis size"
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
axis = None
|
|
99
|
+
return QuantizeSpec(
|
|
100
|
+
input_shape=input_shape,
|
|
101
|
+
scale_shape=scale_shape,
|
|
102
|
+
axis=axis,
|
|
103
|
+
output_dtype=output_dtype,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@register_lowering("QuantizeLinear")
|
|
108
|
+
def lower_quantize_linear(graph: Graph, node: Node) -> QuantizeLinearOp:
|
|
109
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
110
|
+
scale_dtype = _value_dtype(graph, node.inputs[1], node)
|
|
111
|
+
if not op_dtype.is_float or not scale_dtype.is_float:
|
|
112
|
+
raise UnsupportedOpError(
|
|
113
|
+
"QuantizeLinear supports float16/float/double inputs only"
|
|
114
|
+
)
|
|
115
|
+
spec = resolve_quantize_spec(graph, node)
|
|
116
|
+
return QuantizeLinearOp(
|
|
117
|
+
input0=node.inputs[0],
|
|
118
|
+
scale=node.inputs[1],
|
|
119
|
+
zero_point=optional_name(node.inputs, 2),
|
|
120
|
+
output=node.outputs[0],
|
|
121
|
+
input_shape=spec.input_shape,
|
|
122
|
+
axis=spec.axis,
|
|
123
|
+
dtype=spec.output_dtype,
|
|
124
|
+
input_dtype=op_dtype,
|
|
125
|
+
scale_dtype=scale_dtype,
|
|
126
|
+
)
|
emx_onnx_cgen/lowering/range.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
|
|
7
7
|
from shared.scalar_types import ScalarType
|
|
8
8
|
|
|
9
|
-
from ..
|
|
9
|
+
from ..ir.ops import RangeOp
|
|
10
10
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
11
11
|
from ..ir.model import Graph, Initializer, Node
|
|
12
12
|
from ..lowering.common import node_dtype, value_shape
|