onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,339 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Helper utilities for operator implementations.
3
+
4
+ This module provides factory functions and helpers to reduce boilerplate
5
+ in operator implementations.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
9
+
10
+ import onnx
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from ..exceptions import ConversionError
15
+
16
+ if TYPE_CHECKING:
17
+ from ..graph_builder import GraphBuilder
18
+
19
+
20
+ def get_optional_input(
21
+ builder: "GraphBuilder",
22
+ node: onnx.NodeProto,
23
+ index: int,
24
+ default: Any = None,
25
+ ) -> Any:
26
+ """Get an optional input from a node, returning default if not present.
27
+
28
+ This replaces the common pattern:
29
+ if len(node.input) > N and node.input[N]:
30
+ value = builder.get_value(node.input[N])
31
+
32
+ Args:
33
+ builder: The graph builder instance.
34
+ node: The ONNX node.
35
+ index: The input index to retrieve.
36
+ default: Default value if input is not present.
37
+
38
+ Returns:
39
+ The input value or default.
40
+ """
41
+ if len(node.input) > index and node.input[index]:
42
+ return builder.get_value(node.input[index])
43
+ return default
44
+
45
+
46
+ def get_attribute_or_input(
47
+ builder: "GraphBuilder",
48
+ node: onnx.NodeProto,
49
+ *,
50
+ attr_name: str,
51
+ input_index: int,
52
+ opset_version: int,
53
+ attr_allowed_until: Optional[int] = None,
54
+ input_allowed_since: Optional[int] = None,
55
+ default: Any = None,
56
+ as_python: bool = True,
57
+ ) -> Any:
58
+ """Resolve a value from attribute or input with opset checks.
59
+
60
+ Args:
61
+ builder: The graph builder instance.
62
+ node: The ONNX node.
63
+ attr_name: Attribute name to read.
64
+ input_index: Input index to read.
65
+ opset_version: Active opset version.
66
+ attr_allowed_until: Highest opset version that allows the attribute.
67
+ input_allowed_since: Lowest opset version that allows the input.
68
+ default: Value to return when neither attribute nor input is provided.
69
+ as_python: Convert constant tensors to Python scalars/lists.
70
+
71
+ Returns:
72
+ The resolved value, or default.
73
+ """
74
+ from .attributes import get_attribute
75
+
76
+ input_present = len(node.input) > input_index and node.input[input_index]
77
+ if input_present and input_allowed_since is not None:
78
+ if opset_version < input_allowed_since:
79
+ raise ConversionError(
80
+ (
81
+ f"Input[{input_index}] for '{attr_name}' is not valid before "
82
+ f"opset {input_allowed_since}"
83
+ ),
84
+ node_name=node.name,
85
+ op_type=node.op_type,
86
+ )
87
+
88
+ attr_value = get_attribute(node, attr_name)
89
+ if attr_value is not None:
90
+ if attr_allowed_until is not None and opset_version > attr_allowed_until:
91
+ raise ConversionError(
92
+ (
93
+ f"Attribute '{attr_name}' is not valid after opset "
94
+ f"{attr_allowed_until}"
95
+ ),
96
+ node_name=node.name,
97
+ op_type=node.op_type,
98
+ )
99
+ return attr_value
100
+
101
+ if input_present:
102
+ value = _resolve_input_value(builder, node.input[input_index])
103
+ if as_python:
104
+ return _as_python_value(value)
105
+ return value
106
+
107
+ return default
108
+
109
+
110
+ def _resolve_input_value(builder: "GraphBuilder", name: str) -> Any:
111
+ if name in builder.initializer_map:
112
+ return builder.initializer_map[name]
113
+ return builder.get_value(name)
114
+
115
+
116
+ def _as_python_value(value: Any) -> Any:
117
+ if isinstance(value, torch.Tensor):
118
+ result = value.tolist()
119
+ if isinstance(result, int):
120
+ return [result]
121
+ return result
122
+ return value
123
+
124
+
125
+ def unary_op(
126
+ torch_fn: Callable[..., torch.Tensor],
127
+ doc: Optional[str] = None,
128
+ ) -> Callable[["GraphBuilder", onnx.NodeProto], torch.fx.Node]:
129
+ """Create a handler for simple unary operators.
130
+
131
+ This replaces the common pattern:
132
+ @register("OpName")
133
+ def op_name(builder, node):
134
+ x = builder.get_value(node.input[0])
135
+ return builder.call_function(torch.op_name, args=(x,))
136
+
137
+ Args:
138
+ torch_fn: The PyTorch function to call.
139
+ doc: Optional docstring for the handler.
140
+
141
+ Returns:
142
+ A handler function for the operator.
143
+ """
144
+
145
+ def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
146
+ x = builder.get_value(node.input[0])
147
+ return builder.call_function(torch_fn, args=(x,))
148
+
149
+ if doc:
150
+ handler.__doc__ = doc
151
+ else:
152
+ handler.__doc__ = f"Element-wise {torch_fn.__name__}."
153
+ return handler
154
+
155
+
156
+ def binary_op(
157
+ torch_fn: Callable[..., torch.Tensor],
158
+ doc: Optional[str] = None,
159
+ ) -> Callable[["GraphBuilder", onnx.NodeProto], torch.fx.Node]:
160
+ """Create a handler for simple binary operators.
161
+
162
+ This replaces the common pattern:
163
+ @register("OpName")
164
+ def op_name(builder, node):
165
+ lhs = builder.get_value(node.input[0])
166
+ rhs = builder.get_value(node.input[1])
167
+ return builder.call_function(torch.op_name, args=(lhs, rhs))
168
+
169
+ Args:
170
+ torch_fn: The PyTorch function to call.
171
+ doc: Optional docstring for the handler.
172
+
173
+ Returns:
174
+ A handler function for the operator.
175
+ """
176
+
177
+ def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
178
+ lhs = builder.get_value(node.input[0])
179
+ rhs = builder.get_value(node.input[1])
180
+ return builder.call_function(torch_fn, args=(lhs, rhs))
181
+
182
+ if doc:
183
+ handler.__doc__ = doc
184
+ else:
185
+ handler.__doc__ = f"Element-wise {torch_fn.__name__}."
186
+ return handler
187
+
188
+
189
+ def unary_op_with_kwargs(
190
+ torch_fn: Callable[..., torch.Tensor],
191
+ *,
192
+ attr_map: dict[str, tuple[str, Any]],
193
+ fixed_kwargs: Optional[dict[str, Any]] = None,
194
+ doc: Optional[str] = None,
195
+ ) -> Callable[["GraphBuilder", onnx.NodeProto], torch.fx.Node]:
196
+ """Create a handler for unary operators with attribute-based kwargs.
197
+
198
+ Args:
199
+ torch_fn: The PyTorch function to call.
200
+ attr_map: Mapping of kwarg name to (attribute name, default).
201
+ doc: Optional docstring for the handler.
202
+
203
+ Returns:
204
+ A handler function for the operator.
205
+ """
206
+
207
+ def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
208
+ from .attributes import get_attribute
209
+
210
+ x = builder.get_value(node.input[0])
211
+ kwargs = {
212
+ kwarg: get_attribute(node, attr_name, default)
213
+ for kwarg, (attr_name, default) in attr_map.items()
214
+ }
215
+ if fixed_kwargs:
216
+ kwargs.update(fixed_kwargs)
217
+ return builder.call_function(torch_fn, args=(x,), kwargs=kwargs)
218
+
219
+ if doc:
220
+ handler.__doc__ = doc
221
+ else:
222
+ handler.__doc__ = f"Element-wise {torch_fn.__name__} with attributes."
223
+ return handler
224
+
225
+
226
+ def compute_same_padding(
227
+ input_shape: tuple[int, ...],
228
+ kernel_shape: tuple[int, ...],
229
+ strides: tuple[int, ...],
230
+ dilations: tuple[int, ...],
231
+ mode: str,
232
+ use_effective_kernel: bool = False,
233
+ ) -> list[int]:
234
+ """Compute SAME_UPPER or SAME_LOWER padding.
235
+
236
+ This consolidates the repeated auto_pad handling logic from
237
+ convolution.py and pooling.py.
238
+
239
+ Args:
240
+ input_shape: Spatial dimensions of input (excluding batch and channel).
241
+ kernel_shape: Kernel dimensions.
242
+ strides: Stride values.
243
+ dilations: Dilation values.
244
+ mode: Either "SAME_UPPER" or "SAME_LOWER".
245
+ use_effective_kernel: If True, use effective kernel size
246
+ ((k-1)*d + 1) for padding calculation (used by AvgPool/LpPool).
247
+ If False, use the standard formula with separate dilation term.
248
+
249
+ Returns:
250
+ A list of padding values in F.pad format (reversed order):
251
+ [xn_begin, xn_end, ..., x1_begin, x1_end]
252
+ """
253
+ # Calculate output shape with SAME padding
254
+ # output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
255
+ output_shape = [(s + st - 1) // st for s, st in zip(input_shape, strides)]
256
+
257
+ # Calculate total padding needed for each dimension
258
+ if use_effective_kernel:
259
+ # For AvgPool/LpPool: use effective kernel = (k-1)*d + 1
260
+ effective_kernel = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
261
+ pad_total = [
262
+ max(0, (o - 1) * st + ek - i)
263
+ for i, o, ek, st in zip(
264
+ input_shape, output_shape, effective_kernel, strides
265
+ )
266
+ ]
267
+ else:
268
+ # Standard formula: (output - 1) * stride + (kernel - 1) * dilation + 1 - input
269
+ pad_total = [
270
+ max(0, (o - 1) * st + (k - 1) * d + 1 - i)
271
+ for i, o, k, st, d in zip(
272
+ input_shape, output_shape, kernel_shape, strides, dilations
273
+ )
274
+ ]
275
+
276
+ # Build pad list in F.pad format (reversed spatial order)
277
+ pad_list: list[int] = []
278
+ if mode == "SAME_UPPER":
279
+ # SAME_UPPER: more padding at end (right/bottom)
280
+ for p in reversed(pad_total):
281
+ pad_list.extend([p // 2, p - p // 2])
282
+ else: # SAME_LOWER
283
+ # SAME_LOWER: more padding at beginning (left/top)
284
+ for p in reversed(pad_total):
285
+ pad_list.extend([p - p // 2, p // 2])
286
+
287
+ return pad_list
288
+
289
+
290
+ def pad_list_to_onnx_pads(pad_list: list[int], ndim: int) -> list[int]:
291
+ """Convert F.pad format to ONNX pads format.
292
+
293
+ F.pad format: [xn_begin, xn_end, ..., x1_begin, x1_end] (reversed)
294
+ ONNX format: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
295
+
296
+ Args:
297
+ pad_list: Padding in F.pad format.
298
+ ndim: Number of spatial dimensions.
299
+
300
+ Returns:
301
+ Padding in ONNX format.
302
+ """
303
+ pads_onnx = [0] * (2 * ndim)
304
+ for i in range(ndim):
305
+ pads_onnx[i] = pad_list[2 * (ndim - 1 - i)]
306
+ pads_onnx[i + ndim] = pad_list[2 * (ndim - 1 - i) + 1]
307
+ return pads_onnx
308
+
309
+
310
+ def apply_auto_pad(
311
+ x: torch.Tensor,
312
+ kernel_shape: tuple[int, ...],
313
+ strides: tuple[int, ...],
314
+ dilations: tuple[int, ...],
315
+ auto_pad: str,
316
+ pad_value: Union[int, float] = 0,
317
+ ) -> tuple[torch.Tensor, int]:
318
+ """Apply auto-padding to input tensor if needed.
319
+
320
+ Args:
321
+ x: Input tensor with shape (N, C, *spatial_dims).
322
+ kernel_shape: Kernel dimensions.
323
+ strides: Stride values.
324
+ dilations: Dilation values.
325
+ auto_pad: Auto-pad mode ("NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID").
326
+ pad_value: Value to use for padding.
327
+
328
+ Returns:
329
+ Tuple of (padded tensor, padding value for conv/pool operation).
330
+ If auto_pad is applied, the returned padding value is 0.
331
+ """
332
+ if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
333
+ input_shape = tuple(x.shape[2:]) # Spatial dimensions
334
+ pad_list = compute_same_padding(
335
+ input_shape, kernel_shape, strides, dilations, auto_pad
336
+ )
337
+ x = F.pad(x, pad_list, value=pad_value)
338
+ return x, 0
339
+ return x, 0 # NOTSET or VALID - no auto-padding applied
@@ -0,0 +1,54 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Training utilities for converted FX modules.
3
+
4
+ This module provides utilities to make converted ONNX models trainable
5
+ by converting buffers to trainable parameters.
6
+ """
7
+
8
+ import torch
9
+ import torch.fx
10
+
11
+ from ..exceptions import InferenceOnlyError
12
+
13
+
14
+ def make_trainable(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
15
+ """Convert all buffers to trainable parameters.
16
+
17
+ By default, ONNX initializers (weights) are registered as buffers in the
18
+ converted FX module, making them non-trainable. This function converts
19
+ all buffers to trainable parameters, enabling gradient computation and
20
+ optimizer updates.
21
+
22
+ Args:
23
+ module: A converted FX GraphModule from onnx2fx.convert().
24
+
25
+ Returns:
26
+ The same module with buffers converted to parameters (modified in-place).
27
+
28
+ Example:
29
+ >>> import onnx
30
+ >>> from onnx2fx import convert, make_trainable
31
+ >>> onnx_model = onnx.load("model.onnx")
32
+ >>> fx_module = convert(onnx_model)
33
+ >>> fx_module = make_trainable(fx_module)
34
+ >>> optimizer = torch.optim.SGD(fx_module.parameters(), lr=0.01)
35
+ """
36
+ if getattr(module, "_onnx2fx_inference_only", False):
37
+ raise InferenceOnlyError(
38
+ "make_trainable is not supported for memmap-based inference-only models"
39
+ )
40
+
41
+ # Collect all buffer names and tensors first to avoid modifying dict during iteration
42
+ buffers_to_convert = list(module.named_buffers())
43
+
44
+ for name, buf in buffers_to_convert:
45
+ # Delete the buffer
46
+ delattr(module, name)
47
+ # Only floating point tensors can require gradients
48
+ if buf.is_floating_point() or buf.is_complex():
49
+ module.register_parameter(name, torch.nn.Parameter(buf.clone()))
50
+ else:
51
+ # Re-register non-floating point tensors as buffers (e.g., int64 indices)
52
+ module.register_buffer(name, buf.clone())
53
+
54
+ return module