emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/ir/ops/nn.py
CHANGED
|
@@ -6,7 +6,7 @@ from enum import Enum
|
|
|
6
6
|
from shared.scalar_functions import ScalarFunction
|
|
7
7
|
from shared.scalar_types import ScalarType
|
|
8
8
|
|
|
9
|
-
from ...errors import ShapeInferenceError
|
|
9
|
+
from ...errors import ShapeInferenceError, UnsupportedOpError
|
|
10
10
|
from ..op_base import ConvLikeOpBase, GemmLikeOpBase, MatMulLikeOpBase, RenderableOpBase
|
|
11
11
|
from ..op_context import OpContext
|
|
12
12
|
|
|
@@ -29,23 +29,117 @@ def _shape_product(shape: tuple[int, ...]) -> int:
|
|
|
29
29
|
return product
|
|
30
30
|
|
|
31
31
|
|
|
32
|
+
def _broadcast_batch_shapes(
|
|
33
|
+
left: tuple[int, ...], right: tuple[int, ...]
|
|
34
|
+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
35
|
+
max_rank = max(len(left), len(right))
|
|
36
|
+
left_padded = (1,) * (max_rank - len(left)) + left
|
|
37
|
+
right_padded = (1,) * (max_rank - len(right)) + right
|
|
38
|
+
broadcast_shape: list[int] = []
|
|
39
|
+
for left_dim, right_dim in zip(left_padded, right_padded):
|
|
40
|
+
if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
|
|
41
|
+
raise ShapeInferenceError(
|
|
42
|
+
"MatMul batch dimensions must be broadcastable, "
|
|
43
|
+
f"got {left} x {right}"
|
|
44
|
+
)
|
|
45
|
+
broadcast_shape.append(max(left_dim, right_dim))
|
|
46
|
+
return tuple(broadcast_shape), left_padded, right_padded
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _resolve_matmul_spec(
|
|
50
|
+
ctx: OpContext, input0: str, input1: str
|
|
51
|
+
) -> dict[str, object]:
|
|
52
|
+
input0_shape = ctx.shape(input0)
|
|
53
|
+
input1_shape = ctx.shape(input1)
|
|
54
|
+
if len(input0_shape) < 1 or len(input1_shape) < 1:
|
|
55
|
+
raise UnsupportedOpError(
|
|
56
|
+
"MatMul inputs must be at least 1D, "
|
|
57
|
+
f"got {input0_shape} x {input1_shape}"
|
|
58
|
+
)
|
|
59
|
+
left_vector = len(input0_shape) == 1
|
|
60
|
+
right_vector = len(input1_shape) == 1
|
|
61
|
+
input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
|
|
62
|
+
input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
|
|
63
|
+
m, k_left = input0_effective[-2], input0_effective[-1]
|
|
64
|
+
k_right, n = input1_effective[-2], input1_effective[-1]
|
|
65
|
+
if k_left != k_right:
|
|
66
|
+
raise ShapeInferenceError(
|
|
67
|
+
f"MatMul inner dimensions must match, got {k_left} and {k_right}"
|
|
68
|
+
)
|
|
69
|
+
batch_shape, input0_batch_shape, input1_batch_shape = (
|
|
70
|
+
_broadcast_batch_shapes(
|
|
71
|
+
input0_effective[:-2],
|
|
72
|
+
input1_effective[:-2],
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
if left_vector and right_vector:
|
|
76
|
+
output_shape = batch_shape
|
|
77
|
+
elif left_vector:
|
|
78
|
+
output_shape = batch_shape + (n,)
|
|
79
|
+
elif right_vector:
|
|
80
|
+
output_shape = batch_shape + (m,)
|
|
81
|
+
else:
|
|
82
|
+
output_shape = batch_shape + (m, n)
|
|
83
|
+
return {
|
|
84
|
+
"input0_shape": input0_shape,
|
|
85
|
+
"input1_shape": input1_shape,
|
|
86
|
+
"output_shape": output_shape,
|
|
87
|
+
"batch_shape": batch_shape,
|
|
88
|
+
"input0_batch_shape": input0_batch_shape,
|
|
89
|
+
"input1_batch_shape": input1_batch_shape,
|
|
90
|
+
"m": m,
|
|
91
|
+
"n": n,
|
|
92
|
+
"k": k_left,
|
|
93
|
+
"left_vector": left_vector,
|
|
94
|
+
"right_vector": right_vector,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
32
98
|
@dataclass(frozen=True)
|
|
33
99
|
class MatMulOp(MatMulLikeOpBase):
|
|
34
100
|
input0: str
|
|
35
101
|
input1: str
|
|
36
102
|
output: str
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
103
|
+
|
|
104
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
105
|
+
input0_dtype = ctx.dtype(self.input0)
|
|
106
|
+
input1_dtype = ctx.dtype(self.input1)
|
|
107
|
+
if input0_dtype != input1_dtype:
|
|
108
|
+
raise UnsupportedOpError(
|
|
109
|
+
"MatMul expects matching input dtypes, "
|
|
110
|
+
f"got {input0_dtype.onnx_name} and {input1_dtype.onnx_name}"
|
|
111
|
+
)
|
|
112
|
+
try:
|
|
113
|
+
output_dtype = ctx.dtype(self.output)
|
|
114
|
+
except ShapeInferenceError:
|
|
115
|
+
ctx.set_dtype(self.output, input0_dtype)
|
|
116
|
+
output_dtype = input0_dtype
|
|
117
|
+
if output_dtype != input0_dtype:
|
|
118
|
+
raise UnsupportedOpError(
|
|
119
|
+
"MatMul expects output dtype to match inputs, "
|
|
120
|
+
f"got {output_dtype.onnx_name} and {input0_dtype.onnx_name}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
124
|
+
spec = _resolve_matmul_spec(ctx, self.input0, self.input1)
|
|
125
|
+
output_shape = spec["output_shape"]
|
|
126
|
+
try:
|
|
127
|
+
expected = ctx.shape(self.output)
|
|
128
|
+
except ShapeInferenceError:
|
|
129
|
+
expected = None
|
|
130
|
+
if expected is not None and expected != output_shape:
|
|
131
|
+
raise ShapeInferenceError(
|
|
132
|
+
f"MatMul output shape must be {output_shape}, got {expected}"
|
|
133
|
+
)
|
|
134
|
+
ctx.set_shape(self.output, output_shape)
|
|
135
|
+
ctx.set_derived(self, "batch_shape", spec["batch_shape"])
|
|
136
|
+
ctx.set_derived(self, "input0_batch_shape", spec["input0_batch_shape"])
|
|
137
|
+
ctx.set_derived(self, "input1_batch_shape", spec["input1_batch_shape"])
|
|
138
|
+
ctx.set_derived(self, "m", spec["m"])
|
|
139
|
+
ctx.set_derived(self, "n", spec["n"])
|
|
140
|
+
ctx.set_derived(self, "k", spec["k"])
|
|
141
|
+
ctx.set_derived(self, "left_vector", spec["left_vector"])
|
|
142
|
+
ctx.set_derived(self, "right_vector", spec["right_vector"])
|
|
49
143
|
|
|
50
144
|
@dataclass(frozen=True)
|
|
51
145
|
class QLinearMatMulOp(MatMulLikeOpBase):
|
|
@@ -98,15 +192,139 @@ class GemmOp(GemmLikeOpBase):
|
|
|
98
192
|
input_b: str
|
|
99
193
|
input_c: str | None
|
|
100
194
|
output: str
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
k: int
|
|
104
|
-
trans_a: bool
|
|
105
|
-
trans_b: bool
|
|
195
|
+
trans_a: int
|
|
196
|
+
trans_b: int
|
|
106
197
|
alpha: float | int
|
|
107
198
|
beta: float | int
|
|
108
|
-
|
|
109
|
-
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def _normalize_attrs(
|
|
202
|
+
dtype: ScalarType,
|
|
203
|
+
*,
|
|
204
|
+
alpha: float | int,
|
|
205
|
+
beta: float | int,
|
|
206
|
+
trans_a: int,
|
|
207
|
+
trans_b: int,
|
|
208
|
+
) -> tuple[float | int, float | int, bool, bool]:
|
|
209
|
+
if trans_a not in {0, 1} or trans_b not in {0, 1}:
|
|
210
|
+
raise UnsupportedOpError(
|
|
211
|
+
"Gemm only supports transA/transB values of 0 or 1"
|
|
212
|
+
)
|
|
213
|
+
if dtype == ScalarType.BOOL:
|
|
214
|
+
raise UnsupportedOpError("Gemm supports numeric inputs only")
|
|
215
|
+
if not dtype.is_float:
|
|
216
|
+
alpha_int = int(alpha)
|
|
217
|
+
beta_int = int(beta)
|
|
218
|
+
if alpha != alpha_int or beta != beta_int:
|
|
219
|
+
raise UnsupportedOpError(
|
|
220
|
+
"Gemm alpha and beta must be integers for non-float inputs"
|
|
221
|
+
)
|
|
222
|
+
alpha = alpha_int
|
|
223
|
+
beta = beta_int
|
|
224
|
+
return alpha, beta, bool(trans_a), bool(trans_b)
|
|
225
|
+
|
|
226
|
+
@staticmethod
|
|
227
|
+
def _validate_bias_shape(
|
|
228
|
+
output_shape: tuple[int, int], bias_shape: tuple[int, ...]
|
|
229
|
+
) -> tuple[int, ...]:
|
|
230
|
+
if len(bias_shape) == 0:
|
|
231
|
+
return bias_shape
|
|
232
|
+
if len(bias_shape) == 1:
|
|
233
|
+
if bias_shape[0] not in {1, output_shape[1]}:
|
|
234
|
+
raise ShapeInferenceError(
|
|
235
|
+
"Gemm bias input must be broadcastable to output shape, "
|
|
236
|
+
f"got {bias_shape} vs {output_shape}"
|
|
237
|
+
)
|
|
238
|
+
return bias_shape
|
|
239
|
+
if len(bias_shape) == 2:
|
|
240
|
+
m, n = output_shape
|
|
241
|
+
if bias_shape[0] not in {1, m} or bias_shape[1] not in {1, n}:
|
|
242
|
+
raise ShapeInferenceError(
|
|
243
|
+
"Gemm bias input must be broadcastable to output shape, "
|
|
244
|
+
f"got {bias_shape} vs {output_shape}"
|
|
245
|
+
)
|
|
246
|
+
return bias_shape
|
|
247
|
+
raise ShapeInferenceError(
|
|
248
|
+
f"Gemm bias input must be rank 1 or 2, got {bias_shape}"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
252
|
+
input_a_dtype = ctx.dtype(self.input_a)
|
|
253
|
+
input_b_dtype = ctx.dtype(self.input_b)
|
|
254
|
+
if input_a_dtype != input_b_dtype:
|
|
255
|
+
raise UnsupportedOpError(
|
|
256
|
+
"Gemm expects matching input dtypes, "
|
|
257
|
+
f"got {input_a_dtype.onnx_name} and {input_b_dtype.onnx_name}"
|
|
258
|
+
)
|
|
259
|
+
if self.input_c is not None:
|
|
260
|
+
input_c_dtype = ctx.dtype(self.input_c)
|
|
261
|
+
if input_c_dtype != input_a_dtype:
|
|
262
|
+
raise UnsupportedOpError(
|
|
263
|
+
"Gemm expects bias dtype to match inputs, "
|
|
264
|
+
f"got {input_c_dtype.onnx_name} and {input_a_dtype.onnx_name}"
|
|
265
|
+
)
|
|
266
|
+
try:
|
|
267
|
+
output_dtype = ctx.dtype(self.output)
|
|
268
|
+
except ShapeInferenceError:
|
|
269
|
+
ctx.set_dtype(self.output, input_a_dtype)
|
|
270
|
+
output_dtype = input_a_dtype
|
|
271
|
+
if output_dtype != input_a_dtype:
|
|
272
|
+
raise UnsupportedOpError(
|
|
273
|
+
"Gemm expects output dtype to match inputs, "
|
|
274
|
+
f"got {output_dtype.onnx_name} and {input_a_dtype.onnx_name}"
|
|
275
|
+
)
|
|
276
|
+
alpha, beta, trans_a, trans_b = self._normalize_attrs(
|
|
277
|
+
output_dtype,
|
|
278
|
+
alpha=self.alpha,
|
|
279
|
+
beta=self.beta,
|
|
280
|
+
trans_a=self.trans_a,
|
|
281
|
+
trans_b=self.trans_b,
|
|
282
|
+
)
|
|
283
|
+
ctx.set_derived(self, "alpha", alpha)
|
|
284
|
+
ctx.set_derived(self, "beta", beta)
|
|
285
|
+
ctx.set_derived(self, "trans_a", trans_a)
|
|
286
|
+
ctx.set_derived(self, "trans_b", trans_b)
|
|
287
|
+
|
|
288
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
289
|
+
trans_a = ctx.require_derived(self, "trans_a")
|
|
290
|
+
trans_b = ctx.require_derived(self, "trans_b")
|
|
291
|
+
input_a_shape = ctx.shape(self.input_a)
|
|
292
|
+
input_b_shape = ctx.shape(self.input_b)
|
|
293
|
+
if len(input_a_shape) != 2 or len(input_b_shape) != 2:
|
|
294
|
+
raise UnsupportedOpError(
|
|
295
|
+
"Gemm supports 2D inputs only, "
|
|
296
|
+
f"got {input_a_shape} x {input_b_shape}"
|
|
297
|
+
)
|
|
298
|
+
if trans_a:
|
|
299
|
+
m, k_left = input_a_shape[1], input_a_shape[0]
|
|
300
|
+
else:
|
|
301
|
+
m, k_left = input_a_shape
|
|
302
|
+
if trans_b:
|
|
303
|
+
n, k_right = input_b_shape[0], input_b_shape[1]
|
|
304
|
+
else:
|
|
305
|
+
k_right, n = input_b_shape
|
|
306
|
+
if k_left != k_right:
|
|
307
|
+
raise ShapeInferenceError(
|
|
308
|
+
f"Gemm inner dimensions must match, got {k_left} and {k_right}"
|
|
309
|
+
)
|
|
310
|
+
output_shape = (m, n)
|
|
311
|
+
try:
|
|
312
|
+
expected = ctx.shape(self.output)
|
|
313
|
+
except ShapeInferenceError:
|
|
314
|
+
expected = None
|
|
315
|
+
if expected is not None and expected != output_shape:
|
|
316
|
+
raise ShapeInferenceError(
|
|
317
|
+
f"Gemm output shape must be {output_shape}, got {expected}"
|
|
318
|
+
)
|
|
319
|
+
ctx.set_shape(self.output, output_shape)
|
|
320
|
+
c_shape = None
|
|
321
|
+
if self.input_c is not None:
|
|
322
|
+
bias_shape = ctx.shape(self.input_c)
|
|
323
|
+
c_shape = self._validate_bias_shape(output_shape, bias_shape)
|
|
324
|
+
ctx.set_derived(self, "m", m)
|
|
325
|
+
ctx.set_derived(self, "n", n)
|
|
326
|
+
ctx.set_derived(self, "k", k_left)
|
|
327
|
+
ctx.set_derived(self, "c_shape", c_shape)
|
|
110
328
|
|
|
111
329
|
@dataclass(frozen=True)
|
|
112
330
|
class AttentionOp(RenderableOpBase):
|
|
@@ -205,6 +423,31 @@ class ConvOp(ConvLikeOpBase):
|
|
|
205
423
|
raise ValueError("Conv output width is undefined for spatial_rank < 2")
|
|
206
424
|
return self.out_spatial[1]
|
|
207
425
|
|
|
426
|
+
@dataclass(frozen=True)
|
|
427
|
+
class ConvIntegerOp(ConvLikeOpBase):
|
|
428
|
+
input0: str
|
|
429
|
+
weights: str
|
|
430
|
+
x_zero_point: str | None
|
|
431
|
+
w_zero_point: str | None
|
|
432
|
+
output: str
|
|
433
|
+
batch: int
|
|
434
|
+
in_channels: int
|
|
435
|
+
out_channels: int
|
|
436
|
+
spatial_rank: int
|
|
437
|
+
in_spatial: tuple[int, ...]
|
|
438
|
+
out_spatial: tuple[int, ...]
|
|
439
|
+
kernel_shape: tuple[int, ...]
|
|
440
|
+
strides: tuple[int, ...]
|
|
441
|
+
pads: tuple[int, ...]
|
|
442
|
+
dilations: tuple[int, ...]
|
|
443
|
+
group: int
|
|
444
|
+
input_dtype: ScalarType
|
|
445
|
+
weight_dtype: ScalarType
|
|
446
|
+
dtype: ScalarType
|
|
447
|
+
x_zero_point_shape: tuple[int, ...] | None
|
|
448
|
+
w_zero_point_shape: tuple[int, ...] | None
|
|
449
|
+
w_zero_point_per_channel: bool
|
|
450
|
+
|
|
208
451
|
@dataclass(frozen=True)
|
|
209
452
|
class ConvTransposeOp(ConvLikeOpBase):
|
|
210
453
|
input0: str
|
|
@@ -237,6 +480,8 @@ class AveragePoolOp(RenderableOpBase):
|
|
|
237
480
|
out_w: int
|
|
238
481
|
kernel_h: int
|
|
239
482
|
kernel_w: int
|
|
483
|
+
dilation_h: int
|
|
484
|
+
dilation_w: int
|
|
240
485
|
stride_h: int
|
|
241
486
|
stride_w: int
|
|
242
487
|
pad_top: int
|
|
@@ -245,6 +490,14 @@ class AveragePoolOp(RenderableOpBase):
|
|
|
245
490
|
pad_right: int
|
|
246
491
|
count_include_pad: bool
|
|
247
492
|
dtype: ScalarType
|
|
493
|
+
spatial_rank: int = 2
|
|
494
|
+
in_d: int = 1
|
|
495
|
+
out_d: int = 1
|
|
496
|
+
kernel_d: int = 1
|
|
497
|
+
dilation_d: int = 1
|
|
498
|
+
stride_d: int = 1
|
|
499
|
+
pad_front: int = 0
|
|
500
|
+
pad_back: int = 0
|
|
248
501
|
|
|
249
502
|
@dataclass(frozen=True)
|
|
250
503
|
class LpPoolOp(RenderableOpBase):
|
|
@@ -258,6 +511,8 @@ class LpPoolOp(RenderableOpBase):
|
|
|
258
511
|
out_w: int
|
|
259
512
|
kernel_h: int
|
|
260
513
|
kernel_w: int
|
|
514
|
+
dilation_h: int
|
|
515
|
+
dilation_w: int
|
|
261
516
|
stride_h: int
|
|
262
517
|
stride_w: int
|
|
263
518
|
pad_top: int
|
|
@@ -271,16 +526,29 @@ class LpPoolOp(RenderableOpBase):
|
|
|
271
526
|
class SoftmaxOp(RenderableOpBase):
|
|
272
527
|
input0: str
|
|
273
528
|
output: str
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
529
|
+
axis: int | None
|
|
530
|
+
|
|
531
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
532
|
+
input_dtype = ctx.dtype(self.input0)
|
|
533
|
+
if not input_dtype.is_float:
|
|
534
|
+
raise UnsupportedOpError(
|
|
535
|
+
"Softmax supports float16, float, and double inputs only"
|
|
536
|
+
)
|
|
537
|
+
try:
|
|
538
|
+
output_dtype = ctx.dtype(self.output)
|
|
539
|
+
except ShapeInferenceError:
|
|
540
|
+
ctx.set_dtype(self.output, input_dtype)
|
|
541
|
+
return None
|
|
542
|
+
if output_dtype != input_dtype:
|
|
543
|
+
raise UnsupportedOpError(
|
|
544
|
+
"Softmax expects output dtype to match input dtype"
|
|
545
|
+
)
|
|
280
546
|
|
|
281
547
|
def infer_shapes(self, ctx: OpContext) -> None:
|
|
282
548
|
input_shape = ctx.shape(self.input0)
|
|
283
549
|
axis = self.axis
|
|
550
|
+
if axis is None:
|
|
551
|
+
axis = -1
|
|
284
552
|
if axis < 0:
|
|
285
553
|
axis += len(input_shape)
|
|
286
554
|
if axis < 0 or axis >= len(input_shape):
|
|
@@ -295,6 +563,7 @@ class SoftmaxOp(RenderableOpBase):
|
|
|
295
563
|
else 1
|
|
296
564
|
)
|
|
297
565
|
ctx.set_shape(self.output, input_shape)
|
|
566
|
+
ctx.set_derived(self, "axis", axis)
|
|
298
567
|
ctx.set_derived(self, "outer", outer)
|
|
299
568
|
ctx.set_derived(self, "axis_size", axis_size)
|
|
300
569
|
ctx.set_derived(self, "inner", inner)
|
|
@@ -303,16 +572,29 @@ class SoftmaxOp(RenderableOpBase):
|
|
|
303
572
|
class LogSoftmaxOp(RenderableOpBase):
|
|
304
573
|
input0: str
|
|
305
574
|
output: str
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
575
|
+
axis: int | None
|
|
576
|
+
|
|
577
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
578
|
+
input_dtype = ctx.dtype(self.input0)
|
|
579
|
+
if not input_dtype.is_float:
|
|
580
|
+
raise UnsupportedOpError(
|
|
581
|
+
"LogSoftmax supports float16, float, and double inputs only"
|
|
582
|
+
)
|
|
583
|
+
try:
|
|
584
|
+
output_dtype = ctx.dtype(self.output)
|
|
585
|
+
except ShapeInferenceError:
|
|
586
|
+
ctx.set_dtype(self.output, input_dtype)
|
|
587
|
+
return None
|
|
588
|
+
if output_dtype != input_dtype:
|
|
589
|
+
raise UnsupportedOpError(
|
|
590
|
+
"LogSoftmax expects output dtype to match input dtype"
|
|
591
|
+
)
|
|
312
592
|
|
|
313
593
|
def infer_shapes(self, ctx: OpContext) -> None:
|
|
314
594
|
input_shape = ctx.shape(self.input0)
|
|
315
595
|
axis = self.axis
|
|
596
|
+
if axis is None:
|
|
597
|
+
axis = -1
|
|
316
598
|
if axis < 0:
|
|
317
599
|
axis += len(input_shape)
|
|
318
600
|
if axis < 0 or axis >= len(input_shape):
|
|
@@ -327,6 +609,7 @@ class LogSoftmaxOp(RenderableOpBase):
|
|
|
327
609
|
else 1
|
|
328
610
|
)
|
|
329
611
|
ctx.set_shape(self.output, input_shape)
|
|
612
|
+
ctx.set_derived(self, "axis", axis)
|
|
330
613
|
ctx.set_derived(self, "outer", outer)
|
|
331
614
|
ctx.set_derived(self, "axis_size", axis_size)
|
|
332
615
|
ctx.set_derived(self, "inner", inner)
|
|
@@ -335,16 +618,30 @@ class LogSoftmaxOp(RenderableOpBase):
|
|
|
335
618
|
class HardmaxOp(RenderableOpBase):
|
|
336
619
|
input0: str
|
|
337
620
|
output: str
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
621
|
+
axis: int | None
|
|
622
|
+
|
|
623
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
624
|
+
input_dtype = ctx.dtype(self.input0)
|
|
625
|
+
if input_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
|
|
626
|
+
raise UnsupportedOpError(
|
|
627
|
+
"Hardmax supports float16, float, and double inputs only"
|
|
628
|
+
)
|
|
629
|
+
try:
|
|
630
|
+
output_dtype = ctx.dtype(self.output)
|
|
631
|
+
except ShapeInferenceError:
|
|
632
|
+
ctx.set_dtype(self.output, input_dtype)
|
|
633
|
+
return None
|
|
634
|
+
if output_dtype != input_dtype:
|
|
635
|
+
raise UnsupportedOpError(
|
|
636
|
+
"Hardmax expects output dtype to match input dtype"
|
|
637
|
+
)
|
|
344
638
|
|
|
345
639
|
def infer_shapes(self, ctx: OpContext) -> None:
|
|
346
640
|
input_shape = ctx.shape(self.input0)
|
|
347
641
|
axis = self.axis
|
|
642
|
+
if axis is None:
|
|
643
|
+
opset_version = ctx.opset_version()
|
|
644
|
+
axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
348
645
|
if axis < 0:
|
|
349
646
|
axis += len(input_shape)
|
|
350
647
|
if axis < 0 or axis >= len(input_shape):
|
|
@@ -359,6 +656,7 @@ class HardmaxOp(RenderableOpBase):
|
|
|
359
656
|
else 1
|
|
360
657
|
)
|
|
361
658
|
ctx.set_shape(self.output, input_shape)
|
|
659
|
+
ctx.set_derived(self, "axis", axis)
|
|
362
660
|
ctx.set_derived(self, "outer", outer)
|
|
363
661
|
ctx.set_derived(self, "axis_size", axis_size)
|
|
364
662
|
ctx.set_derived(self, "inner", inner)
|
|
@@ -512,6 +810,31 @@ class LrnOp(RenderableOpBase):
|
|
|
512
810
|
bias: float
|
|
513
811
|
dtype: ScalarType
|
|
514
812
|
|
|
813
|
+
@dataclass(frozen=True)
|
|
814
|
+
class GruOp(RenderableOpBase):
|
|
815
|
+
input_x: str
|
|
816
|
+
input_w: str
|
|
817
|
+
input_r: str
|
|
818
|
+
input_b: str | None
|
|
819
|
+
input_sequence_lens: str | None
|
|
820
|
+
input_initial_h: str | None
|
|
821
|
+
output_y: str | None
|
|
822
|
+
output_y_h: str | None
|
|
823
|
+
seq_length: int
|
|
824
|
+
batch_size: int
|
|
825
|
+
input_size: int
|
|
826
|
+
hidden_size: int
|
|
827
|
+
num_directions: int
|
|
828
|
+
direction: str
|
|
829
|
+
layout: int
|
|
830
|
+
linear_before_reset: int
|
|
831
|
+
clip: float | None
|
|
832
|
+
activation_kinds: tuple[int, ...]
|
|
833
|
+
activation_alphas: tuple[float, ...]
|
|
834
|
+
activation_betas: tuple[float, ...]
|
|
835
|
+
dtype: ScalarType
|
|
836
|
+
sequence_lens_dtype: ScalarType | None
|
|
837
|
+
|
|
515
838
|
@dataclass(frozen=True)
|
|
516
839
|
class LstmOp(RenderableOpBase):
|
|
517
840
|
input_x: str
|
emx_onnx_cgen/ir/ops/reduce.py
CHANGED
|
@@ -2,8 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from shared.scalar_types import ScalarType
|
|
6
|
-
|
|
7
5
|
from ..op_base import ReduceOpBase
|
|
8
6
|
from ..op_context import OpContext
|
|
9
7
|
|
|
@@ -12,17 +10,12 @@ from ..op_context import OpContext
|
|
|
12
10
|
class ReduceOp(ReduceOpBase):
|
|
13
11
|
input0: str
|
|
14
12
|
output: str
|
|
15
|
-
input_shape: tuple[int, ...]
|
|
16
|
-
output_shape: tuple[int, ...]
|
|
17
13
|
axes: tuple[int, ...]
|
|
18
14
|
axes_input: str | None
|
|
19
|
-
axes_input_shape: tuple[int, ...] | None
|
|
20
|
-
axes_input_dtype: ScalarType | None
|
|
21
15
|
keepdims: bool
|
|
22
16
|
noop_with_empty_axes: bool
|
|
23
17
|
reduce_kind: str
|
|
24
18
|
reduce_count: int | None
|
|
25
|
-
dtype: ScalarType
|
|
26
19
|
|
|
27
20
|
def infer_types(self, ctx: OpContext) -> None:
|
|
28
21
|
ctx.dtype(self.output)
|
|
@@ -45,14 +38,10 @@ class ReduceOp(ReduceOpBase):
|
|
|
45
38
|
class ArgReduceOp(ReduceOpBase):
|
|
46
39
|
input0: str
|
|
47
40
|
output: str
|
|
48
|
-
input_shape: tuple[int, ...]
|
|
49
|
-
output_shape: tuple[int, ...]
|
|
50
41
|
axis: int
|
|
51
42
|
keepdims: bool
|
|
52
43
|
select_last_index: bool
|
|
53
44
|
reduce_kind: str
|
|
54
|
-
input_dtype: ScalarType
|
|
55
|
-
output_dtype: ScalarType
|
|
56
45
|
|
|
57
46
|
def infer_types(self, ctx: OpContext) -> None:
|
|
58
47
|
ctx.dtype(self.input0)
|
|
@@ -71,17 +60,13 @@ class ArgReduceOp(ReduceOpBase):
|
|
|
71
60
|
@dataclass(frozen=True)
|
|
72
61
|
class TopKOp(ReduceOpBase):
|
|
73
62
|
input0: str
|
|
63
|
+
k_input: str
|
|
74
64
|
output_values: str
|
|
75
65
|
output_indices: str
|
|
76
|
-
input_shape: tuple[int, ...]
|
|
77
|
-
output_shape: tuple[int, ...]
|
|
78
66
|
axis: int
|
|
79
67
|
k: int
|
|
80
68
|
largest: bool
|
|
81
69
|
sorted: bool
|
|
82
|
-
input_dtype: ScalarType
|
|
83
|
-
output_values_dtype: ScalarType
|
|
84
|
-
output_indices_dtype: ScalarType
|
|
85
70
|
|
|
86
71
|
def infer_types(self, ctx: OpContext) -> None:
|
|
87
72
|
ctx.dtype(self.input0)
|
|
@@ -10,13 +10,16 @@ _LOWERING_MODULES = [
|
|
|
10
10
|
"attention",
|
|
11
11
|
"average_pool",
|
|
12
12
|
"batch_normalization",
|
|
13
|
+
"bernoulli",
|
|
13
14
|
"cast",
|
|
14
15
|
"concat",
|
|
15
16
|
"constant_of_shape",
|
|
16
17
|
"conv",
|
|
18
|
+
"conv_integer",
|
|
17
19
|
"conv_transpose",
|
|
18
20
|
"cumsum",
|
|
19
21
|
"depth_space",
|
|
22
|
+
"dequantize_linear",
|
|
20
23
|
"dropout",
|
|
21
24
|
"einsum",
|
|
22
25
|
"elementwise",
|
|
@@ -29,8 +32,10 @@ _LOWERING_MODULES = [
|
|
|
29
32
|
"gemm",
|
|
30
33
|
"global_max_pool",
|
|
31
34
|
"grid_sample",
|
|
35
|
+
"gru",
|
|
32
36
|
"group_normalization",
|
|
33
37
|
"hardmax",
|
|
38
|
+
"hamming_window",
|
|
34
39
|
"identity",
|
|
35
40
|
"instance_normalization",
|
|
36
41
|
"layer_normalization",
|
|
@@ -45,9 +50,11 @@ _LOWERING_MODULES = [
|
|
|
45
50
|
"negative_log_likelihood_loss",
|
|
46
51
|
"non_max_suppression",
|
|
47
52
|
"nonzero",
|
|
53
|
+
"optional_has_element",
|
|
48
54
|
"one_hot",
|
|
49
55
|
"pad",
|
|
50
56
|
"qlinear_matmul",
|
|
57
|
+
"qlinear_mul",
|
|
51
58
|
"quantize_linear",
|
|
52
59
|
"range",
|
|
53
60
|
"reduce",
|
|
@@ -64,11 +71,13 @@ _LOWERING_MODULES = [
|
|
|
64
71
|
"split",
|
|
65
72
|
"squeeze",
|
|
66
73
|
"tensor_scatter",
|
|
74
|
+
"tfidf_vectorizer",
|
|
67
75
|
"tile",
|
|
68
76
|
"topk",
|
|
69
77
|
"transpose",
|
|
70
78
|
"trilu",
|
|
71
79
|
"unsqueeze",
|
|
80
|
+
"upsample",
|
|
72
81
|
"variadic",
|
|
73
82
|
"where",
|
|
74
83
|
]
|
|
@@ -84,14 +84,10 @@ def lower_arg_reduce(graph: Graph, node: Node) -> ArgReduceOp:
|
|
|
84
84
|
return ArgReduceOp(
|
|
85
85
|
input0=input_name,
|
|
86
86
|
output=output_name,
|
|
87
|
-
input_shape=input_shape,
|
|
88
|
-
output_shape=output_shape,
|
|
89
87
|
axis=axis,
|
|
90
88
|
keepdims=keepdims,
|
|
91
89
|
select_last_index=select_last_index,
|
|
92
90
|
reduce_kind=ARG_REDUCE_KIND_BY_OP[node.op_type],
|
|
93
|
-
input_dtype=input_dtype,
|
|
94
|
-
output_dtype=output_dtype,
|
|
95
91
|
)
|
|
96
92
|
|
|
97
93
|
|