onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Arithmetic and element-wise operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..op_registry import register
|
|
10
|
+
from ..utils.attributes import get_attribute
|
|
11
|
+
from ..utils.op_helpers import binary_op, get_optional_input, unary_op
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..graph_builder import GraphBuilder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# =============================================================================
|
|
18
|
+
# Binary arithmetic operators
|
|
19
|
+
# =============================================================================
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
register("Add")(binary_op(torch.add, "Element-wise addition."))
|
|
23
|
+
register("Sub")(binary_op(torch.sub, "Element-wise subtraction."))
|
|
24
|
+
register("Mul")(binary_op(torch.mul, "Element-wise multiplication."))
|
|
25
|
+
register("Pow")(binary_op(torch.pow, "Element-wise power."))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _onnx_div(lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
"""ONNX-compatible division.
|
|
30
|
+
|
|
31
|
+
For integer types, ONNX Div truncates toward zero (like C integer division),
|
|
32
|
+
and the result must have the same integer dtype as the inputs.
|
|
33
|
+
For floating-point types, it performs standard division.
|
|
34
|
+
"""
|
|
35
|
+
if not lhs.dtype.is_floating_point and not lhs.dtype.is_complex:
|
|
36
|
+
# Integer types: use truncation toward zero
|
|
37
|
+
return torch.div(lhs, rhs, rounding_mode="trunc")
|
|
38
|
+
return torch.div(lhs, rhs)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register("Div")
|
|
42
|
+
def div(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
43
|
+
"""Element-wise division."""
|
|
44
|
+
lhs = builder.get_value(node.input[0])
|
|
45
|
+
rhs = builder.get_value(node.input[1])
|
|
46
|
+
return builder.call_function(_onnx_div, args=(lhs, rhs))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@register("Mod")
|
|
50
|
+
def mod(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
51
|
+
"""Element-wise modulo."""
|
|
52
|
+
lhs = builder.get_value(node.input[0])
|
|
53
|
+
rhs = builder.get_value(node.input[1])
|
|
54
|
+
fmod = get_attribute(node, "fmod", 0)
|
|
55
|
+
if fmod:
|
|
56
|
+
return builder.call_function(torch.fmod, args=(lhs, rhs))
|
|
57
|
+
return builder.call_function(torch.remainder, args=(lhs, rhs))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@register("Min")
|
|
61
|
+
def min_(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
62
|
+
"""Element-wise minimum of inputs."""
|
|
63
|
+
result = builder.get_value(node.input[0])
|
|
64
|
+
for i in range(1, len(node.input)):
|
|
65
|
+
other = builder.get_value(node.input[i])
|
|
66
|
+
result = builder.call_function(torch.minimum, args=(result, other))
|
|
67
|
+
return result
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@register("Max")
|
|
71
|
+
def max_(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
72
|
+
"""Element-wise maximum of inputs."""
|
|
73
|
+
result = builder.get_value(node.input[0])
|
|
74
|
+
for i in range(1, len(node.input)):
|
|
75
|
+
other = builder.get_value(node.input[i])
|
|
76
|
+
result = builder.call_function(torch.maximum, args=(result, other))
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@register("Mean")
|
|
81
|
+
def mean(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
82
|
+
"""Element-wise mean of inputs."""
|
|
83
|
+
inputs = [builder.get_value(name) for name in node.input]
|
|
84
|
+
# Stack and compute mean along first dimension
|
|
85
|
+
stacked = builder.call_function(torch.stack, args=(inputs,))
|
|
86
|
+
return builder.call_function(torch.mean, args=(stacked,), kwargs={"dim": 0})
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@register("Sum")
|
|
90
|
+
def sum_(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
91
|
+
"""Element-wise sum of inputs."""
|
|
92
|
+
result = builder.get_value(node.input[0])
|
|
93
|
+
for i in range(1, len(node.input)):
|
|
94
|
+
other = builder.get_value(node.input[i])
|
|
95
|
+
result = builder.call_function(torch.add, args=(result, other))
|
|
96
|
+
return result
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# =============================================================================
|
|
100
|
+
# Unary arithmetic operators
|
|
101
|
+
# =============================================================================
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
register("Neg")(unary_op(torch.neg, "Element-wise negation."))
|
|
105
|
+
register("Abs")(unary_op(torch.abs, "Element-wise absolute value."))
|
|
106
|
+
register("Sign")(unary_op(torch.sign, "Element-wise sign."))
|
|
107
|
+
register("Ceil")(unary_op(torch.ceil, "Element-wise ceiling."))
|
|
108
|
+
register("Floor")(unary_op(torch.floor, "Element-wise floor."))
|
|
109
|
+
register("Round")(unary_op(torch.round, "Element-wise rounding."))
|
|
110
|
+
register("Reciprocal")(unary_op(torch.reciprocal, "Element-wise reciprocal."))
|
|
111
|
+
register("Sqrt")(unary_op(torch.sqrt, "Element-wise square root."))
|
|
112
|
+
register("Exp")(unary_op(torch.exp, "Element-wise exponential."))
|
|
113
|
+
register("Log")(unary_op(torch.log, "Element-wise natural logarithm."))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# =============================================================================
|
|
117
|
+
# Comparison operators
|
|
118
|
+
# =============================================================================
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
register("Equal")(binary_op(torch.eq, "Element-wise equality comparison."))
|
|
122
|
+
register("Greater")(binary_op(torch.gt, "Element-wise greater-than comparison."))
|
|
123
|
+
register("GreaterOrEqual")(
|
|
124
|
+
binary_op(torch.ge, "Element-wise greater-than-or-equal comparison.")
|
|
125
|
+
)
|
|
126
|
+
register("Less")(binary_op(torch.lt, "Element-wise less-than comparison."))
|
|
127
|
+
register("LessOrEqual")(
|
|
128
|
+
binary_op(torch.le, "Element-wise less-than-or-equal comparison.")
|
|
129
|
+
)
|
|
130
|
+
register("And")(binary_op(torch.logical_and, "Element-wise logical AND."))
|
|
131
|
+
register("Or")(binary_op(torch.logical_or, "Element-wise logical OR."))
|
|
132
|
+
register("Xor")(binary_op(torch.logical_xor, "Element-wise logical XOR."))
|
|
133
|
+
register("Not")(unary_op(torch.logical_not, "Element-wise logical NOT."))
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@register("Where")
|
|
137
|
+
def where(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
138
|
+
"""Element-wise conditional selection."""
|
|
139
|
+
condition = builder.get_value(node.input[0])
|
|
140
|
+
x = builder.get_value(node.input[1])
|
|
141
|
+
y = builder.get_value(node.input[2])
|
|
142
|
+
return builder.call_function(torch.where, args=(condition, x, y))
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# =============================================================================
|
|
146
|
+
# Clip operator
|
|
147
|
+
# =============================================================================
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@register("Clip", since_version=1)
|
|
151
|
+
def clip_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
152
|
+
"""Clip tensor values to a range for opset 1-10.
|
|
153
|
+
|
|
154
|
+
In opset < 11, min and max are required attributes.
|
|
155
|
+
"""
|
|
156
|
+
x = builder.get_value(node.input[0])
|
|
157
|
+
|
|
158
|
+
min_val = get_attribute(node, "min", float("-inf"))
|
|
159
|
+
max_val = get_attribute(node, "max", float("inf"))
|
|
160
|
+
|
|
161
|
+
return builder.call_function(
|
|
162
|
+
torch.clamp, args=(x,), kwargs={"min": min_val, "max": max_val}
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@register("Clip", since_version=11)
|
|
167
|
+
def clip_v11(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
168
|
+
"""Clip tensor values to a range for opset 11+.
|
|
169
|
+
|
|
170
|
+
In opset 11+, min and max are optional inputs.
|
|
171
|
+
"""
|
|
172
|
+
x = builder.get_value(node.input[0])
|
|
173
|
+
|
|
174
|
+
min_val = get_optional_input(builder, node, 1)
|
|
175
|
+
max_val = get_optional_input(builder, node, 2)
|
|
176
|
+
|
|
177
|
+
return builder.call_function(
|
|
178
|
+
torch.clamp, args=(x,), kwargs={"min": min_val, "max": max_val}
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# =============================================================================
|
|
183
|
+
# Trigonometric functions
|
|
184
|
+
# =============================================================================
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
register("Sin")(unary_op(torch.sin, "Sine."))
|
|
188
|
+
register("Cos")(unary_op(torch.cos, "Cosine."))
|
|
189
|
+
register("Tan")(unary_op(torch.tan, "Tangent."))
|
|
190
|
+
register("Asin")(unary_op(torch.asin, "Arc sine."))
|
|
191
|
+
register("Acos")(unary_op(torch.acos, "Arc cosine."))
|
|
192
|
+
register("Atan")(unary_op(torch.atan, "Arc tangent."))
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# =============================================================================
|
|
196
|
+
# Hyperbolic functions
|
|
197
|
+
# =============================================================================
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
register("Sinh")(unary_op(torch.sinh, "Hyperbolic sine."))
|
|
201
|
+
register("Cosh")(unary_op(torch.cosh, "Hyperbolic cosine."))
|
|
202
|
+
register("Asinh")(unary_op(torch.asinh, "Inverse hyperbolic sine."))
|
|
203
|
+
register("Acosh")(unary_op(torch.acosh, "Inverse hyperbolic cosine."))
|
|
204
|
+
register("Atanh")(unary_op(torch.atanh, "Inverse hyperbolic tangent."))
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# =============================================================================
|
|
208
|
+
# Additional math functions
|
|
209
|
+
# =============================================================================
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
register("Erf")(unary_op(torch.erf, "Error function."))
|
|
213
|
+
register("IsNaN")(unary_op(torch.isnan, "Check for NaN."))
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@register("IsInf")
|
|
217
|
+
def isinf(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
218
|
+
"""Check for Inf."""
|
|
219
|
+
x = builder.get_value(node.input[0])
|
|
220
|
+
detect_negative = get_attribute(node, "detect_negative", 1)
|
|
221
|
+
detect_positive = get_attribute(node, "detect_positive", 1)
|
|
222
|
+
|
|
223
|
+
def _isinf(x, detect_neg, detect_pos):
|
|
224
|
+
if detect_neg and detect_pos:
|
|
225
|
+
return torch.isinf(x)
|
|
226
|
+
elif detect_pos:
|
|
227
|
+
return torch.isposinf(x)
|
|
228
|
+
elif detect_neg:
|
|
229
|
+
return torch.isneginf(x)
|
|
230
|
+
else:
|
|
231
|
+
return torch.zeros_like(x, dtype=torch.bool)
|
|
232
|
+
|
|
233
|
+
return builder.call_function(_isinf, args=(x, detect_negative, detect_positive))
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# =============================================================================
|
|
237
|
+
# Bitwise operations
|
|
238
|
+
# =============================================================================
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
register("BitwiseAnd")(binary_op(torch.bitwise_and, "Bitwise AND."))
|
|
242
|
+
register("BitwiseOr")(binary_op(torch.bitwise_or, "Bitwise OR."))
|
|
243
|
+
register("BitwiseXor")(binary_op(torch.bitwise_xor, "Bitwise XOR."))
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@register("BitwiseNot")
|
|
247
|
+
def bitwise_not(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
248
|
+
"""Bitwise NOT."""
|
|
249
|
+
x = builder.get_value(node.input[0])
|
|
250
|
+
|
|
251
|
+
# PyTorch bitwise_not doesn't support some unsigned types (e.g., uint16) on CPU.
|
|
252
|
+
# We handle this by casting to a signed type with the same bit width,
|
|
253
|
+
# performing the operation, and casting back.
|
|
254
|
+
def _bitwise_not(x):
|
|
255
|
+
original_dtype = x.dtype
|
|
256
|
+
# Map unsigned types to signed equivalents with same bit width
|
|
257
|
+
dtype_map = {
|
|
258
|
+
torch.uint16: torch.int16,
|
|
259
|
+
torch.uint32: torch.int32,
|
|
260
|
+
}
|
|
261
|
+
if original_dtype in dtype_map:
|
|
262
|
+
x = x.to(dtype_map[original_dtype])
|
|
263
|
+
result = torch.bitwise_not(x)
|
|
264
|
+
return result.to(original_dtype)
|
|
265
|
+
return torch.bitwise_not(x)
|
|
266
|
+
|
|
267
|
+
return builder.call_function(_bitwise_not, args=(x,))
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@register("BitShift")
|
|
271
|
+
def bit_shift(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
272
|
+
"""Bitwise shift operation."""
|
|
273
|
+
x = builder.get_value(node.input[0])
|
|
274
|
+
y = builder.get_value(node.input[1])
|
|
275
|
+
|
|
276
|
+
direction = get_attribute(node, "direction", "LEFT")
|
|
277
|
+
|
|
278
|
+
if direction == "LEFT":
|
|
279
|
+
return builder.call_function(torch.bitwise_left_shift, args=(x, y))
|
|
280
|
+
else:
|
|
281
|
+
return builder.call_function(torch.bitwise_right_shift, args=(x, y))
|