emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
6
|
+
from ..ir.model import Graph, Node
|
|
7
|
+
from ..ir.ops import TfIdfVectorizerOp
|
|
8
|
+
from ..lowering.common import value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
_SUPPORTED_INPUT_DTYPES = {ScalarType.I32, ScalarType.I64}
|
|
12
|
+
_SUPPORTED_OUTPUT_DTYPES = {ScalarType.F32}
|
|
13
|
+
_SUPPORTED_MODES = {"TF", "IDF", "TFIDF"}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _decode_mode(value: object) -> str:
|
|
17
|
+
if isinstance(value, bytes):
|
|
18
|
+
return value.decode()
|
|
19
|
+
return str(value)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _ensure_int_list(
|
|
23
|
+
values: object | None, *, name: str, node: Node
|
|
24
|
+
) -> tuple[int, ...]:
|
|
25
|
+
if values is None:
|
|
26
|
+
raise UnsupportedOpError(f"{node.op_type} requires {name} attribute")
|
|
27
|
+
try:
|
|
28
|
+
return tuple(int(value) for value in values) # type: ignore[arg-type]
|
|
29
|
+
except TypeError as exc:
|
|
30
|
+
raise UnsupportedOpError(
|
|
31
|
+
f"{node.op_type} {name} attribute must be a list of integers"
|
|
32
|
+
) from exc
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _ensure_float_list(
|
|
36
|
+
values: object | None, *, name: str, node: Node
|
|
37
|
+
) -> tuple[float, ...] | None:
|
|
38
|
+
if values is None:
|
|
39
|
+
return None
|
|
40
|
+
try:
|
|
41
|
+
return tuple(float(value) for value in values) # type: ignore[arg-type]
|
|
42
|
+
except TypeError as exc:
|
|
43
|
+
raise UnsupportedOpError(
|
|
44
|
+
f"{node.op_type} {name} attribute must be a list of floats"
|
|
45
|
+
) from exc
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _validate_output_shape(
|
|
49
|
+
node: Node,
|
|
50
|
+
input_shape: tuple[int, ...],
|
|
51
|
+
output_shape: tuple[int, ...],
|
|
52
|
+
output_dim: int,
|
|
53
|
+
) -> None:
|
|
54
|
+
if len(input_shape) == 1:
|
|
55
|
+
expected = (output_dim,)
|
|
56
|
+
else:
|
|
57
|
+
expected = (input_shape[0], output_dim)
|
|
58
|
+
if output_shape != expected:
|
|
59
|
+
raise ShapeInferenceError(
|
|
60
|
+
f"{node.op_type} output shape must be {expected}, got {output_shape}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@register_lowering("TfIdfVectorizer")
|
|
65
|
+
def lower_tfidf_vectorizer(graph: Graph, node: Node) -> TfIdfVectorizerOp:
|
|
66
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
67
|
+
raise UnsupportedOpError(
|
|
68
|
+
f"{node.op_type} expects 1 input and 1 output"
|
|
69
|
+
)
|
|
70
|
+
input_name = node.inputs[0]
|
|
71
|
+
output_name = node.outputs[0]
|
|
72
|
+
input_shape = value_shape(graph, input_name, node)
|
|
73
|
+
output_shape = value_shape(graph, output_name, node)
|
|
74
|
+
input_dtype = value_dtype(graph, input_name, node)
|
|
75
|
+
output_dtype = value_dtype(graph, output_name, node)
|
|
76
|
+
if input_dtype not in _SUPPORTED_INPUT_DTYPES:
|
|
77
|
+
raise UnsupportedOpError(
|
|
78
|
+
f"{node.op_type} input dtype must be int32 or int64, "
|
|
79
|
+
f"got {input_dtype.onnx_name}"
|
|
80
|
+
)
|
|
81
|
+
if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
|
|
82
|
+
raise UnsupportedOpError(
|
|
83
|
+
f"{node.op_type} output dtype must be float, "
|
|
84
|
+
f"got {output_dtype.onnx_name}"
|
|
85
|
+
)
|
|
86
|
+
if len(input_shape) not in {1, 2}:
|
|
87
|
+
raise UnsupportedOpError(
|
|
88
|
+
f"{node.op_type} input rank must be 1 or 2, got {len(input_shape)}"
|
|
89
|
+
)
|
|
90
|
+
mode_value = node.attrs.get("mode")
|
|
91
|
+
if mode_value is None:
|
|
92
|
+
raise UnsupportedOpError(
|
|
93
|
+
f"{node.op_type} requires mode attribute"
|
|
94
|
+
)
|
|
95
|
+
mode = _decode_mode(mode_value)
|
|
96
|
+
if mode not in _SUPPORTED_MODES:
|
|
97
|
+
raise UnsupportedOpError(
|
|
98
|
+
f"{node.op_type} mode must be one of {sorted(_SUPPORTED_MODES)}, "
|
|
99
|
+
f"got {mode}"
|
|
100
|
+
)
|
|
101
|
+
min_gram_length = int(node.attrs.get("min_gram_length", 0))
|
|
102
|
+
max_gram_length = int(node.attrs.get("max_gram_length", 0))
|
|
103
|
+
max_skip_count = int(node.attrs.get("max_skip_count", 0))
|
|
104
|
+
if min_gram_length <= 0 or max_gram_length <= 0:
|
|
105
|
+
raise UnsupportedOpError(
|
|
106
|
+
f"{node.op_type} requires positive min/max gram lengths"
|
|
107
|
+
)
|
|
108
|
+
if min_gram_length > max_gram_length:
|
|
109
|
+
raise UnsupportedOpError(
|
|
110
|
+
f"{node.op_type} min_gram_length {min_gram_length} exceeds "
|
|
111
|
+
f"max_gram_length {max_gram_length}"
|
|
112
|
+
)
|
|
113
|
+
if max_skip_count < 0:
|
|
114
|
+
raise UnsupportedOpError(
|
|
115
|
+
f"{node.op_type} max_skip_count must be non-negative"
|
|
116
|
+
)
|
|
117
|
+
ngram_counts = _ensure_int_list(
|
|
118
|
+
node.attrs.get("ngram_counts"), name="ngram_counts", node=node
|
|
119
|
+
)
|
|
120
|
+
ngram_indexes = _ensure_int_list(
|
|
121
|
+
node.attrs.get("ngram_indexes"), name="ngram_indexes", node=node
|
|
122
|
+
)
|
|
123
|
+
if "pool_strings" in node.attrs:
|
|
124
|
+
raise UnsupportedOpError(
|
|
125
|
+
f"{node.op_type} string pools are not supported"
|
|
126
|
+
)
|
|
127
|
+
pool_int64s = _ensure_int_list(
|
|
128
|
+
node.attrs.get("pool_int64s"), name="pool_int64s", node=node
|
|
129
|
+
)
|
|
130
|
+
weights = _ensure_float_list(
|
|
131
|
+
node.attrs.get("weights"), name="weights", node=node
|
|
132
|
+
)
|
|
133
|
+
if len(ngram_counts) < max_gram_length:
|
|
134
|
+
raise UnsupportedOpError(
|
|
135
|
+
f"{node.op_type} ngram_counts length must be >= max_gram_length"
|
|
136
|
+
)
|
|
137
|
+
if ngram_counts and ngram_counts[0] != 0:
|
|
138
|
+
raise UnsupportedOpError(
|
|
139
|
+
f"{node.op_type} ngram_counts must start with 0"
|
|
140
|
+
)
|
|
141
|
+
if any(value < 0 for value in ngram_counts):
|
|
142
|
+
raise UnsupportedOpError(
|
|
143
|
+
f"{node.op_type} ngram_counts must be non-negative"
|
|
144
|
+
)
|
|
145
|
+
if any(
|
|
146
|
+
later < earlier
|
|
147
|
+
for earlier, later in zip(ngram_counts, ngram_counts[1:])
|
|
148
|
+
):
|
|
149
|
+
raise UnsupportedOpError(
|
|
150
|
+
f"{node.op_type} ngram_counts must be non-decreasing"
|
|
151
|
+
)
|
|
152
|
+
pool_size = len(pool_int64s)
|
|
153
|
+
if ngram_counts and ngram_counts[-1] > pool_size:
|
|
154
|
+
raise UnsupportedOpError(
|
|
155
|
+
f"{node.op_type} ngram_counts exceeds pool_int64s length"
|
|
156
|
+
)
|
|
157
|
+
total_ngrams = 0
|
|
158
|
+
for gram_length in range(1, max_gram_length + 1):
|
|
159
|
+
start = ngram_counts[gram_length - 1]
|
|
160
|
+
end = (
|
|
161
|
+
ngram_counts[gram_length]
|
|
162
|
+
if gram_length < len(ngram_counts)
|
|
163
|
+
else pool_size
|
|
164
|
+
)
|
|
165
|
+
count = end - start
|
|
166
|
+
if count < 0 or count % gram_length != 0:
|
|
167
|
+
raise UnsupportedOpError(
|
|
168
|
+
f"{node.op_type} pool size for {gram_length}-grams "
|
|
169
|
+
"must be divisible by gram length"
|
|
170
|
+
)
|
|
171
|
+
total_ngrams += count // gram_length
|
|
172
|
+
if total_ngrams != len(ngram_indexes):
|
|
173
|
+
raise UnsupportedOpError(
|
|
174
|
+
f"{node.op_type} ngram_indexes length {len(ngram_indexes)} "
|
|
175
|
+
f"does not match pool ngram count {total_ngrams}"
|
|
176
|
+
)
|
|
177
|
+
if weights is not None and len(weights) != len(ngram_indexes):
|
|
178
|
+
raise UnsupportedOpError(
|
|
179
|
+
f"{node.op_type} weights length {len(weights)} does not match "
|
|
180
|
+
f"ngram_indexes length {len(ngram_indexes)}"
|
|
181
|
+
)
|
|
182
|
+
output_dim = max(ngram_indexes, default=-1) + 1
|
|
183
|
+
_validate_output_shape(node, input_shape, output_shape, output_dim)
|
|
184
|
+
return TfIdfVectorizerOp(
|
|
185
|
+
input0=input_name,
|
|
186
|
+
output=output_name,
|
|
187
|
+
input_shape=input_shape,
|
|
188
|
+
output_shape=output_shape,
|
|
189
|
+
input_dtype=input_dtype,
|
|
190
|
+
output_dtype=output_dtype,
|
|
191
|
+
min_gram_length=min_gram_length,
|
|
192
|
+
max_gram_length=max_gram_length,
|
|
193
|
+
max_skip_count=max_skip_count,
|
|
194
|
+
mode=mode,
|
|
195
|
+
ngram_counts=ngram_counts,
|
|
196
|
+
ngram_indexes=ngram_indexes,
|
|
197
|
+
pool_int64s=pool_int64s,
|
|
198
|
+
weights=weights,
|
|
199
|
+
)
|
emx_onnx_cgen/lowering/tile.py
CHANGED
|
@@ -30,6 +30,37 @@ def _read_repeats(graph: Graph, name: str, node: Node) -> tuple[int, ...] | None
|
|
|
30
30
|
return tuple(int(value) for value in values)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
def _infer_repeats_from_shapes(
|
|
34
|
+
input_shape: tuple[int, ...],
|
|
35
|
+
output_shape: tuple[int, ...],
|
|
36
|
+
) -> tuple[int, ...]:
|
|
37
|
+
if len(input_shape) != len(output_shape):
|
|
38
|
+
raise ShapeInferenceError(
|
|
39
|
+
"Tile repeats must have the same rank as input shape"
|
|
40
|
+
)
|
|
41
|
+
repeats: list[int] = []
|
|
42
|
+
for input_dim, output_dim in zip(input_shape, output_shape):
|
|
43
|
+
if input_dim < 0 or output_dim < 0:
|
|
44
|
+
raise ShapeInferenceError(
|
|
45
|
+
"Tile repeats input must be constant when shapes are dynamic"
|
|
46
|
+
)
|
|
47
|
+
if input_dim == 0:
|
|
48
|
+
if output_dim != 0:
|
|
49
|
+
raise ShapeInferenceError(
|
|
50
|
+
"Tile output shape mismatch: "
|
|
51
|
+
f"expected 0 for dimension, got {output_dim}"
|
|
52
|
+
)
|
|
53
|
+
repeats.append(0)
|
|
54
|
+
continue
|
|
55
|
+
if output_dim % input_dim != 0:
|
|
56
|
+
raise ShapeInferenceError(
|
|
57
|
+
"Tile output shape mismatch: "
|
|
58
|
+
f"expected multiple of {input_dim}, got {output_dim}"
|
|
59
|
+
)
|
|
60
|
+
repeats.append(int(output_dim // input_dim))
|
|
61
|
+
return tuple(repeats)
|
|
62
|
+
|
|
63
|
+
|
|
33
64
|
def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
34
65
|
strides: list[int] = []
|
|
35
66
|
stride = 1
|
|
@@ -54,7 +85,13 @@ def lower_tile(graph: Graph, node: Node) -> TileOp:
|
|
|
54
85
|
)
|
|
55
86
|
repeats = _read_repeats(graph, node.inputs[1], node)
|
|
56
87
|
if repeats is None:
|
|
57
|
-
|
|
88
|
+
repeats_shape = value_shape(graph, node.inputs[1], node)
|
|
89
|
+
repeats_dtype = value_dtype(graph, node.inputs[1], node)
|
|
90
|
+
if repeats_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
91
|
+
raise UnsupportedOpError("Tile repeats input must be int64 or int32")
|
|
92
|
+
if len(repeats_shape) != 1:
|
|
93
|
+
raise UnsupportedOpError("Tile repeats input must be a 1D tensor")
|
|
94
|
+
repeats = _infer_repeats_from_shapes(input_shape, output_shape)
|
|
58
95
|
if len(repeats) != len(input_shape):
|
|
59
96
|
raise ShapeInferenceError(
|
|
60
97
|
"Tile repeats must have the same rank as input shape"
|
emx_onnx_cgen/lowering/topk.py
CHANGED
|
@@ -117,17 +117,13 @@ def lower_topk(graph: Graph, node: Node) -> TopKOp:
|
|
|
117
117
|
sorted_output = bool(int(node.attrs.get("sorted", 1)))
|
|
118
118
|
return TopKOp(
|
|
119
119
|
input0=input_name,
|
|
120
|
+
k_input=k_name,
|
|
120
121
|
output_values=output_values,
|
|
121
122
|
output_indices=output_indices,
|
|
122
|
-
input_shape=input_shape,
|
|
123
|
-
output_shape=output_shape,
|
|
124
123
|
axis=axis,
|
|
125
124
|
k=k,
|
|
126
125
|
largest=largest,
|
|
127
126
|
sorted=sorted_output,
|
|
128
|
-
input_dtype=input_dtype,
|
|
129
|
-
output_values_dtype=values_dtype,
|
|
130
|
-
output_indices_dtype=indices_dtype,
|
|
131
127
|
)
|
|
132
128
|
|
|
133
129
|
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..ir.ops import TransposeOp
|
|
4
3
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
4
|
+
from ..ir.context import GraphContext
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
|
+
from ..ir.ops import TransposeOp
|
|
6
7
|
from .common import node_dtype as _node_dtype
|
|
8
|
+
from .common import value_has_dim_params as _value_has_dim_params
|
|
7
9
|
from .common import value_shape as _value_shape
|
|
8
10
|
from .registry import register_lowering
|
|
9
11
|
|
|
@@ -14,6 +16,8 @@ def lower_transpose(graph: Graph, node: Node) -> TransposeOp:
|
|
|
14
16
|
raise UnsupportedOpError("Transpose must have 1 input and 1 output")
|
|
15
17
|
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
16
18
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
19
|
+
if _value_has_dim_params(graph, node.outputs[0]) or not output_shape:
|
|
20
|
+
output_shape = ()
|
|
17
21
|
perm = node.attrs.get("perm")
|
|
18
22
|
if perm is None:
|
|
19
23
|
perm = tuple(reversed(range(len(input_shape))))
|
|
@@ -29,18 +33,20 @@ def lower_transpose(graph: Graph, node: Node) -> TransposeOp:
|
|
|
29
33
|
f"Transpose perm must be a permutation, got {perm}"
|
|
30
34
|
)
|
|
31
35
|
expected_shape = tuple(input_shape[axis] for axis in perm)
|
|
32
|
-
if output_shape != expected_shape:
|
|
36
|
+
if output_shape and output_shape != expected_shape:
|
|
33
37
|
raise ShapeInferenceError(
|
|
34
38
|
"Transpose output shape must match permuted input shape, "
|
|
35
39
|
f"expected {expected_shape}, got {output_shape}"
|
|
36
40
|
)
|
|
41
|
+
if isinstance(graph, GraphContext):
|
|
42
|
+
graph.set_shape(node.outputs[0], expected_shape)
|
|
37
43
|
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
38
44
|
return TransposeOp(
|
|
39
45
|
input0=node.inputs[0],
|
|
40
46
|
output=node.outputs[0],
|
|
41
47
|
perm=perm,
|
|
42
48
|
input_shape=input_shape,
|
|
43
|
-
output_shape=
|
|
49
|
+
output_shape=expected_shape,
|
|
44
50
|
dtype=op_dtype,
|
|
45
51
|
input_dtype=op_dtype,
|
|
46
52
|
)
|
|
@@ -2,30 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..ir.ops import ReshapeOp
|
|
6
5
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
6
|
+
from ..ir.context import GraphContext
|
|
7
7
|
from ..ir.model import Graph, Initializer, Node
|
|
8
|
+
from ..ir.ops import ReshapeOp
|
|
9
|
+
from ..lowering.common import value_dtype, value_has_dim_params, value_shape
|
|
8
10
|
from .registry import register_lowering
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
12
|
-
|
|
13
|
-
return graph.find_value(name).type.shape
|
|
14
|
-
except KeyError as exc:
|
|
15
|
-
raise ShapeInferenceError(
|
|
16
|
-
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
17
|
-
"Hint: run ONNX shape inference or export with static shapes."
|
|
18
|
-
) from exc
|
|
14
|
+
return value_shape(graph, name, node)
|
|
19
15
|
|
|
20
16
|
|
|
21
17
|
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
22
|
-
|
|
23
|
-
return graph.find_value(name).type.dtype
|
|
24
|
-
except KeyError as exc:
|
|
25
|
-
raise ShapeInferenceError(
|
|
26
|
-
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
27
|
-
"Hint: run ONNX shape inference or export with static shapes."
|
|
28
|
-
) from exc
|
|
18
|
+
return value_dtype(graph, name, node)
|
|
29
19
|
|
|
30
20
|
|
|
31
21
|
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
@@ -105,6 +95,8 @@ def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
|
|
|
105
95
|
raise UnsupportedOpError("Unsqueeze must have 1 or 2 inputs and 1 output")
|
|
106
96
|
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
107
97
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
98
|
+
if value_has_dim_params(graph, node.outputs[0]):
|
|
99
|
+
output_shape = ()
|
|
108
100
|
_validate_shape(input_shape, node, "input")
|
|
109
101
|
_validate_shape(output_shape, node, "output")
|
|
110
102
|
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
@@ -142,11 +134,14 @@ def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
|
|
|
142
134
|
)
|
|
143
135
|
else:
|
|
144
136
|
expected_shape = _expected_output_shape(input_shape, axes, node)
|
|
145
|
-
if expected_shape != output_shape:
|
|
137
|
+
if output_shape and expected_shape != output_shape:
|
|
146
138
|
raise ShapeInferenceError(
|
|
147
139
|
"Unsqueeze output shape must be "
|
|
148
140
|
f"{expected_shape}, got {output_shape}"
|
|
149
141
|
)
|
|
142
|
+
output_shape = expected_shape
|
|
143
|
+
if isinstance(graph, GraphContext):
|
|
144
|
+
graph.set_shape(node.outputs[0], output_shape)
|
|
150
145
|
return ReshapeOp(
|
|
151
146
|
input0=node.inputs[0],
|
|
152
147
|
output=node.outputs[0],
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
6
|
+
from ..ir.model import Graph, Initializer, Node
|
|
7
|
+
from ..ir.ops import ResizeOp
|
|
8
|
+
from .common import value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
_SUPPORTED_MODES = {"nearest", "linear"}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _decode_attr(value: object, default: str) -> str:
|
|
15
|
+
if value is None:
|
|
16
|
+
return default
|
|
17
|
+
if isinstance(value, bytes):
|
|
18
|
+
return value.decode("utf-8", errors="ignore")
|
|
19
|
+
if isinstance(value, str):
|
|
20
|
+
return value
|
|
21
|
+
return str(value)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
25
|
+
for initializer in graph.initializers:
|
|
26
|
+
if initializer.name == name:
|
|
27
|
+
return initializer
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _load_initializer_values(
|
|
32
|
+
graph: Graph, name: str, node: Node
|
|
33
|
+
) -> tuple[float | int, ...] | None:
|
|
34
|
+
initializer = _find_initializer(graph, name)
|
|
35
|
+
if initializer is None:
|
|
36
|
+
return None
|
|
37
|
+
if initializer.type.dtype not in {
|
|
38
|
+
ScalarType.F16,
|
|
39
|
+
ScalarType.F32,
|
|
40
|
+
ScalarType.F64,
|
|
41
|
+
}:
|
|
42
|
+
raise UnsupportedOpError(
|
|
43
|
+
"Upsample scales initializer must be float16/float32/float64"
|
|
44
|
+
)
|
|
45
|
+
data = initializer.data.reshape(-1)
|
|
46
|
+
return tuple(data.tolist())
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _validate_output_shape(
|
|
50
|
+
expected: tuple[int, ...],
|
|
51
|
+
actual: tuple[int, ...],
|
|
52
|
+
) -> None:
|
|
53
|
+
if expected != actual:
|
|
54
|
+
raise ShapeInferenceError(
|
|
55
|
+
f"Upsample output shape must be {expected}, got {actual}"
|
|
56
|
+
)
|
|
57
|
+
if any(dim < 0 for dim in actual):
|
|
58
|
+
raise ShapeInferenceError("Upsample output shape must be non-negative")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@register_lowering("Upsample")
|
|
62
|
+
def lower_upsample(graph: Graph, node: Node) -> ResizeOp:
|
|
63
|
+
if len(node.outputs) != 1:
|
|
64
|
+
raise UnsupportedOpError("Upsample expects one output")
|
|
65
|
+
if len(node.inputs) not in {1, 2}:
|
|
66
|
+
raise UnsupportedOpError("Upsample expects 1 or 2 inputs")
|
|
67
|
+
mode = _decode_attr(node.attrs.get("mode"), "nearest")
|
|
68
|
+
if mode not in _SUPPORTED_MODES:
|
|
69
|
+
raise UnsupportedOpError(f"Upsample mode {mode!r} is not supported")
|
|
70
|
+
input_name = node.inputs[0]
|
|
71
|
+
output_name = node.outputs[0]
|
|
72
|
+
input_shape = value_shape(graph, input_name, node)
|
|
73
|
+
output_shape = value_shape(graph, output_name, node)
|
|
74
|
+
input_dtype = value_dtype(graph, input_name, node)
|
|
75
|
+
output_dtype = value_dtype(graph, output_name, node)
|
|
76
|
+
if input_dtype != output_dtype:
|
|
77
|
+
raise UnsupportedOpError(
|
|
78
|
+
"Upsample expects matching input/output dtypes, "
|
|
79
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
80
|
+
)
|
|
81
|
+
rank = len(input_shape)
|
|
82
|
+
axes = tuple(range(rank))
|
|
83
|
+
scales_input = None
|
|
84
|
+
scales_shape = None
|
|
85
|
+
scales_dtype = None
|
|
86
|
+
scales_axes = None
|
|
87
|
+
scales: tuple[float, ...]
|
|
88
|
+
if len(node.inputs) == 2 and node.inputs[1]:
|
|
89
|
+
scales_input = node.inputs[1]
|
|
90
|
+
scales_shape = value_shape(graph, scales_input, node)
|
|
91
|
+
if len(scales_shape) != 1:
|
|
92
|
+
raise UnsupportedOpError("Upsample expects scales to be 1D")
|
|
93
|
+
if scales_shape[0] != rank:
|
|
94
|
+
raise UnsupportedOpError("Upsample scales length mismatch")
|
|
95
|
+
scales_dtype = value_dtype(graph, scales_input, node)
|
|
96
|
+
if scales_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
|
|
97
|
+
raise UnsupportedOpError(
|
|
98
|
+
"Upsample expects scales input to be float16/float32/float64"
|
|
99
|
+
)
|
|
100
|
+
values = _load_initializer_values(graph, scales_input, node)
|
|
101
|
+
if values is None:
|
|
102
|
+
scales = tuple(
|
|
103
|
+
output_shape[axis] / input_shape[axis]
|
|
104
|
+
for axis in range(rank)
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
scales = tuple(float(value) for value in values)
|
|
108
|
+
expected = tuple(
|
|
109
|
+
int(input_shape[axis] * scales[axis]) for axis in range(rank)
|
|
110
|
+
)
|
|
111
|
+
_validate_output_shape(expected, output_shape)
|
|
112
|
+
else:
|
|
113
|
+
scales_attr = node.attrs.get("scales")
|
|
114
|
+
if scales_attr is None:
|
|
115
|
+
raise UnsupportedOpError("Upsample requires scales attribute or input")
|
|
116
|
+
scales = tuple(float(value) for value in scales_attr)
|
|
117
|
+
if len(scales) != rank:
|
|
118
|
+
raise UnsupportedOpError("Upsample scales length mismatch")
|
|
119
|
+
expected = tuple(
|
|
120
|
+
int(input_shape[axis] * scales[axis]) for axis in range(rank)
|
|
121
|
+
)
|
|
122
|
+
_validate_output_shape(expected, output_shape)
|
|
123
|
+
return ResizeOp(
|
|
124
|
+
input0=input_name,
|
|
125
|
+
output=output_name,
|
|
126
|
+
input_shape=input_shape,
|
|
127
|
+
output_shape=output_shape,
|
|
128
|
+
scales=scales,
|
|
129
|
+
scales_input=scales_input,
|
|
130
|
+
sizes_input=None,
|
|
131
|
+
roi_input=None,
|
|
132
|
+
axes=axes,
|
|
133
|
+
scales_shape=scales_shape,
|
|
134
|
+
sizes_shape=None,
|
|
135
|
+
roi_shape=None,
|
|
136
|
+
scales_dtype=scales_dtype,
|
|
137
|
+
sizes_dtype=None,
|
|
138
|
+
roi_dtype=None,
|
|
139
|
+
scales_axes=scales_axes,
|
|
140
|
+
sizes_axes=None,
|
|
141
|
+
roi_axes=None,
|
|
142
|
+
mode=mode,
|
|
143
|
+
coordinate_transformation_mode="asymmetric",
|
|
144
|
+
nearest_mode="floor",
|
|
145
|
+
cubic_coeff_a=-0.75,
|
|
146
|
+
exclude_outside=False,
|
|
147
|
+
extrapolation_value=0.0,
|
|
148
|
+
antialias=False,
|
|
149
|
+
keep_aspect_ratio_policy="stretch",
|
|
150
|
+
dtype=input_dtype,
|
|
151
|
+
)
|
|
@@ -53,7 +53,7 @@ def _lower_variadic(graph: Graph, node: Node) -> MultiInputBinaryOp:
|
|
|
53
53
|
output=node.outputs[0],
|
|
54
54
|
function=VARIADIC_OP_FUNCTIONS[node.op_type],
|
|
55
55
|
operator_kind=VARIADIC_OP_OPERATOR_KINDS[node.op_type],
|
|
56
|
-
min_inputs=2,
|
|
56
|
+
min_inputs=1 if node.op_type not in BINARY_ONLY_OPS else 2,
|
|
57
57
|
max_inputs=2 if node.op_type in BINARY_ONLY_OPS else None,
|
|
58
58
|
)
|
|
59
59
|
|
emx_onnx_cgen/lowering/where.py
CHANGED