emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Protocol
6
+
7
+ from shared.scalar_types import ScalarType
8
+
9
+ from ..errors import ShapeInferenceError, UnsupportedOpError
10
+ from .op_context import OpContext
11
+
12
+
13
+ class Emitter(Protocol):
14
+ def render_op(self, op: "OpBase", ctx: "EmitContext") -> str:
15
+ ...
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class EmitContext:
20
+ op_index: int
21
+
22
+
23
+ class OpBase(ABC):
24
+ """Ops should not mutate themselves; store derived values in OpContext."""
25
+ inputs: tuple[str, ...]
26
+ outputs: tuple[str, ...]
27
+
28
+ def __getattr__(self, name: str) -> str:
29
+ if name == "kind":
30
+ return self.__class__.__name__
31
+ raise AttributeError(
32
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
33
+ )
34
+
35
+ def validate(self, ctx: OpContext) -> None:
36
+ return None
37
+
38
+ def infer_types(self, ctx: OpContext) -> None:
39
+ return None
40
+
41
+ def infer_shapes(self, ctx: OpContext) -> None:
42
+ return None
43
+
44
+ @abstractmethod
45
+ def emit(self, emitter: Emitter, ctx: EmitContext) -> str:
46
+ raise NotImplementedError
47
+
48
+
49
+ class RenderableOpBase(OpBase):
50
+ def emit(self, emitter: Emitter, ctx: EmitContext) -> str:
51
+ return emitter.render_op(self, ctx)
52
+
53
+
54
+ class ElementwiseOpBase(RenderableOpBase):
55
+ """Elementwise ops should validate against OpContext and store no derived state."""
56
+
57
+ def _elementwise_inputs(self) -> tuple[str, ...]:
58
+ raise NotImplementedError
59
+
60
+ def _elementwise_output(self) -> str:
61
+ raise NotImplementedError
62
+
63
+ def _elementwise_condition_inputs(self) -> tuple[str, ...]:
64
+ return ()
65
+
66
+ def _elementwise_compare(self) -> bool:
67
+ return False
68
+
69
+ def _elementwise_data_inputs(self) -> tuple[str, ...]:
70
+ inputs = self._elementwise_inputs()
71
+ condition_inputs = set(self._elementwise_condition_inputs())
72
+ return tuple(name for name in inputs if name not in condition_inputs)
73
+
74
+ def validate(self, ctx: OpContext) -> None:
75
+ condition_inputs = self._elementwise_condition_inputs()
76
+ for name in condition_inputs:
77
+ dtype = ctx.dtype(name)
78
+ if dtype != ScalarType.BOOL:
79
+ raise UnsupportedOpError(
80
+ f"{self.kind} expects bool condition, got {dtype.onnx_name}"
81
+ )
82
+ data_inputs = self._elementwise_data_inputs()
83
+ if not data_inputs:
84
+ return None
85
+ data_dtypes = tuple(ctx.dtype(name) for name in data_inputs)
86
+ if any(dtype != data_dtypes[0] for dtype in data_dtypes[1:]):
87
+ dtype_names = ", ".join(dtype.onnx_name for dtype in data_dtypes)
88
+ raise UnsupportedOpError(
89
+ f"{self.kind} expects matching input dtypes, got {dtype_names}"
90
+ )
91
+ output_dtype = ctx.dtype(self._elementwise_output())
92
+ if self._elementwise_compare():
93
+ if output_dtype != ScalarType.BOOL:
94
+ raise UnsupportedOpError(
95
+ f"{self.kind} expects bool output, got {output_dtype.onnx_name}"
96
+ )
97
+ return None
98
+ if output_dtype != data_dtypes[0]:
99
+ raise UnsupportedOpError(
100
+ f"{self.kind} expects output dtype {data_dtypes[0].onnx_name}, "
101
+ f"got {output_dtype.onnx_name}"
102
+ )
103
+ return None
104
+
105
+ def infer_types(self, ctx: OpContext) -> None:
106
+ input_names = self._elementwise_inputs()
107
+ output_name = self._elementwise_output()
108
+ for name in input_names:
109
+ ctx.dtype(name)
110
+ ctx.dtype(output_name)
111
+
112
+ def infer_shapes(self, ctx: OpContext) -> None:
113
+ input_names = self._elementwise_inputs()
114
+ output_name = self._elementwise_output()
115
+ input_shapes = tuple(ctx.shape(name) for name in input_names)
116
+ if len(input_shapes) == 1:
117
+ output_shape = input_shapes[0]
118
+ else:
119
+ output_shape = BroadcastingOpBase.broadcast_shapes(*input_shapes)
120
+ ctx.set_shape(output_name, output_shape)
121
+ return None
122
+
123
+
124
+ class ReduceOpBase(RenderableOpBase):
125
+ @staticmethod
126
+ def normalize_axes(
127
+ axes: tuple[int, ...] | None, rank: int
128
+ ) -> tuple[int, ...]:
129
+ if axes is None:
130
+ axes = tuple(range(rank))
131
+ normalized: list[int] = []
132
+ for axis in axes:
133
+ if axis < 0:
134
+ axis += rank
135
+ if axis < 0 or axis >= rank:
136
+ raise ShapeInferenceError(
137
+ f"Reduce axis {axis} is out of bounds for rank {rank}"
138
+ )
139
+ normalized.append(axis)
140
+ return tuple(dict.fromkeys(normalized))
141
+
142
+ @staticmethod
143
+ def reduced_shape(
144
+ input_shape: tuple[int, ...],
145
+ axes: tuple[int, ...] | None,
146
+ *,
147
+ keepdims: bool,
148
+ ) -> tuple[int, ...]:
149
+ rank = len(input_shape)
150
+ normalized_axes = ReduceOpBase.normalize_axes(axes, rank)
151
+ if keepdims:
152
+ return tuple(
153
+ 1 if axis in normalized_axes else dim
154
+ for axis, dim in enumerate(input_shape)
155
+ )
156
+ return tuple(
157
+ dim for axis, dim in enumerate(input_shape) if axis not in normalized_axes
158
+ )
159
+
160
+
161
+ class BroadcastingOpBase(RenderableOpBase):
162
+ @staticmethod
163
+ def broadcast_shapes(
164
+ *shapes: tuple[int, ...],
165
+ ) -> tuple[int, ...]:
166
+ if not shapes:
167
+ return ()
168
+ max_rank = max(len(shape) for shape in shapes)
169
+ padded_shapes = [
170
+ (1,) * (max_rank - len(shape)) + shape for shape in shapes
171
+ ]
172
+ result: list[int] = []
173
+ for dims in zip(*padded_shapes):
174
+ dim = max(dims)
175
+ if any(d not in {1, dim} for d in dims):
176
+ raise ShapeInferenceError(
177
+ "Broadcasting mismatch for shapes: "
178
+ + ", ".join(str(shape) for shape in shapes)
179
+ )
180
+ result.append(dim)
181
+ return tuple(result)
182
+
183
+
184
+ class MatMulLikeOpBase(RenderableOpBase):
185
+ pass
186
+
187
+
188
+ class GemmLikeOpBase(RenderableOpBase):
189
+ pass
190
+
191
+
192
+ class ConvLikeOpBase(RenderableOpBase):
193
+ pass
@@ -0,0 +1,65 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from .context import GraphContext
8
+
9
+
10
+ _MISSING = object()
11
+
12
+
13
+ @dataclass
14
+ class OpContext:
15
+ graph: GraphContext
16
+ _dtype_overrides: dict[str, ScalarType] = field(default_factory=dict)
17
+ _shape_overrides: dict[str, tuple[int, ...]] = field(default_factory=dict)
18
+ _derived: dict[int, dict[str, object]] = field(default_factory=dict)
19
+
20
+ def dtype(self, name: str) -> ScalarType:
21
+ if name in self._dtype_overrides:
22
+ return self._dtype_overrides[name]
23
+ return self.graph.dtype(name)
24
+
25
+ def shape(self, name: str) -> tuple[int, ...]:
26
+ if name in self._shape_overrides:
27
+ return self._shape_overrides[name]
28
+ return self.graph.shape(name)
29
+
30
+ def set_dtype(self, name: str, dtype: ScalarType) -> None:
31
+ self._dtype_overrides[name] = dtype
32
+ self.graph.set_dtype(name, dtype)
33
+
34
+ def set_shape(self, name: str, shape: tuple[int, ...]) -> None:
35
+ self._shape_overrides[name] = shape
36
+ self.graph.set_shape(name, shape)
37
+
38
+ def set_derived(self, op: object, key: str, value: object) -> None:
39
+ self._derived.setdefault(id(op), {})[key] = value
40
+
41
+ def get_derived(
42
+ self, op: object, key: str, default: object = _MISSING
43
+ ) -> object:
44
+ derived = self._derived.get(id(op), {})
45
+ if key in derived:
46
+ return derived[key]
47
+ if default is _MISSING:
48
+ return _MISSING
49
+ return default
50
+
51
+ def require_derived(self, op: object, key: str) -> object:
52
+ derived = self._derived.get(id(op), {})
53
+ if key in derived:
54
+ return derived[key]
55
+ raise KeyError(
56
+ f"Missing derived value '{key}' for op {op.__class__.__name__}"
57
+ )
58
+
59
+ def copy_derived(self, source_op: object, target_op: object) -> None:
60
+ derived = self._derived.get(id(source_op))
61
+ if derived:
62
+ self._derived[id(target_op)] = dict(derived)
63
+
64
+ def __getattr__(self, name: str):
65
+ return getattr(self.graph, name)
@@ -0,0 +1,130 @@
1
+ from .elementwise import BinaryOp, ClipOp, IdentityOp, MultiInputBinaryOp, UnaryOp, WhereOp
2
+ from .misc import (
3
+ CastOp,
4
+ ConcatOp,
5
+ ConstantOfShapeOp,
6
+ CumSumOp,
7
+ DepthToSpaceOp,
8
+ ExpandOp,
9
+ EyeLikeOp,
10
+ GatherElementsOp,
11
+ GatherNDOp,
12
+ GatherOp,
13
+ GridSampleOp,
14
+ NonMaxSuppressionOp,
15
+ NonZeroOp,
16
+ OneHotOp,
17
+ PadOp,
18
+ QuantizeLinearOp,
19
+ RangeOp,
20
+ ReshapeOp,
21
+ ResizeOp,
22
+ ScatterNDOp,
23
+ ShapeOp,
24
+ SizeOp,
25
+ SliceOp,
26
+ SpaceToDepthOp,
27
+ SplitOp,
28
+ TensorScatterOp,
29
+ TileOp,
30
+ TransposeOp,
31
+ TriluOp,
32
+ )
33
+ from .nn import (
34
+ AdagradOp,
35
+ AttentionOp,
36
+ AveragePoolOp,
37
+ BatchNormOp,
38
+ ConvOp,
39
+ ConvTransposeOp,
40
+ EinsumKind,
41
+ EinsumOp,
42
+ GemmOp,
43
+ GroupNormalizationOp,
44
+ HardmaxOp,
45
+ InstanceNormalizationOp,
46
+ LayerNormalizationOp,
47
+ LogSoftmaxOp,
48
+ LpNormalizationOp,
49
+ LpPoolOp,
50
+ LrnOp,
51
+ LstmOp,
52
+ MatMulOp,
53
+ MaxPoolOp,
54
+ MeanVarianceNormalizationOp,
55
+ NegativeLogLikelihoodLossOp,
56
+ QLinearMatMulOp,
57
+ RMSNormalizationOp,
58
+ RotaryEmbeddingOp,
59
+ SoftmaxCrossEntropyLossOp,
60
+ SoftmaxOp,
61
+ )
62
+ from .reduce import ArgReduceOp, ReduceOp, TopKOp
63
+
64
+ __all__ = [
65
+ "AdagradOp",
66
+ "ArgReduceOp",
67
+ "AttentionOp",
68
+ "AveragePoolOp",
69
+ "BatchNormOp",
70
+ "BinaryOp",
71
+ "CastOp",
72
+ "ClipOp",
73
+ "ConcatOp",
74
+ "ConstantOfShapeOp",
75
+ "ConvOp",
76
+ "ConvTransposeOp",
77
+ "CumSumOp",
78
+ "DepthToSpaceOp",
79
+ "EinsumKind",
80
+ "EinsumOp",
81
+ "ExpandOp",
82
+ "EyeLikeOp",
83
+ "GatherElementsOp",
84
+ "GatherNDOp",
85
+ "GatherOp",
86
+ "GemmOp",
87
+ "GridSampleOp",
88
+ "GroupNormalizationOp",
89
+ "HardmaxOp",
90
+ "IdentityOp",
91
+ "InstanceNormalizationOp",
92
+ "LayerNormalizationOp",
93
+ "LogSoftmaxOp",
94
+ "LpNormalizationOp",
95
+ "LpPoolOp",
96
+ "LrnOp",
97
+ "LstmOp",
98
+ "MatMulOp",
99
+ "MaxPoolOp",
100
+ "MeanVarianceNormalizationOp",
101
+ "MultiInputBinaryOp",
102
+ "NegativeLogLikelihoodLossOp",
103
+ "NonMaxSuppressionOp",
104
+ "NonZeroOp",
105
+ "OneHotOp",
106
+ "PadOp",
107
+ "QuantizeLinearOp",
108
+ "QLinearMatMulOp",
109
+ "RangeOp",
110
+ "ReduceOp",
111
+ "ReshapeOp",
112
+ "ResizeOp",
113
+ "RMSNormalizationOp",
114
+ "RotaryEmbeddingOp",
115
+ "ScatterNDOp",
116
+ "ShapeOp",
117
+ "SizeOp",
118
+ "SliceOp",
119
+ "SoftmaxCrossEntropyLossOp",
120
+ "SoftmaxOp",
121
+ "SpaceToDepthOp",
122
+ "SplitOp",
123
+ "TensorScatterOp",
124
+ "TileOp",
125
+ "TopKOp",
126
+ "TransposeOp",
127
+ "TriluOp",
128
+ "UnaryOp",
129
+ "WhereOp",
130
+ ]
@@ -0,0 +1,146 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_functions import ScalarFunction
6
+ from shared.scalar_types import ScalarType
7
+
8
+ from ...ops import COMPARE_FUNCTIONS, OperatorKind
9
+ from ..op_base import ElementwiseOpBase
10
+ from ..op_context import OpContext
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class BinaryOp(ElementwiseOpBase):
15
+ input0: str
16
+ input1: str
17
+ output: str
18
+ function: ScalarFunction
19
+ operator_kind: OperatorKind
20
+ input0_shape: tuple[int, ...]
21
+ input1_shape: tuple[int, ...]
22
+ shape: tuple[int, ...]
23
+ dtype: ScalarType
24
+ input_dtype: ScalarType
25
+
26
+ def _elementwise_inputs(self) -> tuple[str, ...]:
27
+ return (self.input0, self.input1)
28
+
29
+ def _elementwise_output(self) -> str:
30
+ return self.output
31
+
32
+ def _elementwise_compare(self) -> bool:
33
+ return self.function in COMPARE_FUNCTIONS
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class MultiInputBinaryOp(ElementwiseOpBase):
38
+ inputs: tuple[str, ...]
39
+ output: str
40
+ function: ScalarFunction
41
+ operator_kind: OperatorKind
42
+ shape: tuple[int, ...]
43
+ dtype: ScalarType
44
+ input_dtype: ScalarType
45
+
46
+ def _elementwise_inputs(self) -> tuple[str, ...]:
47
+ return self.inputs
48
+
49
+ def _elementwise_output(self) -> str:
50
+ return self.output
51
+
52
+ def _elementwise_compare(self) -> bool:
53
+ return self.function in COMPARE_FUNCTIONS
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class WhereOp(ElementwiseOpBase):
58
+ condition: str
59
+ input_x: str
60
+ input_y: str
61
+ output: str
62
+ condition_shape: tuple[int, ...]
63
+ x_shape: tuple[int, ...]
64
+ y_shape: tuple[int, ...]
65
+ output_shape: tuple[int, ...]
66
+ dtype: ScalarType
67
+
68
+ def _elementwise_inputs(self) -> tuple[str, ...]:
69
+ return (self.condition, self.input_x, self.input_y)
70
+
71
+ def _elementwise_output(self) -> str:
72
+ return self.output
73
+
74
+ def _elementwise_condition_inputs(self) -> tuple[str, ...]:
75
+ return (self.condition,)
76
+
77
+
78
+ @dataclass(frozen=True)
79
+ class UnaryOp(ElementwiseOpBase):
80
+ input0: str
81
+ output: str
82
+ function: ScalarFunction
83
+ shape: tuple[int, ...]
84
+ dtype: ScalarType
85
+ input_dtype: ScalarType
86
+ params: tuple[float, ...] = ()
87
+
88
+ def _elementwise_inputs(self) -> tuple[str, ...]:
89
+ return (self.input0,)
90
+
91
+ def _elementwise_output(self) -> str:
92
+ return self.output
93
+
94
+ def validate(self, ctx: OpContext) -> None:
95
+ super().validate(ctx)
96
+ return None
97
+
98
+ def _elementwise_compare(self) -> bool:
99
+ return self.function in {ScalarFunction.ISINF, ScalarFunction.ISNAN}
100
+
101
+
102
+ @dataclass(frozen=True)
103
+ class ClipOp(ElementwiseOpBase):
104
+ input0: str
105
+ input_min: str | None
106
+ input_max: str | None
107
+ output: str
108
+ input_shape: tuple[int, ...]
109
+ min_shape: tuple[int, ...] | None
110
+ max_shape: tuple[int, ...] | None
111
+ output_shape: tuple[int, ...]
112
+ dtype: ScalarType
113
+
114
+ def _elementwise_inputs(self) -> tuple[str, ...]:
115
+ inputs = [self.input0]
116
+ if self.input_min is not None:
117
+ inputs.append(self.input_min)
118
+ if self.input_max is not None:
119
+ inputs.append(self.input_max)
120
+ return tuple(inputs)
121
+
122
+ def _elementwise_output(self) -> str:
123
+ return self.output
124
+
125
+ def validate(self, ctx: OpContext) -> None:
126
+ super().validate(ctx)
127
+ return None
128
+
129
+
130
+ @dataclass(frozen=True)
131
+ class IdentityOp(ElementwiseOpBase):
132
+ input0: str
133
+ output: str
134
+ shape: tuple[int, ...]
135
+ dtype: ScalarType
136
+ input_dtype: ScalarType
137
+
138
+ def _elementwise_inputs(self) -> tuple[str, ...]:
139
+ return (self.input0,)
140
+
141
+ def _elementwise_output(self) -> str:
142
+ return self.output
143
+
144
+ def validate(self, ctx: OpContext) -> None:
145
+ super().validate(ctx)
146
+ return None