emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import SoftmaxCrossEntropyLossOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import shape_product as _shape_product
|
|
9
|
+
from .common import value_dtype as _value_dtype
|
|
10
|
+
from .common import value_shape as _value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_lowering("SoftmaxCrossEntropyLoss")
|
|
15
|
+
def lower_softmax_cross_entropy_loss(
|
|
16
|
+
graph: Graph, node: Node
|
|
17
|
+
) -> SoftmaxCrossEntropyLossOp:
|
|
18
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) not in {1, 2}:
|
|
19
|
+
raise UnsupportedOpError(
|
|
20
|
+
"SoftmaxCrossEntropyLoss must have 2 or 3 inputs and 1 or 2 outputs"
|
|
21
|
+
)
|
|
22
|
+
input_name = node.inputs[0]
|
|
23
|
+
target_name = node.inputs[1]
|
|
24
|
+
weight_name = node.inputs[2] if len(node.inputs) > 2 else None
|
|
25
|
+
input_dtype = _value_dtype(graph, input_name, node)
|
|
26
|
+
if not input_dtype.is_float:
|
|
27
|
+
raise UnsupportedOpError(
|
|
28
|
+
"SoftmaxCrossEntropyLoss supports float16, float, and double inputs only"
|
|
29
|
+
)
|
|
30
|
+
output_name = node.outputs[0]
|
|
31
|
+
output_dtype = _value_dtype(graph, output_name, node)
|
|
32
|
+
if output_dtype != input_dtype:
|
|
33
|
+
raise UnsupportedOpError(
|
|
34
|
+
"SoftmaxCrossEntropyLoss output dtype must match input dtype"
|
|
35
|
+
)
|
|
36
|
+
log_prob_name = node.outputs[1] if len(node.outputs) > 1 else None
|
|
37
|
+
if log_prob_name is not None:
|
|
38
|
+
log_prob_dtype = _value_dtype(graph, log_prob_name, node)
|
|
39
|
+
if log_prob_dtype != input_dtype:
|
|
40
|
+
raise UnsupportedOpError(
|
|
41
|
+
"SoftmaxCrossEntropyLoss log_prob dtype must match input dtype"
|
|
42
|
+
)
|
|
43
|
+
target_dtype = _value_dtype(graph, target_name, node)
|
|
44
|
+
if target_dtype not in {ScalarType.I32, ScalarType.I64}:
|
|
45
|
+
raise UnsupportedOpError(
|
|
46
|
+
"SoftmaxCrossEntropyLoss target must be int32 or int64"
|
|
47
|
+
)
|
|
48
|
+
weight_dtype = None
|
|
49
|
+
weight_shape: tuple[int, ...] | None = None
|
|
50
|
+
if weight_name is not None:
|
|
51
|
+
weight_dtype = _value_dtype(graph, weight_name, node)
|
|
52
|
+
if weight_dtype != input_dtype:
|
|
53
|
+
raise UnsupportedOpError(
|
|
54
|
+
"SoftmaxCrossEntropyLoss weight dtype must match input dtype"
|
|
55
|
+
)
|
|
56
|
+
input_shape = _value_shape(graph, input_name, node)
|
|
57
|
+
target_shape = _value_shape(graph, target_name, node)
|
|
58
|
+
output_shape = _value_shape(graph, output_name, node)
|
|
59
|
+
if len(input_shape) < 2:
|
|
60
|
+
raise ShapeInferenceError("SoftmaxCrossEntropyLoss input must be at least 2D")
|
|
61
|
+
if len(target_shape) != len(input_shape) - 1:
|
|
62
|
+
raise ShapeInferenceError(
|
|
63
|
+
"SoftmaxCrossEntropyLoss target rank must be input rank - 1"
|
|
64
|
+
)
|
|
65
|
+
if input_shape[0] != target_shape[0]:
|
|
66
|
+
raise ShapeInferenceError(
|
|
67
|
+
"SoftmaxCrossEntropyLoss target batch dimension must match input"
|
|
68
|
+
)
|
|
69
|
+
if input_shape[2:] != target_shape[1:]:
|
|
70
|
+
raise ShapeInferenceError(
|
|
71
|
+
"SoftmaxCrossEntropyLoss target spatial dimensions must match input"
|
|
72
|
+
)
|
|
73
|
+
if weight_name is not None:
|
|
74
|
+
weight_shape = _value_shape(graph, weight_name, node)
|
|
75
|
+
if len(weight_shape) != 1 or weight_shape[0] != input_shape[1]:
|
|
76
|
+
raise ShapeInferenceError(
|
|
77
|
+
"SoftmaxCrossEntropyLoss weight must have shape (C,)"
|
|
78
|
+
)
|
|
79
|
+
if log_prob_name is not None:
|
|
80
|
+
log_prob_shape = _value_shape(graph, log_prob_name, node)
|
|
81
|
+
if log_prob_shape != input_shape:
|
|
82
|
+
raise ShapeInferenceError(
|
|
83
|
+
"SoftmaxCrossEntropyLoss log_prob output must match input shape"
|
|
84
|
+
)
|
|
85
|
+
reduction = node.attrs.get("reduction", "mean")
|
|
86
|
+
if isinstance(reduction, bytes):
|
|
87
|
+
reduction = reduction.decode("utf-8")
|
|
88
|
+
if reduction not in {"none", "mean", "sum"}:
|
|
89
|
+
raise UnsupportedOpError(
|
|
90
|
+
"SoftmaxCrossEntropyLoss reduction must be none, mean, or sum"
|
|
91
|
+
)
|
|
92
|
+
if reduction == "none":
|
|
93
|
+
if output_shape != target_shape:
|
|
94
|
+
raise ShapeInferenceError(
|
|
95
|
+
"SoftmaxCrossEntropyLoss output must match target shape "
|
|
96
|
+
"when reduction is none"
|
|
97
|
+
)
|
|
98
|
+
else:
|
|
99
|
+
if output_shape not in {(), (1,)}:
|
|
100
|
+
raise ShapeInferenceError(
|
|
101
|
+
"SoftmaxCrossEntropyLoss output must be scalar when reduced"
|
|
102
|
+
)
|
|
103
|
+
n = input_shape[0]
|
|
104
|
+
c = input_shape[1]
|
|
105
|
+
d = _shape_product(input_shape[2:]) if len(input_shape) > 2 else 1
|
|
106
|
+
ignore_index = node.attrs.get("ignore_index")
|
|
107
|
+
if ignore_index is not None:
|
|
108
|
+
ignore_index = int(ignore_index)
|
|
109
|
+
return SoftmaxCrossEntropyLossOp(
|
|
110
|
+
input0=input_name,
|
|
111
|
+
target=target_name,
|
|
112
|
+
weight=weight_name,
|
|
113
|
+
output=output_name,
|
|
114
|
+
log_prob=log_prob_name,
|
|
115
|
+
input_shape=input_shape,
|
|
116
|
+
target_shape=target_shape,
|
|
117
|
+
output_shape=output_shape,
|
|
118
|
+
log_prob_shape=input_shape if log_prob_name is not None else None,
|
|
119
|
+
n=n,
|
|
120
|
+
c=c,
|
|
121
|
+
d=d,
|
|
122
|
+
reduction=reduction,
|
|
123
|
+
ignore_index=ignore_index,
|
|
124
|
+
input_dtype=input_dtype,
|
|
125
|
+
weight_dtype=weight_dtype,
|
|
126
|
+
weight_shape=weight_shape,
|
|
127
|
+
dtype=input_dtype,
|
|
128
|
+
target_dtype=target_dtype,
|
|
129
|
+
)
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..codegen.c_emitter import SplitOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..lowering.common import optional_name, value_dtype, value_shape
|
|
11
|
+
from ..validation import normalize_axis
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
16
|
+
for initializer in graph.initializers:
|
|
17
|
+
if initializer.name == name:
|
|
18
|
+
return initializer
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _read_split_sizes(graph: Graph, name: str, node: Node) -> list[int] | None:
|
|
23
|
+
initializer = _find_initializer(graph, name)
|
|
24
|
+
if initializer is None:
|
|
25
|
+
return None
|
|
26
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
27
|
+
raise UnsupportedOpError(
|
|
28
|
+
f"{node.op_type} split input must be int64 or int32"
|
|
29
|
+
)
|
|
30
|
+
if len(initializer.type.shape) != 1:
|
|
31
|
+
raise UnsupportedOpError(
|
|
32
|
+
f"{node.op_type} split input must be a 1D tensor"
|
|
33
|
+
)
|
|
34
|
+
values = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
35
|
+
if values.size == 0:
|
|
36
|
+
raise ShapeInferenceError(
|
|
37
|
+
f"{node.op_type} split input cannot be empty"
|
|
38
|
+
)
|
|
39
|
+
return [int(value) for value in values]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _validate_static_dims(shape: tuple[int, ...], node: Node) -> None:
|
|
43
|
+
if any(dim < 0 for dim in shape):
|
|
44
|
+
raise ShapeInferenceError(
|
|
45
|
+
f"{node.op_type} does not support dynamic dims"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _normalize_num_outputs(node: Node, output_count: int) -> int:
|
|
50
|
+
num_outputs_attr = node.attrs.get("num_outputs")
|
|
51
|
+
if num_outputs_attr is None:
|
|
52
|
+
return output_count
|
|
53
|
+
num_outputs = int(num_outputs_attr)
|
|
54
|
+
if num_outputs <= 0:
|
|
55
|
+
raise UnsupportedOpError("Split num_outputs must be positive")
|
|
56
|
+
if output_count != num_outputs:
|
|
57
|
+
raise ShapeInferenceError(
|
|
58
|
+
f"Split expects {num_outputs} outputs, got {output_count}"
|
|
59
|
+
)
|
|
60
|
+
return num_outputs
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@register_lowering("Split")
|
|
64
|
+
def lower_split(graph: Graph, node: Node) -> SplitOp:
|
|
65
|
+
if len(node.inputs) < 1 or len(node.outputs) < 1:
|
|
66
|
+
raise UnsupportedOpError("Split must have at least 1 input and 1 output")
|
|
67
|
+
if len(node.inputs) > 2:
|
|
68
|
+
raise UnsupportedOpError("Split supports up to 2 inputs")
|
|
69
|
+
input_name = node.inputs[0]
|
|
70
|
+
if not input_name:
|
|
71
|
+
raise UnsupportedOpError("Split input must be provided")
|
|
72
|
+
input_shape = value_shape(graph, input_name, node)
|
|
73
|
+
_validate_static_dims(input_shape, node)
|
|
74
|
+
axis = normalize_axis(int(node.attrs.get("axis", 0)), input_shape, node)
|
|
75
|
+
output_shapes = [
|
|
76
|
+
value_shape(graph, output, node) for output in node.outputs
|
|
77
|
+
]
|
|
78
|
+
input_dtype = value_dtype(graph, input_name, node)
|
|
79
|
+
output_dtypes = {value_dtype(graph, output, node) for output in node.outputs}
|
|
80
|
+
if output_dtypes != {input_dtype}:
|
|
81
|
+
dtype_names = ", ".join(
|
|
82
|
+
dtype.onnx_name for dtype in sorted(output_dtypes, key=str)
|
|
83
|
+
)
|
|
84
|
+
raise UnsupportedOpError(
|
|
85
|
+
f"Split expects matching dtypes, got {dtype_names}"
|
|
86
|
+
)
|
|
87
|
+
split_name = optional_name(node.inputs, 1)
|
|
88
|
+
if split_name is not None and "num_outputs" in node.attrs:
|
|
89
|
+
raise UnsupportedOpError(
|
|
90
|
+
"Split cannot specify both split input and num_outputs"
|
|
91
|
+
)
|
|
92
|
+
if split_name is not None:
|
|
93
|
+
split_sizes = _read_split_sizes(graph, split_name, node)
|
|
94
|
+
if split_sizes is None:
|
|
95
|
+
split_shape, split_dtype = value_shape(
|
|
96
|
+
graph, split_name, node
|
|
97
|
+
), value_dtype(graph, split_name, node)
|
|
98
|
+
if split_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
99
|
+
raise UnsupportedOpError(
|
|
100
|
+
f"{node.op_type} split input must be int64 or int32"
|
|
101
|
+
)
|
|
102
|
+
if len(split_shape) != 1:
|
|
103
|
+
raise UnsupportedOpError(
|
|
104
|
+
f"{node.op_type} split input must be a 1D tensor"
|
|
105
|
+
)
|
|
106
|
+
if split_shape[0] != len(node.outputs):
|
|
107
|
+
raise ShapeInferenceError(
|
|
108
|
+
f"Split expects {len(node.outputs)} outputs, got {split_shape[0]}"
|
|
109
|
+
)
|
|
110
|
+
split_sizes = [shape[axis] for shape in output_shapes]
|
|
111
|
+
if len(split_sizes) != len(node.outputs):
|
|
112
|
+
raise ShapeInferenceError(
|
|
113
|
+
f"Split expects {len(split_sizes)} outputs, got {len(node.outputs)}"
|
|
114
|
+
)
|
|
115
|
+
if any(size < 0 for size in split_sizes):
|
|
116
|
+
raise ShapeInferenceError("Split sizes must be non-negative")
|
|
117
|
+
if sum(split_sizes) != input_shape[axis]:
|
|
118
|
+
raise ShapeInferenceError(
|
|
119
|
+
"Split sizes must sum to the axis dimension"
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
num_outputs = _normalize_num_outputs(node, len(node.outputs))
|
|
123
|
+
axis_dim = input_shape[axis]
|
|
124
|
+
base = axis_dim // num_outputs
|
|
125
|
+
remainder = axis_dim % num_outputs
|
|
126
|
+
split_sizes = [base + 1] * remainder + [base] * (
|
|
127
|
+
num_outputs - remainder
|
|
128
|
+
)
|
|
129
|
+
computed_shapes: list[tuple[int, ...]] = []
|
|
130
|
+
for size, output_shape in zip(split_sizes, output_shapes):
|
|
131
|
+
if size < 0:
|
|
132
|
+
raise ShapeInferenceError("Split output size must be non-negative")
|
|
133
|
+
shape = list(input_shape)
|
|
134
|
+
shape[axis] = size
|
|
135
|
+
computed_shape = tuple(shape)
|
|
136
|
+
if output_shape != computed_shape:
|
|
137
|
+
raise ShapeInferenceError(
|
|
138
|
+
f"Split output shape must be {computed_shape}, got {output_shape}"
|
|
139
|
+
)
|
|
140
|
+
computed_shapes.append(computed_shape)
|
|
141
|
+
return SplitOp(
|
|
142
|
+
input0=input_name,
|
|
143
|
+
outputs=tuple(node.outputs),
|
|
144
|
+
input_shape=input_shape,
|
|
145
|
+
output_shapes=tuple(computed_shapes),
|
|
146
|
+
axis=axis,
|
|
147
|
+
split_sizes=tuple(split_sizes),
|
|
148
|
+
dtype=input_dtype,
|
|
149
|
+
input_dtype=input_dtype,
|
|
150
|
+
)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..codegen.c_emitter import ReshapeOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Initializer, Node
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
12
|
+
try:
|
|
13
|
+
return graph.find_value(name).type.shape
|
|
14
|
+
except KeyError as exc:
|
|
15
|
+
raise ShapeInferenceError(
|
|
16
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
17
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
18
|
+
) from exc
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
22
|
+
try:
|
|
23
|
+
return graph.find_value(name).type.dtype
|
|
24
|
+
except KeyError as exc:
|
|
25
|
+
raise ShapeInferenceError(
|
|
26
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
27
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
28
|
+
) from exc
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
32
|
+
for initializer in graph.initializers:
|
|
33
|
+
if initializer.name == name:
|
|
34
|
+
return initializer
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _validate_shape(shape: tuple[int, ...], node: Node, label: str) -> None:
|
|
39
|
+
for dim in shape:
|
|
40
|
+
if dim < 0:
|
|
41
|
+
raise ShapeInferenceError(
|
|
42
|
+
f"{node.op_type} does not support dynamic dims in {label}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _normalize_axes(
|
|
47
|
+
axes: list[int], input_rank: int, node: Node
|
|
48
|
+
) -> tuple[int, ...]:
|
|
49
|
+
normalized: list[int] = []
|
|
50
|
+
for axis in axes:
|
|
51
|
+
if axis < 0:
|
|
52
|
+
axis += input_rank
|
|
53
|
+
if axis < 0 or axis >= input_rank:
|
|
54
|
+
raise ShapeInferenceError(
|
|
55
|
+
f"{node.op_type} axis {axis} is out of range for rank {input_rank}"
|
|
56
|
+
)
|
|
57
|
+
normalized.append(axis)
|
|
58
|
+
if len(set(normalized)) != len(normalized):
|
|
59
|
+
raise ShapeInferenceError(f"{node.op_type} axes must be unique")
|
|
60
|
+
return tuple(sorted(normalized))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _resolve_axes(graph: Graph, node: Node) -> tuple[int, ...] | None:
|
|
64
|
+
axes_attr = node.attrs.get("axes")
|
|
65
|
+
axes_values: list[int] | None = None
|
|
66
|
+
if len(node.inputs) == 2:
|
|
67
|
+
axes_initializer = _find_initializer(graph, node.inputs[1])
|
|
68
|
+
if axes_initializer is not None:
|
|
69
|
+
if axes_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
70
|
+
raise UnsupportedOpError(
|
|
71
|
+
"Squeeze axes input must be int64 or int32, "
|
|
72
|
+
f"got {axes_initializer.type.dtype.onnx_name}"
|
|
73
|
+
)
|
|
74
|
+
axes_values = [int(value) for value in axes_initializer.data.reshape(-1)]
|
|
75
|
+
elif axes_attr is not None:
|
|
76
|
+
axes_values = [int(value) for value in axes_attr]
|
|
77
|
+
if axes_values is None:
|
|
78
|
+
return None
|
|
79
|
+
return tuple(axes_values)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _expected_output_shape(
|
|
83
|
+
input_shape: tuple[int, ...], axes: tuple[int, ...]
|
|
84
|
+
) -> tuple[int, ...]:
|
|
85
|
+
axis_set = set(axes)
|
|
86
|
+
return tuple(
|
|
87
|
+
dim for index, dim in enumerate(input_shape) if index not in axis_set
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _validate_output_shape_for_unknown_axes(
|
|
92
|
+
input_shape: tuple[int, ...], output_shape: tuple[int, ...], node: Node
|
|
93
|
+
) -> None:
|
|
94
|
+
output_index = 0
|
|
95
|
+
for dim in input_shape:
|
|
96
|
+
if output_index < len(output_shape) and dim == output_shape[output_index]:
|
|
97
|
+
output_index += 1
|
|
98
|
+
continue
|
|
99
|
+
if dim != 1:
|
|
100
|
+
raise ShapeInferenceError(
|
|
101
|
+
"Squeeze output shape must remove only dimensions of size 1"
|
|
102
|
+
)
|
|
103
|
+
if output_index != len(output_shape):
|
|
104
|
+
raise ShapeInferenceError(
|
|
105
|
+
"Squeeze output shape must preserve input order while removing size-1 axes"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@register_lowering("Squeeze")
|
|
110
|
+
def lower_squeeze(graph: Graph, node: Node) -> ReshapeOp:
|
|
111
|
+
if len(node.outputs) != 1 or len(node.inputs) not in {1, 2}:
|
|
112
|
+
raise UnsupportedOpError("Squeeze must have 1 or 2 inputs and 1 output")
|
|
113
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
114
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
115
|
+
_validate_shape(input_shape, node, "input")
|
|
116
|
+
_validate_shape(output_shape, node, "output")
|
|
117
|
+
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
118
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
119
|
+
if input_dtype != output_dtype:
|
|
120
|
+
raise UnsupportedOpError(
|
|
121
|
+
"Squeeze expects matching input/output dtypes, "
|
|
122
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
123
|
+
)
|
|
124
|
+
axes = _resolve_axes(graph, node)
|
|
125
|
+
if axes is None:
|
|
126
|
+
if len(node.inputs) == 2:
|
|
127
|
+
axes_dtype = _value_dtype(graph, node.inputs[1], node)
|
|
128
|
+
if axes_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
129
|
+
raise UnsupportedOpError(
|
|
130
|
+
"Squeeze axes input must be int64 or int32, "
|
|
131
|
+
f"got {axes_dtype.onnx_name}"
|
|
132
|
+
)
|
|
133
|
+
_validate_output_shape_for_unknown_axes(input_shape, output_shape, node)
|
|
134
|
+
else:
|
|
135
|
+
expected_shape = tuple(dim for dim in input_shape if dim != 1)
|
|
136
|
+
if expected_shape != output_shape:
|
|
137
|
+
raise ShapeInferenceError(
|
|
138
|
+
"Squeeze output shape must be "
|
|
139
|
+
f"{expected_shape}, got {output_shape}"
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
normalized_axes = _normalize_axes(list(axes), len(input_shape), node)
|
|
143
|
+
for axis in normalized_axes:
|
|
144
|
+
if input_shape[axis] != 1:
|
|
145
|
+
raise ShapeInferenceError(
|
|
146
|
+
"Squeeze axes must target dimensions of size 1"
|
|
147
|
+
)
|
|
148
|
+
expected_shape = _expected_output_shape(input_shape, normalized_axes)
|
|
149
|
+
if expected_shape != output_shape:
|
|
150
|
+
raise ShapeInferenceError(
|
|
151
|
+
"Squeeze output shape must be "
|
|
152
|
+
f"{expected_shape}, got {output_shape}"
|
|
153
|
+
)
|
|
154
|
+
return ReshapeOp(
|
|
155
|
+
input0=node.inputs[0],
|
|
156
|
+
output=node.outputs[0],
|
|
157
|
+
input_shape=input_shape,
|
|
158
|
+
output_shape=output_shape,
|
|
159
|
+
dtype=input_dtype,
|
|
160
|
+
input_dtype=input_dtype,
|
|
161
|
+
)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..codegen.c_emitter import TileOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..lowering.common import value_dtype, value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
15
|
+
for initializer in graph.initializers:
|
|
16
|
+
if initializer.name == name:
|
|
17
|
+
return initializer
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _read_repeats(graph: Graph, name: str, node: Node) -> tuple[int, ...] | None:
|
|
22
|
+
initializer = _find_initializer(graph, name)
|
|
23
|
+
if initializer is None:
|
|
24
|
+
return None
|
|
25
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
26
|
+
raise UnsupportedOpError("Tile repeats input must be int64 or int32")
|
|
27
|
+
if len(initializer.type.shape) != 1:
|
|
28
|
+
raise UnsupportedOpError("Tile repeats input must be a 1D tensor")
|
|
29
|
+
values = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
30
|
+
return tuple(int(value) for value in values)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
34
|
+
strides: list[int] = []
|
|
35
|
+
stride = 1
|
|
36
|
+
for dim in reversed(shape):
|
|
37
|
+
strides.append(stride)
|
|
38
|
+
stride *= dim
|
|
39
|
+
return tuple(reversed(strides))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_lowering("Tile")
|
|
43
|
+
def lower_tile(graph: Graph, node: Node) -> TileOp:
|
|
44
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
45
|
+
raise UnsupportedOpError("Tile must have 2 inputs and 1 output")
|
|
46
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
47
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
48
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
49
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
50
|
+
if input_dtype != output_dtype:
|
|
51
|
+
raise UnsupportedOpError(
|
|
52
|
+
"Tile expects matching input/output dtypes, "
|
|
53
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
54
|
+
)
|
|
55
|
+
repeats = _read_repeats(graph, node.inputs[1], node)
|
|
56
|
+
if repeats is None:
|
|
57
|
+
raise UnsupportedOpError("Tile repeats input must be a constant initializer")
|
|
58
|
+
if len(repeats) != len(input_shape):
|
|
59
|
+
raise ShapeInferenceError(
|
|
60
|
+
"Tile repeats must have the same rank as input shape"
|
|
61
|
+
)
|
|
62
|
+
if any(value < 0 for value in repeats):
|
|
63
|
+
raise UnsupportedOpError("Tile repeats must be non-negative")
|
|
64
|
+
expected_shape = tuple(
|
|
65
|
+
int(dim) * int(repeat) for dim, repeat in zip(input_shape, repeats)
|
|
66
|
+
)
|
|
67
|
+
if output_shape and output_shape != expected_shape:
|
|
68
|
+
raise ShapeInferenceError(
|
|
69
|
+
"Tile output shape mismatch: "
|
|
70
|
+
f"expected {expected_shape}, got {output_shape}"
|
|
71
|
+
)
|
|
72
|
+
return TileOp(
|
|
73
|
+
input0=node.inputs[0],
|
|
74
|
+
output=node.outputs[0],
|
|
75
|
+
input_shape=input_shape,
|
|
76
|
+
output_shape=expected_shape,
|
|
77
|
+
repeats=repeats,
|
|
78
|
+
input_strides=_compute_strides(input_shape),
|
|
79
|
+
dtype=output_dtype,
|
|
80
|
+
input_dtype=input_dtype,
|
|
81
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..codegen.c_emitter import TransposeOp
|
|
4
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
|
+
from ..ir.model import Graph, Node
|
|
6
|
+
from .common import node_dtype as _node_dtype
|
|
7
|
+
from .common import value_shape as _value_shape
|
|
8
|
+
from .registry import register_lowering
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_lowering("Transpose")
|
|
12
|
+
def lower_transpose(graph: Graph, node: Node) -> TransposeOp:
|
|
13
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
14
|
+
raise UnsupportedOpError("Transpose must have 1 input and 1 output")
|
|
15
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
16
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
17
|
+
perm = node.attrs.get("perm")
|
|
18
|
+
if perm is None:
|
|
19
|
+
perm = tuple(reversed(range(len(input_shape))))
|
|
20
|
+
else:
|
|
21
|
+
perm = tuple(int(axis) for axis in perm)
|
|
22
|
+
if len(perm) != len(input_shape):
|
|
23
|
+
raise ShapeInferenceError(
|
|
24
|
+
"Transpose perm must match input rank, "
|
|
25
|
+
f"got perm {perm} for shape {input_shape}"
|
|
26
|
+
)
|
|
27
|
+
if set(perm) != set(range(len(input_shape))):
|
|
28
|
+
raise UnsupportedOpError(
|
|
29
|
+
f"Transpose perm must be a permutation, got {perm}"
|
|
30
|
+
)
|
|
31
|
+
expected_shape = tuple(input_shape[axis] for axis in perm)
|
|
32
|
+
if output_shape != expected_shape:
|
|
33
|
+
raise ShapeInferenceError(
|
|
34
|
+
"Transpose output shape must match permuted input shape, "
|
|
35
|
+
f"expected {expected_shape}, got {output_shape}"
|
|
36
|
+
)
|
|
37
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
38
|
+
return TransposeOp(
|
|
39
|
+
input0=node.inputs[0],
|
|
40
|
+
output=node.outputs[0],
|
|
41
|
+
perm=perm,
|
|
42
|
+
input_shape=input_shape,
|
|
43
|
+
output_shape=output_shape,
|
|
44
|
+
dtype=op_dtype,
|
|
45
|
+
input_dtype=op_dtype,
|
|
46
|
+
)
|