onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Helper utilities for operator implementations.
|
|
3
|
+
|
|
4
|
+
This module provides factory functions and helpers to reduce boilerplate
|
|
5
|
+
in operator implementations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
9
|
+
|
|
10
|
+
import onnx
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
from ..exceptions import ConversionError
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..graph_builder import GraphBuilder
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_optional_input(
|
|
21
|
+
builder: "GraphBuilder",
|
|
22
|
+
node: onnx.NodeProto,
|
|
23
|
+
index: int,
|
|
24
|
+
default: Any = None,
|
|
25
|
+
) -> Any:
|
|
26
|
+
"""Get an optional input from a node, returning default if not present.
|
|
27
|
+
|
|
28
|
+
This replaces the common pattern:
|
|
29
|
+
if len(node.input) > N and node.input[N]:
|
|
30
|
+
value = builder.get_value(node.input[N])
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
builder: The graph builder instance.
|
|
34
|
+
node: The ONNX node.
|
|
35
|
+
index: The input index to retrieve.
|
|
36
|
+
default: Default value if input is not present.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
The input value or default.
|
|
40
|
+
"""
|
|
41
|
+
if len(node.input) > index and node.input[index]:
|
|
42
|
+
return builder.get_value(node.input[index])
|
|
43
|
+
return default
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_attribute_or_input(
|
|
47
|
+
builder: "GraphBuilder",
|
|
48
|
+
node: onnx.NodeProto,
|
|
49
|
+
*,
|
|
50
|
+
attr_name: str,
|
|
51
|
+
input_index: int,
|
|
52
|
+
opset_version: int,
|
|
53
|
+
attr_allowed_until: Optional[int] = None,
|
|
54
|
+
input_allowed_since: Optional[int] = None,
|
|
55
|
+
default: Any = None,
|
|
56
|
+
as_python: bool = True,
|
|
57
|
+
) -> Any:
|
|
58
|
+
"""Resolve a value from attribute or input with opset checks.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
builder: The graph builder instance.
|
|
62
|
+
node: The ONNX node.
|
|
63
|
+
attr_name: Attribute name to read.
|
|
64
|
+
input_index: Input index to read.
|
|
65
|
+
opset_version: Active opset version.
|
|
66
|
+
attr_allowed_until: Highest opset version that allows the attribute.
|
|
67
|
+
input_allowed_since: Lowest opset version that allows the input.
|
|
68
|
+
default: Value to return when neither attribute nor input is provided.
|
|
69
|
+
as_python: Convert constant tensors to Python scalars/lists.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The resolved value, or default.
|
|
73
|
+
"""
|
|
74
|
+
from .attributes import get_attribute
|
|
75
|
+
|
|
76
|
+
input_present = len(node.input) > input_index and node.input[input_index]
|
|
77
|
+
if input_present and input_allowed_since is not None:
|
|
78
|
+
if opset_version < input_allowed_since:
|
|
79
|
+
raise ConversionError(
|
|
80
|
+
(
|
|
81
|
+
f"Input[{input_index}] for '{attr_name}' is not valid before "
|
|
82
|
+
f"opset {input_allowed_since}"
|
|
83
|
+
),
|
|
84
|
+
node_name=node.name,
|
|
85
|
+
op_type=node.op_type,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
attr_value = get_attribute(node, attr_name)
|
|
89
|
+
if attr_value is not None:
|
|
90
|
+
if attr_allowed_until is not None and opset_version > attr_allowed_until:
|
|
91
|
+
raise ConversionError(
|
|
92
|
+
(
|
|
93
|
+
f"Attribute '{attr_name}' is not valid after opset "
|
|
94
|
+
f"{attr_allowed_until}"
|
|
95
|
+
),
|
|
96
|
+
node_name=node.name,
|
|
97
|
+
op_type=node.op_type,
|
|
98
|
+
)
|
|
99
|
+
return attr_value
|
|
100
|
+
|
|
101
|
+
if input_present:
|
|
102
|
+
value = _resolve_input_value(builder, node.input[input_index])
|
|
103
|
+
if as_python:
|
|
104
|
+
return _as_python_value(value)
|
|
105
|
+
return value
|
|
106
|
+
|
|
107
|
+
return default
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _resolve_input_value(builder: "GraphBuilder", name: str) -> Any:
|
|
111
|
+
if name in builder.initializer_map:
|
|
112
|
+
return builder.initializer_map[name]
|
|
113
|
+
return builder.get_value(name)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _as_python_value(value: Any) -> Any:
|
|
117
|
+
if isinstance(value, torch.Tensor):
|
|
118
|
+
result = value.tolist()
|
|
119
|
+
if isinstance(result, int):
|
|
120
|
+
return [result]
|
|
121
|
+
return result
|
|
122
|
+
return value
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def unary_op(
|
|
126
|
+
torch_fn: Callable[..., torch.Tensor],
|
|
127
|
+
doc: Optional[str] = None,
|
|
128
|
+
) -> Callable[["GraphBuilder", onnx.NodeProto], torch.fx.Node]:
|
|
129
|
+
"""Create a handler for simple unary operators.
|
|
130
|
+
|
|
131
|
+
This replaces the common pattern:
|
|
132
|
+
@register("OpName")
|
|
133
|
+
def op_name(builder, node):
|
|
134
|
+
x = builder.get_value(node.input[0])
|
|
135
|
+
return builder.call_function(torch.op_name, args=(x,))
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
torch_fn: The PyTorch function to call.
|
|
139
|
+
doc: Optional docstring for the handler.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
A handler function for the operator.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
146
|
+
x = builder.get_value(node.input[0])
|
|
147
|
+
return builder.call_function(torch_fn, args=(x,))
|
|
148
|
+
|
|
149
|
+
if doc:
|
|
150
|
+
handler.__doc__ = doc
|
|
151
|
+
else:
|
|
152
|
+
handler.__doc__ = f"Element-wise {torch_fn.__name__}."
|
|
153
|
+
return handler
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def binary_op(
|
|
157
|
+
torch_fn: Callable[..., torch.Tensor],
|
|
158
|
+
doc: Optional[str] = None,
|
|
159
|
+
) -> Callable[["GraphBuilder", onnx.NodeProto], torch.fx.Node]:
|
|
160
|
+
"""Create a handler for simple binary operators.
|
|
161
|
+
|
|
162
|
+
This replaces the common pattern:
|
|
163
|
+
@register("OpName")
|
|
164
|
+
def op_name(builder, node):
|
|
165
|
+
lhs = builder.get_value(node.input[0])
|
|
166
|
+
rhs = builder.get_value(node.input[1])
|
|
167
|
+
return builder.call_function(torch.op_name, args=(lhs, rhs))
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
torch_fn: The PyTorch function to call.
|
|
171
|
+
doc: Optional docstring for the handler.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
A handler function for the operator.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
178
|
+
lhs = builder.get_value(node.input[0])
|
|
179
|
+
rhs = builder.get_value(node.input[1])
|
|
180
|
+
return builder.call_function(torch_fn, args=(lhs, rhs))
|
|
181
|
+
|
|
182
|
+
if doc:
|
|
183
|
+
handler.__doc__ = doc
|
|
184
|
+
else:
|
|
185
|
+
handler.__doc__ = f"Element-wise {torch_fn.__name__}."
|
|
186
|
+
return handler
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def unary_op_with_kwargs(
|
|
190
|
+
torch_fn: Callable[..., torch.Tensor],
|
|
191
|
+
*,
|
|
192
|
+
attr_map: dict[str, tuple[str, Any]],
|
|
193
|
+
fixed_kwargs: Optional[dict[str, Any]] = None,
|
|
194
|
+
doc: Optional[str] = None,
|
|
195
|
+
) -> Callable[["GraphBuilder", onnx.NodeProto], torch.fx.Node]:
|
|
196
|
+
"""Create a handler for unary operators with attribute-based kwargs.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
torch_fn: The PyTorch function to call.
|
|
200
|
+
attr_map: Mapping of kwarg name to (attribute name, default).
|
|
201
|
+
doc: Optional docstring for the handler.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
A handler function for the operator.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
208
|
+
from .attributes import get_attribute
|
|
209
|
+
|
|
210
|
+
x = builder.get_value(node.input[0])
|
|
211
|
+
kwargs = {
|
|
212
|
+
kwarg: get_attribute(node, attr_name, default)
|
|
213
|
+
for kwarg, (attr_name, default) in attr_map.items()
|
|
214
|
+
}
|
|
215
|
+
if fixed_kwargs:
|
|
216
|
+
kwargs.update(fixed_kwargs)
|
|
217
|
+
return builder.call_function(torch_fn, args=(x,), kwargs=kwargs)
|
|
218
|
+
|
|
219
|
+
if doc:
|
|
220
|
+
handler.__doc__ = doc
|
|
221
|
+
else:
|
|
222
|
+
handler.__doc__ = f"Element-wise {torch_fn.__name__} with attributes."
|
|
223
|
+
return handler
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def compute_same_padding(
|
|
227
|
+
input_shape: tuple[int, ...],
|
|
228
|
+
kernel_shape: tuple[int, ...],
|
|
229
|
+
strides: tuple[int, ...],
|
|
230
|
+
dilations: tuple[int, ...],
|
|
231
|
+
mode: str,
|
|
232
|
+
use_effective_kernel: bool = False,
|
|
233
|
+
) -> list[int]:
|
|
234
|
+
"""Compute SAME_UPPER or SAME_LOWER padding.
|
|
235
|
+
|
|
236
|
+
This consolidates the repeated auto_pad handling logic from
|
|
237
|
+
convolution.py and pooling.py.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
input_shape: Spatial dimensions of input (excluding batch and channel).
|
|
241
|
+
kernel_shape: Kernel dimensions.
|
|
242
|
+
strides: Stride values.
|
|
243
|
+
dilations: Dilation values.
|
|
244
|
+
mode: Either "SAME_UPPER" or "SAME_LOWER".
|
|
245
|
+
use_effective_kernel: If True, use effective kernel size
|
|
246
|
+
((k-1)*d + 1) for padding calculation (used by AvgPool/LpPool).
|
|
247
|
+
If False, use the standard formula with separate dilation term.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
A list of padding values in F.pad format (reversed order):
|
|
251
|
+
[xn_begin, xn_end, ..., x1_begin, x1_end]
|
|
252
|
+
"""
|
|
253
|
+
# Calculate output shape with SAME padding
|
|
254
|
+
# output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
|
|
255
|
+
output_shape = [(s + st - 1) // st for s, st in zip(input_shape, strides)]
|
|
256
|
+
|
|
257
|
+
# Calculate total padding needed for each dimension
|
|
258
|
+
if use_effective_kernel:
|
|
259
|
+
# For AvgPool/LpPool: use effective kernel = (k-1)*d + 1
|
|
260
|
+
effective_kernel = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
|
|
261
|
+
pad_total = [
|
|
262
|
+
max(0, (o - 1) * st + ek - i)
|
|
263
|
+
for i, o, ek, st in zip(
|
|
264
|
+
input_shape, output_shape, effective_kernel, strides
|
|
265
|
+
)
|
|
266
|
+
]
|
|
267
|
+
else:
|
|
268
|
+
# Standard formula: (output - 1) * stride + (kernel - 1) * dilation + 1 - input
|
|
269
|
+
pad_total = [
|
|
270
|
+
max(0, (o - 1) * st + (k - 1) * d + 1 - i)
|
|
271
|
+
for i, o, k, st, d in zip(
|
|
272
|
+
input_shape, output_shape, kernel_shape, strides, dilations
|
|
273
|
+
)
|
|
274
|
+
]
|
|
275
|
+
|
|
276
|
+
# Build pad list in F.pad format (reversed spatial order)
|
|
277
|
+
pad_list: list[int] = []
|
|
278
|
+
if mode == "SAME_UPPER":
|
|
279
|
+
# SAME_UPPER: more padding at end (right/bottom)
|
|
280
|
+
for p in reversed(pad_total):
|
|
281
|
+
pad_list.extend([p // 2, p - p // 2])
|
|
282
|
+
else: # SAME_LOWER
|
|
283
|
+
# SAME_LOWER: more padding at beginning (left/top)
|
|
284
|
+
for p in reversed(pad_total):
|
|
285
|
+
pad_list.extend([p - p // 2, p // 2])
|
|
286
|
+
|
|
287
|
+
return pad_list
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def pad_list_to_onnx_pads(pad_list: list[int], ndim: int) -> list[int]:
|
|
291
|
+
"""Convert F.pad format to ONNX pads format.
|
|
292
|
+
|
|
293
|
+
F.pad format: [xn_begin, xn_end, ..., x1_begin, x1_end] (reversed)
|
|
294
|
+
ONNX format: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
pad_list: Padding in F.pad format.
|
|
298
|
+
ndim: Number of spatial dimensions.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Padding in ONNX format.
|
|
302
|
+
"""
|
|
303
|
+
pads_onnx = [0] * (2 * ndim)
|
|
304
|
+
for i in range(ndim):
|
|
305
|
+
pads_onnx[i] = pad_list[2 * (ndim - 1 - i)]
|
|
306
|
+
pads_onnx[i + ndim] = pad_list[2 * (ndim - 1 - i) + 1]
|
|
307
|
+
return pads_onnx
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def apply_auto_pad(
|
|
311
|
+
x: torch.Tensor,
|
|
312
|
+
kernel_shape: tuple[int, ...],
|
|
313
|
+
strides: tuple[int, ...],
|
|
314
|
+
dilations: tuple[int, ...],
|
|
315
|
+
auto_pad: str,
|
|
316
|
+
pad_value: Union[int, float] = 0,
|
|
317
|
+
) -> tuple[torch.Tensor, int]:
|
|
318
|
+
"""Apply auto-padding to input tensor if needed.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
x: Input tensor with shape (N, C, *spatial_dims).
|
|
322
|
+
kernel_shape: Kernel dimensions.
|
|
323
|
+
strides: Stride values.
|
|
324
|
+
dilations: Dilation values.
|
|
325
|
+
auto_pad: Auto-pad mode ("NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID").
|
|
326
|
+
pad_value: Value to use for padding.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Tuple of (padded tensor, padding value for conv/pool operation).
|
|
330
|
+
If auto_pad is applied, the returned padding value is 0.
|
|
331
|
+
"""
|
|
332
|
+
if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
|
|
333
|
+
input_shape = tuple(x.shape[2:]) # Spatial dimensions
|
|
334
|
+
pad_list = compute_same_padding(
|
|
335
|
+
input_shape, kernel_shape, strides, dilations, auto_pad
|
|
336
|
+
)
|
|
337
|
+
x = F.pad(x, pad_list, value=pad_value)
|
|
338
|
+
return x, 0
|
|
339
|
+
return x, 0 # NOTSET or VALID - no auto-padding applied
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Training utilities for converted FX modules.
|
|
3
|
+
|
|
4
|
+
This module provides utilities to make converted ONNX models trainable
|
|
5
|
+
by converting buffers to trainable parameters.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.fx
|
|
10
|
+
|
|
11
|
+
from ..exceptions import InferenceOnlyError
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def make_trainable(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
15
|
+
"""Convert all buffers to trainable parameters.
|
|
16
|
+
|
|
17
|
+
By default, ONNX initializers (weights) are registered as buffers in the
|
|
18
|
+
converted FX module, making them non-trainable. This function converts
|
|
19
|
+
all buffers to trainable parameters, enabling gradient computation and
|
|
20
|
+
optimizer updates.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
module: A converted FX GraphModule from onnx2fx.convert().
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The same module with buffers converted to parameters (modified in-place).
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> import onnx
|
|
30
|
+
>>> from onnx2fx import convert, make_trainable
|
|
31
|
+
>>> onnx_model = onnx.load("model.onnx")
|
|
32
|
+
>>> fx_module = convert(onnx_model)
|
|
33
|
+
>>> fx_module = make_trainable(fx_module)
|
|
34
|
+
>>> optimizer = torch.optim.SGD(fx_module.parameters(), lr=0.01)
|
|
35
|
+
"""
|
|
36
|
+
if getattr(module, "_onnx2fx_inference_only", False):
|
|
37
|
+
raise InferenceOnlyError(
|
|
38
|
+
"make_trainable is not supported for memmap-based inference-only models"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Collect all buffer names and tensors first to avoid modifying dict during iteration
|
|
42
|
+
buffers_to_convert = list(module.named_buffers())
|
|
43
|
+
|
|
44
|
+
for name, buf in buffers_to_convert:
|
|
45
|
+
# Delete the buffer
|
|
46
|
+
delattr(module, name)
|
|
47
|
+
# Only floating point tensors can require gradients
|
|
48
|
+
if buf.is_floating_point() or buf.is_complex():
|
|
49
|
+
module.register_parameter(name, torch.nn.Parameter(buf.clone()))
|
|
50
|
+
else:
|
|
51
|
+
# Re-register non-floating point tensors as buffers (e.g., int64 indices)
|
|
52
|
+
module.register_buffer(name, buf.clone())
|
|
53
|
+
|
|
54
|
+
return module
|