onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/op_registry.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""ONNX operator registry with custom operator and opset version support."""
|
|
3
|
+
|
|
4
|
+
import copy
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import TYPE_CHECKING, Dict, Callable, Optional, Union, List, Tuple
|
|
7
|
+
|
|
8
|
+
import onnx
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING: # pragma: no cover - only for type checking
|
|
11
|
+
from .graph_builder import GraphBuilder
|
|
12
|
+
|
|
13
|
+
import torch.fx
|
|
14
|
+
|
|
15
|
+
# Type alias for operator handler functions.
|
|
16
|
+
# Handlers take a GraphBuilder and ONNX node, returning one or more FX nodes.
|
|
17
|
+
OpHandler = Callable[
|
|
18
|
+
["GraphBuilder", onnx.NodeProto], Union[torch.fx.Node, Tuple[torch.fx.Node, ...]]
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
# Registry: {domain: {op_type: [(since_version, handler), ...]}}
|
|
22
|
+
# Handlers are stored in descending version order for efficient lookup.
|
|
23
|
+
# Empty string "" represents the default ONNX domain.
|
|
24
|
+
_VERSIONED_REGISTRY: Dict[str, Dict[str, List[Tuple[int, OpHandler]]]] = {"": {}}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@contextmanager
|
|
28
|
+
def registry_context():
|
|
29
|
+
"""Temporarily isolate registry mutations (intended for tests)."""
|
|
30
|
+
snapshot = copy.deepcopy(_VERSIONED_REGISTRY)
|
|
31
|
+
try:
|
|
32
|
+
yield
|
|
33
|
+
finally:
|
|
34
|
+
_VERSIONED_REGISTRY.clear()
|
|
35
|
+
_VERSIONED_REGISTRY.update(snapshot)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def register(
|
|
39
|
+
op_type: str, domain: str = "", since_version: int = 1
|
|
40
|
+
) -> Callable[[OpHandler], OpHandler]:
|
|
41
|
+
"""Decorator to register an ONNX operator handler with version support.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
op_type : str
|
|
46
|
+
The ONNX operator type name (e.g., "Add", "Relu").
|
|
47
|
+
domain : str, optional
|
|
48
|
+
The ONNX domain (e.g., "com.microsoft"). Default is "" (standard ONNX domain).
|
|
49
|
+
since_version : int, optional
|
|
50
|
+
The minimum opset version this handler supports. Default is 1.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
Callable
|
|
55
|
+
Decorator function.
|
|
56
|
+
|
|
57
|
+
Examples
|
|
58
|
+
--------
|
|
59
|
+
Register a standard ONNX operator (all versions):
|
|
60
|
+
|
|
61
|
+
>>> @register("MyOp")
|
|
62
|
+
... def my_op(builder, node):
|
|
63
|
+
... x = builder.get_value(node.input[0])
|
|
64
|
+
... return builder.call_function(torch.relu, args=(x,))
|
|
65
|
+
|
|
66
|
+
Register version-specific handlers:
|
|
67
|
+
|
|
68
|
+
>>> @register("Softmax", since_version=1)
|
|
69
|
+
... def softmax_v1(builder, node):
|
|
70
|
+
... # opset 1-12: axis defaults to 1
|
|
71
|
+
... ...
|
|
72
|
+
|
|
73
|
+
>>> @register("Softmax", since_version=13)
|
|
74
|
+
... def softmax_v13(builder, node):
|
|
75
|
+
... # opset 13+: axis defaults to -1
|
|
76
|
+
... ...
|
|
77
|
+
|
|
78
|
+
Register a custom domain operator:
|
|
79
|
+
|
|
80
|
+
>>> @register("CustomOp", domain="com.mycompany")
|
|
81
|
+
... def custom_op(builder, node):
|
|
82
|
+
... x = builder.get_value(node.input[0])
|
|
83
|
+
... return builder.call_function(my_custom_function, args=(x,))
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def decorator(func: OpHandler) -> OpHandler:
|
|
87
|
+
if domain not in _VERSIONED_REGISTRY:
|
|
88
|
+
_VERSIONED_REGISTRY[domain] = {}
|
|
89
|
+
if op_type not in _VERSIONED_REGISTRY[domain]:
|
|
90
|
+
_VERSIONED_REGISTRY[domain][op_type] = []
|
|
91
|
+
|
|
92
|
+
handlers = _VERSIONED_REGISTRY[domain][op_type]
|
|
93
|
+
# Remove existing handler with same since_version to allow re-registration
|
|
94
|
+
handlers[:] = [(v, h) for v, h in handlers if v != since_version]
|
|
95
|
+
handlers.append((since_version, func))
|
|
96
|
+
# Keep sorted in descending order by version for efficient lookup
|
|
97
|
+
handlers.sort(key=lambda x: x[0], reverse=True)
|
|
98
|
+
|
|
99
|
+
return func
|
|
100
|
+
|
|
101
|
+
return decorator
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def register_op(
|
|
105
|
+
op_type: str,
|
|
106
|
+
handler: Optional[OpHandler] = None,
|
|
107
|
+
domain: str = "",
|
|
108
|
+
since_version: int = 1,
|
|
109
|
+
) -> Union[OpHandler, Callable[[OpHandler], OpHandler]]:
|
|
110
|
+
"""Register an ONNX operator handler with version support.
|
|
111
|
+
|
|
112
|
+
This function can be used as a decorator or called directly to register
|
|
113
|
+
custom operator handlers for ONNX operators that are not natively supported.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
op_type : str
|
|
118
|
+
The ONNX operator type name.
|
|
119
|
+
handler : OpHandler, optional
|
|
120
|
+
The handler function. If not provided, returns a decorator.
|
|
121
|
+
domain : str, optional
|
|
122
|
+
The ONNX domain. Default is "" (standard ONNX domain).
|
|
123
|
+
since_version : int, optional
|
|
124
|
+
The minimum opset version this handler supports. Default is 1.
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
OpHandler or Callable
|
|
129
|
+
If handler is provided, returns the handler.
|
|
130
|
+
Otherwise, returns a decorator function.
|
|
131
|
+
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
Using as a decorator:
|
|
135
|
+
|
|
136
|
+
>>> @register_op("MyCustomOp")
|
|
137
|
+
... def my_custom_op(builder, node):
|
|
138
|
+
... x = builder.get_value(node.input[0])
|
|
139
|
+
... return builder.call_function(torch.sigmoid, args=(x,))
|
|
140
|
+
|
|
141
|
+
Using as a function:
|
|
142
|
+
|
|
143
|
+
>>> def my_handler(builder, node):
|
|
144
|
+
... x = builder.get_value(node.input[0])
|
|
145
|
+
... return builder.call_function(torch.tanh, args=(x,))
|
|
146
|
+
>>> register_op("TanhCustom", my_handler)
|
|
147
|
+
|
|
148
|
+
Registering for a custom domain:
|
|
149
|
+
|
|
150
|
+
>>> @register_op("BiasGelu", domain="com.microsoft")
|
|
151
|
+
... def bias_gelu(builder, node):
|
|
152
|
+
... x = builder.get_value(node.input[0])
|
|
153
|
+
... bias = builder.get_value(node.input[1])
|
|
154
|
+
... return builder.call_function(
|
|
155
|
+
... lambda t, b: torch.nn.functional.gelu(t + b),
|
|
156
|
+
... args=(x, bias)
|
|
157
|
+
... )
|
|
158
|
+
|
|
159
|
+
Registering version-specific handlers:
|
|
160
|
+
|
|
161
|
+
>>> @register_op("MyOp", since_version=1)
|
|
162
|
+
... def my_op_v1(builder, node): ...
|
|
163
|
+
|
|
164
|
+
>>> @register_op("MyOp", since_version=13)
|
|
165
|
+
... def my_op_v13(builder, node): ...
|
|
166
|
+
"""
|
|
167
|
+
if handler is not None:
|
|
168
|
+
# Direct call: register_op("Op", handler)
|
|
169
|
+
if domain not in _VERSIONED_REGISTRY:
|
|
170
|
+
_VERSIONED_REGISTRY[domain] = {}
|
|
171
|
+
if op_type not in _VERSIONED_REGISTRY[domain]:
|
|
172
|
+
_VERSIONED_REGISTRY[domain][op_type] = []
|
|
173
|
+
|
|
174
|
+
handlers = _VERSIONED_REGISTRY[domain][op_type]
|
|
175
|
+
# Remove existing handler with same since_version
|
|
176
|
+
handlers[:] = [(v, h) for v, h in handlers if v != since_version]
|
|
177
|
+
handlers.append((since_version, handler))
|
|
178
|
+
handlers.sort(key=lambda x: x[0], reverse=True)
|
|
179
|
+
|
|
180
|
+
return handler
|
|
181
|
+
else:
|
|
182
|
+
# Decorator usage: @register_op("Op")
|
|
183
|
+
return register(op_type, domain, since_version)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def unregister_op(
|
|
187
|
+
op_type: str, domain: str = "", since_version: Optional[int] = None
|
|
188
|
+
) -> bool:
|
|
189
|
+
"""Unregister an operator handler.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
op_type : str
|
|
194
|
+
The ONNX operator type name.
|
|
195
|
+
domain : str, optional
|
|
196
|
+
The ONNX domain. Default is "" (standard ONNX domain).
|
|
197
|
+
since_version : int, optional
|
|
198
|
+
The specific version handler to remove. If None, removes all versions.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
bool
|
|
203
|
+
True if the operator was unregistered, False if it wasn't registered.
|
|
204
|
+
"""
|
|
205
|
+
if domain not in _VERSIONED_REGISTRY:
|
|
206
|
+
return False
|
|
207
|
+
if op_type not in _VERSIONED_REGISTRY[domain]:
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
handlers = _VERSIONED_REGISTRY[domain][op_type]
|
|
211
|
+
if since_version is None:
|
|
212
|
+
# Remove all versions
|
|
213
|
+
del _VERSIONED_REGISTRY[domain][op_type]
|
|
214
|
+
return True
|
|
215
|
+
else:
|
|
216
|
+
# Remove specific version
|
|
217
|
+
original_len = len(handlers)
|
|
218
|
+
handlers[:] = [(v, h) for v, h in handlers if v != since_version]
|
|
219
|
+
if len(handlers) < original_len:
|
|
220
|
+
if not handlers:
|
|
221
|
+
del _VERSIONED_REGISTRY[domain][op_type]
|
|
222
|
+
return True
|
|
223
|
+
return False
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def get_handler(
|
|
227
|
+
op_type: str, domain: str = "", opset_version: int = 23
|
|
228
|
+
) -> Optional[OpHandler]:
|
|
229
|
+
"""Get the handler for an operator at a specific opset version.
|
|
230
|
+
|
|
231
|
+
Finds the handler with the highest since_version that is <= opset_version.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
op_type : str
|
|
236
|
+
The ONNX operator type name.
|
|
237
|
+
domain : str, optional
|
|
238
|
+
The ONNX domain. Default is "" (standard ONNX domain).
|
|
239
|
+
opset_version : int, optional
|
|
240
|
+
The target opset version. Default is 23 (current latest).
|
|
241
|
+
|
|
242
|
+
Returns
|
|
243
|
+
-------
|
|
244
|
+
OpHandler or None
|
|
245
|
+
The appropriate handler function, or None if not found.
|
|
246
|
+
"""
|
|
247
|
+
# Normalize domain: "ai.onnx" is equivalent to ""
|
|
248
|
+
if domain == "ai.onnx":
|
|
249
|
+
domain = ""
|
|
250
|
+
|
|
251
|
+
if domain not in _VERSIONED_REGISTRY:
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
handlers = _VERSIONED_REGISTRY[domain].get(op_type)
|
|
255
|
+
if not handlers:
|
|
256
|
+
return None
|
|
257
|
+
|
|
258
|
+
# Handlers are sorted in descending order by since_version
|
|
259
|
+
# Find the first handler where since_version <= opset_version
|
|
260
|
+
for since_version, handler in handlers:
|
|
261
|
+
if since_version <= opset_version:
|
|
262
|
+
return handler
|
|
263
|
+
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def is_supported(op_type: str, domain: str = "", opset_version: int = 23) -> bool:
|
|
268
|
+
"""Check if an operator is supported.
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
op_type : str
|
|
273
|
+
The ONNX operator type name.
|
|
274
|
+
domain : str, optional
|
|
275
|
+
The ONNX domain. Default is "" (standard ONNX domain).
|
|
276
|
+
opset_version : int, optional
|
|
277
|
+
The target opset version. Default is 23.
|
|
278
|
+
|
|
279
|
+
Returns
|
|
280
|
+
-------
|
|
281
|
+
bool
|
|
282
|
+
True if the operator is supported.
|
|
283
|
+
"""
|
|
284
|
+
return get_handler(op_type, domain, opset_version) is not None
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_supported_ops(domain: str = "") -> list:
|
|
288
|
+
"""Get list of supported ONNX operators for a domain.
|
|
289
|
+
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
domain : str, optional
|
|
293
|
+
The ONNX domain. Default is "" (standard ONNX domain).
|
|
294
|
+
|
|
295
|
+
Returns
|
|
296
|
+
-------
|
|
297
|
+
list
|
|
298
|
+
Sorted list of supported operator names.
|
|
299
|
+
"""
|
|
300
|
+
if domain in _VERSIONED_REGISTRY:
|
|
301
|
+
return sorted(_VERSIONED_REGISTRY[domain].keys())
|
|
302
|
+
return []
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def get_all_supported_ops() -> Dict[str, list]:
|
|
306
|
+
"""Get all supported operators across all domains.
|
|
307
|
+
|
|
308
|
+
Returns
|
|
309
|
+
-------
|
|
310
|
+
Dict[str, list]
|
|
311
|
+
Dictionary mapping domain names to sorted lists of operator names.
|
|
312
|
+
"""
|
|
313
|
+
return {domain: sorted(ops.keys()) for domain, ops in _VERSIONED_REGISTRY.items()}
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def get_registered_domains() -> list:
|
|
317
|
+
"""Get list of registered domains.
|
|
318
|
+
|
|
319
|
+
Returns
|
|
320
|
+
-------
|
|
321
|
+
list
|
|
322
|
+
List of domain names.
|
|
323
|
+
"""
|
|
324
|
+
return list(_VERSIONED_REGISTRY.keys())
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def get_handler_versions(op_type: str, domain: str = "") -> List[int]:
|
|
328
|
+
"""Get all registered opset versions for an operator.
|
|
329
|
+
|
|
330
|
+
Parameters
|
|
331
|
+
----------
|
|
332
|
+
op_type : str
|
|
333
|
+
The ONNX operator type name.
|
|
334
|
+
domain : str, optional
|
|
335
|
+
The ONNX domain. Default is "" (standard ONNX domain).
|
|
336
|
+
|
|
337
|
+
Returns
|
|
338
|
+
-------
|
|
339
|
+
List[int]
|
|
340
|
+
List of registered since_version values, sorted in ascending order.
|
|
341
|
+
"""
|
|
342
|
+
if domain in _VERSIONED_REGISTRY:
|
|
343
|
+
handlers = _VERSIONED_REGISTRY[domain].get(op_type, [])
|
|
344
|
+
return sorted([v for v, _ in handlers])
|
|
345
|
+
return []
|
onnx2fx/ops/__init__.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""ONNX operator implementations.
|
|
3
|
+
|
|
4
|
+
This package contains all ONNX operator implementations organized by category:
|
|
5
|
+
|
|
6
|
+
- activation.py: Activation functions (Relu, Sigmoid, Softmax, etc.)
|
|
7
|
+
- arithmetic.py: Arithmetic and math ops (Add, Mul, Sin, Cos, etc.)
|
|
8
|
+
- attention.py: Attention mechanisms (standard ONNX domain)
|
|
9
|
+
- attention_msft.py: Attention mechanisms (com.microsoft domain)
|
|
10
|
+
- control_flow.py: Control flow ops (Loop, If, Scan)
|
|
11
|
+
- convolution.py: Convolution ops (Conv, ConvTranspose, DeformConv)
|
|
12
|
+
- image.py: Image processing ops (Resize, DepthToSpace, etc.)
|
|
13
|
+
- linalg.py: Linear algebra ops (Einsum, Det)
|
|
14
|
+
- loss.py: Loss functions (SoftmaxCrossEntropyLoss, etc.)
|
|
15
|
+
- nn.py: Core neural network ops (MatMul, Gemm, Dropout)
|
|
16
|
+
- normalization.py: Normalization ops (BatchNorm, LayerNorm, etc.)
|
|
17
|
+
- pooling.py: Pooling ops (MaxPool, AveragePool, etc.)
|
|
18
|
+
- quantization.py: Quantization ops (QLinearConv, etc.)
|
|
19
|
+
- random.py: Random number generation (RandomNormal, etc.)
|
|
20
|
+
- recurrent.py: Recurrent neural networks (LSTM, GRU, RNN)
|
|
21
|
+
- reduction.py: Reduction ops (ReduceSum, ReduceMean, etc.)
|
|
22
|
+
- sequence.py: Sequence ops (SequenceConstruct, etc.)
|
|
23
|
+
- signal.py: Signal processing (STFT, MelWeightMatrix, window functions, NMS)
|
|
24
|
+
- string.py: String ops (StringNormalizer)
|
|
25
|
+
- tensor.py: Tensor manipulation ops (Reshape, Transpose, etc.)
|
|
26
|
+
- training.py: Training ops (Gradient, Momentum, Adagrad)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# Import all operator modules to register handlers
|
|
30
|
+
from . import activation
|
|
31
|
+
from . import arithmetic
|
|
32
|
+
from . import attention
|
|
33
|
+
from . import attention_msft
|
|
34
|
+
from . import control_flow
|
|
35
|
+
from . import convolution
|
|
36
|
+
from . import image
|
|
37
|
+
from . import linalg
|
|
38
|
+
from . import loss
|
|
39
|
+
from . import nn
|
|
40
|
+
from . import normalization
|
|
41
|
+
from . import pooling
|
|
42
|
+
from . import quantization
|
|
43
|
+
from . import random
|
|
44
|
+
from . import recurrent
|
|
45
|
+
from . import reduction
|
|
46
|
+
from . import sequence
|
|
47
|
+
from . import signal
|
|
48
|
+
from . import string
|
|
49
|
+
from . import tensor
|
|
50
|
+
from . import training
|
|
51
|
+
|
|
52
|
+
__all__ = [
|
|
53
|
+
"activation",
|
|
54
|
+
"arithmetic",
|
|
55
|
+
"attention",
|
|
56
|
+
"attention_msft",
|
|
57
|
+
"control_flow",
|
|
58
|
+
"convolution",
|
|
59
|
+
"image",
|
|
60
|
+
"linalg",
|
|
61
|
+
"loss",
|
|
62
|
+
"nn",
|
|
63
|
+
"normalization",
|
|
64
|
+
"pooling",
|
|
65
|
+
"quantization",
|
|
66
|
+
"random",
|
|
67
|
+
"recurrent",
|
|
68
|
+
"reduction",
|
|
69
|
+
"sequence",
|
|
70
|
+
"signal",
|
|
71
|
+
"string",
|
|
72
|
+
"tensor",
|
|
73
|
+
"training",
|
|
74
|
+
]
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Activation function operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, Callable
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from ..op_registry import register
|
|
11
|
+
from ..utils.attributes import get_attribute
|
|
12
|
+
from ..utils.op_helpers import unary_op, unary_op_with_kwargs
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from ..graph_builder import GraphBuilder
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _coerced_softmax_handler(
|
|
19
|
+
torch_fn: Callable[..., torch.Tensor],
|
|
20
|
+
*,
|
|
21
|
+
default_axis: int,
|
|
22
|
+
doc: str,
|
|
23
|
+
) -> Callable[["GraphBuilder", onnx.NodeProto], torch.fx.Node]:
|
|
24
|
+
def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
25
|
+
x = builder.get_value(node.input[0])
|
|
26
|
+
axis = get_attribute(node, "axis", default_axis)
|
|
27
|
+
|
|
28
|
+
def _coerced_softmax(t: torch.Tensor, axis: int) -> torch.Tensor:
|
|
29
|
+
# Handle negative axis
|
|
30
|
+
if axis < 0:
|
|
31
|
+
axis = t.dim() + axis
|
|
32
|
+
|
|
33
|
+
# Coerce to 2D: flatten [0:axis] and [axis:]
|
|
34
|
+
orig_shape = t.shape
|
|
35
|
+
pre_dim = 1
|
|
36
|
+
for i in range(axis):
|
|
37
|
+
pre_dim *= t.shape[i]
|
|
38
|
+
|
|
39
|
+
t_2d = t.reshape(pre_dim, -1)
|
|
40
|
+
result_2d = torch_fn(t_2d, dim=1)
|
|
41
|
+
return result_2d.reshape(orig_shape)
|
|
42
|
+
|
|
43
|
+
return builder.call_function(_coerced_softmax, args=(x, axis))
|
|
44
|
+
|
|
45
|
+
handler.__doc__ = doc
|
|
46
|
+
return handler
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
register("Relu")(unary_op(F.relu, "ReLU activation."))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
register("LeakyRelu")(
|
|
53
|
+
unary_op_with_kwargs(
|
|
54
|
+
F.leaky_relu,
|
|
55
|
+
attr_map={"negative_slope": ("alpha", 0.01)},
|
|
56
|
+
doc="Leaky ReLU activation.",
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@register("PRelu")
|
|
62
|
+
def prelu(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
63
|
+
"""Parametric ReLU activation.
|
|
64
|
+
|
|
65
|
+
ONNX PRelu allows slope to be broadcastable to the input tensor,
|
|
66
|
+
while PyTorch's F.prelu expects slope to match channel dimension.
|
|
67
|
+
We implement using torch.where for proper broadcasting support.
|
|
68
|
+
|
|
69
|
+
When slope has shape [C] and input has shape [N, C, ...], we need to
|
|
70
|
+
reshape slope to [1, C, 1, ...] for proper broadcasting.
|
|
71
|
+
"""
|
|
72
|
+
x = builder.get_value(node.input[0])
|
|
73
|
+
slope = builder.get_value(node.input[1])
|
|
74
|
+
|
|
75
|
+
def _prelu(x: torch.Tensor, slope: torch.Tensor) -> torch.Tensor:
|
|
76
|
+
# If slope is 1D with size matching channels and input is ND with N > 1,
|
|
77
|
+
# reshape slope for proper broadcasting along channel dimension (dim=1)
|
|
78
|
+
if slope.ndim == 1 and x.ndim > 1 and slope.numel() == x.shape[1]:
|
|
79
|
+
# Reshape [C] to [1, C, 1, 1, ...] for broadcasting
|
|
80
|
+
shape = [1, slope.numel()] + [1] * (x.ndim - 2)
|
|
81
|
+
slope = slope.view(shape)
|
|
82
|
+
return torch.where(x >= 0, x, x * slope)
|
|
83
|
+
|
|
84
|
+
return builder.call_function(_prelu, args=(x, slope))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
register("Elu")(
|
|
88
|
+
unary_op_with_kwargs(
|
|
89
|
+
F.elu,
|
|
90
|
+
attr_map={"alpha": ("alpha", 1.0)},
|
|
91
|
+
doc="Exponential Linear Unit activation.",
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@register("Selu")
|
|
97
|
+
def selu(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
98
|
+
"""Scaled Exponential Linear Unit activation.
|
|
99
|
+
|
|
100
|
+
ONNX SELU: y = gamma * (alpha * (exp(x) - 1) for x < 0, x for x >= 0)
|
|
101
|
+
PyTorch SELU uses fixed alpha=1.6732... and gamma=1.0507...
|
|
102
|
+
ONNX defaults: alpha=1.67326..., gamma=1.0507... but allows custom values.
|
|
103
|
+
"""
|
|
104
|
+
x = builder.get_value(node.input[0])
|
|
105
|
+
alpha = get_attribute(node, "alpha", 1.67326319217681884765625)
|
|
106
|
+
gamma = get_attribute(node, "gamma", 1.05070102214813232421875)
|
|
107
|
+
|
|
108
|
+
# PyTorch's fixed SELU values
|
|
109
|
+
pytorch_alpha = 1.6732632423543772848170429916717
|
|
110
|
+
pytorch_gamma = 1.0507009873554804934193349852946
|
|
111
|
+
|
|
112
|
+
# If using PyTorch's fixed values (within tolerance), use F.selu for efficiency
|
|
113
|
+
if abs(alpha - pytorch_alpha) < 1e-5 and abs(gamma - pytorch_gamma) < 1e-5:
|
|
114
|
+
return builder.call_function(F.selu, args=(x,))
|
|
115
|
+
|
|
116
|
+
# Otherwise implement manually: gamma * (alpha * (exp(x) - 1) for x < 0, x for x >= 0)
|
|
117
|
+
def _custom_selu(x: torch.Tensor, alpha: float, gamma: float) -> torch.Tensor:
|
|
118
|
+
return gamma * torch.where(x > 0, x, alpha * (torch.exp(x) - 1))
|
|
119
|
+
|
|
120
|
+
return builder.call_function(_custom_selu, args=(x, alpha, gamma))
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
register("Celu")(
|
|
124
|
+
unary_op_with_kwargs(
|
|
125
|
+
F.celu,
|
|
126
|
+
attr_map={"alpha": ("alpha", 1.0)},
|
|
127
|
+
doc="Continuously Differentiable Exponential Linear Unit activation.",
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
register("Sigmoid")(unary_op(torch.sigmoid, "Sigmoid activation."))
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@register("HardSigmoid")
|
|
136
|
+
def hard_sigmoid(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
137
|
+
"""Hard Sigmoid activation.
|
|
138
|
+
|
|
139
|
+
ONNX HardSigmoid: max(0, min(1, alpha * x + beta))
|
|
140
|
+
ONNX defaults: alpha=0.2, beta=0.5
|
|
141
|
+
PyTorch hardsigmoid uses fixed alpha=1/6, beta=0.5
|
|
142
|
+
"""
|
|
143
|
+
x = builder.get_value(node.input[0])
|
|
144
|
+
alpha = get_attribute(node, "alpha", 0.2)
|
|
145
|
+
beta = get_attribute(node, "beta", 0.5)
|
|
146
|
+
|
|
147
|
+
# PyTorch hardsigmoid uses alpha=1/6 ≈ 0.16667, beta=0.5
|
|
148
|
+
pytorch_alpha = 1.0 / 6.0
|
|
149
|
+
|
|
150
|
+
# If using PyTorch's fixed values (within tolerance), use F.hardsigmoid
|
|
151
|
+
if abs(alpha - pytorch_alpha) < 1e-5 and abs(beta - 0.5) < 1e-5:
|
|
152
|
+
return builder.call_function(F.hardsigmoid, args=(x,))
|
|
153
|
+
|
|
154
|
+
# Otherwise implement manually: max(0, min(1, alpha * x + beta))
|
|
155
|
+
def _custom_hardsigmoid(x: torch.Tensor, alpha: float, beta: float) -> torch.Tensor:
|
|
156
|
+
return torch.clamp(alpha * x + beta, 0.0, 1.0)
|
|
157
|
+
|
|
158
|
+
return builder.call_function(_custom_hardsigmoid, args=(x, alpha, beta))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
register("Tanh")(unary_op(torch.tanh, "Tanh activation."))
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
register("Softmax", since_version=1)(
|
|
165
|
+
_coerced_softmax_handler(
|
|
166
|
+
F.softmax,
|
|
167
|
+
default_axis=1,
|
|
168
|
+
doc=(
|
|
169
|
+
"Softmax activation for opset 1-12 with axis defaulting to 1 and "
|
|
170
|
+
"2D coercion."
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
register("Softmax", since_version=13)(
|
|
177
|
+
unary_op_with_kwargs(
|
|
178
|
+
F.softmax,
|
|
179
|
+
attr_map={"dim": ("axis", -1)},
|
|
180
|
+
doc=(
|
|
181
|
+
"Softmax activation for opset 13+ with axis defaulting to the last "
|
|
182
|
+
"dimension."
|
|
183
|
+
),
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
register("LogSoftmax", since_version=1)(
|
|
189
|
+
_coerced_softmax_handler(
|
|
190
|
+
F.log_softmax,
|
|
191
|
+
default_axis=1,
|
|
192
|
+
doc=(
|
|
193
|
+
"Log Softmax activation for opset 1-12 with axis defaulting to 1 "
|
|
194
|
+
"and 2D coercion."
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
register("LogSoftmax", since_version=13)(
|
|
201
|
+
unary_op_with_kwargs(
|
|
202
|
+
F.log_softmax,
|
|
203
|
+
attr_map={"dim": ("axis", -1)},
|
|
204
|
+
doc="Log Softmax activation for opset 13+.",
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
register("Softplus")(unary_op(F.softplus, "Softplus activation."))
|
|
210
|
+
register("Softsign")(unary_op(F.softsign, "Softsign activation."))
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
register("Gelu")(
|
|
214
|
+
unary_op_with_kwargs(
|
|
215
|
+
F.gelu,
|
|
216
|
+
attr_map={"approximate": ("approximate", "none")},
|
|
217
|
+
doc="Gaussian Error Linear Unit activation.",
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
register("Silu")(unary_op(F.silu, "Sigmoid Linear Unit (SiLU/Swish) activation."))
|
|
223
|
+
register("Swish")(unary_op(F.silu, "Swish activation (alias for SiLU)."))
|
|
224
|
+
register("Mish")(unary_op(F.mish, "Mish activation."))
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
register("ThresholdedRelu")(
|
|
228
|
+
unary_op_with_kwargs(
|
|
229
|
+
F.threshold,
|
|
230
|
+
attr_map={"threshold": ("alpha", 1.0)},
|
|
231
|
+
fixed_kwargs={"value": 0.0},
|
|
232
|
+
doc="Thresholded ReLU activation.",
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
register("HardSwish")(unary_op(F.hardswish, "Hard Swish activation."))
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@register("Hardmax")
|
|
241
|
+
def hardmax(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
242
|
+
"""Hardmax - one-hot encoding of argmax."""
|
|
243
|
+
x = builder.get_value(node.input[0])
|
|
244
|
+
axis = get_attribute(node, "axis", -1)
|
|
245
|
+
|
|
246
|
+
def _hardmax(t: torch.Tensor, ax: int) -> torch.Tensor:
|
|
247
|
+
# Normalize axis to positive
|
|
248
|
+
if ax < 0:
|
|
249
|
+
ax = t.dim() + ax
|
|
250
|
+
# one_hot appends the class dimension at the end
|
|
251
|
+
one_hot = torch.nn.functional.one_hot(
|
|
252
|
+
torch.argmax(t, dim=ax), num_classes=t.shape[ax]
|
|
253
|
+
).to(t.dtype)
|
|
254
|
+
# Move the one-hot dimension from the end back to the original axis position
|
|
255
|
+
# one_hot has shape: [...dims before ax..., ...dims after ax..., num_classes]
|
|
256
|
+
# We need to move the last dim to position ax
|
|
257
|
+
return torch.movedim(one_hot, -1, ax)
|
|
258
|
+
|
|
259
|
+
return builder.call_function(_hardmax, args=(x, axis))
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@register("Shrink")
|
|
263
|
+
def shrink(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
264
|
+
"""Shrink activation.
|
|
265
|
+
|
|
266
|
+
If x < -lambd: y = x + bias
|
|
267
|
+
If x > lambd: y = x - bias
|
|
268
|
+
Otherwise: y = 0
|
|
269
|
+
"""
|
|
270
|
+
x = builder.get_value(node.input[0])
|
|
271
|
+
bias = get_attribute(node, "bias", 0.0)
|
|
272
|
+
lambd = get_attribute(node, "lambd", 0.5)
|
|
273
|
+
|
|
274
|
+
def _shrink(t: torch.Tensor, bias: float, lambd: float) -> torch.Tensor:
|
|
275
|
+
result = torch.zeros_like(t)
|
|
276
|
+
mask_neg = t < -lambd
|
|
277
|
+
mask_pos = t > lambd
|
|
278
|
+
result = torch.where(mask_neg, t + bias, result)
|
|
279
|
+
result = torch.where(mask_pos, t - bias, result)
|
|
280
|
+
return result
|
|
281
|
+
|
|
282
|
+
return builder.call_function(_shrink, args=(x, bias, lambd))
|