onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2fx/ops/linalg.py ADDED
@@ -0,0 +1,33 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Linear algebra operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ from ..op_registry import register
10
+ from ..utils.attributes import get_attribute
11
+
12
+ if TYPE_CHECKING:
13
+ from ..graph_builder import GraphBuilder
14
+
15
+
16
+ # =============================================================================
17
+ # Linear algebra operators
18
+ # =============================================================================
19
+
20
+
21
+ @register("Einsum")
22
+ def einsum(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
23
+ """Einstein summation."""
24
+ equation = get_attribute(node, "equation")
25
+ inputs = [builder.get_value(name) for name in node.input]
26
+ return builder.call_function(torch.einsum, args=(equation, *inputs))
27
+
28
+
29
+ @register("Det")
30
+ def det(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
31
+ """Matrix determinant."""
32
+ x = builder.get_value(node.input[0])
33
+ return builder.call_function(torch.linalg.det, args=(x,))
onnx2fx/ops/loss.py ADDED
@@ -0,0 +1,56 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Loss function operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ from ..op_registry import register
10
+ from ..utils.attributes import get_attribute
11
+ from ..utils.op_helpers import get_optional_input
12
+
13
+ if TYPE_CHECKING:
14
+ from ..graph_builder import GraphBuilder
15
+
16
+
17
+ @register("SoftmaxCrossEntropyLoss")
18
+ def softmax_cross_entropy_loss(
19
+ builder: "GraphBuilder", node: onnx.NodeProto
20
+ ) -> torch.fx.Node:
21
+ """Softmax cross entropy loss."""
22
+ scores = builder.get_value(node.input[0])
23
+ labels = builder.get_value(node.input[1])
24
+ weights = get_optional_input(builder, node, 2)
25
+
26
+ ignore_index = get_attribute(node, "ignore_index", -100)
27
+ reduction = get_attribute(node, "reduction", "mean")
28
+
29
+ kwargs = {"ignore_index": ignore_index, "reduction": reduction}
30
+ if weights is not None:
31
+ kwargs["weight"] = weights
32
+
33
+ return builder.call_function(
34
+ torch.nn.functional.cross_entropy, args=(scores, labels), kwargs=kwargs
35
+ )
36
+
37
+
38
+ @register("NegativeLogLikelihoodLoss")
39
+ def negative_log_likelihood_loss(
40
+ builder: "GraphBuilder", node: onnx.NodeProto
41
+ ) -> torch.fx.Node:
42
+ """Negative log likelihood loss."""
43
+ input_node = builder.get_value(node.input[0])
44
+ target = builder.get_value(node.input[1])
45
+ weight = get_optional_input(builder, node, 2)
46
+
47
+ ignore_index = get_attribute(node, "ignore_index", -100)
48
+ reduction = get_attribute(node, "reduction", "mean")
49
+
50
+ kwargs = {"ignore_index": ignore_index, "reduction": reduction}
51
+ if weight is not None:
52
+ kwargs["weight"] = weight
53
+
54
+ return builder.call_function(
55
+ torch.nn.functional.nll_loss, args=(input_node, target), kwargs=kwargs
56
+ )
onnx2fx/ops/nn.py ADDED
@@ -0,0 +1,96 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Neural network layer operators.
3
+
4
+ This module contains core neural network operators like MatMul, Gemm, and Dropout.
5
+ Other neural network operators are organized in specialized modules:
6
+ - convolution.py: Conv, ConvTranspose, DeformConv
7
+ - pooling.py: MaxPool, AveragePool, GlobalAveragePool, etc.
8
+ - normalization.py: BatchNormalization, LayerNormalization, etc.
9
+ - recurrent.py: LSTM, GRU, RNN
10
+ """
11
+
12
+ from typing import TYPE_CHECKING
13
+
14
+ import onnx
15
+ import torch
16
+
17
+ from ..op_registry import register
18
+ from ..utils.attributes import get_attribute
19
+ from ..utils.op_helpers import get_optional_input
20
+
21
+ if TYPE_CHECKING:
22
+ from ..graph_builder import GraphBuilder
23
+
24
+
25
+ # =============================================================================
26
+ # Matrix multiplication operators
27
+ # =============================================================================
28
+
29
+
30
+ @register("MatMul")
31
+ def matmul(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
32
+ """Matrix multiplication."""
33
+ a = builder.get_value(node.input[0])
34
+ b = builder.get_value(node.input[1])
35
+ return builder.call_function(torch.matmul, args=(a, b))
36
+
37
+
38
+ @register("Gemm")
39
+ def gemm(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
40
+ """General Matrix Multiplication: Y = alpha * A' * B' + beta * C."""
41
+ a = builder.get_value(node.input[0])
42
+ b = builder.get_value(node.input[1])
43
+
44
+ alpha = get_attribute(node, "alpha", 1.0)
45
+ beta = get_attribute(node, "beta", 1.0)
46
+ trans_a = get_attribute(node, "transA", 0)
47
+ trans_b = get_attribute(node, "transB", 0)
48
+
49
+ def _gemm(a, b, c, alpha, beta, trans_a, trans_b):
50
+ if trans_a:
51
+ a = a.T
52
+ if trans_b:
53
+ b = b.T
54
+ result = alpha * torch.matmul(a, b)
55
+ if c is not None:
56
+ result = result + beta * c
57
+ return result
58
+
59
+ c = get_optional_input(builder, node, 2)
60
+
61
+ return builder.call_function(_gemm, args=(a, b, c, alpha, beta, trans_a, trans_b))
62
+
63
+
64
+ # =============================================================================
65
+ # Dropout and regularization
66
+ # =============================================================================
67
+
68
+
69
+ @register("Dropout")
70
+ def dropout(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
71
+ """Dropout (inference mode - identity).
72
+
73
+ ONNX Dropout can have 2 outputs:
74
+ - output: The result after dropout (same as input in inference mode)
75
+ - mask (optional): Boolean mask indicating which elements were kept (all True in inference mode)
76
+ """
77
+ x = builder.get_value(node.input[0])
78
+
79
+ # Check if mask output is requested (second output)
80
+ return_mask = len(node.output) > 1 and node.output[1] != ""
81
+
82
+ # In inference mode, dropout is identity
83
+ # ratio = get_attribute(node, "ratio", 0.5)
84
+ # training_mode from input or default to False
85
+
86
+ def _dropout_with_mask(x):
87
+ # In inference mode, output is identity and mask is all True
88
+ output = x
89
+ mask = torch.ones_like(x, dtype=torch.bool)
90
+ return output, mask
91
+
92
+ if return_mask:
93
+ return builder.call_function(_dropout_with_mask, args=(x,))
94
+ else:
95
+ # For inference without mask, just return input
96
+ return builder.call_function(lambda t: t, args=(x,))
@@ -0,0 +1,289 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Normalization operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from ..op_registry import register
11
+ from ..utils.attributes import get_attribute
12
+ from ..utils.dtype import stash_type_to_torch_dtype
13
+ from ..utils.op_helpers import get_optional_input
14
+
15
+ if TYPE_CHECKING:
16
+ from ..graph_builder import GraphBuilder
17
+
18
+
19
+ # =============================================================================
20
+ # Normalization operators
21
+ # =============================================================================
22
+
23
+
24
+ @register("LpNormalization")
25
+ def lp_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
26
+ """Lp Normalization.
27
+
28
+ Normalizes input element-wise by dividing by the Lp norm along the specified axis.
29
+
30
+ Attributes:
31
+ axis: The axis on which to apply normalization (default: -1)
32
+ p: The order of the normalization, only 1 or 2 are supported (default: 2)
33
+ """
34
+ x = builder.get_value(node.input[0])
35
+
36
+ axis = get_attribute(node, "axis", -1)
37
+ p = get_attribute(node, "p", 2)
38
+
39
+ def _lp_normalize(x, axis, p):
40
+ if p == 1:
41
+ # L1 normalization: x / sum(|x|)
42
+ norm = torch.sum(torch.abs(x), dim=axis, keepdim=True)
43
+ # Avoid division by zero
44
+ norm = torch.clamp(norm, min=1e-12)
45
+ return x / norm
46
+ elif p == 2:
47
+ # L2 normalization: x / sqrt(sum(x^2))
48
+ # Note: We don't use F.normalize because it returns 0 for zero vectors,
49
+ # but ONNX expects NaN (0/0 behavior)
50
+ norm = torch.sqrt(torch.sum(x * x, dim=axis, keepdim=True))
51
+ return x / norm
52
+ else:
53
+ raise ValueError(f"LpNormalization only supports p=1 or p=2, got p={p}")
54
+
55
+ return builder.call_function(_lp_normalize, args=(x, axis, p))
56
+
57
+
58
+ @register("BatchNormalization")
59
+ def batch_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
60
+ """Batch normalization."""
61
+ x = builder.get_value(node.input[0])
62
+ scale = builder.get_value(node.input[1])
63
+ bias = builder.get_value(node.input[2])
64
+ mean = builder.get_value(node.input[3])
65
+ var = builder.get_value(node.input[4])
66
+
67
+ epsilon = get_attribute(node, "epsilon", 1e-5)
68
+ # Note: ONNX momentum attribute is not used in inference mode
69
+ training_mode = get_attribute(node, "training_mode", 0)
70
+
71
+ def _batch_norm(x, scale, bias, mean, var, epsilon, training_mode):
72
+ return F.batch_norm(
73
+ x,
74
+ mean,
75
+ var,
76
+ weight=scale,
77
+ bias=bias,
78
+ training=bool(training_mode),
79
+ eps=epsilon,
80
+ )
81
+
82
+ return builder.call_function(
83
+ _batch_norm, args=(x, scale, bias, mean, var, epsilon, training_mode)
84
+ )
85
+
86
+
87
+ @register("LayerNormalization")
88
+ def layer_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
89
+ """Layer normalization.
90
+
91
+ ONNX LayerNormalization returns up to 3 outputs:
92
+ - Y: normalized output (required)
93
+ - Mean: mean values (optional)
94
+ - InvStdDev: inverse standard deviation (optional)
95
+ """
96
+ x = builder.get_value(node.input[0])
97
+ scale = builder.get_value(node.input[1])
98
+
99
+ bias = get_optional_input(builder, node, 2)
100
+
101
+ axis = get_attribute(node, "axis", -1)
102
+ epsilon = get_attribute(node, "epsilon", 1e-5)
103
+ stash_type = get_attribute(node, "stash_type", 1) # 1 = float32
104
+
105
+ # Check how many outputs are requested
106
+ num_outputs = len([o for o in node.output if o])
107
+
108
+ def _layer_norm_single(x, scale, bias, axis, epsilon):
109
+ # Compute normalized shape from axis
110
+ if axis < 0:
111
+ axis = x.dim() + axis
112
+ normalized_shape = x.shape[axis:]
113
+ return F.layer_norm(x, normalized_shape, weight=scale, bias=bias, eps=epsilon)
114
+
115
+ def _layer_norm_with_stats(x, scale, bias, axis, epsilon, stash_type):
116
+ # Compute normalized shape from axis
117
+ if axis < 0:
118
+ axis = x.dim() + axis
119
+
120
+ # Determine stash dtype for mean/invstddev computation
121
+ stash_dtype = stash_type_to_torch_dtype(stash_type)
122
+
123
+ # Cast input to stash dtype for computing statistics
124
+ original_dtype = x.dtype
125
+ x_stash = x.to(stash_dtype)
126
+
127
+ # Compute mean and variance over the normalized dimensions
128
+ dims = list(range(axis, x.dim()))
129
+ mean = x_stash.mean(dim=dims, keepdim=True)
130
+ var = x_stash.var(dim=dims, unbiased=False, keepdim=True)
131
+ inv_std_dev = 1.0 / torch.sqrt(var + epsilon)
132
+
133
+ # Normalize
134
+ x_norm = (x_stash - mean) * inv_std_dev
135
+
136
+ # Apply scale and bias
137
+ if scale is not None:
138
+ x_norm = x_norm * scale.to(stash_dtype)
139
+ if bias is not None:
140
+ x_norm = x_norm + bias.to(stash_dtype)
141
+
142
+ # Cast back to original dtype
143
+ y = x_norm.to(original_dtype)
144
+
145
+ return (y, mean, inv_std_dev)
146
+
147
+ if num_outputs == 1:
148
+ return builder.call_function(
149
+ _layer_norm_single, args=(x, scale, bias, axis, epsilon)
150
+ )
151
+ else:
152
+ return builder.call_function(
153
+ _layer_norm_with_stats, args=(x, scale, bias, axis, epsilon, stash_type)
154
+ )
155
+
156
+
157
+ @register("RMSNormalization", since_version=23)
158
+ def rms_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
159
+ """RMS Normalization (Root Mean Square Layer Normalization).
160
+
161
+ This is LayerNormalization without mean subtraction, also known as RMSNorm.
162
+ Formula: Y = X / sqrt(mean(X^2) + epsilon) * scale
163
+
164
+ Inputs:
165
+ X: Input tensor
166
+ scale: Scale tensor (broadcastable to normalized shape)
167
+
168
+ Attributes:
169
+ axis: First normalization dimension (default: -1)
170
+ epsilon: Small constant for numerical stability (default: 1e-5)
171
+ stash_type: Floating-point precision for computation (default: 1 = float32)
172
+ """
173
+ x = builder.get_value(node.input[0])
174
+ scale = builder.get_value(node.input[1])
175
+
176
+ axis = get_attribute(node, "axis", -1)
177
+ epsilon = get_attribute(node, "epsilon", 1e-5)
178
+ stash_type = get_attribute(node, "stash_type", 1)
179
+
180
+ def _rms_norm(x, scale, axis, epsilon, stash_type):
181
+ # Determine stash dtype for computation
182
+ stash_dtype = stash_type_to_torch_dtype(stash_type)
183
+
184
+ # Normalize axis
185
+ if axis < 0:
186
+ axis_pos = x.dim() + axis
187
+ else:
188
+ axis_pos = axis
189
+
190
+ # Save original dtype for casting back
191
+ original_dtype = x.dtype
192
+
193
+ # Cast to stash dtype for computation
194
+ x_stash = x.to(stash_dtype)
195
+
196
+ # Compute dimensions to reduce over (from axis to end)
197
+ dims = list(range(axis_pos, x.dim()))
198
+
199
+ # Compute RMS: sqrt(mean(x^2) + epsilon)
200
+ x_squared = x_stash.pow(2)
201
+ mean_squared = x_squared.mean(dim=dims, keepdim=True)
202
+ rms = torch.sqrt(mean_squared + epsilon)
203
+
204
+ # Normalize
205
+ x_normalized = x_stash / rms
206
+
207
+ # Apply scale
208
+ scale_stash = scale.to(stash_dtype)
209
+ y = x_normalized * scale_stash
210
+
211
+ # Cast back to original dtype
212
+ return y.to(original_dtype)
213
+
214
+ return builder.call_function(_rms_norm, args=(x, scale, axis, epsilon, stash_type))
215
+
216
+
217
+ @register("InstanceNormalization")
218
+ def instance_normalization(
219
+ builder: "GraphBuilder", node: onnx.NodeProto
220
+ ) -> torch.fx.Node:
221
+ """Instance normalization."""
222
+ x = builder.get_value(node.input[0])
223
+ scale = builder.get_value(node.input[1])
224
+ bias = builder.get_value(node.input[2])
225
+
226
+ epsilon = get_attribute(node, "epsilon", 1e-5)
227
+
228
+ def _instance_norm(x, scale, bias, epsilon):
229
+ return F.instance_norm(x, weight=scale, bias=bias, eps=epsilon)
230
+
231
+ return builder.call_function(_instance_norm, args=(x, scale, bias, epsilon))
232
+
233
+
234
+ @register("GroupNormalization")
235
+ def group_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
236
+ """Group normalization."""
237
+ x = builder.get_value(node.input[0])
238
+ scale = builder.get_value(node.input[1])
239
+ bias = builder.get_value(node.input[2])
240
+
241
+ epsilon = get_attribute(node, "epsilon", 1e-5)
242
+ num_groups = get_attribute(node, "num_groups")
243
+
244
+ def _group_norm(x, scale, bias, num_groups, epsilon):
245
+ return F.group_norm(x, num_groups, weight=scale, bias=bias, eps=epsilon)
246
+
247
+ return builder.call_function(
248
+ _group_norm, args=(x, scale, bias, num_groups, epsilon)
249
+ )
250
+
251
+
252
+ @register("LRN")
253
+ def lrn(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
254
+ """Local Response Normalization."""
255
+ x = builder.get_value(node.input[0])
256
+
257
+ alpha = get_attribute(node, "alpha", 0.0001)
258
+ beta = get_attribute(node, "beta", 0.75)
259
+ bias = get_attribute(node, "bias", 1.0)
260
+ size = get_attribute(node, "size")
261
+
262
+ return builder.call_function(
263
+ F.local_response_norm,
264
+ args=(x, size),
265
+ kwargs={"alpha": alpha, "beta": beta, "k": bias},
266
+ )
267
+
268
+
269
+ @register("MeanVarianceNormalization")
270
+ def mean_variance_normalization(
271
+ builder: "GraphBuilder", node: onnx.NodeProto
272
+ ) -> torch.fx.Node:
273
+ """Mean Variance Normalization.
274
+
275
+ Performs normalization using formula: (X - E[X]) / sqrt(E[(X - E[X])^2])
276
+ Default axes are [0, 2, 3] for NCHW format (normalize across N, H, W).
277
+ """
278
+ x = builder.get_value(node.input[0])
279
+ axes = get_attribute(node, "axes", [0, 2, 3])
280
+
281
+ def _mvn(x, axes):
282
+ axes = tuple(axes)
283
+ eps = 1e-9
284
+ mean = x.mean(dim=axes, keepdim=True)
285
+ variance = ((x - mean) ** 2).mean(dim=axes, keepdim=True)
286
+ std = torch.sqrt(variance + eps)
287
+ return (x - mean) / std
288
+
289
+ return builder.call_function(_mvn, args=(x, axes))