emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +340 -59
- emx_onnx_cgen/codegen/c_emitter.py +2369 -111
- emx_onnx_cgen/compiler.py +188 -5
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/lowering/common.py +379 -2
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/gather_elements.py +1 -3
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +6 -5
- emx_onnx_cgen/lowering/logsoftmax.py +5 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/matmul.py +6 -7
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/reduce.py +5 -6
- emx_onnx_cgen/lowering/reshape.py +223 -51
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/softmax.py +5 -1
- emx_onnx_cgen/lowering/squeeze.py +5 -5
- emx_onnx_cgen/lowering/topk.py +116 -0
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +5 -5
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +460 -42
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +61 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
- shared/scalar_functions.py +49 -17
- shared/ulp.py +48 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..codegen.c_emitter import OneHotOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..lowering.common import value_dtype, value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
15
|
+
for initializer in graph.initializers:
|
|
16
|
+
if initializer.name == name:
|
|
17
|
+
return initializer
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _read_scalar_initializer(
|
|
22
|
+
graph: Graph, name: str, node: Node, label: str
|
|
23
|
+
) -> int | None:
|
|
24
|
+
initializer = _find_initializer(graph, name)
|
|
25
|
+
if initializer is None:
|
|
26
|
+
return None
|
|
27
|
+
data = np.array(initializer.data)
|
|
28
|
+
if data.size != 1:
|
|
29
|
+
raise UnsupportedOpError(
|
|
30
|
+
f"{node.op_type} {label} input must be a scalar"
|
|
31
|
+
)
|
|
32
|
+
return int(data.reshape(-1)[0])
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
|
|
36
|
+
return shape == () or shape == (1,)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _normalize_onehot_axis(axis: int, rank: int, node: Node) -> int:
|
|
40
|
+
if axis < 0:
|
|
41
|
+
axis += rank + 1
|
|
42
|
+
if axis < 0 or axis > rank:
|
|
43
|
+
raise ShapeInferenceError(
|
|
44
|
+
f"{node.op_type} axis {axis} is out of range for rank {rank}"
|
|
45
|
+
)
|
|
46
|
+
return axis
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@register_lowering("OneHot")
|
|
50
|
+
def lower_onehot(graph: Graph, node: Node) -> OneHotOp:
|
|
51
|
+
if len(node.inputs) != 3 or len(node.outputs) != 1:
|
|
52
|
+
raise UnsupportedOpError("OneHot must have 3 inputs and 1 output")
|
|
53
|
+
indices_name, depth_name, values_name = node.inputs
|
|
54
|
+
indices_shape = value_shape(graph, indices_name, node)
|
|
55
|
+
depth_shape = value_shape(graph, depth_name, node)
|
|
56
|
+
values_shape = value_shape(graph, values_name, node)
|
|
57
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
58
|
+
if not _is_scalar_shape(depth_shape):
|
|
59
|
+
raise UnsupportedOpError("OneHot depth input must be a scalar")
|
|
60
|
+
if len(values_shape) != 1 or values_shape[0] != 2:
|
|
61
|
+
raise UnsupportedOpError(
|
|
62
|
+
"OneHot values input must be a 1D tensor of size 2"
|
|
63
|
+
)
|
|
64
|
+
output_rank = len(indices_shape) + 1
|
|
65
|
+
if len(output_shape) != output_rank:
|
|
66
|
+
raise ShapeInferenceError(
|
|
67
|
+
f"OneHot output rank must be {output_rank}, got {len(output_shape)}"
|
|
68
|
+
)
|
|
69
|
+
axis = _normalize_onehot_axis(
|
|
70
|
+
int(node.attrs.get("axis", -1)), len(indices_shape), node
|
|
71
|
+
)
|
|
72
|
+
depth_value = _read_scalar_initializer(graph, depth_name, node, "depth")
|
|
73
|
+
if depth_value is not None:
|
|
74
|
+
if depth_value < 0:
|
|
75
|
+
raise ShapeInferenceError("OneHot depth must be non-negative")
|
|
76
|
+
if output_shape[axis] != depth_value:
|
|
77
|
+
raise ShapeInferenceError(
|
|
78
|
+
"OneHot output depth must be "
|
|
79
|
+
f"{depth_value}, got {output_shape[axis]}"
|
|
80
|
+
)
|
|
81
|
+
depth_dim = depth_value
|
|
82
|
+
else:
|
|
83
|
+
depth_dim = output_shape[axis]
|
|
84
|
+
if depth_dim < 0:
|
|
85
|
+
raise ShapeInferenceError("OneHot output depth must be non-negative")
|
|
86
|
+
expected_output_shape = (
|
|
87
|
+
indices_shape[:axis] + (depth_dim,) + indices_shape[axis:]
|
|
88
|
+
)
|
|
89
|
+
if output_shape != expected_output_shape:
|
|
90
|
+
raise ShapeInferenceError(
|
|
91
|
+
"OneHot output shape must be "
|
|
92
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
93
|
+
)
|
|
94
|
+
indices_dtype = value_dtype(graph, indices_name, node)
|
|
95
|
+
depth_dtype = value_dtype(graph, depth_name, node)
|
|
96
|
+
values_dtype = value_dtype(graph, values_name, node)
|
|
97
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
98
|
+
if indices_dtype.is_bool:
|
|
99
|
+
raise UnsupportedOpError("OneHot indices must be numeric")
|
|
100
|
+
if depth_dtype.is_bool:
|
|
101
|
+
raise UnsupportedOpError("OneHot depth must be numeric")
|
|
102
|
+
if values_dtype != output_dtype:
|
|
103
|
+
raise UnsupportedOpError(
|
|
104
|
+
"OneHot values dtype must match output dtype, "
|
|
105
|
+
f"got {values_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
106
|
+
)
|
|
107
|
+
return OneHotOp(
|
|
108
|
+
indices=indices_name,
|
|
109
|
+
depth=depth_name,
|
|
110
|
+
values=values_name,
|
|
111
|
+
output=node.outputs[0],
|
|
112
|
+
axis=axis,
|
|
113
|
+
indices_shape=indices_shape,
|
|
114
|
+
values_shape=values_shape,
|
|
115
|
+
output_shape=output_shape,
|
|
116
|
+
depth_dim=depth_dim,
|
|
117
|
+
dtype=values_dtype,
|
|
118
|
+
indices_dtype=indices_dtype,
|
|
119
|
+
depth_dtype=depth_dtype,
|
|
120
|
+
)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..dtypes import scalar_type_from_onnx
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Node
|
|
10
|
+
from ..validation import normalize_axis
|
|
11
|
+
from .common import optional_name, value_dtype as _value_dtype, value_shape as _value_shape
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
from ..codegen.c_emitter import QuantizeLinearOp
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class QuantizeSpec:
|
|
18
|
+
input_shape: tuple[int, ...]
|
|
19
|
+
scale_shape: tuple[int, ...]
|
|
20
|
+
axis: int | None
|
|
21
|
+
output_dtype: ScalarType
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _resolve_output_dtype(
|
|
25
|
+
graph: Graph, node: Node, zero_point_name: str | None
|
|
26
|
+
) -> ScalarType:
|
|
27
|
+
output_attr = int(node.attrs.get("output_dtype", 0))
|
|
28
|
+
if output_attr:
|
|
29
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
30
|
+
attr_dtype = scalar_type_from_onnx(output_attr)
|
|
31
|
+
if attr_dtype is None:
|
|
32
|
+
raise UnsupportedOpError(
|
|
33
|
+
"QuantizeLinear output_dtype must map to a supported scalar type"
|
|
34
|
+
)
|
|
35
|
+
if output_dtype != attr_dtype:
|
|
36
|
+
raise UnsupportedOpError(
|
|
37
|
+
"QuantizeLinear output_dtype must match output tensor dtype"
|
|
38
|
+
)
|
|
39
|
+
return output_dtype
|
|
40
|
+
if zero_point_name is None:
|
|
41
|
+
return ScalarType.U8
|
|
42
|
+
return _value_dtype(graph, zero_point_name, node)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def resolve_quantize_spec(graph: Graph, node: Node) -> QuantizeSpec:
|
|
46
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
47
|
+
raise UnsupportedOpError(
|
|
48
|
+
"QuantizeLinear must have 2 or 3 inputs and 1 output"
|
|
49
|
+
)
|
|
50
|
+
supported_attrs = {"axis", "block_size", "output_dtype", "precision", "saturate"}
|
|
51
|
+
if set(node.attrs) - supported_attrs:
|
|
52
|
+
raise UnsupportedOpError("QuantizeLinear has unsupported attributes")
|
|
53
|
+
block_size = int(node.attrs.get("block_size", 0))
|
|
54
|
+
if block_size != 0:
|
|
55
|
+
raise UnsupportedOpError("QuantizeLinear block_size is not supported")
|
|
56
|
+
precision = int(node.attrs.get("precision", 0))
|
|
57
|
+
if precision != 0:
|
|
58
|
+
raise UnsupportedOpError("QuantizeLinear precision is not supported")
|
|
59
|
+
saturate = int(node.attrs.get("saturate", 1))
|
|
60
|
+
if saturate != 1:
|
|
61
|
+
raise UnsupportedOpError("QuantizeLinear saturate must be 1")
|
|
62
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
63
|
+
scale_shape = _value_shape(graph, node.inputs[1], node)
|
|
64
|
+
zero_point_name = optional_name(node.inputs, 2)
|
|
65
|
+
output_dtype = _resolve_output_dtype(graph, node, zero_point_name)
|
|
66
|
+
if output_dtype not in {
|
|
67
|
+
ScalarType.U8,
|
|
68
|
+
ScalarType.I8,
|
|
69
|
+
ScalarType.U16,
|
|
70
|
+
ScalarType.I16,
|
|
71
|
+
}:
|
|
72
|
+
raise UnsupportedOpError(
|
|
73
|
+
"QuantizeLinear supports int8/uint8/int16/uint16 outputs only"
|
|
74
|
+
)
|
|
75
|
+
if zero_point_name is not None:
|
|
76
|
+
zero_point_dtype = _value_dtype(graph, zero_point_name, node)
|
|
77
|
+
if zero_point_dtype != output_dtype:
|
|
78
|
+
raise UnsupportedOpError(
|
|
79
|
+
"QuantizeLinear zero_point dtype must match output dtype"
|
|
80
|
+
)
|
|
81
|
+
zero_point_shape = _value_shape(graph, zero_point_name, node)
|
|
82
|
+
if zero_point_shape != scale_shape:
|
|
83
|
+
raise ShapeInferenceError(
|
|
84
|
+
"QuantizeLinear zero_point shape must match scale shape"
|
|
85
|
+
)
|
|
86
|
+
if scale_shape not in {(), (1,)}:
|
|
87
|
+
if len(scale_shape) != 1:
|
|
88
|
+
raise UnsupportedOpError(
|
|
89
|
+
"QuantizeLinear supports per-tensor and per-axis scales only"
|
|
90
|
+
)
|
|
91
|
+
axis = int(node.attrs.get("axis", 1))
|
|
92
|
+
axis = normalize_axis(axis, input_shape, node)
|
|
93
|
+
if scale_shape[0] != input_shape[axis]:
|
|
94
|
+
raise ShapeInferenceError(
|
|
95
|
+
"QuantizeLinear scale length must match input axis size"
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
axis = None
|
|
99
|
+
return QuantizeSpec(
|
|
100
|
+
input_shape=input_shape,
|
|
101
|
+
scale_shape=scale_shape,
|
|
102
|
+
axis=axis,
|
|
103
|
+
output_dtype=output_dtype,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@register_lowering("QuantizeLinear")
|
|
108
|
+
def lower_quantize_linear(graph: Graph, node: Node) -> QuantizeLinearOp:
|
|
109
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
110
|
+
scale_dtype = _value_dtype(graph, node.inputs[1], node)
|
|
111
|
+
if not op_dtype.is_float or not scale_dtype.is_float:
|
|
112
|
+
raise UnsupportedOpError(
|
|
113
|
+
"QuantizeLinear supports float16/float/double inputs only"
|
|
114
|
+
)
|
|
115
|
+
spec = resolve_quantize_spec(graph, node)
|
|
116
|
+
return QuantizeLinearOp(
|
|
117
|
+
input0=node.inputs[0],
|
|
118
|
+
scale=node.inputs[1],
|
|
119
|
+
zero_point=optional_name(node.inputs, 2),
|
|
120
|
+
output=node.outputs[0],
|
|
121
|
+
input_shape=spec.input_shape,
|
|
122
|
+
axis=spec.axis,
|
|
123
|
+
dtype=spec.output_dtype,
|
|
124
|
+
input_dtype=op_dtype,
|
|
125
|
+
scale_dtype=scale_dtype,
|
|
126
|
+
)
|
emx_onnx_cgen/lowering/reduce.py
CHANGED
|
@@ -261,13 +261,12 @@ def _infer_axes_from_shapes(
|
|
|
261
261
|
if out_dim == in_dim:
|
|
262
262
|
if in_dim == 1:
|
|
263
263
|
return None
|
|
264
|
-
|
|
265
|
-
if out_dim == 1 and in_dim != 1:
|
|
264
|
+
elif out_dim == 1 and in_dim != 1:
|
|
266
265
|
axes.append(axis)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
266
|
+
else:
|
|
267
|
+
raise ShapeInferenceError(
|
|
268
|
+
f"{node.op_type} output shape does not match input shape"
|
|
269
|
+
)
|
|
271
270
|
return tuple(axes)
|
|
272
271
|
if len(output_shape) > len(input_shape):
|
|
273
272
|
return None
|
|
@@ -5,6 +5,7 @@ from shared.scalar_types import ScalarType
|
|
|
5
5
|
from ..codegen.c_emitter import ReshapeOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Initializer, Node
|
|
8
|
+
from .common import value_shape as resolved_value_shape
|
|
8
9
|
from .registry import register_lowering
|
|
9
10
|
|
|
10
11
|
|
|
@@ -37,6 +38,21 @@ def _shape_product(shape: tuple[int, ...]) -> int:
|
|
|
37
38
|
return product
|
|
38
39
|
|
|
39
40
|
|
|
41
|
+
def _reshape_mismatch_error(
|
|
42
|
+
node: Node,
|
|
43
|
+
input_shape: tuple[int, ...],
|
|
44
|
+
output_shape: tuple[int, ...],
|
|
45
|
+
) -> ShapeInferenceError:
|
|
46
|
+
node_name = node.name or "<unnamed>"
|
|
47
|
+
return ShapeInferenceError(
|
|
48
|
+
"Reshape input/output element counts must match for op "
|
|
49
|
+
f"{node.op_type} (node '{node_name}'): input shape {input_shape}, "
|
|
50
|
+
f"output shape {output_shape}. "
|
|
51
|
+
"Hint: ensure the reshape target has the same number of elements as "
|
|
52
|
+
"the input."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
40
56
|
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
41
57
|
for initializer in graph.initializers:
|
|
42
58
|
if initializer.name == name:
|
|
@@ -52,15 +68,190 @@ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
|
|
|
52
68
|
|
|
53
69
|
|
|
54
70
|
def _shape_values_from_shape_node(
|
|
55
|
-
graph: Graph,
|
|
56
|
-
) -> list[int]
|
|
57
|
-
shape_node = _find_node_by_output(graph, name)
|
|
58
|
-
if shape_node is None or shape_node.op_type != "Shape":
|
|
59
|
-
return None
|
|
71
|
+
graph: Graph, shape_node: Node, node: Node
|
|
72
|
+
) -> list[int]:
|
|
60
73
|
if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
|
|
61
74
|
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
62
75
|
source_shape = _value_shape(graph, shape_node.inputs[0], node)
|
|
63
|
-
|
|
76
|
+
start = int(shape_node.attrs.get("start", 0))
|
|
77
|
+
end = int(shape_node.attrs.get("end", len(source_shape)))
|
|
78
|
+
if start < 0:
|
|
79
|
+
start += len(source_shape)
|
|
80
|
+
if end < 0:
|
|
81
|
+
end += len(source_shape)
|
|
82
|
+
start = max(start, 0)
|
|
83
|
+
end = min(end, len(source_shape))
|
|
84
|
+
if start > end:
|
|
85
|
+
return []
|
|
86
|
+
return list(source_shape[start:end])
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _shape_values_from_initializer(
|
|
90
|
+
graph: Graph,
|
|
91
|
+
name: str,
|
|
92
|
+
) -> list[int] | None:
|
|
93
|
+
initializer = _find_initializer(graph, name)
|
|
94
|
+
if initializer is None:
|
|
95
|
+
return None
|
|
96
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
97
|
+
raise UnsupportedOpError(
|
|
98
|
+
"Reshape expects int64 or int32 shape input, "
|
|
99
|
+
f"got {initializer.type.dtype.onnx_name}"
|
|
100
|
+
)
|
|
101
|
+
return [int(value) for value in initializer.data.reshape(-1)]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _shape_values_from_input(
|
|
105
|
+
graph: Graph,
|
|
106
|
+
name: str,
|
|
107
|
+
node: Node,
|
|
108
|
+
*,
|
|
109
|
+
_visited: set[str] | None = None,
|
|
110
|
+
) -> list[int] | None:
|
|
111
|
+
if _visited is None:
|
|
112
|
+
_visited = set()
|
|
113
|
+
if name in _visited:
|
|
114
|
+
return None
|
|
115
|
+
_visited.add(name)
|
|
116
|
+
try:
|
|
117
|
+
shape_values = _shape_values_from_initializer(graph, name)
|
|
118
|
+
if shape_values is not None:
|
|
119
|
+
return shape_values
|
|
120
|
+
source_node = _find_node_by_output(graph, name)
|
|
121
|
+
if source_node is None:
|
|
122
|
+
return None
|
|
123
|
+
if source_node.op_type == "Shape":
|
|
124
|
+
return _shape_values_from_shape_node(graph, source_node, node)
|
|
125
|
+
if source_node.op_type == "Concat":
|
|
126
|
+
axis = int(source_node.attrs.get("axis", 0))
|
|
127
|
+
if axis != 0:
|
|
128
|
+
raise UnsupportedOpError("Reshape shape concat must use axis 0")
|
|
129
|
+
values: list[int] = []
|
|
130
|
+
for input_name in source_node.inputs:
|
|
131
|
+
input_values = _shape_values_from_input(
|
|
132
|
+
graph,
|
|
133
|
+
input_name,
|
|
134
|
+
node,
|
|
135
|
+
_visited=_visited,
|
|
136
|
+
)
|
|
137
|
+
if input_values is None:
|
|
138
|
+
return None
|
|
139
|
+
values.extend(input_values)
|
|
140
|
+
return values
|
|
141
|
+
if source_node.op_type == "Cast":
|
|
142
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
143
|
+
raise UnsupportedOpError("Cast must have 1 input and 1 output")
|
|
144
|
+
return _shape_values_from_input(
|
|
145
|
+
graph,
|
|
146
|
+
source_node.inputs[0],
|
|
147
|
+
node,
|
|
148
|
+
_visited=_visited,
|
|
149
|
+
)
|
|
150
|
+
if source_node.op_type == "Unsqueeze":
|
|
151
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
152
|
+
raise UnsupportedOpError("Unsqueeze must have 1 input and 1 output")
|
|
153
|
+
return _shape_values_from_input(
|
|
154
|
+
graph,
|
|
155
|
+
source_node.inputs[0],
|
|
156
|
+
node,
|
|
157
|
+
_visited=_visited,
|
|
158
|
+
)
|
|
159
|
+
if source_node.op_type == "Identity":
|
|
160
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
161
|
+
raise UnsupportedOpError("Identity must have 1 input and 1 output")
|
|
162
|
+
return _shape_values_from_input(
|
|
163
|
+
graph,
|
|
164
|
+
source_node.inputs[0],
|
|
165
|
+
node,
|
|
166
|
+
_visited=_visited,
|
|
167
|
+
)
|
|
168
|
+
if source_node.op_type in {"Equal", "And", "Or", "Div", "Mod"}:
|
|
169
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
170
|
+
raise UnsupportedOpError(
|
|
171
|
+
f"{source_node.op_type} must have 2 inputs and 1 output"
|
|
172
|
+
)
|
|
173
|
+
left = _shape_values_from_input(
|
|
174
|
+
graph,
|
|
175
|
+
source_node.inputs[0],
|
|
176
|
+
node,
|
|
177
|
+
_visited=_visited,
|
|
178
|
+
)
|
|
179
|
+
right = _shape_values_from_input(
|
|
180
|
+
graph,
|
|
181
|
+
source_node.inputs[1],
|
|
182
|
+
node,
|
|
183
|
+
_visited=_visited,
|
|
184
|
+
)
|
|
185
|
+
if left is None or right is None:
|
|
186
|
+
return None
|
|
187
|
+
if len(left) == 1 and len(right) != 1:
|
|
188
|
+
left = left * len(right)
|
|
189
|
+
if len(right) == 1 and len(left) != 1:
|
|
190
|
+
right = right * len(left)
|
|
191
|
+
if len(left) != len(right):
|
|
192
|
+
return None
|
|
193
|
+
if source_node.op_type == "Equal":
|
|
194
|
+
return [1 if l == r else 0 for l, r in zip(left, right)]
|
|
195
|
+
if source_node.op_type == "And":
|
|
196
|
+
return [1 if (l and r) else 0 for l, r in zip(left, right)]
|
|
197
|
+
if source_node.op_type == "Or":
|
|
198
|
+
return [1 if (l or r) else 0 for l, r in zip(left, right)]
|
|
199
|
+
if source_node.op_type == "Div":
|
|
200
|
+
return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
|
|
201
|
+
if source_node.op_type == "Mod":
|
|
202
|
+
return [l % r if r != 0 else 0 for l, r in zip(left, right)]
|
|
203
|
+
if source_node.op_type == "Not":
|
|
204
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
205
|
+
raise UnsupportedOpError("Not must have 1 input and 1 output")
|
|
206
|
+
values = _shape_values_from_input(
|
|
207
|
+
graph,
|
|
208
|
+
source_node.inputs[0],
|
|
209
|
+
node,
|
|
210
|
+
_visited=_visited,
|
|
211
|
+
)
|
|
212
|
+
if values is None:
|
|
213
|
+
return None
|
|
214
|
+
return [0 if value else 1 for value in values]
|
|
215
|
+
if source_node.op_type == "Where":
|
|
216
|
+
if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
|
|
217
|
+
raise UnsupportedOpError("Where must have 3 inputs and 1 output")
|
|
218
|
+
condition = _shape_values_from_input(
|
|
219
|
+
graph,
|
|
220
|
+
source_node.inputs[0],
|
|
221
|
+
node,
|
|
222
|
+
_visited=_visited,
|
|
223
|
+
)
|
|
224
|
+
if condition is None:
|
|
225
|
+
return None
|
|
226
|
+
on_true = _shape_values_from_input(
|
|
227
|
+
graph,
|
|
228
|
+
source_node.inputs[1],
|
|
229
|
+
node,
|
|
230
|
+
_visited=_visited,
|
|
231
|
+
)
|
|
232
|
+
on_false = _shape_values_from_input(
|
|
233
|
+
graph,
|
|
234
|
+
source_node.inputs[2],
|
|
235
|
+
node,
|
|
236
|
+
_visited=_visited,
|
|
237
|
+
)
|
|
238
|
+
if on_true is None or on_false is None:
|
|
239
|
+
return None
|
|
240
|
+
if len(condition) == 1:
|
|
241
|
+
condition = condition * max(len(on_true), len(on_false))
|
|
242
|
+
if len(on_true) == 1 and len(condition) != 1:
|
|
243
|
+
on_true = on_true * len(condition)
|
|
244
|
+
if len(on_false) == 1 and len(condition) != 1:
|
|
245
|
+
on_false = on_false * len(condition)
|
|
246
|
+
if not (len(condition) == len(on_true) == len(on_false)):
|
|
247
|
+
return None
|
|
248
|
+
return [
|
|
249
|
+
t if cond else f
|
|
250
|
+
for cond, t, f in zip(condition, on_true, on_false)
|
|
251
|
+
]
|
|
252
|
+
return None
|
|
253
|
+
finally:
|
|
254
|
+
_visited.remove(name)
|
|
64
255
|
|
|
65
256
|
|
|
66
257
|
def _resolve_target_shape(
|
|
@@ -82,19 +273,19 @@ def _resolve_target_shape(
|
|
|
82
273
|
raise ShapeInferenceError("Reshape allows only one -1 dimension")
|
|
83
274
|
unknown_index = index
|
|
84
275
|
output_dims.append(-1)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
276
|
+
else:
|
|
277
|
+
if dim == 0:
|
|
278
|
+
contains_zero = True
|
|
279
|
+
if allowzero == 0:
|
|
280
|
+
if index >= len(input_shape):
|
|
281
|
+
raise ShapeInferenceError(
|
|
282
|
+
"Reshape zero dim must index into input shape"
|
|
283
|
+
)
|
|
284
|
+
dim = input_shape[index]
|
|
285
|
+
if dim < 0:
|
|
286
|
+
raise ShapeInferenceError("Reshape dims must be >= -1")
|
|
287
|
+
output_dims.append(dim)
|
|
288
|
+
known_product *= dim
|
|
98
289
|
if allowzero == 1 and contains_zero and unknown_index is not None:
|
|
99
290
|
raise ShapeInferenceError(
|
|
100
291
|
"Reshape allowzero cannot combine zero and -1 dimensions"
|
|
@@ -115,9 +306,7 @@ def _resolve_target_shape(
|
|
|
115
306
|
output_dims[unknown_index] = input_product // known_product
|
|
116
307
|
output_shape = tuple(output_dims)
|
|
117
308
|
if _shape_product(output_shape) != input_product:
|
|
118
|
-
raise
|
|
119
|
-
"Reshape input and output element counts must match"
|
|
120
|
-
)
|
|
309
|
+
raise _reshape_mismatch_error(node, input_shape, output_shape)
|
|
121
310
|
return output_shape
|
|
122
311
|
|
|
123
312
|
|
|
@@ -125,7 +314,7 @@ def _resolve_target_shape(
|
|
|
125
314
|
def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
|
|
126
315
|
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
127
316
|
raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
|
|
128
|
-
input_shape =
|
|
317
|
+
input_shape = resolved_value_shape(graph, node.inputs[0], node)
|
|
129
318
|
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
130
319
|
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
131
320
|
if input_dtype != output_dtype:
|
|
@@ -133,46 +322,29 @@ def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
|
|
|
133
322
|
"Reshape expects matching input/output dtypes, "
|
|
134
323
|
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
135
324
|
)
|
|
136
|
-
|
|
325
|
+
output_value = graph.find_value(node.outputs[0])
|
|
326
|
+
output_shape = resolved_value_shape(graph, node.outputs[0], node)
|
|
327
|
+
output_dim_params = output_value.type.dim_params
|
|
137
328
|
allowzero = int(node.attrs.get("allowzero", 0))
|
|
138
|
-
shape_initializer = _find_initializer(graph, node.inputs[1])
|
|
139
329
|
resolved_shape: tuple[int, ...] | None = None
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
graph, node.inputs[1], node
|
|
143
|
-
)
|
|
144
|
-
if shape_values is not None:
|
|
145
|
-
resolved_shape = _resolve_target_shape(
|
|
146
|
-
input_shape,
|
|
147
|
-
shape_values,
|
|
148
|
-
allowzero=allowzero,
|
|
149
|
-
node=node,
|
|
150
|
-
)
|
|
151
|
-
else:
|
|
152
|
-
if _shape_product(output_shape) != _shape_product(input_shape):
|
|
153
|
-
raise ShapeInferenceError(
|
|
154
|
-
"Reshape input and output element counts must match"
|
|
155
|
-
)
|
|
156
|
-
else:
|
|
157
|
-
if shape_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
158
|
-
raise UnsupportedOpError(
|
|
159
|
-
"Reshape expects int64 or int32 shape input, "
|
|
160
|
-
f"got {shape_initializer.type.dtype.onnx_name}"
|
|
161
|
-
)
|
|
162
|
-
if len(shape_initializer.type.shape) != 1:
|
|
163
|
-
raise UnsupportedOpError("Reshape expects a 1D shape input")
|
|
164
|
-
shape_values = [int(value) for value in shape_initializer.data.reshape(-1)]
|
|
330
|
+
shape_values = _shape_values_from_input(graph, node.inputs[1], node)
|
|
331
|
+
if shape_values is not None:
|
|
165
332
|
resolved_shape = _resolve_target_shape(
|
|
166
333
|
input_shape,
|
|
167
334
|
shape_values,
|
|
168
335
|
allowzero=allowzero,
|
|
169
336
|
node=node,
|
|
170
337
|
)
|
|
171
|
-
if output_shape and resolved_shape != output_shape
|
|
338
|
+
if output_shape and resolved_shape != output_shape and not any(
|
|
339
|
+
output_dim_params
|
|
340
|
+
):
|
|
172
341
|
raise ShapeInferenceError(
|
|
173
342
|
"Reshape output shape must be "
|
|
174
343
|
f"{resolved_shape}, got {output_shape}"
|
|
175
344
|
)
|
|
345
|
+
else:
|
|
346
|
+
if _shape_product(output_shape) != _shape_product(input_shape):
|
|
347
|
+
raise _reshape_mismatch_error(node, input_shape, output_shape)
|
|
176
348
|
if resolved_shape is not None:
|
|
177
349
|
output_shape = resolved_shape
|
|
178
350
|
for dim in output_shape:
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import ScatterNDOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
_ALLOWED_REDUCTIONS = {"none", "add", "mul", "min", "max"}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_lowering("ScatterND")
|
|
15
|
+
def lower_scatternd(graph: Graph, node: Node) -> ScatterNDOp:
|
|
16
|
+
if len(node.inputs) != 3 or len(node.outputs) != 1:
|
|
17
|
+
raise UnsupportedOpError("ScatterND must have 3 inputs and 1 output")
|
|
18
|
+
data_name, indices_name, updates_name = node.inputs
|
|
19
|
+
output_name = node.outputs[0]
|
|
20
|
+
data_shape = value_shape(graph, data_name, node)
|
|
21
|
+
indices_shape = value_shape(graph, indices_name, node)
|
|
22
|
+
updates_shape = value_shape(graph, updates_name, node)
|
|
23
|
+
output_shape = value_shape(graph, output_name, node)
|
|
24
|
+
if output_shape != data_shape:
|
|
25
|
+
raise ShapeInferenceError(
|
|
26
|
+
"ScatterND output shape must match data shape, "
|
|
27
|
+
f"got {output_shape} vs {data_shape}"
|
|
28
|
+
)
|
|
29
|
+
if len(indices_shape) < 1:
|
|
30
|
+
raise ShapeInferenceError("ScatterND indices must have rank >= 1")
|
|
31
|
+
index_depth = indices_shape[-1]
|
|
32
|
+
if index_depth <= 0:
|
|
33
|
+
raise ShapeInferenceError(
|
|
34
|
+
"ScatterND indices final dimension must be >= 1"
|
|
35
|
+
)
|
|
36
|
+
if index_depth > len(data_shape):
|
|
37
|
+
raise ShapeInferenceError(
|
|
38
|
+
"ScatterND indices final dimension must be <= data rank, "
|
|
39
|
+
f"got {index_depth} vs {len(data_shape)}"
|
|
40
|
+
)
|
|
41
|
+
expected_updates_shape = indices_shape[:-1] + data_shape[index_depth:]
|
|
42
|
+
if updates_shape != expected_updates_shape:
|
|
43
|
+
raise ShapeInferenceError(
|
|
44
|
+
"ScatterND updates shape must be "
|
|
45
|
+
f"{expected_updates_shape}, got {updates_shape}"
|
|
46
|
+
)
|
|
47
|
+
data_dtype = value_dtype(graph, data_name, node)
|
|
48
|
+
updates_dtype = value_dtype(graph, updates_name, node)
|
|
49
|
+
if updates_dtype != data_dtype:
|
|
50
|
+
raise UnsupportedOpError(
|
|
51
|
+
"ScatterND updates dtype must match data dtype, "
|
|
52
|
+
f"got {updates_dtype.onnx_name} vs {data_dtype.onnx_name}"
|
|
53
|
+
)
|
|
54
|
+
indices_dtype = value_dtype(graph, indices_name, node)
|
|
55
|
+
if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
56
|
+
raise UnsupportedOpError(
|
|
57
|
+
"ScatterND indices must be int32 or int64, "
|
|
58
|
+
f"got {indices_dtype.onnx_name}"
|
|
59
|
+
)
|
|
60
|
+
reduction_attr = node.attrs.get("reduction", "none")
|
|
61
|
+
if isinstance(reduction_attr, bytes):
|
|
62
|
+
reduction = reduction_attr.decode()
|
|
63
|
+
else:
|
|
64
|
+
reduction = str(reduction_attr)
|
|
65
|
+
if reduction not in _ALLOWED_REDUCTIONS:
|
|
66
|
+
raise UnsupportedOpError(
|
|
67
|
+
"ScatterND reduction must be one of "
|
|
68
|
+
f"{sorted(_ALLOWED_REDUCTIONS)}, got {reduction}"
|
|
69
|
+
)
|
|
70
|
+
return ScatterNDOp(
|
|
71
|
+
data=data_name,
|
|
72
|
+
indices=indices_name,
|
|
73
|
+
updates=updates_name,
|
|
74
|
+
output=output_name,
|
|
75
|
+
data_shape=data_shape,
|
|
76
|
+
indices_shape=indices_shape,
|
|
77
|
+
updates_shape=updates_shape,
|
|
78
|
+
output_shape=output_shape,
|
|
79
|
+
reduction=reduction,
|
|
80
|
+
dtype=data_dtype,
|
|
81
|
+
indices_dtype=indices_dtype,
|
|
82
|
+
)
|