emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
|
@@ -6,7 +6,13 @@ from shared.scalar_functions import ScalarFunction
|
|
|
6
6
|
from shared.scalar_types import ScalarType
|
|
7
7
|
|
|
8
8
|
from ...ops import COMPARE_FUNCTIONS, OperatorKind, binary_op_symbol
|
|
9
|
-
from
|
|
9
|
+
from ...errors import ShapeInferenceError, UnsupportedOpError
|
|
10
|
+
from ..op_base import (
|
|
11
|
+
BroadcastingOpBase,
|
|
12
|
+
ElementwiseOpBase,
|
|
13
|
+
RenderableOpBase,
|
|
14
|
+
VariadicLikeOpBase,
|
|
15
|
+
)
|
|
10
16
|
from ..op_context import OpContext
|
|
11
17
|
|
|
12
18
|
|
|
@@ -17,11 +23,6 @@ class BinaryOp(ElementwiseOpBase):
|
|
|
17
23
|
output: str
|
|
18
24
|
function: ScalarFunction
|
|
19
25
|
operator_kind: OperatorKind
|
|
20
|
-
input0_shape: tuple[int, ...]
|
|
21
|
-
input1_shape: tuple[int, ...]
|
|
22
|
-
shape: tuple[int, ...]
|
|
23
|
-
dtype: ScalarType
|
|
24
|
-
input_dtype: ScalarType
|
|
25
26
|
|
|
26
27
|
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
27
28
|
return (self.input0, self.input1)
|
|
@@ -32,6 +33,107 @@ class BinaryOp(ElementwiseOpBase):
|
|
|
32
33
|
def _elementwise_compare(self) -> bool:
|
|
33
34
|
return self.function in COMPARE_FUNCTIONS
|
|
34
35
|
|
|
36
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
37
|
+
if self.function != ScalarFunction.PRELU:
|
|
38
|
+
return super().infer_shapes(ctx)
|
|
39
|
+
input_shape = ctx.shape(self.input0)
|
|
40
|
+
slope_shape = ctx.shape(self.input1)
|
|
41
|
+
output_name = self.output
|
|
42
|
+
if BroadcastingOpBase.unidirectional_broadcastable(
|
|
43
|
+
slope_shape, input_shape
|
|
44
|
+
):
|
|
45
|
+
ctx.set_shape(output_name, input_shape)
|
|
46
|
+
return None
|
|
47
|
+
channel_axis = BroadcastingOpBase.prelu_channel_axis(
|
|
48
|
+
input_shape, slope_shape
|
|
49
|
+
)
|
|
50
|
+
if channel_axis is not None:
|
|
51
|
+
ctx.set_shape(output_name, input_shape)
|
|
52
|
+
ctx.set_derived(self, "prelu_slope_axis", channel_axis)
|
|
53
|
+
return None
|
|
54
|
+
raise ShapeInferenceError(
|
|
55
|
+
"Broadcasting mismatch for shapes: "
|
|
56
|
+
+ ", ".join(str(shape) for shape in (input_shape, slope_shape))
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
_POW_BASE_DTYPES = {
|
|
61
|
+
ScalarType.F16,
|
|
62
|
+
ScalarType.F32,
|
|
63
|
+
ScalarType.F64,
|
|
64
|
+
ScalarType.I32,
|
|
65
|
+
ScalarType.I64,
|
|
66
|
+
}
|
|
67
|
+
_POW_EXPONENT_DTYPES = {
|
|
68
|
+
ScalarType.F16,
|
|
69
|
+
ScalarType.F32,
|
|
70
|
+
ScalarType.F64,
|
|
71
|
+
ScalarType.I8,
|
|
72
|
+
ScalarType.I16,
|
|
73
|
+
ScalarType.I32,
|
|
74
|
+
ScalarType.I64,
|
|
75
|
+
ScalarType.U8,
|
|
76
|
+
ScalarType.U16,
|
|
77
|
+
ScalarType.U32,
|
|
78
|
+
ScalarType.U64,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass(frozen=True)
|
|
83
|
+
class PowOp(BinaryOp):
|
|
84
|
+
def validate(self, ctx: OpContext) -> None:
|
|
85
|
+
base_dtype = ctx.dtype(self.input0)
|
|
86
|
+
exponent_dtype = ctx.dtype(self.input1)
|
|
87
|
+
if base_dtype not in _POW_BASE_DTYPES:
|
|
88
|
+
raise UnsupportedOpError(
|
|
89
|
+
"Pow base dtype must be one of "
|
|
90
|
+
f"{', '.join(dtype.onnx_name for dtype in sorted(_POW_BASE_DTYPES, key=str))}, "
|
|
91
|
+
f"got {base_dtype.onnx_name}"
|
|
92
|
+
)
|
|
93
|
+
if exponent_dtype not in _POW_EXPONENT_DTYPES:
|
|
94
|
+
raise UnsupportedOpError(
|
|
95
|
+
"Pow exponent dtype must be one of "
|
|
96
|
+
f"{', '.join(dtype.onnx_name for dtype in sorted(_POW_EXPONENT_DTYPES, key=str))}, "
|
|
97
|
+
f"got {exponent_dtype.onnx_name}"
|
|
98
|
+
)
|
|
99
|
+
try:
|
|
100
|
+
output_dtype = ctx.dtype(self.output)
|
|
101
|
+
except ShapeInferenceError:
|
|
102
|
+
return None
|
|
103
|
+
if output_dtype != base_dtype:
|
|
104
|
+
raise UnsupportedOpError(
|
|
105
|
+
"Pow expects output dtype "
|
|
106
|
+
f"{base_dtype.onnx_name}, got {output_dtype.onnx_name}"
|
|
107
|
+
)
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
111
|
+
base_dtype = ctx.dtype(self.input0)
|
|
112
|
+
exponent_dtype = ctx.dtype(self.input1)
|
|
113
|
+
if base_dtype not in _POW_BASE_DTYPES:
|
|
114
|
+
raise UnsupportedOpError(
|
|
115
|
+
"Pow base dtype must be one of "
|
|
116
|
+
f"{', '.join(dtype.onnx_name for dtype in sorted(_POW_BASE_DTYPES, key=str))}, "
|
|
117
|
+
f"got {base_dtype.onnx_name}"
|
|
118
|
+
)
|
|
119
|
+
if exponent_dtype not in _POW_EXPONENT_DTYPES:
|
|
120
|
+
raise UnsupportedOpError(
|
|
121
|
+
"Pow exponent dtype must be one of "
|
|
122
|
+
f"{', '.join(dtype.onnx_name for dtype in sorted(_POW_EXPONENT_DTYPES, key=str))}, "
|
|
123
|
+
f"got {exponent_dtype.onnx_name}"
|
|
124
|
+
)
|
|
125
|
+
try:
|
|
126
|
+
output_dtype = ctx.dtype(self.output)
|
|
127
|
+
except ShapeInferenceError:
|
|
128
|
+
ctx.set_dtype(self.output, base_dtype)
|
|
129
|
+
return None
|
|
130
|
+
if output_dtype != base_dtype:
|
|
131
|
+
raise UnsupportedOpError(
|
|
132
|
+
"Pow expects output dtype "
|
|
133
|
+
f"{base_dtype.onnx_name}, got {output_dtype.onnx_name}"
|
|
134
|
+
)
|
|
135
|
+
return None
|
|
136
|
+
|
|
35
137
|
|
|
36
138
|
@dataclass(frozen=True)
|
|
37
139
|
class VariadicOp(VariadicLikeOpBase):
|
|
@@ -80,11 +182,6 @@ class WhereOp(ElementwiseOpBase):
|
|
|
80
182
|
input_x: str
|
|
81
183
|
input_y: str
|
|
82
184
|
output: str
|
|
83
|
-
condition_shape: tuple[int, ...]
|
|
84
|
-
x_shape: tuple[int, ...]
|
|
85
|
-
y_shape: tuple[int, ...]
|
|
86
|
-
output_shape: tuple[int, ...]
|
|
87
|
-
dtype: ScalarType
|
|
88
185
|
|
|
89
186
|
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
90
187
|
return (self.condition, self.input_x, self.input_y)
|
|
@@ -101,9 +198,6 @@ class UnaryOp(ElementwiseOpBase):
|
|
|
101
198
|
input0: str
|
|
102
199
|
output: str
|
|
103
200
|
function: ScalarFunction
|
|
104
|
-
shape: tuple[int, ...]
|
|
105
|
-
dtype: ScalarType
|
|
106
|
-
input_dtype: ScalarType
|
|
107
201
|
params: tuple[float, ...] = ()
|
|
108
202
|
|
|
109
203
|
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
@@ -126,11 +220,8 @@ class ClipOp(ElementwiseOpBase):
|
|
|
126
220
|
input_min: str | None
|
|
127
221
|
input_max: str | None
|
|
128
222
|
output: str
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
max_shape: tuple[int, ...] | None
|
|
132
|
-
output_shape: tuple[int, ...]
|
|
133
|
-
dtype: ScalarType
|
|
223
|
+
min_value: float | None = None
|
|
224
|
+
max_value: float | None = None
|
|
134
225
|
|
|
135
226
|
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
136
227
|
inputs = [self.input0]
|
|
@@ -152,9 +243,6 @@ class ClipOp(ElementwiseOpBase):
|
|
|
152
243
|
class IdentityOp(ElementwiseOpBase):
|
|
153
244
|
input0: str
|
|
154
245
|
output: str
|
|
155
|
-
shape: tuple[int, ...]
|
|
156
|
-
dtype: ScalarType
|
|
157
|
-
input_dtype: ScalarType
|
|
158
246
|
|
|
159
247
|
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
160
248
|
return (self.input0,)
|
|
@@ -165,3 +253,31 @@ class IdentityOp(ElementwiseOpBase):
|
|
|
165
253
|
def validate(self, ctx: OpContext) -> None:
|
|
166
254
|
super().validate(ctx)
|
|
167
255
|
return None
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@dataclass(frozen=True)
|
|
259
|
+
class QLinearMulOp(RenderableOpBase):
|
|
260
|
+
input0: str
|
|
261
|
+
input0_scale: str
|
|
262
|
+
input0_zero_point: str
|
|
263
|
+
input1: str
|
|
264
|
+
input1_scale: str
|
|
265
|
+
input1_zero_point: str
|
|
266
|
+
output_scale: str
|
|
267
|
+
output_zero_point: str
|
|
268
|
+
output: str
|
|
269
|
+
input0_shape: tuple[int, ...]
|
|
270
|
+
input1_shape: tuple[int, ...]
|
|
271
|
+
output_shape: tuple[int, ...]
|
|
272
|
+
input0_dtype: ScalarType
|
|
273
|
+
input1_dtype: ScalarType
|
|
274
|
+
dtype: ScalarType
|
|
275
|
+
input0_scale_dtype: ScalarType
|
|
276
|
+
input1_scale_dtype: ScalarType
|
|
277
|
+
output_scale_dtype: ScalarType
|
|
278
|
+
input0_scale_shape: tuple[int, ...]
|
|
279
|
+
input1_scale_shape: tuple[int, ...]
|
|
280
|
+
output_scale_shape: tuple[int, ...]
|
|
281
|
+
input0_zero_shape: tuple[int, ...]
|
|
282
|
+
input1_zero_shape: tuple[int, ...]
|
|
283
|
+
output_zero_shape: tuple[int, ...]
|
emx_onnx_cgen/ir/ops/misc.py
CHANGED
|
@@ -53,6 +53,20 @@ class QuantizeLinearOp(RenderableOpBase):
|
|
|
53
53
|
input_dtype: ScalarType
|
|
54
54
|
scale_dtype: ScalarType
|
|
55
55
|
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True)
|
|
58
|
+
class DequantizeLinearOp(RenderableOpBase):
|
|
59
|
+
input0: str
|
|
60
|
+
scale: str
|
|
61
|
+
zero_point: str | None
|
|
62
|
+
output: str
|
|
63
|
+
input_shape: tuple[int, ...]
|
|
64
|
+
axis: int | None
|
|
65
|
+
block_size: int | None
|
|
66
|
+
dtype: ScalarType
|
|
67
|
+
input_dtype: ScalarType
|
|
68
|
+
scale_dtype: ScalarType
|
|
69
|
+
|
|
56
70
|
@dataclass(frozen=True)
|
|
57
71
|
class ConcatOp(RenderableOpBase):
|
|
58
72
|
inputs: tuple[str, ...]
|
|
@@ -184,6 +198,16 @@ class EyeLikeOp(RenderableOpBase):
|
|
|
184
198
|
dtype: ScalarType
|
|
185
199
|
input_dtype: ScalarType
|
|
186
200
|
|
|
201
|
+
@dataclass(frozen=True)
|
|
202
|
+
class BernoulliOp(RenderableOpBase):
|
|
203
|
+
input0: str
|
|
204
|
+
output: str
|
|
205
|
+
input_shape: tuple[int, ...]
|
|
206
|
+
output_shape: tuple[int, ...]
|
|
207
|
+
input_dtype: ScalarType
|
|
208
|
+
dtype: ScalarType
|
|
209
|
+
seed: int | None
|
|
210
|
+
|
|
187
211
|
@dataclass(frozen=True)
|
|
188
212
|
class TriluOp(RenderableOpBase):
|
|
189
213
|
input0: str
|
|
@@ -355,6 +379,51 @@ class SizeOp(RenderableOpBase):
|
|
|
355
379
|
dtype: ScalarType
|
|
356
380
|
input_dtype: ScalarType
|
|
357
381
|
|
|
382
|
+
|
|
383
|
+
@dataclass(frozen=True)
|
|
384
|
+
class OptionalHasElementOp(RenderableOpBase):
|
|
385
|
+
input0: str
|
|
386
|
+
output: str
|
|
387
|
+
|
|
388
|
+
def validate(self, ctx: OpContext) -> None:
|
|
389
|
+
value = ctx.graph.find_value(self.input0)
|
|
390
|
+
if not value.type.is_optional:
|
|
391
|
+
raise UnsupportedOpError(
|
|
392
|
+
f"{self.kind} expects optional input, got non-optional tensor."
|
|
393
|
+
)
|
|
394
|
+
try:
|
|
395
|
+
output_dtype = ctx.dtype(self.output)
|
|
396
|
+
except ShapeInferenceError:
|
|
397
|
+
return None
|
|
398
|
+
if output_dtype != ScalarType.BOOL:
|
|
399
|
+
raise UnsupportedOpError(
|
|
400
|
+
f"{self.kind} expects bool output, got {output_dtype.onnx_name}"
|
|
401
|
+
)
|
|
402
|
+
return None
|
|
403
|
+
|
|
404
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
405
|
+
ctx.dtype(self.input0)
|
|
406
|
+
try:
|
|
407
|
+
output_dtype = ctx.dtype(self.output)
|
|
408
|
+
except ShapeInferenceError:
|
|
409
|
+
ctx.set_dtype(self.output, ScalarType.BOOL)
|
|
410
|
+
return None
|
|
411
|
+
if output_dtype != ScalarType.BOOL:
|
|
412
|
+
raise UnsupportedOpError(
|
|
413
|
+
f"{self.kind} expects bool output, got {output_dtype.onnx_name}"
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
417
|
+
try:
|
|
418
|
+
output_shape = ctx.shape(self.output)
|
|
419
|
+
except ShapeInferenceError:
|
|
420
|
+
ctx.set_shape(self.output, ())
|
|
421
|
+
return None
|
|
422
|
+
if output_shape not in {(), (1,)}:
|
|
423
|
+
raise UnsupportedOpError(
|
|
424
|
+
f"{self.kind} expects scalar output, got shape {output_shape}"
|
|
425
|
+
)
|
|
426
|
+
|
|
358
427
|
@dataclass(frozen=True)
|
|
359
428
|
class NonZeroOp(RenderableOpBase):
|
|
360
429
|
input0: str
|
|
@@ -477,6 +546,15 @@ class RangeOp(RenderableOpBase):
|
|
|
477
546
|
dtype: ScalarType
|
|
478
547
|
input_dtype: ScalarType
|
|
479
548
|
|
|
549
|
+
@dataclass(frozen=True)
|
|
550
|
+
class HammingWindowOp(RenderableOpBase):
|
|
551
|
+
size: str
|
|
552
|
+
output: str
|
|
553
|
+
output_shape: tuple[int, ...]
|
|
554
|
+
periodic: bool
|
|
555
|
+
dtype: ScalarType
|
|
556
|
+
input_dtype: ScalarType
|
|
557
|
+
|
|
480
558
|
@dataclass(frozen=True)
|
|
481
559
|
class OneHotOp(RenderableOpBase):
|
|
482
560
|
indices: str
|
|
@@ -492,6 +570,23 @@ class OneHotOp(RenderableOpBase):
|
|
|
492
570
|
indices_dtype: ScalarType
|
|
493
571
|
depth_dtype: ScalarType
|
|
494
572
|
|
|
573
|
+
@dataclass(frozen=True)
|
|
574
|
+
class TfIdfVectorizerOp(RenderableOpBase):
|
|
575
|
+
input0: str
|
|
576
|
+
output: str
|
|
577
|
+
input_shape: tuple[int, ...]
|
|
578
|
+
output_shape: tuple[int, ...]
|
|
579
|
+
input_dtype: ScalarType
|
|
580
|
+
output_dtype: ScalarType
|
|
581
|
+
min_gram_length: int
|
|
582
|
+
max_gram_length: int
|
|
583
|
+
max_skip_count: int
|
|
584
|
+
mode: str
|
|
585
|
+
ngram_counts: tuple[int, ...]
|
|
586
|
+
ngram_indexes: tuple[int, ...]
|
|
587
|
+
pool_int64s: tuple[int, ...]
|
|
588
|
+
weights: tuple[float, ...] | None
|
|
589
|
+
|
|
495
590
|
@dataclass(frozen=True)
|
|
496
591
|
class SplitOp(RenderableOpBase):
|
|
497
592
|
input0: str
|