onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,281 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Arithmetic and element-wise operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ from ..op_registry import register
10
+ from ..utils.attributes import get_attribute
11
+ from ..utils.op_helpers import binary_op, get_optional_input, unary_op
12
+
13
+ if TYPE_CHECKING:
14
+ from ..graph_builder import GraphBuilder
15
+
16
+
17
+ # =============================================================================
18
+ # Binary arithmetic operators
19
+ # =============================================================================
20
+
21
+
22
+ register("Add")(binary_op(torch.add, "Element-wise addition."))
23
+ register("Sub")(binary_op(torch.sub, "Element-wise subtraction."))
24
+ register("Mul")(binary_op(torch.mul, "Element-wise multiplication."))
25
+ register("Pow")(binary_op(torch.pow, "Element-wise power."))
26
+
27
+
28
+ def _onnx_div(lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
29
+ """ONNX-compatible division.
30
+
31
+ For integer types, ONNX Div truncates toward zero (like C integer division),
32
+ and the result must have the same integer dtype as the inputs.
33
+ For floating-point types, it performs standard division.
34
+ """
35
+ if not lhs.dtype.is_floating_point and not lhs.dtype.is_complex:
36
+ # Integer types: use truncation toward zero
37
+ return torch.div(lhs, rhs, rounding_mode="trunc")
38
+ return torch.div(lhs, rhs)
39
+
40
+
41
+ @register("Div")
42
+ def div(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
43
+ """Element-wise division."""
44
+ lhs = builder.get_value(node.input[0])
45
+ rhs = builder.get_value(node.input[1])
46
+ return builder.call_function(_onnx_div, args=(lhs, rhs))
47
+
48
+
49
+ @register("Mod")
50
+ def mod(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
51
+ """Element-wise modulo."""
52
+ lhs = builder.get_value(node.input[0])
53
+ rhs = builder.get_value(node.input[1])
54
+ fmod = get_attribute(node, "fmod", 0)
55
+ if fmod:
56
+ return builder.call_function(torch.fmod, args=(lhs, rhs))
57
+ return builder.call_function(torch.remainder, args=(lhs, rhs))
58
+
59
+
60
+ @register("Min")
61
+ def min_(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
62
+ """Element-wise minimum of inputs."""
63
+ result = builder.get_value(node.input[0])
64
+ for i in range(1, len(node.input)):
65
+ other = builder.get_value(node.input[i])
66
+ result = builder.call_function(torch.minimum, args=(result, other))
67
+ return result
68
+
69
+
70
+ @register("Max")
71
+ def max_(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
72
+ """Element-wise maximum of inputs."""
73
+ result = builder.get_value(node.input[0])
74
+ for i in range(1, len(node.input)):
75
+ other = builder.get_value(node.input[i])
76
+ result = builder.call_function(torch.maximum, args=(result, other))
77
+ return result
78
+
79
+
80
+ @register("Mean")
81
+ def mean(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
82
+ """Element-wise mean of inputs."""
83
+ inputs = [builder.get_value(name) for name in node.input]
84
+ # Stack and compute mean along first dimension
85
+ stacked = builder.call_function(torch.stack, args=(inputs,))
86
+ return builder.call_function(torch.mean, args=(stacked,), kwargs={"dim": 0})
87
+
88
+
89
+ @register("Sum")
90
+ def sum_(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
91
+ """Element-wise sum of inputs."""
92
+ result = builder.get_value(node.input[0])
93
+ for i in range(1, len(node.input)):
94
+ other = builder.get_value(node.input[i])
95
+ result = builder.call_function(torch.add, args=(result, other))
96
+ return result
97
+
98
+
99
+ # =============================================================================
100
+ # Unary arithmetic operators
101
+ # =============================================================================
102
+
103
+
104
+ register("Neg")(unary_op(torch.neg, "Element-wise negation."))
105
+ register("Abs")(unary_op(torch.abs, "Element-wise absolute value."))
106
+ register("Sign")(unary_op(torch.sign, "Element-wise sign."))
107
+ register("Ceil")(unary_op(torch.ceil, "Element-wise ceiling."))
108
+ register("Floor")(unary_op(torch.floor, "Element-wise floor."))
109
+ register("Round")(unary_op(torch.round, "Element-wise rounding."))
110
+ register("Reciprocal")(unary_op(torch.reciprocal, "Element-wise reciprocal."))
111
+ register("Sqrt")(unary_op(torch.sqrt, "Element-wise square root."))
112
+ register("Exp")(unary_op(torch.exp, "Element-wise exponential."))
113
+ register("Log")(unary_op(torch.log, "Element-wise natural logarithm."))
114
+
115
+
116
+ # =============================================================================
117
+ # Comparison operators
118
+ # =============================================================================
119
+
120
+
121
+ register("Equal")(binary_op(torch.eq, "Element-wise equality comparison."))
122
+ register("Greater")(binary_op(torch.gt, "Element-wise greater-than comparison."))
123
+ register("GreaterOrEqual")(
124
+ binary_op(torch.ge, "Element-wise greater-than-or-equal comparison.")
125
+ )
126
+ register("Less")(binary_op(torch.lt, "Element-wise less-than comparison."))
127
+ register("LessOrEqual")(
128
+ binary_op(torch.le, "Element-wise less-than-or-equal comparison.")
129
+ )
130
+ register("And")(binary_op(torch.logical_and, "Element-wise logical AND."))
131
+ register("Or")(binary_op(torch.logical_or, "Element-wise logical OR."))
132
+ register("Xor")(binary_op(torch.logical_xor, "Element-wise logical XOR."))
133
+ register("Not")(unary_op(torch.logical_not, "Element-wise logical NOT."))
134
+
135
+
136
+ @register("Where")
137
+ def where(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
138
+ """Element-wise conditional selection."""
139
+ condition = builder.get_value(node.input[0])
140
+ x = builder.get_value(node.input[1])
141
+ y = builder.get_value(node.input[2])
142
+ return builder.call_function(torch.where, args=(condition, x, y))
143
+
144
+
145
+ # =============================================================================
146
+ # Clip operator
147
+ # =============================================================================
148
+
149
+
150
+ @register("Clip", since_version=1)
151
+ def clip_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
152
+ """Clip tensor values to a range for opset 1-10.
153
+
154
+ In opset < 11, min and max are required attributes.
155
+ """
156
+ x = builder.get_value(node.input[0])
157
+
158
+ min_val = get_attribute(node, "min", float("-inf"))
159
+ max_val = get_attribute(node, "max", float("inf"))
160
+
161
+ return builder.call_function(
162
+ torch.clamp, args=(x,), kwargs={"min": min_val, "max": max_val}
163
+ )
164
+
165
+
166
+ @register("Clip", since_version=11)
167
+ def clip_v11(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
168
+ """Clip tensor values to a range for opset 11+.
169
+
170
+ In opset 11+, min and max are optional inputs.
171
+ """
172
+ x = builder.get_value(node.input[0])
173
+
174
+ min_val = get_optional_input(builder, node, 1)
175
+ max_val = get_optional_input(builder, node, 2)
176
+
177
+ return builder.call_function(
178
+ torch.clamp, args=(x,), kwargs={"min": min_val, "max": max_val}
179
+ )
180
+
181
+
182
+ # =============================================================================
183
+ # Trigonometric functions
184
+ # =============================================================================
185
+
186
+
187
+ register("Sin")(unary_op(torch.sin, "Sine."))
188
+ register("Cos")(unary_op(torch.cos, "Cosine."))
189
+ register("Tan")(unary_op(torch.tan, "Tangent."))
190
+ register("Asin")(unary_op(torch.asin, "Arc sine."))
191
+ register("Acos")(unary_op(torch.acos, "Arc cosine."))
192
+ register("Atan")(unary_op(torch.atan, "Arc tangent."))
193
+
194
+
195
+ # =============================================================================
196
+ # Hyperbolic functions
197
+ # =============================================================================
198
+
199
+
200
+ register("Sinh")(unary_op(torch.sinh, "Hyperbolic sine."))
201
+ register("Cosh")(unary_op(torch.cosh, "Hyperbolic cosine."))
202
+ register("Asinh")(unary_op(torch.asinh, "Inverse hyperbolic sine."))
203
+ register("Acosh")(unary_op(torch.acosh, "Inverse hyperbolic cosine."))
204
+ register("Atanh")(unary_op(torch.atanh, "Inverse hyperbolic tangent."))
205
+
206
+
207
+ # =============================================================================
208
+ # Additional math functions
209
+ # =============================================================================
210
+
211
+
212
+ register("Erf")(unary_op(torch.erf, "Error function."))
213
+ register("IsNaN")(unary_op(torch.isnan, "Check for NaN."))
214
+
215
+
216
+ @register("IsInf")
217
+ def isinf(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
218
+ """Check for Inf."""
219
+ x = builder.get_value(node.input[0])
220
+ detect_negative = get_attribute(node, "detect_negative", 1)
221
+ detect_positive = get_attribute(node, "detect_positive", 1)
222
+
223
+ def _isinf(x, detect_neg, detect_pos):
224
+ if detect_neg and detect_pos:
225
+ return torch.isinf(x)
226
+ elif detect_pos:
227
+ return torch.isposinf(x)
228
+ elif detect_neg:
229
+ return torch.isneginf(x)
230
+ else:
231
+ return torch.zeros_like(x, dtype=torch.bool)
232
+
233
+ return builder.call_function(_isinf, args=(x, detect_negative, detect_positive))
234
+
235
+
236
+ # =============================================================================
237
+ # Bitwise operations
238
+ # =============================================================================
239
+
240
+
241
+ register("BitwiseAnd")(binary_op(torch.bitwise_and, "Bitwise AND."))
242
+ register("BitwiseOr")(binary_op(torch.bitwise_or, "Bitwise OR."))
243
+ register("BitwiseXor")(binary_op(torch.bitwise_xor, "Bitwise XOR."))
244
+
245
+
246
+ @register("BitwiseNot")
247
+ def bitwise_not(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
248
+ """Bitwise NOT."""
249
+ x = builder.get_value(node.input[0])
250
+
251
+ # PyTorch bitwise_not doesn't support some unsigned types (e.g., uint16) on CPU.
252
+ # We handle this by casting to a signed type with the same bit width,
253
+ # performing the operation, and casting back.
254
+ def _bitwise_not(x):
255
+ original_dtype = x.dtype
256
+ # Map unsigned types to signed equivalents with same bit width
257
+ dtype_map = {
258
+ torch.uint16: torch.int16,
259
+ torch.uint32: torch.int32,
260
+ }
261
+ if original_dtype in dtype_map:
262
+ x = x.to(dtype_map[original_dtype])
263
+ result = torch.bitwise_not(x)
264
+ return result.to(original_dtype)
265
+ return torch.bitwise_not(x)
266
+
267
+ return builder.call_function(_bitwise_not, args=(x,))
268
+
269
+
270
+ @register("BitShift")
271
+ def bit_shift(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
272
+ """Bitwise shift operation."""
273
+ x = builder.get_value(node.input[0])
274
+ y = builder.get_value(node.input[1])
275
+
276
+ direction = get_attribute(node, "direction", "LEFT")
277
+
278
+ if direction == "LEFT":
279
+ return builder.call_function(torch.bitwise_left_shift, args=(x, y))
280
+ else:
281
+ return builder.call_function(torch.bitwise_right_shift, args=(x, y))