emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +50 -23
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +30 -387
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +36 -18
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +1 -1
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +1 -1
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +1 -1
- emx_onnx_cgen/lowering/gather_nd.py +1 -1
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +1 -1
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +1 -1
- emx_onnx_cgen/lowering/identity.py +1 -1
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +1 -1
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +1 -1
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +1 -1
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +1 -1
- emx_onnx_cgen/lowering/one_hot.py +1 -1
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +1 -1
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +1 -1
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +1 -1
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +1 -1
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +1 -1
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +1 -1
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +25 -7
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +1 -1
- emx_onnx_cgen/lowering/unsqueeze.py +1 -1
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/runtime/evaluator.py +325 -1
- emx_onnx_cgen/verification.py +9 -39
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
- emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +11 -0
- shared/ulp.py +17 -0
- emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Protocol
|
|
6
|
+
|
|
7
|
+
from shared.scalar_types import ScalarType
|
|
8
|
+
|
|
9
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
|
+
from .op_context import OpContext
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Emitter(Protocol):
|
|
14
|
+
def render_op(self, op: "OpBase", ctx: "EmitContext") -> str:
|
|
15
|
+
...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class EmitContext:
|
|
20
|
+
op_index: int
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OpBase(ABC):
|
|
24
|
+
"""Ops should not mutate themselves; store derived values in OpContext."""
|
|
25
|
+
inputs: tuple[str, ...]
|
|
26
|
+
outputs: tuple[str, ...]
|
|
27
|
+
|
|
28
|
+
def __getattr__(self, name: str) -> str:
|
|
29
|
+
if name == "kind":
|
|
30
|
+
return self.__class__.__name__
|
|
31
|
+
raise AttributeError(
|
|
32
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def validate(self, ctx: OpContext) -> None:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def emit(self, emitter: Emitter, ctx: EmitContext) -> str:
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class RenderableOpBase(OpBase):
|
|
50
|
+
def emit(self, emitter: Emitter, ctx: EmitContext) -> str:
|
|
51
|
+
return emitter.render_op(self, ctx)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ElementwiseOpBase(RenderableOpBase):
|
|
55
|
+
"""Elementwise ops should validate against OpContext and store no derived state."""
|
|
56
|
+
|
|
57
|
+
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
def _elementwise_output(self) -> str:
|
|
61
|
+
raise NotImplementedError
|
|
62
|
+
|
|
63
|
+
def _elementwise_condition_inputs(self) -> tuple[str, ...]:
|
|
64
|
+
return ()
|
|
65
|
+
|
|
66
|
+
def _elementwise_compare(self) -> bool:
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
def _elementwise_data_inputs(self) -> tuple[str, ...]:
|
|
70
|
+
inputs = self._elementwise_inputs()
|
|
71
|
+
condition_inputs = set(self._elementwise_condition_inputs())
|
|
72
|
+
return tuple(name for name in inputs if name not in condition_inputs)
|
|
73
|
+
|
|
74
|
+
def validate(self, ctx: OpContext) -> None:
|
|
75
|
+
condition_inputs = self._elementwise_condition_inputs()
|
|
76
|
+
for name in condition_inputs:
|
|
77
|
+
dtype = ctx.dtype(name)
|
|
78
|
+
if dtype != ScalarType.BOOL:
|
|
79
|
+
raise UnsupportedOpError(
|
|
80
|
+
f"{self.kind} expects bool condition, got {dtype.onnx_name}"
|
|
81
|
+
)
|
|
82
|
+
data_inputs = self._elementwise_data_inputs()
|
|
83
|
+
if not data_inputs:
|
|
84
|
+
return None
|
|
85
|
+
data_dtypes = tuple(ctx.dtype(name) for name in data_inputs)
|
|
86
|
+
if any(dtype != data_dtypes[0] for dtype in data_dtypes[1:]):
|
|
87
|
+
dtype_names = ", ".join(dtype.onnx_name for dtype in data_dtypes)
|
|
88
|
+
raise UnsupportedOpError(
|
|
89
|
+
f"{self.kind} expects matching input dtypes, got {dtype_names}"
|
|
90
|
+
)
|
|
91
|
+
output_dtype = ctx.dtype(self._elementwise_output())
|
|
92
|
+
if self._elementwise_compare():
|
|
93
|
+
if output_dtype != ScalarType.BOOL:
|
|
94
|
+
raise UnsupportedOpError(
|
|
95
|
+
f"{self.kind} expects bool output, got {output_dtype.onnx_name}"
|
|
96
|
+
)
|
|
97
|
+
return None
|
|
98
|
+
if output_dtype != data_dtypes[0]:
|
|
99
|
+
raise UnsupportedOpError(
|
|
100
|
+
f"{self.kind} expects output dtype {data_dtypes[0].onnx_name}, "
|
|
101
|
+
f"got {output_dtype.onnx_name}"
|
|
102
|
+
)
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
106
|
+
input_names = self._elementwise_inputs()
|
|
107
|
+
output_name = self._elementwise_output()
|
|
108
|
+
for name in input_names:
|
|
109
|
+
ctx.dtype(name)
|
|
110
|
+
ctx.dtype(output_name)
|
|
111
|
+
|
|
112
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
113
|
+
input_names = self._elementwise_inputs()
|
|
114
|
+
output_name = self._elementwise_output()
|
|
115
|
+
input_shapes = tuple(ctx.shape(name) for name in input_names)
|
|
116
|
+
if len(input_shapes) == 1:
|
|
117
|
+
output_shape = input_shapes[0]
|
|
118
|
+
else:
|
|
119
|
+
output_shape = BroadcastingOpBase.broadcast_shapes(*input_shapes)
|
|
120
|
+
ctx.set_shape(output_name, output_shape)
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ReduceOpBase(RenderableOpBase):
|
|
125
|
+
@staticmethod
|
|
126
|
+
def normalize_axes(
|
|
127
|
+
axes: tuple[int, ...] | None, rank: int
|
|
128
|
+
) -> tuple[int, ...]:
|
|
129
|
+
if axes is None:
|
|
130
|
+
axes = tuple(range(rank))
|
|
131
|
+
normalized: list[int] = []
|
|
132
|
+
for axis in axes:
|
|
133
|
+
if axis < 0:
|
|
134
|
+
axis += rank
|
|
135
|
+
if axis < 0 or axis >= rank:
|
|
136
|
+
raise ShapeInferenceError(
|
|
137
|
+
f"Reduce axis {axis} is out of bounds for rank {rank}"
|
|
138
|
+
)
|
|
139
|
+
normalized.append(axis)
|
|
140
|
+
return tuple(dict.fromkeys(normalized))
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def reduced_shape(
|
|
144
|
+
input_shape: tuple[int, ...],
|
|
145
|
+
axes: tuple[int, ...] | None,
|
|
146
|
+
*,
|
|
147
|
+
keepdims: bool,
|
|
148
|
+
) -> tuple[int, ...]:
|
|
149
|
+
rank = len(input_shape)
|
|
150
|
+
normalized_axes = ReduceOpBase.normalize_axes(axes, rank)
|
|
151
|
+
if keepdims:
|
|
152
|
+
return tuple(
|
|
153
|
+
1 if axis in normalized_axes else dim
|
|
154
|
+
for axis, dim in enumerate(input_shape)
|
|
155
|
+
)
|
|
156
|
+
return tuple(
|
|
157
|
+
dim for axis, dim in enumerate(input_shape) if axis not in normalized_axes
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class BroadcastingOpBase(RenderableOpBase):
|
|
162
|
+
@staticmethod
|
|
163
|
+
def broadcast_shapes(
|
|
164
|
+
*shapes: tuple[int, ...],
|
|
165
|
+
) -> tuple[int, ...]:
|
|
166
|
+
if not shapes:
|
|
167
|
+
return ()
|
|
168
|
+
max_rank = max(len(shape) for shape in shapes)
|
|
169
|
+
padded_shapes = [
|
|
170
|
+
(1,) * (max_rank - len(shape)) + shape for shape in shapes
|
|
171
|
+
]
|
|
172
|
+
result: list[int] = []
|
|
173
|
+
for dims in zip(*padded_shapes):
|
|
174
|
+
dim = max(dims)
|
|
175
|
+
if any(d not in {1, dim} for d in dims):
|
|
176
|
+
raise ShapeInferenceError(
|
|
177
|
+
"Broadcasting mismatch for shapes: "
|
|
178
|
+
+ ", ".join(str(shape) for shape in shapes)
|
|
179
|
+
)
|
|
180
|
+
result.append(dim)
|
|
181
|
+
return tuple(result)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class MatMulLikeOpBase(RenderableOpBase):
|
|
185
|
+
pass
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class GemmLikeOpBase(RenderableOpBase):
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class ConvLikeOpBase(RenderableOpBase):
|
|
193
|
+
pass
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from .context import GraphContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
_MISSING = object()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class OpContext:
|
|
15
|
+
graph: GraphContext
|
|
16
|
+
_dtype_overrides: dict[str, ScalarType] = field(default_factory=dict)
|
|
17
|
+
_shape_overrides: dict[str, tuple[int, ...]] = field(default_factory=dict)
|
|
18
|
+
_derived: dict[int, dict[str, object]] = field(default_factory=dict)
|
|
19
|
+
|
|
20
|
+
def dtype(self, name: str) -> ScalarType:
|
|
21
|
+
if name in self._dtype_overrides:
|
|
22
|
+
return self._dtype_overrides[name]
|
|
23
|
+
return self.graph.dtype(name)
|
|
24
|
+
|
|
25
|
+
def shape(self, name: str) -> tuple[int, ...]:
|
|
26
|
+
if name in self._shape_overrides:
|
|
27
|
+
return self._shape_overrides[name]
|
|
28
|
+
return self.graph.shape(name)
|
|
29
|
+
|
|
30
|
+
def set_dtype(self, name: str, dtype: ScalarType) -> None:
|
|
31
|
+
self._dtype_overrides[name] = dtype
|
|
32
|
+
self.graph.set_dtype(name, dtype)
|
|
33
|
+
|
|
34
|
+
def set_shape(self, name: str, shape: tuple[int, ...]) -> None:
|
|
35
|
+
self._shape_overrides[name] = shape
|
|
36
|
+
self.graph.set_shape(name, shape)
|
|
37
|
+
|
|
38
|
+
def set_derived(self, op: object, key: str, value: object) -> None:
|
|
39
|
+
self._derived.setdefault(id(op), {})[key] = value
|
|
40
|
+
|
|
41
|
+
def get_derived(
|
|
42
|
+
self, op: object, key: str, default: object = _MISSING
|
|
43
|
+
) -> object:
|
|
44
|
+
derived = self._derived.get(id(op), {})
|
|
45
|
+
if key in derived:
|
|
46
|
+
return derived[key]
|
|
47
|
+
if default is _MISSING:
|
|
48
|
+
return _MISSING
|
|
49
|
+
return default
|
|
50
|
+
|
|
51
|
+
def require_derived(self, op: object, key: str) -> object:
|
|
52
|
+
derived = self._derived.get(id(op), {})
|
|
53
|
+
if key in derived:
|
|
54
|
+
return derived[key]
|
|
55
|
+
raise KeyError(
|
|
56
|
+
f"Missing derived value '{key}' for op {op.__class__.__name__}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def copy_derived(self, source_op: object, target_op: object) -> None:
|
|
60
|
+
derived = self._derived.get(id(source_op))
|
|
61
|
+
if derived:
|
|
62
|
+
self._derived[id(target_op)] = dict(derived)
|
|
63
|
+
|
|
64
|
+
def __getattr__(self, name: str):
|
|
65
|
+
return getattr(self.graph, name)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from .elementwise import BinaryOp, ClipOp, IdentityOp, MultiInputBinaryOp, UnaryOp, WhereOp
|
|
2
|
+
from .misc import (
|
|
3
|
+
CastOp,
|
|
4
|
+
ConcatOp,
|
|
5
|
+
ConstantOfShapeOp,
|
|
6
|
+
CumSumOp,
|
|
7
|
+
DepthToSpaceOp,
|
|
8
|
+
ExpandOp,
|
|
9
|
+
EyeLikeOp,
|
|
10
|
+
GatherElementsOp,
|
|
11
|
+
GatherNDOp,
|
|
12
|
+
GatherOp,
|
|
13
|
+
GridSampleOp,
|
|
14
|
+
NonMaxSuppressionOp,
|
|
15
|
+
NonZeroOp,
|
|
16
|
+
OneHotOp,
|
|
17
|
+
PadOp,
|
|
18
|
+
QuantizeLinearOp,
|
|
19
|
+
RangeOp,
|
|
20
|
+
ReshapeOp,
|
|
21
|
+
ResizeOp,
|
|
22
|
+
ScatterNDOp,
|
|
23
|
+
ShapeOp,
|
|
24
|
+
SizeOp,
|
|
25
|
+
SliceOp,
|
|
26
|
+
SpaceToDepthOp,
|
|
27
|
+
SplitOp,
|
|
28
|
+
TensorScatterOp,
|
|
29
|
+
TileOp,
|
|
30
|
+
TransposeOp,
|
|
31
|
+
TriluOp,
|
|
32
|
+
)
|
|
33
|
+
from .nn import (
|
|
34
|
+
AdagradOp,
|
|
35
|
+
AttentionOp,
|
|
36
|
+
AveragePoolOp,
|
|
37
|
+
BatchNormOp,
|
|
38
|
+
ConvOp,
|
|
39
|
+
ConvTransposeOp,
|
|
40
|
+
EinsumKind,
|
|
41
|
+
EinsumOp,
|
|
42
|
+
GemmOp,
|
|
43
|
+
GroupNormalizationOp,
|
|
44
|
+
HardmaxOp,
|
|
45
|
+
InstanceNormalizationOp,
|
|
46
|
+
LayerNormalizationOp,
|
|
47
|
+
LogSoftmaxOp,
|
|
48
|
+
LpNormalizationOp,
|
|
49
|
+
LpPoolOp,
|
|
50
|
+
LrnOp,
|
|
51
|
+
LstmOp,
|
|
52
|
+
MatMulOp,
|
|
53
|
+
MaxPoolOp,
|
|
54
|
+
MeanVarianceNormalizationOp,
|
|
55
|
+
NegativeLogLikelihoodLossOp,
|
|
56
|
+
QLinearMatMulOp,
|
|
57
|
+
RMSNormalizationOp,
|
|
58
|
+
RotaryEmbeddingOp,
|
|
59
|
+
SoftmaxCrossEntropyLossOp,
|
|
60
|
+
SoftmaxOp,
|
|
61
|
+
)
|
|
62
|
+
from .reduce import ArgReduceOp, ReduceOp, TopKOp
|
|
63
|
+
|
|
64
|
+
__all__ = [
|
|
65
|
+
"AdagradOp",
|
|
66
|
+
"ArgReduceOp",
|
|
67
|
+
"AttentionOp",
|
|
68
|
+
"AveragePoolOp",
|
|
69
|
+
"BatchNormOp",
|
|
70
|
+
"BinaryOp",
|
|
71
|
+
"CastOp",
|
|
72
|
+
"ClipOp",
|
|
73
|
+
"ConcatOp",
|
|
74
|
+
"ConstantOfShapeOp",
|
|
75
|
+
"ConvOp",
|
|
76
|
+
"ConvTransposeOp",
|
|
77
|
+
"CumSumOp",
|
|
78
|
+
"DepthToSpaceOp",
|
|
79
|
+
"EinsumKind",
|
|
80
|
+
"EinsumOp",
|
|
81
|
+
"ExpandOp",
|
|
82
|
+
"EyeLikeOp",
|
|
83
|
+
"GatherElementsOp",
|
|
84
|
+
"GatherNDOp",
|
|
85
|
+
"GatherOp",
|
|
86
|
+
"GemmOp",
|
|
87
|
+
"GridSampleOp",
|
|
88
|
+
"GroupNormalizationOp",
|
|
89
|
+
"HardmaxOp",
|
|
90
|
+
"IdentityOp",
|
|
91
|
+
"InstanceNormalizationOp",
|
|
92
|
+
"LayerNormalizationOp",
|
|
93
|
+
"LogSoftmaxOp",
|
|
94
|
+
"LpNormalizationOp",
|
|
95
|
+
"LpPoolOp",
|
|
96
|
+
"LrnOp",
|
|
97
|
+
"LstmOp",
|
|
98
|
+
"MatMulOp",
|
|
99
|
+
"MaxPoolOp",
|
|
100
|
+
"MeanVarianceNormalizationOp",
|
|
101
|
+
"MultiInputBinaryOp",
|
|
102
|
+
"NegativeLogLikelihoodLossOp",
|
|
103
|
+
"NonMaxSuppressionOp",
|
|
104
|
+
"NonZeroOp",
|
|
105
|
+
"OneHotOp",
|
|
106
|
+
"PadOp",
|
|
107
|
+
"QuantizeLinearOp",
|
|
108
|
+
"QLinearMatMulOp",
|
|
109
|
+
"RangeOp",
|
|
110
|
+
"ReduceOp",
|
|
111
|
+
"ReshapeOp",
|
|
112
|
+
"ResizeOp",
|
|
113
|
+
"RMSNormalizationOp",
|
|
114
|
+
"RotaryEmbeddingOp",
|
|
115
|
+
"ScatterNDOp",
|
|
116
|
+
"ShapeOp",
|
|
117
|
+
"SizeOp",
|
|
118
|
+
"SliceOp",
|
|
119
|
+
"SoftmaxCrossEntropyLossOp",
|
|
120
|
+
"SoftmaxOp",
|
|
121
|
+
"SpaceToDepthOp",
|
|
122
|
+
"SplitOp",
|
|
123
|
+
"TensorScatterOp",
|
|
124
|
+
"TileOp",
|
|
125
|
+
"TopKOp",
|
|
126
|
+
"TransposeOp",
|
|
127
|
+
"TriluOp",
|
|
128
|
+
"UnaryOp",
|
|
129
|
+
"WhereOp",
|
|
130
|
+
]
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from shared.scalar_functions import ScalarFunction
|
|
6
|
+
from shared.scalar_types import ScalarType
|
|
7
|
+
|
|
8
|
+
from ...ops import COMPARE_FUNCTIONS, OperatorKind
|
|
9
|
+
from ..op_base import ElementwiseOpBase
|
|
10
|
+
from ..op_context import OpContext
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class BinaryOp(ElementwiseOpBase):
|
|
15
|
+
input0: str
|
|
16
|
+
input1: str
|
|
17
|
+
output: str
|
|
18
|
+
function: ScalarFunction
|
|
19
|
+
operator_kind: OperatorKind
|
|
20
|
+
input0_shape: tuple[int, ...]
|
|
21
|
+
input1_shape: tuple[int, ...]
|
|
22
|
+
shape: tuple[int, ...]
|
|
23
|
+
dtype: ScalarType
|
|
24
|
+
input_dtype: ScalarType
|
|
25
|
+
|
|
26
|
+
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
27
|
+
return (self.input0, self.input1)
|
|
28
|
+
|
|
29
|
+
def _elementwise_output(self) -> str:
|
|
30
|
+
return self.output
|
|
31
|
+
|
|
32
|
+
def _elementwise_compare(self) -> bool:
|
|
33
|
+
return self.function in COMPARE_FUNCTIONS
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class MultiInputBinaryOp(ElementwiseOpBase):
|
|
38
|
+
inputs: tuple[str, ...]
|
|
39
|
+
output: str
|
|
40
|
+
function: ScalarFunction
|
|
41
|
+
operator_kind: OperatorKind
|
|
42
|
+
shape: tuple[int, ...]
|
|
43
|
+
dtype: ScalarType
|
|
44
|
+
input_dtype: ScalarType
|
|
45
|
+
|
|
46
|
+
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
47
|
+
return self.inputs
|
|
48
|
+
|
|
49
|
+
def _elementwise_output(self) -> str:
|
|
50
|
+
return self.output
|
|
51
|
+
|
|
52
|
+
def _elementwise_compare(self) -> bool:
|
|
53
|
+
return self.function in COMPARE_FUNCTIONS
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(frozen=True)
|
|
57
|
+
class WhereOp(ElementwiseOpBase):
|
|
58
|
+
condition: str
|
|
59
|
+
input_x: str
|
|
60
|
+
input_y: str
|
|
61
|
+
output: str
|
|
62
|
+
condition_shape: tuple[int, ...]
|
|
63
|
+
x_shape: tuple[int, ...]
|
|
64
|
+
y_shape: tuple[int, ...]
|
|
65
|
+
output_shape: tuple[int, ...]
|
|
66
|
+
dtype: ScalarType
|
|
67
|
+
|
|
68
|
+
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
69
|
+
return (self.condition, self.input_x, self.input_y)
|
|
70
|
+
|
|
71
|
+
def _elementwise_output(self) -> str:
|
|
72
|
+
return self.output
|
|
73
|
+
|
|
74
|
+
def _elementwise_condition_inputs(self) -> tuple[str, ...]:
|
|
75
|
+
return (self.condition,)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass(frozen=True)
|
|
79
|
+
class UnaryOp(ElementwiseOpBase):
|
|
80
|
+
input0: str
|
|
81
|
+
output: str
|
|
82
|
+
function: ScalarFunction
|
|
83
|
+
shape: tuple[int, ...]
|
|
84
|
+
dtype: ScalarType
|
|
85
|
+
input_dtype: ScalarType
|
|
86
|
+
params: tuple[float, ...] = ()
|
|
87
|
+
|
|
88
|
+
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
89
|
+
return (self.input0,)
|
|
90
|
+
|
|
91
|
+
def _elementwise_output(self) -> str:
|
|
92
|
+
return self.output
|
|
93
|
+
|
|
94
|
+
def validate(self, ctx: OpContext) -> None:
|
|
95
|
+
super().validate(ctx)
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
def _elementwise_compare(self) -> bool:
|
|
99
|
+
return self.function in {ScalarFunction.ISINF, ScalarFunction.ISNAN}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass(frozen=True)
|
|
103
|
+
class ClipOp(ElementwiseOpBase):
|
|
104
|
+
input0: str
|
|
105
|
+
input_min: str | None
|
|
106
|
+
input_max: str | None
|
|
107
|
+
output: str
|
|
108
|
+
input_shape: tuple[int, ...]
|
|
109
|
+
min_shape: tuple[int, ...] | None
|
|
110
|
+
max_shape: tuple[int, ...] | None
|
|
111
|
+
output_shape: tuple[int, ...]
|
|
112
|
+
dtype: ScalarType
|
|
113
|
+
|
|
114
|
+
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
115
|
+
inputs = [self.input0]
|
|
116
|
+
if self.input_min is not None:
|
|
117
|
+
inputs.append(self.input_min)
|
|
118
|
+
if self.input_max is not None:
|
|
119
|
+
inputs.append(self.input_max)
|
|
120
|
+
return tuple(inputs)
|
|
121
|
+
|
|
122
|
+
def _elementwise_output(self) -> str:
|
|
123
|
+
return self.output
|
|
124
|
+
|
|
125
|
+
def validate(self, ctx: OpContext) -> None:
|
|
126
|
+
super().validate(ctx)
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass(frozen=True)
|
|
131
|
+
class IdentityOp(ElementwiseOpBase):
|
|
132
|
+
input0: str
|
|
133
|
+
output: str
|
|
134
|
+
shape: tuple[int, ...]
|
|
135
|
+
dtype: ScalarType
|
|
136
|
+
input_dtype: ScalarType
|
|
137
|
+
|
|
138
|
+
def _elementwise_inputs(self) -> tuple[str, ...]:
|
|
139
|
+
return (self.input0,)
|
|
140
|
+
|
|
141
|
+
def _elementwise_output(self) -> str:
|
|
142
|
+
return self.output
|
|
143
|
+
|
|
144
|
+
def validate(self, ctx: OpContext) -> None:
|
|
145
|
+
super().validate(ctx)
|
|
146
|
+
return None
|