onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2fx/op_registry.py ADDED
@@ -0,0 +1,345 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """ONNX operator registry with custom operator and opset version support."""
3
+
4
+ import copy
5
+ from contextlib import contextmanager
6
+ from typing import TYPE_CHECKING, Dict, Callable, Optional, Union, List, Tuple
7
+
8
+ import onnx
9
+
10
+ if TYPE_CHECKING: # pragma: no cover - only for type checking
11
+ from .graph_builder import GraphBuilder
12
+
13
+ import torch.fx
14
+
15
+ # Type alias for operator handler functions.
16
+ # Handlers take a GraphBuilder and ONNX node, returning one or more FX nodes.
17
+ OpHandler = Callable[
18
+ ["GraphBuilder", onnx.NodeProto], Union[torch.fx.Node, Tuple[torch.fx.Node, ...]]
19
+ ]
20
+
21
+ # Registry: {domain: {op_type: [(since_version, handler), ...]}}
22
+ # Handlers are stored in descending version order for efficient lookup.
23
+ # Empty string "" represents the default ONNX domain.
24
+ _VERSIONED_REGISTRY: Dict[str, Dict[str, List[Tuple[int, OpHandler]]]] = {"": {}}
25
+
26
+
27
+ @contextmanager
28
+ def registry_context():
29
+ """Temporarily isolate registry mutations (intended for tests)."""
30
+ snapshot = copy.deepcopy(_VERSIONED_REGISTRY)
31
+ try:
32
+ yield
33
+ finally:
34
+ _VERSIONED_REGISTRY.clear()
35
+ _VERSIONED_REGISTRY.update(snapshot)
36
+
37
+
38
+ def register(
39
+ op_type: str, domain: str = "", since_version: int = 1
40
+ ) -> Callable[[OpHandler], OpHandler]:
41
+ """Decorator to register an ONNX operator handler with version support.
42
+
43
+ Parameters
44
+ ----------
45
+ op_type : str
46
+ The ONNX operator type name (e.g., "Add", "Relu").
47
+ domain : str, optional
48
+ The ONNX domain (e.g., "com.microsoft"). Default is "" (standard ONNX domain).
49
+ since_version : int, optional
50
+ The minimum opset version this handler supports. Default is 1.
51
+
52
+ Returns
53
+ -------
54
+ Callable
55
+ Decorator function.
56
+
57
+ Examples
58
+ --------
59
+ Register a standard ONNX operator (all versions):
60
+
61
+ >>> @register("MyOp")
62
+ ... def my_op(builder, node):
63
+ ... x = builder.get_value(node.input[0])
64
+ ... return builder.call_function(torch.relu, args=(x,))
65
+
66
+ Register version-specific handlers:
67
+
68
+ >>> @register("Softmax", since_version=1)
69
+ ... def softmax_v1(builder, node):
70
+ ... # opset 1-12: axis defaults to 1
71
+ ... ...
72
+
73
+ >>> @register("Softmax", since_version=13)
74
+ ... def softmax_v13(builder, node):
75
+ ... # opset 13+: axis defaults to -1
76
+ ... ...
77
+
78
+ Register a custom domain operator:
79
+
80
+ >>> @register("CustomOp", domain="com.mycompany")
81
+ ... def custom_op(builder, node):
82
+ ... x = builder.get_value(node.input[0])
83
+ ... return builder.call_function(my_custom_function, args=(x,))
84
+ """
85
+
86
+ def decorator(func: OpHandler) -> OpHandler:
87
+ if domain not in _VERSIONED_REGISTRY:
88
+ _VERSIONED_REGISTRY[domain] = {}
89
+ if op_type not in _VERSIONED_REGISTRY[domain]:
90
+ _VERSIONED_REGISTRY[domain][op_type] = []
91
+
92
+ handlers = _VERSIONED_REGISTRY[domain][op_type]
93
+ # Remove existing handler with same since_version to allow re-registration
94
+ handlers[:] = [(v, h) for v, h in handlers if v != since_version]
95
+ handlers.append((since_version, func))
96
+ # Keep sorted in descending order by version for efficient lookup
97
+ handlers.sort(key=lambda x: x[0], reverse=True)
98
+
99
+ return func
100
+
101
+ return decorator
102
+
103
+
104
+ def register_op(
105
+ op_type: str,
106
+ handler: Optional[OpHandler] = None,
107
+ domain: str = "",
108
+ since_version: int = 1,
109
+ ) -> Union[OpHandler, Callable[[OpHandler], OpHandler]]:
110
+ """Register an ONNX operator handler with version support.
111
+
112
+ This function can be used as a decorator or called directly to register
113
+ custom operator handlers for ONNX operators that are not natively supported.
114
+
115
+ Parameters
116
+ ----------
117
+ op_type : str
118
+ The ONNX operator type name.
119
+ handler : OpHandler, optional
120
+ The handler function. If not provided, returns a decorator.
121
+ domain : str, optional
122
+ The ONNX domain. Default is "" (standard ONNX domain).
123
+ since_version : int, optional
124
+ The minimum opset version this handler supports. Default is 1.
125
+
126
+ Returns
127
+ -------
128
+ OpHandler or Callable
129
+ If handler is provided, returns the handler.
130
+ Otherwise, returns a decorator function.
131
+
132
+ Examples
133
+ --------
134
+ Using as a decorator:
135
+
136
+ >>> @register_op("MyCustomOp")
137
+ ... def my_custom_op(builder, node):
138
+ ... x = builder.get_value(node.input[0])
139
+ ... return builder.call_function(torch.sigmoid, args=(x,))
140
+
141
+ Using as a function:
142
+
143
+ >>> def my_handler(builder, node):
144
+ ... x = builder.get_value(node.input[0])
145
+ ... return builder.call_function(torch.tanh, args=(x,))
146
+ >>> register_op("TanhCustom", my_handler)
147
+
148
+ Registering for a custom domain:
149
+
150
+ >>> @register_op("BiasGelu", domain="com.microsoft")
151
+ ... def bias_gelu(builder, node):
152
+ ... x = builder.get_value(node.input[0])
153
+ ... bias = builder.get_value(node.input[1])
154
+ ... return builder.call_function(
155
+ ... lambda t, b: torch.nn.functional.gelu(t + b),
156
+ ... args=(x, bias)
157
+ ... )
158
+
159
+ Registering version-specific handlers:
160
+
161
+ >>> @register_op("MyOp", since_version=1)
162
+ ... def my_op_v1(builder, node): ...
163
+
164
+ >>> @register_op("MyOp", since_version=13)
165
+ ... def my_op_v13(builder, node): ...
166
+ """
167
+ if handler is not None:
168
+ # Direct call: register_op("Op", handler)
169
+ if domain not in _VERSIONED_REGISTRY:
170
+ _VERSIONED_REGISTRY[domain] = {}
171
+ if op_type not in _VERSIONED_REGISTRY[domain]:
172
+ _VERSIONED_REGISTRY[domain][op_type] = []
173
+
174
+ handlers = _VERSIONED_REGISTRY[domain][op_type]
175
+ # Remove existing handler with same since_version
176
+ handlers[:] = [(v, h) for v, h in handlers if v != since_version]
177
+ handlers.append((since_version, handler))
178
+ handlers.sort(key=lambda x: x[0], reverse=True)
179
+
180
+ return handler
181
+ else:
182
+ # Decorator usage: @register_op("Op")
183
+ return register(op_type, domain, since_version)
184
+
185
+
186
+ def unregister_op(
187
+ op_type: str, domain: str = "", since_version: Optional[int] = None
188
+ ) -> bool:
189
+ """Unregister an operator handler.
190
+
191
+ Parameters
192
+ ----------
193
+ op_type : str
194
+ The ONNX operator type name.
195
+ domain : str, optional
196
+ The ONNX domain. Default is "" (standard ONNX domain).
197
+ since_version : int, optional
198
+ The specific version handler to remove. If None, removes all versions.
199
+
200
+ Returns
201
+ -------
202
+ bool
203
+ True if the operator was unregistered, False if it wasn't registered.
204
+ """
205
+ if domain not in _VERSIONED_REGISTRY:
206
+ return False
207
+ if op_type not in _VERSIONED_REGISTRY[domain]:
208
+ return False
209
+
210
+ handlers = _VERSIONED_REGISTRY[domain][op_type]
211
+ if since_version is None:
212
+ # Remove all versions
213
+ del _VERSIONED_REGISTRY[domain][op_type]
214
+ return True
215
+ else:
216
+ # Remove specific version
217
+ original_len = len(handlers)
218
+ handlers[:] = [(v, h) for v, h in handlers if v != since_version]
219
+ if len(handlers) < original_len:
220
+ if not handlers:
221
+ del _VERSIONED_REGISTRY[domain][op_type]
222
+ return True
223
+ return False
224
+
225
+
226
+ def get_handler(
227
+ op_type: str, domain: str = "", opset_version: int = 23
228
+ ) -> Optional[OpHandler]:
229
+ """Get the handler for an operator at a specific opset version.
230
+
231
+ Finds the handler with the highest since_version that is <= opset_version.
232
+
233
+ Parameters
234
+ ----------
235
+ op_type : str
236
+ The ONNX operator type name.
237
+ domain : str, optional
238
+ The ONNX domain. Default is "" (standard ONNX domain).
239
+ opset_version : int, optional
240
+ The target opset version. Default is 23 (current latest).
241
+
242
+ Returns
243
+ -------
244
+ OpHandler or None
245
+ The appropriate handler function, or None if not found.
246
+ """
247
+ # Normalize domain: "ai.onnx" is equivalent to ""
248
+ if domain == "ai.onnx":
249
+ domain = ""
250
+
251
+ if domain not in _VERSIONED_REGISTRY:
252
+ return None
253
+
254
+ handlers = _VERSIONED_REGISTRY[domain].get(op_type)
255
+ if not handlers:
256
+ return None
257
+
258
+ # Handlers are sorted in descending order by since_version
259
+ # Find the first handler where since_version <= opset_version
260
+ for since_version, handler in handlers:
261
+ if since_version <= opset_version:
262
+ return handler
263
+
264
+ return None
265
+
266
+
267
+ def is_supported(op_type: str, domain: str = "", opset_version: int = 23) -> bool:
268
+ """Check if an operator is supported.
269
+
270
+ Parameters
271
+ ----------
272
+ op_type : str
273
+ The ONNX operator type name.
274
+ domain : str, optional
275
+ The ONNX domain. Default is "" (standard ONNX domain).
276
+ opset_version : int, optional
277
+ The target opset version. Default is 23.
278
+
279
+ Returns
280
+ -------
281
+ bool
282
+ True if the operator is supported.
283
+ """
284
+ return get_handler(op_type, domain, opset_version) is not None
285
+
286
+
287
+ def get_supported_ops(domain: str = "") -> list:
288
+ """Get list of supported ONNX operators for a domain.
289
+
290
+ Parameters
291
+ ----------
292
+ domain : str, optional
293
+ The ONNX domain. Default is "" (standard ONNX domain).
294
+
295
+ Returns
296
+ -------
297
+ list
298
+ Sorted list of supported operator names.
299
+ """
300
+ if domain in _VERSIONED_REGISTRY:
301
+ return sorted(_VERSIONED_REGISTRY[domain].keys())
302
+ return []
303
+
304
+
305
+ def get_all_supported_ops() -> Dict[str, list]:
306
+ """Get all supported operators across all domains.
307
+
308
+ Returns
309
+ -------
310
+ Dict[str, list]
311
+ Dictionary mapping domain names to sorted lists of operator names.
312
+ """
313
+ return {domain: sorted(ops.keys()) for domain, ops in _VERSIONED_REGISTRY.items()}
314
+
315
+
316
+ def get_registered_domains() -> list:
317
+ """Get list of registered domains.
318
+
319
+ Returns
320
+ -------
321
+ list
322
+ List of domain names.
323
+ """
324
+ return list(_VERSIONED_REGISTRY.keys())
325
+
326
+
327
+ def get_handler_versions(op_type: str, domain: str = "") -> List[int]:
328
+ """Get all registered opset versions for an operator.
329
+
330
+ Parameters
331
+ ----------
332
+ op_type : str
333
+ The ONNX operator type name.
334
+ domain : str, optional
335
+ The ONNX domain. Default is "" (standard ONNX domain).
336
+
337
+ Returns
338
+ -------
339
+ List[int]
340
+ List of registered since_version values, sorted in ascending order.
341
+ """
342
+ if domain in _VERSIONED_REGISTRY:
343
+ handlers = _VERSIONED_REGISTRY[domain].get(op_type, [])
344
+ return sorted([v for v, _ in handlers])
345
+ return []
@@ -0,0 +1,74 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """ONNX operator implementations.
3
+
4
+ This package contains all ONNX operator implementations organized by category:
5
+
6
+ - activation.py: Activation functions (Relu, Sigmoid, Softmax, etc.)
7
+ - arithmetic.py: Arithmetic and math ops (Add, Mul, Sin, Cos, etc.)
8
+ - attention.py: Attention mechanisms (standard ONNX domain)
9
+ - attention_msft.py: Attention mechanisms (com.microsoft domain)
10
+ - control_flow.py: Control flow ops (Loop, If, Scan)
11
+ - convolution.py: Convolution ops (Conv, ConvTranspose, DeformConv)
12
+ - image.py: Image processing ops (Resize, DepthToSpace, etc.)
13
+ - linalg.py: Linear algebra ops (Einsum, Det)
14
+ - loss.py: Loss functions (SoftmaxCrossEntropyLoss, etc.)
15
+ - nn.py: Core neural network ops (MatMul, Gemm, Dropout)
16
+ - normalization.py: Normalization ops (BatchNorm, LayerNorm, etc.)
17
+ - pooling.py: Pooling ops (MaxPool, AveragePool, etc.)
18
+ - quantization.py: Quantization ops (QLinearConv, etc.)
19
+ - random.py: Random number generation (RandomNormal, etc.)
20
+ - recurrent.py: Recurrent neural networks (LSTM, GRU, RNN)
21
+ - reduction.py: Reduction ops (ReduceSum, ReduceMean, etc.)
22
+ - sequence.py: Sequence ops (SequenceConstruct, etc.)
23
+ - signal.py: Signal processing (STFT, MelWeightMatrix, window functions, NMS)
24
+ - string.py: String ops (StringNormalizer)
25
+ - tensor.py: Tensor manipulation ops (Reshape, Transpose, etc.)
26
+ - training.py: Training ops (Gradient, Momentum, Adagrad)
27
+ """
28
+
29
+ # Import all operator modules to register handlers
30
+ from . import activation
31
+ from . import arithmetic
32
+ from . import attention
33
+ from . import attention_msft
34
+ from . import control_flow
35
+ from . import convolution
36
+ from . import image
37
+ from . import linalg
38
+ from . import loss
39
+ from . import nn
40
+ from . import normalization
41
+ from . import pooling
42
+ from . import quantization
43
+ from . import random
44
+ from . import recurrent
45
+ from . import reduction
46
+ from . import sequence
47
+ from . import signal
48
+ from . import string
49
+ from . import tensor
50
+ from . import training
51
+
52
+ __all__ = [
53
+ "activation",
54
+ "arithmetic",
55
+ "attention",
56
+ "attention_msft",
57
+ "control_flow",
58
+ "convolution",
59
+ "image",
60
+ "linalg",
61
+ "loss",
62
+ "nn",
63
+ "normalization",
64
+ "pooling",
65
+ "quantization",
66
+ "random",
67
+ "recurrent",
68
+ "reduction",
69
+ "sequence",
70
+ "signal",
71
+ "string",
72
+ "tensor",
73
+ "training",
74
+ ]
@@ -0,0 +1,282 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Activation function operators."""
3
+
4
+ from typing import TYPE_CHECKING, Callable
5
+
6
+ import onnx
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from ..op_registry import register
11
+ from ..utils.attributes import get_attribute
12
+ from ..utils.op_helpers import unary_op, unary_op_with_kwargs
13
+
14
+ if TYPE_CHECKING:
15
+ from ..graph_builder import GraphBuilder
16
+
17
+
18
+ def _coerced_softmax_handler(
19
+ torch_fn: Callable[..., torch.Tensor],
20
+ *,
21
+ default_axis: int,
22
+ doc: str,
23
+ ) -> Callable[["GraphBuilder", onnx.NodeProto], torch.fx.Node]:
24
+ def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
25
+ x = builder.get_value(node.input[0])
26
+ axis = get_attribute(node, "axis", default_axis)
27
+
28
+ def _coerced_softmax(t: torch.Tensor, axis: int) -> torch.Tensor:
29
+ # Handle negative axis
30
+ if axis < 0:
31
+ axis = t.dim() + axis
32
+
33
+ # Coerce to 2D: flatten [0:axis] and [axis:]
34
+ orig_shape = t.shape
35
+ pre_dim = 1
36
+ for i in range(axis):
37
+ pre_dim *= t.shape[i]
38
+
39
+ t_2d = t.reshape(pre_dim, -1)
40
+ result_2d = torch_fn(t_2d, dim=1)
41
+ return result_2d.reshape(orig_shape)
42
+
43
+ return builder.call_function(_coerced_softmax, args=(x, axis))
44
+
45
+ handler.__doc__ = doc
46
+ return handler
47
+
48
+
49
+ register("Relu")(unary_op(F.relu, "ReLU activation."))
50
+
51
+
52
+ register("LeakyRelu")(
53
+ unary_op_with_kwargs(
54
+ F.leaky_relu,
55
+ attr_map={"negative_slope": ("alpha", 0.01)},
56
+ doc="Leaky ReLU activation.",
57
+ )
58
+ )
59
+
60
+
61
+ @register("PRelu")
62
+ def prelu(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
63
+ """Parametric ReLU activation.
64
+
65
+ ONNX PRelu allows slope to be broadcastable to the input tensor,
66
+ while PyTorch's F.prelu expects slope to match channel dimension.
67
+ We implement using torch.where for proper broadcasting support.
68
+
69
+ When slope has shape [C] and input has shape [N, C, ...], we need to
70
+ reshape slope to [1, C, 1, ...] for proper broadcasting.
71
+ """
72
+ x = builder.get_value(node.input[0])
73
+ slope = builder.get_value(node.input[1])
74
+
75
+ def _prelu(x: torch.Tensor, slope: torch.Tensor) -> torch.Tensor:
76
+ # If slope is 1D with size matching channels and input is ND with N > 1,
77
+ # reshape slope for proper broadcasting along channel dimension (dim=1)
78
+ if slope.ndim == 1 and x.ndim > 1 and slope.numel() == x.shape[1]:
79
+ # Reshape [C] to [1, C, 1, 1, ...] for broadcasting
80
+ shape = [1, slope.numel()] + [1] * (x.ndim - 2)
81
+ slope = slope.view(shape)
82
+ return torch.where(x >= 0, x, x * slope)
83
+
84
+ return builder.call_function(_prelu, args=(x, slope))
85
+
86
+
87
+ register("Elu")(
88
+ unary_op_with_kwargs(
89
+ F.elu,
90
+ attr_map={"alpha": ("alpha", 1.0)},
91
+ doc="Exponential Linear Unit activation.",
92
+ )
93
+ )
94
+
95
+
96
+ @register("Selu")
97
+ def selu(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
98
+ """Scaled Exponential Linear Unit activation.
99
+
100
+ ONNX SELU: y = gamma * (alpha * (exp(x) - 1) for x < 0, x for x >= 0)
101
+ PyTorch SELU uses fixed alpha=1.6732... and gamma=1.0507...
102
+ ONNX defaults: alpha=1.67326..., gamma=1.0507... but allows custom values.
103
+ """
104
+ x = builder.get_value(node.input[0])
105
+ alpha = get_attribute(node, "alpha", 1.67326319217681884765625)
106
+ gamma = get_attribute(node, "gamma", 1.05070102214813232421875)
107
+
108
+ # PyTorch's fixed SELU values
109
+ pytorch_alpha = 1.6732632423543772848170429916717
110
+ pytorch_gamma = 1.0507009873554804934193349852946
111
+
112
+ # If using PyTorch's fixed values (within tolerance), use F.selu for efficiency
113
+ if abs(alpha - pytorch_alpha) < 1e-5 and abs(gamma - pytorch_gamma) < 1e-5:
114
+ return builder.call_function(F.selu, args=(x,))
115
+
116
+ # Otherwise implement manually: gamma * (alpha * (exp(x) - 1) for x < 0, x for x >= 0)
117
+ def _custom_selu(x: torch.Tensor, alpha: float, gamma: float) -> torch.Tensor:
118
+ return gamma * torch.where(x > 0, x, alpha * (torch.exp(x) - 1))
119
+
120
+ return builder.call_function(_custom_selu, args=(x, alpha, gamma))
121
+
122
+
123
+ register("Celu")(
124
+ unary_op_with_kwargs(
125
+ F.celu,
126
+ attr_map={"alpha": ("alpha", 1.0)},
127
+ doc="Continuously Differentiable Exponential Linear Unit activation.",
128
+ )
129
+ )
130
+
131
+
132
+ register("Sigmoid")(unary_op(torch.sigmoid, "Sigmoid activation."))
133
+
134
+
135
+ @register("HardSigmoid")
136
+ def hard_sigmoid(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
137
+ """Hard Sigmoid activation.
138
+
139
+ ONNX HardSigmoid: max(0, min(1, alpha * x + beta))
140
+ ONNX defaults: alpha=0.2, beta=0.5
141
+ PyTorch hardsigmoid uses fixed alpha=1/6, beta=0.5
142
+ """
143
+ x = builder.get_value(node.input[0])
144
+ alpha = get_attribute(node, "alpha", 0.2)
145
+ beta = get_attribute(node, "beta", 0.5)
146
+
147
+ # PyTorch hardsigmoid uses alpha=1/6 ≈ 0.16667, beta=0.5
148
+ pytorch_alpha = 1.0 / 6.0
149
+
150
+ # If using PyTorch's fixed values (within tolerance), use F.hardsigmoid
151
+ if abs(alpha - pytorch_alpha) < 1e-5 and abs(beta - 0.5) < 1e-5:
152
+ return builder.call_function(F.hardsigmoid, args=(x,))
153
+
154
+ # Otherwise implement manually: max(0, min(1, alpha * x + beta))
155
+ def _custom_hardsigmoid(x: torch.Tensor, alpha: float, beta: float) -> torch.Tensor:
156
+ return torch.clamp(alpha * x + beta, 0.0, 1.0)
157
+
158
+ return builder.call_function(_custom_hardsigmoid, args=(x, alpha, beta))
159
+
160
+
161
+ register("Tanh")(unary_op(torch.tanh, "Tanh activation."))
162
+
163
+
164
+ register("Softmax", since_version=1)(
165
+ _coerced_softmax_handler(
166
+ F.softmax,
167
+ default_axis=1,
168
+ doc=(
169
+ "Softmax activation for opset 1-12 with axis defaulting to 1 and "
170
+ "2D coercion."
171
+ ),
172
+ )
173
+ )
174
+
175
+
176
+ register("Softmax", since_version=13)(
177
+ unary_op_with_kwargs(
178
+ F.softmax,
179
+ attr_map={"dim": ("axis", -1)},
180
+ doc=(
181
+ "Softmax activation for opset 13+ with axis defaulting to the last "
182
+ "dimension."
183
+ ),
184
+ )
185
+ )
186
+
187
+
188
+ register("LogSoftmax", since_version=1)(
189
+ _coerced_softmax_handler(
190
+ F.log_softmax,
191
+ default_axis=1,
192
+ doc=(
193
+ "Log Softmax activation for opset 1-12 with axis defaulting to 1 "
194
+ "and 2D coercion."
195
+ ),
196
+ )
197
+ )
198
+
199
+
200
+ register("LogSoftmax", since_version=13)(
201
+ unary_op_with_kwargs(
202
+ F.log_softmax,
203
+ attr_map={"dim": ("axis", -1)},
204
+ doc="Log Softmax activation for opset 13+.",
205
+ )
206
+ )
207
+
208
+
209
+ register("Softplus")(unary_op(F.softplus, "Softplus activation."))
210
+ register("Softsign")(unary_op(F.softsign, "Softsign activation."))
211
+
212
+
213
+ register("Gelu")(
214
+ unary_op_with_kwargs(
215
+ F.gelu,
216
+ attr_map={"approximate": ("approximate", "none")},
217
+ doc="Gaussian Error Linear Unit activation.",
218
+ )
219
+ )
220
+
221
+
222
+ register("Silu")(unary_op(F.silu, "Sigmoid Linear Unit (SiLU/Swish) activation."))
223
+ register("Swish")(unary_op(F.silu, "Swish activation (alias for SiLU)."))
224
+ register("Mish")(unary_op(F.mish, "Mish activation."))
225
+
226
+
227
+ register("ThresholdedRelu")(
228
+ unary_op_with_kwargs(
229
+ F.threshold,
230
+ attr_map={"threshold": ("alpha", 1.0)},
231
+ fixed_kwargs={"value": 0.0},
232
+ doc="Thresholded ReLU activation.",
233
+ )
234
+ )
235
+
236
+
237
+ register("HardSwish")(unary_op(F.hardswish, "Hard Swish activation."))
238
+
239
+
240
+ @register("Hardmax")
241
+ def hardmax(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
242
+ """Hardmax - one-hot encoding of argmax."""
243
+ x = builder.get_value(node.input[0])
244
+ axis = get_attribute(node, "axis", -1)
245
+
246
+ def _hardmax(t: torch.Tensor, ax: int) -> torch.Tensor:
247
+ # Normalize axis to positive
248
+ if ax < 0:
249
+ ax = t.dim() + ax
250
+ # one_hot appends the class dimension at the end
251
+ one_hot = torch.nn.functional.one_hot(
252
+ torch.argmax(t, dim=ax), num_classes=t.shape[ax]
253
+ ).to(t.dtype)
254
+ # Move the one-hot dimension from the end back to the original axis position
255
+ # one_hot has shape: [...dims before ax..., ...dims after ax..., num_classes]
256
+ # We need to move the last dim to position ax
257
+ return torch.movedim(one_hot, -1, ax)
258
+
259
+ return builder.call_function(_hardmax, args=(x, axis))
260
+
261
+
262
+ @register("Shrink")
263
+ def shrink(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
264
+ """Shrink activation.
265
+
266
+ If x < -lambd: y = x + bias
267
+ If x > lambd: y = x - bias
268
+ Otherwise: y = 0
269
+ """
270
+ x = builder.get_value(node.input[0])
271
+ bias = get_attribute(node, "bias", 0.0)
272
+ lambd = get_attribute(node, "lambd", 0.5)
273
+
274
+ def _shrink(t: torch.Tensor, bias: float, lambd: float) -> torch.Tensor:
275
+ result = torch.zeros_like(t)
276
+ mask_neg = t < -lambd
277
+ mask_pos = t > lambd
278
+ result = torch.where(mask_neg, t + bias, result)
279
+ result = torch.where(mask_pos, t - bias, result)
280
+ return result
281
+
282
+ return builder.call_function(_shrink, args=(x, bias, lambd))