onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/linalg.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Linear algebra operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..op_registry import register
|
|
10
|
+
from ..utils.attributes import get_attribute
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ..graph_builder import GraphBuilder
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# =============================================================================
|
|
17
|
+
# Linear algebra operators
|
|
18
|
+
# =============================================================================
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register("Einsum")
|
|
22
|
+
def einsum(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
23
|
+
"""Einstein summation."""
|
|
24
|
+
equation = get_attribute(node, "equation")
|
|
25
|
+
inputs = [builder.get_value(name) for name in node.input]
|
|
26
|
+
return builder.call_function(torch.einsum, args=(equation, *inputs))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@register("Det")
|
|
30
|
+
def det(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
31
|
+
"""Matrix determinant."""
|
|
32
|
+
x = builder.get_value(node.input[0])
|
|
33
|
+
return builder.call_function(torch.linalg.det, args=(x,))
|
onnx2fx/ops/loss.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Loss function operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..op_registry import register
|
|
10
|
+
from ..utils.attributes import get_attribute
|
|
11
|
+
from ..utils.op_helpers import get_optional_input
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..graph_builder import GraphBuilder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register("SoftmaxCrossEntropyLoss")
|
|
18
|
+
def softmax_cross_entropy_loss(
|
|
19
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
20
|
+
) -> torch.fx.Node:
|
|
21
|
+
"""Softmax cross entropy loss."""
|
|
22
|
+
scores = builder.get_value(node.input[0])
|
|
23
|
+
labels = builder.get_value(node.input[1])
|
|
24
|
+
weights = get_optional_input(builder, node, 2)
|
|
25
|
+
|
|
26
|
+
ignore_index = get_attribute(node, "ignore_index", -100)
|
|
27
|
+
reduction = get_attribute(node, "reduction", "mean")
|
|
28
|
+
|
|
29
|
+
kwargs = {"ignore_index": ignore_index, "reduction": reduction}
|
|
30
|
+
if weights is not None:
|
|
31
|
+
kwargs["weight"] = weights
|
|
32
|
+
|
|
33
|
+
return builder.call_function(
|
|
34
|
+
torch.nn.functional.cross_entropy, args=(scores, labels), kwargs=kwargs
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@register("NegativeLogLikelihoodLoss")
|
|
39
|
+
def negative_log_likelihood_loss(
|
|
40
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
41
|
+
) -> torch.fx.Node:
|
|
42
|
+
"""Negative log likelihood loss."""
|
|
43
|
+
input_node = builder.get_value(node.input[0])
|
|
44
|
+
target = builder.get_value(node.input[1])
|
|
45
|
+
weight = get_optional_input(builder, node, 2)
|
|
46
|
+
|
|
47
|
+
ignore_index = get_attribute(node, "ignore_index", -100)
|
|
48
|
+
reduction = get_attribute(node, "reduction", "mean")
|
|
49
|
+
|
|
50
|
+
kwargs = {"ignore_index": ignore_index, "reduction": reduction}
|
|
51
|
+
if weight is not None:
|
|
52
|
+
kwargs["weight"] = weight
|
|
53
|
+
|
|
54
|
+
return builder.call_function(
|
|
55
|
+
torch.nn.functional.nll_loss, args=(input_node, target), kwargs=kwargs
|
|
56
|
+
)
|
onnx2fx/ops/nn.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Neural network layer operators.
|
|
3
|
+
|
|
4
|
+
This module contains core neural network operators like MatMul, Gemm, and Dropout.
|
|
5
|
+
Other neural network operators are organized in specialized modules:
|
|
6
|
+
- convolution.py: Conv, ConvTranspose, DeformConv
|
|
7
|
+
- pooling.py: MaxPool, AveragePool, GlobalAveragePool, etc.
|
|
8
|
+
- normalization.py: BatchNormalization, LayerNormalization, etc.
|
|
9
|
+
- recurrent.py: LSTM, GRU, RNN
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
import onnx
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from ..op_registry import register
|
|
18
|
+
from ..utils.attributes import get_attribute
|
|
19
|
+
from ..utils.op_helpers import get_optional_input
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from ..graph_builder import GraphBuilder
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# =============================================================================
|
|
26
|
+
# Matrix multiplication operators
|
|
27
|
+
# =============================================================================
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register("MatMul")
|
|
31
|
+
def matmul(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
32
|
+
"""Matrix multiplication."""
|
|
33
|
+
a = builder.get_value(node.input[0])
|
|
34
|
+
b = builder.get_value(node.input[1])
|
|
35
|
+
return builder.call_function(torch.matmul, args=(a, b))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@register("Gemm")
|
|
39
|
+
def gemm(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
40
|
+
"""General Matrix Multiplication: Y = alpha * A' * B' + beta * C."""
|
|
41
|
+
a = builder.get_value(node.input[0])
|
|
42
|
+
b = builder.get_value(node.input[1])
|
|
43
|
+
|
|
44
|
+
alpha = get_attribute(node, "alpha", 1.0)
|
|
45
|
+
beta = get_attribute(node, "beta", 1.0)
|
|
46
|
+
trans_a = get_attribute(node, "transA", 0)
|
|
47
|
+
trans_b = get_attribute(node, "transB", 0)
|
|
48
|
+
|
|
49
|
+
def _gemm(a, b, c, alpha, beta, trans_a, trans_b):
|
|
50
|
+
if trans_a:
|
|
51
|
+
a = a.T
|
|
52
|
+
if trans_b:
|
|
53
|
+
b = b.T
|
|
54
|
+
result = alpha * torch.matmul(a, b)
|
|
55
|
+
if c is not None:
|
|
56
|
+
result = result + beta * c
|
|
57
|
+
return result
|
|
58
|
+
|
|
59
|
+
c = get_optional_input(builder, node, 2)
|
|
60
|
+
|
|
61
|
+
return builder.call_function(_gemm, args=(a, b, c, alpha, beta, trans_a, trans_b))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# =============================================================================
|
|
65
|
+
# Dropout and regularization
|
|
66
|
+
# =============================================================================
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@register("Dropout")
|
|
70
|
+
def dropout(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
71
|
+
"""Dropout (inference mode - identity).
|
|
72
|
+
|
|
73
|
+
ONNX Dropout can have 2 outputs:
|
|
74
|
+
- output: The result after dropout (same as input in inference mode)
|
|
75
|
+
- mask (optional): Boolean mask indicating which elements were kept (all True in inference mode)
|
|
76
|
+
"""
|
|
77
|
+
x = builder.get_value(node.input[0])
|
|
78
|
+
|
|
79
|
+
# Check if mask output is requested (second output)
|
|
80
|
+
return_mask = len(node.output) > 1 and node.output[1] != ""
|
|
81
|
+
|
|
82
|
+
# In inference mode, dropout is identity
|
|
83
|
+
# ratio = get_attribute(node, "ratio", 0.5)
|
|
84
|
+
# training_mode from input or default to False
|
|
85
|
+
|
|
86
|
+
def _dropout_with_mask(x):
|
|
87
|
+
# In inference mode, output is identity and mask is all True
|
|
88
|
+
output = x
|
|
89
|
+
mask = torch.ones_like(x, dtype=torch.bool)
|
|
90
|
+
return output, mask
|
|
91
|
+
|
|
92
|
+
if return_mask:
|
|
93
|
+
return builder.call_function(_dropout_with_mask, args=(x,))
|
|
94
|
+
else:
|
|
95
|
+
# For inference without mask, just return input
|
|
96
|
+
return builder.call_function(lambda t: t, args=(x,))
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Normalization operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from ..op_registry import register
|
|
11
|
+
from ..utils.attributes import get_attribute
|
|
12
|
+
from ..utils.dtype import stash_type_to_torch_dtype
|
|
13
|
+
from ..utils.op_helpers import get_optional_input
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ..graph_builder import GraphBuilder
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# =============================================================================
|
|
20
|
+
# Normalization operators
|
|
21
|
+
# =============================================================================
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@register("LpNormalization")
|
|
25
|
+
def lp_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
26
|
+
"""Lp Normalization.
|
|
27
|
+
|
|
28
|
+
Normalizes input element-wise by dividing by the Lp norm along the specified axis.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
axis: The axis on which to apply normalization (default: -1)
|
|
32
|
+
p: The order of the normalization, only 1 or 2 are supported (default: 2)
|
|
33
|
+
"""
|
|
34
|
+
x = builder.get_value(node.input[0])
|
|
35
|
+
|
|
36
|
+
axis = get_attribute(node, "axis", -1)
|
|
37
|
+
p = get_attribute(node, "p", 2)
|
|
38
|
+
|
|
39
|
+
def _lp_normalize(x, axis, p):
|
|
40
|
+
if p == 1:
|
|
41
|
+
# L1 normalization: x / sum(|x|)
|
|
42
|
+
norm = torch.sum(torch.abs(x), dim=axis, keepdim=True)
|
|
43
|
+
# Avoid division by zero
|
|
44
|
+
norm = torch.clamp(norm, min=1e-12)
|
|
45
|
+
return x / norm
|
|
46
|
+
elif p == 2:
|
|
47
|
+
# L2 normalization: x / sqrt(sum(x^2))
|
|
48
|
+
# Note: We don't use F.normalize because it returns 0 for zero vectors,
|
|
49
|
+
# but ONNX expects NaN (0/0 behavior)
|
|
50
|
+
norm = torch.sqrt(torch.sum(x * x, dim=axis, keepdim=True))
|
|
51
|
+
return x / norm
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(f"LpNormalization only supports p=1 or p=2, got p={p}")
|
|
54
|
+
|
|
55
|
+
return builder.call_function(_lp_normalize, args=(x, axis, p))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@register("BatchNormalization")
|
|
59
|
+
def batch_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
60
|
+
"""Batch normalization."""
|
|
61
|
+
x = builder.get_value(node.input[0])
|
|
62
|
+
scale = builder.get_value(node.input[1])
|
|
63
|
+
bias = builder.get_value(node.input[2])
|
|
64
|
+
mean = builder.get_value(node.input[3])
|
|
65
|
+
var = builder.get_value(node.input[4])
|
|
66
|
+
|
|
67
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
68
|
+
# Note: ONNX momentum attribute is not used in inference mode
|
|
69
|
+
training_mode = get_attribute(node, "training_mode", 0)
|
|
70
|
+
|
|
71
|
+
def _batch_norm(x, scale, bias, mean, var, epsilon, training_mode):
|
|
72
|
+
return F.batch_norm(
|
|
73
|
+
x,
|
|
74
|
+
mean,
|
|
75
|
+
var,
|
|
76
|
+
weight=scale,
|
|
77
|
+
bias=bias,
|
|
78
|
+
training=bool(training_mode),
|
|
79
|
+
eps=epsilon,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return builder.call_function(
|
|
83
|
+
_batch_norm, args=(x, scale, bias, mean, var, epsilon, training_mode)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@register("LayerNormalization")
|
|
88
|
+
def layer_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
89
|
+
"""Layer normalization.
|
|
90
|
+
|
|
91
|
+
ONNX LayerNormalization returns up to 3 outputs:
|
|
92
|
+
- Y: normalized output (required)
|
|
93
|
+
- Mean: mean values (optional)
|
|
94
|
+
- InvStdDev: inverse standard deviation (optional)
|
|
95
|
+
"""
|
|
96
|
+
x = builder.get_value(node.input[0])
|
|
97
|
+
scale = builder.get_value(node.input[1])
|
|
98
|
+
|
|
99
|
+
bias = get_optional_input(builder, node, 2)
|
|
100
|
+
|
|
101
|
+
axis = get_attribute(node, "axis", -1)
|
|
102
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
103
|
+
stash_type = get_attribute(node, "stash_type", 1) # 1 = float32
|
|
104
|
+
|
|
105
|
+
# Check how many outputs are requested
|
|
106
|
+
num_outputs = len([o for o in node.output if o])
|
|
107
|
+
|
|
108
|
+
def _layer_norm_single(x, scale, bias, axis, epsilon):
|
|
109
|
+
# Compute normalized shape from axis
|
|
110
|
+
if axis < 0:
|
|
111
|
+
axis = x.dim() + axis
|
|
112
|
+
normalized_shape = x.shape[axis:]
|
|
113
|
+
return F.layer_norm(x, normalized_shape, weight=scale, bias=bias, eps=epsilon)
|
|
114
|
+
|
|
115
|
+
def _layer_norm_with_stats(x, scale, bias, axis, epsilon, stash_type):
|
|
116
|
+
# Compute normalized shape from axis
|
|
117
|
+
if axis < 0:
|
|
118
|
+
axis = x.dim() + axis
|
|
119
|
+
|
|
120
|
+
# Determine stash dtype for mean/invstddev computation
|
|
121
|
+
stash_dtype = stash_type_to_torch_dtype(stash_type)
|
|
122
|
+
|
|
123
|
+
# Cast input to stash dtype for computing statistics
|
|
124
|
+
original_dtype = x.dtype
|
|
125
|
+
x_stash = x.to(stash_dtype)
|
|
126
|
+
|
|
127
|
+
# Compute mean and variance over the normalized dimensions
|
|
128
|
+
dims = list(range(axis, x.dim()))
|
|
129
|
+
mean = x_stash.mean(dim=dims, keepdim=True)
|
|
130
|
+
var = x_stash.var(dim=dims, unbiased=False, keepdim=True)
|
|
131
|
+
inv_std_dev = 1.0 / torch.sqrt(var + epsilon)
|
|
132
|
+
|
|
133
|
+
# Normalize
|
|
134
|
+
x_norm = (x_stash - mean) * inv_std_dev
|
|
135
|
+
|
|
136
|
+
# Apply scale and bias
|
|
137
|
+
if scale is not None:
|
|
138
|
+
x_norm = x_norm * scale.to(stash_dtype)
|
|
139
|
+
if bias is not None:
|
|
140
|
+
x_norm = x_norm + bias.to(stash_dtype)
|
|
141
|
+
|
|
142
|
+
# Cast back to original dtype
|
|
143
|
+
y = x_norm.to(original_dtype)
|
|
144
|
+
|
|
145
|
+
return (y, mean, inv_std_dev)
|
|
146
|
+
|
|
147
|
+
if num_outputs == 1:
|
|
148
|
+
return builder.call_function(
|
|
149
|
+
_layer_norm_single, args=(x, scale, bias, axis, epsilon)
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
return builder.call_function(
|
|
153
|
+
_layer_norm_with_stats, args=(x, scale, bias, axis, epsilon, stash_type)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@register("RMSNormalization", since_version=23)
|
|
158
|
+
def rms_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
159
|
+
"""RMS Normalization (Root Mean Square Layer Normalization).
|
|
160
|
+
|
|
161
|
+
This is LayerNormalization without mean subtraction, also known as RMSNorm.
|
|
162
|
+
Formula: Y = X / sqrt(mean(X^2) + epsilon) * scale
|
|
163
|
+
|
|
164
|
+
Inputs:
|
|
165
|
+
X: Input tensor
|
|
166
|
+
scale: Scale tensor (broadcastable to normalized shape)
|
|
167
|
+
|
|
168
|
+
Attributes:
|
|
169
|
+
axis: First normalization dimension (default: -1)
|
|
170
|
+
epsilon: Small constant for numerical stability (default: 1e-5)
|
|
171
|
+
stash_type: Floating-point precision for computation (default: 1 = float32)
|
|
172
|
+
"""
|
|
173
|
+
x = builder.get_value(node.input[0])
|
|
174
|
+
scale = builder.get_value(node.input[1])
|
|
175
|
+
|
|
176
|
+
axis = get_attribute(node, "axis", -1)
|
|
177
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
178
|
+
stash_type = get_attribute(node, "stash_type", 1)
|
|
179
|
+
|
|
180
|
+
def _rms_norm(x, scale, axis, epsilon, stash_type):
|
|
181
|
+
# Determine stash dtype for computation
|
|
182
|
+
stash_dtype = stash_type_to_torch_dtype(stash_type)
|
|
183
|
+
|
|
184
|
+
# Normalize axis
|
|
185
|
+
if axis < 0:
|
|
186
|
+
axis_pos = x.dim() + axis
|
|
187
|
+
else:
|
|
188
|
+
axis_pos = axis
|
|
189
|
+
|
|
190
|
+
# Save original dtype for casting back
|
|
191
|
+
original_dtype = x.dtype
|
|
192
|
+
|
|
193
|
+
# Cast to stash dtype for computation
|
|
194
|
+
x_stash = x.to(stash_dtype)
|
|
195
|
+
|
|
196
|
+
# Compute dimensions to reduce over (from axis to end)
|
|
197
|
+
dims = list(range(axis_pos, x.dim()))
|
|
198
|
+
|
|
199
|
+
# Compute RMS: sqrt(mean(x^2) + epsilon)
|
|
200
|
+
x_squared = x_stash.pow(2)
|
|
201
|
+
mean_squared = x_squared.mean(dim=dims, keepdim=True)
|
|
202
|
+
rms = torch.sqrt(mean_squared + epsilon)
|
|
203
|
+
|
|
204
|
+
# Normalize
|
|
205
|
+
x_normalized = x_stash / rms
|
|
206
|
+
|
|
207
|
+
# Apply scale
|
|
208
|
+
scale_stash = scale.to(stash_dtype)
|
|
209
|
+
y = x_normalized * scale_stash
|
|
210
|
+
|
|
211
|
+
# Cast back to original dtype
|
|
212
|
+
return y.to(original_dtype)
|
|
213
|
+
|
|
214
|
+
return builder.call_function(_rms_norm, args=(x, scale, axis, epsilon, stash_type))
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@register("InstanceNormalization")
|
|
218
|
+
def instance_normalization(
|
|
219
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
220
|
+
) -> torch.fx.Node:
|
|
221
|
+
"""Instance normalization."""
|
|
222
|
+
x = builder.get_value(node.input[0])
|
|
223
|
+
scale = builder.get_value(node.input[1])
|
|
224
|
+
bias = builder.get_value(node.input[2])
|
|
225
|
+
|
|
226
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
227
|
+
|
|
228
|
+
def _instance_norm(x, scale, bias, epsilon):
|
|
229
|
+
return F.instance_norm(x, weight=scale, bias=bias, eps=epsilon)
|
|
230
|
+
|
|
231
|
+
return builder.call_function(_instance_norm, args=(x, scale, bias, epsilon))
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@register("GroupNormalization")
|
|
235
|
+
def group_normalization(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
236
|
+
"""Group normalization."""
|
|
237
|
+
x = builder.get_value(node.input[0])
|
|
238
|
+
scale = builder.get_value(node.input[1])
|
|
239
|
+
bias = builder.get_value(node.input[2])
|
|
240
|
+
|
|
241
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
242
|
+
num_groups = get_attribute(node, "num_groups")
|
|
243
|
+
|
|
244
|
+
def _group_norm(x, scale, bias, num_groups, epsilon):
|
|
245
|
+
return F.group_norm(x, num_groups, weight=scale, bias=bias, eps=epsilon)
|
|
246
|
+
|
|
247
|
+
return builder.call_function(
|
|
248
|
+
_group_norm, args=(x, scale, bias, num_groups, epsilon)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@register("LRN")
|
|
253
|
+
def lrn(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
254
|
+
"""Local Response Normalization."""
|
|
255
|
+
x = builder.get_value(node.input[0])
|
|
256
|
+
|
|
257
|
+
alpha = get_attribute(node, "alpha", 0.0001)
|
|
258
|
+
beta = get_attribute(node, "beta", 0.75)
|
|
259
|
+
bias = get_attribute(node, "bias", 1.0)
|
|
260
|
+
size = get_attribute(node, "size")
|
|
261
|
+
|
|
262
|
+
return builder.call_function(
|
|
263
|
+
F.local_response_norm,
|
|
264
|
+
args=(x, size),
|
|
265
|
+
kwargs={"alpha": alpha, "beta": beta, "k": bias},
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@register("MeanVarianceNormalization")
|
|
270
|
+
def mean_variance_normalization(
|
|
271
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
272
|
+
) -> torch.fx.Node:
|
|
273
|
+
"""Mean Variance Normalization.
|
|
274
|
+
|
|
275
|
+
Performs normalization using formula: (X - E[X]) / sqrt(E[(X - E[X])^2])
|
|
276
|
+
Default axes are [0, 2, 3] for NCHW format (normalize across N, H, W).
|
|
277
|
+
"""
|
|
278
|
+
x = builder.get_value(node.input[0])
|
|
279
|
+
axes = get_attribute(node, "axes", [0, 2, 3])
|
|
280
|
+
|
|
281
|
+
def _mvn(x, axes):
|
|
282
|
+
axes = tuple(axes)
|
|
283
|
+
eps = 1e-9
|
|
284
|
+
mean = x.mean(dim=axes, keepdim=True)
|
|
285
|
+
variance = ((x - mean) ** 2).mean(dim=axes, keepdim=True)
|
|
286
|
+
std = torch.sqrt(variance + eps)
|
|
287
|
+
return (x - mean) / std
|
|
288
|
+
|
|
289
|
+
return builder.call_function(_mvn, args=(x, axes))
|