onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,402 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Training operators from ai.onnx.preview.training domain."""
3
+
4
+ from typing import TYPE_CHECKING, Tuple
5
+
6
+ import onnx
7
+ import torch
8
+ import torch.fx
9
+
10
+ from ..op_registry import register
11
+ from ..utils.attributes import get_attribute
12
+
13
+ if TYPE_CHECKING:
14
+ from ..graph_builder import GraphBuilder
15
+
16
+
17
+ # =============================================================================
18
+ # Gradient operator
19
+ # =============================================================================
20
+
21
+
22
+ @register("Gradient", domain="ai.onnx.preview.training")
23
+ def gradient(builder: "GraphBuilder", node: onnx.NodeProto) -> list:
24
+ """Compute gradients of y with respect to xs using PyTorch autograd.
25
+
26
+ This operator computes symbolic gradients as specified in the ONNX
27
+ ai.onnx.preview.training domain. It recomputes the forward pass with
28
+ requires_grad=True to enable autograd.
29
+ """
30
+ # Get the xs (inputs to differentiate with respect to) and y (output to differentiate)
31
+ xs_names = get_attribute(node, "xs", [])
32
+ y_name = get_attribute(node, "y", "")
33
+
34
+ # Get the xs nodes - these are the graph inputs we differentiate with respect to
35
+ xs_nodes = [builder.get_value(name) for name in xs_names]
36
+
37
+ # Find the path from xs to y by analyzing the ONNX graph
38
+ # We need to collect all nodes that contribute to y
39
+ graph = builder.model.graph
40
+ node_outputs = {} # output_name -> node
41
+ for n in graph.node:
42
+ for out in n.output:
43
+ node_outputs[out] = n
44
+
45
+ # Trace back from y to find all operations needed
46
+ def trace_ops(target_name, collected_nodes):
47
+ """Recursively collect all nodes needed to compute target_name."""
48
+ if target_name in xs_names or target_name not in node_outputs:
49
+ return
50
+ n = node_outputs[target_name]
51
+ if n not in collected_nodes:
52
+ collected_nodes.append(n)
53
+ for inp in n.input:
54
+ trace_ops(inp, collected_nodes)
55
+
56
+ ops_to_y = []
57
+ trace_ops(y_name, ops_to_y)
58
+ ops_to_y.reverse() # Order from inputs to output
59
+
60
+ # Create a function that recomputes y from xs with autograd enabled
61
+ ops_info = [(n.op_type, list(n.input), list(n.output)) for n in ops_to_y]
62
+
63
+ def _compute_gradient_symbolic(ops_info, xs_names, y_name, *xs_values):
64
+ """Recompute forward pass and compute gradients."""
65
+ # Create tensors with requires_grad
66
+ env = {}
67
+ for name, val in zip(xs_names, xs_values):
68
+ if isinstance(val, torch.Tensor):
69
+ env[name] = val.detach().clone().requires_grad_(True)
70
+ else:
71
+ env[name] = torch.tensor(val, dtype=torch.float32, requires_grad=True)
72
+
73
+ # Replay operations
74
+ for op_type, inputs, outputs in ops_info:
75
+ input_tensors = [env[inp] for inp in inputs if inp in env]
76
+ if op_type == "Add":
77
+ result = input_tensors[0] + input_tensors[1]
78
+ elif op_type == "Mul":
79
+ result = input_tensors[0] * input_tensors[1]
80
+ elif op_type == "Sub":
81
+ result = input_tensors[0] - input_tensors[1]
82
+ elif op_type == "Div":
83
+ result = input_tensors[0] / input_tensors[1]
84
+ else:
85
+ # Fallback for unsupported ops
86
+ result = input_tensors[0] if input_tensors else torch.tensor(0.0)
87
+ for out in outputs:
88
+ env[out] = result
89
+
90
+ # Get y and compute gradients
91
+ y = env.get(y_name)
92
+ if y is None:
93
+ return tuple(torch.zeros_like(x) for x in xs_values)
94
+
95
+ xs_tensors = [env[name] for name in xs_names]
96
+ grads = torch.autograd.grad(
97
+ outputs=y,
98
+ inputs=xs_tensors,
99
+ grad_outputs=torch.ones_like(y),
100
+ create_graph=False,
101
+ allow_unused=True,
102
+ )
103
+ return grads
104
+
105
+ # Create the gradient computation node
106
+ result = builder.call_function(
107
+ _compute_gradient_symbolic,
108
+ args=(ops_info, xs_names, y_name, *xs_nodes),
109
+ )
110
+
111
+ # Return list of gradient outputs (one per xs input)
112
+ outputs = []
113
+ for i in range(len(xs_names)):
114
+
115
+ def _get_grad_i(grads, idx=i):
116
+ return grads[idx]
117
+
118
+ outputs.append(builder.call_function(_get_grad_i, args=(result,)))
119
+
120
+ return outputs
121
+
122
+
123
+ # =============================================================================
124
+ # Optimizer operators
125
+ # =============================================================================
126
+
127
+
128
+ def _momentum_update(
129
+ R: torch.Tensor,
130
+ T: torch.Tensor,
131
+ X: torch.Tensor,
132
+ G: torch.Tensor,
133
+ V: torch.Tensor,
134
+ alpha: float,
135
+ beta: float,
136
+ norm_coefficient: float,
137
+ mode: str,
138
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
139
+ """Compute one iteration of stochastic gradient update with momentum.
140
+
141
+ Parameters
142
+ ----------
143
+ R : torch.Tensor
144
+ Learning rate (scalar).
145
+ T : torch.Tensor
146
+ Training iteration count (scalar, int64).
147
+ X : torch.Tensor
148
+ Parameter tensor to optimize.
149
+ G : torch.Tensor
150
+ Gradient of X.
151
+ V : torch.Tensor
152
+ Accumulated momentum of X.
153
+ alpha : float
154
+ Decay coefficient of previous accumulated gradient (momentum).
155
+ beta : float
156
+ Scaling coefficient of current gradient.
157
+ norm_coefficient : float
158
+ L2-norm regularization coefficient.
159
+ mode : str
160
+ Either "standard" or "nesterov".
161
+
162
+ Returns
163
+ -------
164
+ Tuple[torch.Tensor, torch.Tensor]
165
+ (X_new, V_new) - updated parameter and momentum.
166
+ """
167
+ # Add L2 regularization gradient: gradient of 0.5 * norm_coefficient * ||X||^2
168
+ G_regularized = norm_coefficient * X + G
169
+
170
+ # In the first training iteration (T == 0), beta should always be 1
171
+ beta_adjusted = beta if T.item() > 0 else 1.0
172
+
173
+ # Compute new momentum
174
+ V_new = alpha * V + beta_adjusted * G_regularized
175
+
176
+ if mode == "nesterov":
177
+ # Nesterov momentum: X_new = X - R * (G_regularized + alpha * V_new)
178
+ X_new = X - R * (G_regularized + alpha * V_new)
179
+ else:
180
+ # Standard momentum: X_new = X - R * V_new
181
+ X_new = X - R * V_new
182
+
183
+ return X_new, V_new
184
+
185
+
186
+ @register("Momentum", domain="ai.onnx.preview.training")
187
+ def momentum(
188
+ builder: "GraphBuilder", node: onnx.NodeProto
189
+ ) -> Tuple[torch.fx.Node, ...]:
190
+ """Momentum optimizer operator.
191
+
192
+ Compute one iteration of stochastic gradient update with momentum.
193
+ This operator can conduct the optimization of multiple tensor variables.
194
+
195
+ Inputs:
196
+ R: Learning rate (scalar).
197
+ T: Training iteration count (scalar, int64).
198
+ inputs (variadic): X_1, X_2, ..., X_n (parameters), G_1, G_2, ..., G_n (gradients),
199
+ V_1, V_2, ..., V_n (momentums).
200
+
201
+ Outputs:
202
+ X_1_new, X_2_new, ..., X_n_new, V_1_new, V_2_new, ..., V_n_new.
203
+
204
+ Attributes:
205
+ alpha: Decay coefficient of previous momentum.
206
+ beta: Scaling coefficient of current gradient.
207
+ mode: "standard" or "nesterov".
208
+ norm_coefficient: L2-norm regularization coefficient.
209
+ """
210
+ # Get attributes
211
+ alpha = get_attribute(node, "alpha", 0.9)
212
+ beta = get_attribute(node, "beta", 1.0)
213
+ mode = get_attribute(node, "mode", "standard")
214
+ norm_coefficient = get_attribute(node, "norm_coefficient", 0.0)
215
+
216
+ # Decode mode if bytes
217
+ if isinstance(mode, bytes):
218
+ mode = mode.decode("utf-8")
219
+
220
+ # Get inputs: R, T, then groups of (X, G, V)
221
+ # Input format: R, T, X_1, X_2, ..., X_n, G_1, G_2, ..., G_n, V_1, V_2, ..., V_n
222
+ # Number of tensors to optimize: (num_inputs - 2) / 3
223
+ num_inputs = len(node.input)
224
+ num_tensors = (num_inputs - 2) // 3
225
+
226
+ R = builder.get_value(node.input[0])
227
+ T = builder.get_value(node.input[1])
228
+
229
+ # Collect X, G, V tensors
230
+ X_inputs = [builder.get_value(node.input[2 + i]) for i in range(num_tensors)]
231
+ G_inputs = [
232
+ builder.get_value(node.input[2 + num_tensors + i]) for i in range(num_tensors)
233
+ ]
234
+ V_inputs = [
235
+ builder.get_value(node.input[2 + 2 * num_tensors + i])
236
+ for i in range(num_tensors)
237
+ ]
238
+
239
+ # Process each tensor pair and collect results
240
+ results = []
241
+ for i in range(num_tensors):
242
+ result = builder.call_function(
243
+ _momentum_update,
244
+ args=(
245
+ R,
246
+ T,
247
+ X_inputs[i],
248
+ G_inputs[i],
249
+ V_inputs[i],
250
+ alpha,
251
+ beta,
252
+ norm_coefficient,
253
+ mode,
254
+ ),
255
+ )
256
+ results.append(result)
257
+
258
+ # If only one tensor, return the tuple directly
259
+ if num_tensors == 1:
260
+ return results[0]
261
+
262
+ # For multiple tensors, we need to flatten results:
263
+ # Output format: X_1_new, X_2_new, ..., X_n_new, V_1_new, V_2_new, ..., V_n_new
264
+ # Create a helper function to extract and reorder the outputs
265
+ def _reorder_momentum_outputs(*results):
266
+ X_news = [r[0] for r in results]
267
+ V_news = [r[1] for r in results]
268
+ return tuple(X_news + V_news)
269
+
270
+ return builder.call_function(_reorder_momentum_outputs, args=tuple(results))
271
+
272
+
273
+ def _adagrad_update(
274
+ R: torch.Tensor,
275
+ T: torch.Tensor,
276
+ X: torch.Tensor,
277
+ G: torch.Tensor,
278
+ H: torch.Tensor,
279
+ norm_coefficient: float,
280
+ decay_factor: float,
281
+ epsilon: float,
282
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
283
+ """Compute one iteration of ADAGRAD update.
284
+
285
+ Parameters
286
+ ----------
287
+ R : torch.Tensor
288
+ Initial learning rate (scalar).
289
+ T : torch.Tensor
290
+ Update count (scalar, int64).
291
+ X : torch.Tensor
292
+ Parameter tensor to optimize.
293
+ G : torch.Tensor
294
+ Gradient of X.
295
+ H : torch.Tensor
296
+ Accumulated squared gradient of X.
297
+ norm_coefficient : float
298
+ L2-norm regularization coefficient.
299
+ decay_factor : float
300
+ Learning rate decay factor.
301
+ epsilon : float
302
+ Small constant to avoid dividing by zero.
303
+
304
+ Returns
305
+ -------
306
+ Tuple[torch.Tensor, torch.Tensor]
307
+ (X_new, H_new) - updated parameter and accumulated squared gradient.
308
+ """
309
+ # Compute decayed learning rate: r = R / (1 + T * decay_factor)
310
+ r = R / (1 + T * decay_factor)
311
+
312
+ # Add L2 regularization gradient: gradient of 0.5 * norm_coefficient * ||X||^2
313
+ G_regularized = norm_coefficient * X + G
314
+
315
+ # Compute new accumulated squared gradient
316
+ H_new = H + G_regularized * G_regularized
317
+
318
+ # Compute the adaptive part of per-coordinate learning rate
319
+ H_adaptive = torch.sqrt(H_new) + epsilon
320
+
321
+ # Compute the new value of X
322
+ X_new = X - r * G_regularized / H_adaptive
323
+
324
+ return X_new, H_new
325
+
326
+
327
+ @register("Adagrad", domain="ai.onnx.preview.training")
328
+ def adagrad(builder: "GraphBuilder", node: onnx.NodeProto) -> Tuple[torch.fx.Node, ...]:
329
+ """Adagrad optimizer operator.
330
+
331
+ Compute one iteration of ADAGRAD, a stochastic gradient based optimization
332
+ algorithm. This operator can conduct the optimization of multiple tensor variables.
333
+
334
+ Inputs:
335
+ R: Initial learning rate (scalar).
336
+ T: Update count (scalar, int64).
337
+ inputs (variadic): X_1, X_2, ..., X_n (parameters), G_1, G_2, ..., G_n (gradients),
338
+ H_1, H_2, ..., H_n (accumulated squared gradients).
339
+
340
+ Outputs:
341
+ X_1_new, X_2_new, ..., X_n_new, H_1_new, H_2_new, ..., H_n_new.
342
+
343
+ Attributes:
344
+ norm_coefficient: L2-norm regularization coefficient (default: 0.0).
345
+ decay_factor: Learning rate decay factor (default: 0.0).
346
+ epsilon: Small constant to avoid dividing by zero (default: 1e-6).
347
+ """
348
+ # Get attributes
349
+ norm_coefficient = get_attribute(node, "norm_coefficient", 0.0)
350
+ decay_factor = get_attribute(node, "decay_factor", 0.0)
351
+ epsilon = get_attribute(node, "epsilon", 1e-6)
352
+
353
+ # Get inputs: R, T, then groups of (X, G, H)
354
+ # Input format: R, T, X_1, X_2, ..., X_n, G_1, G_2, ..., G_n, H_1, H_2, ..., H_n
355
+ # Number of tensors to optimize: (num_inputs - 2) / 3
356
+ num_inputs = len(node.input)
357
+ num_tensors = (num_inputs - 2) // 3
358
+
359
+ R = builder.get_value(node.input[0])
360
+ T = builder.get_value(node.input[1])
361
+
362
+ # Collect X, G, H tensors
363
+ X_inputs = [builder.get_value(node.input[2 + i]) for i in range(num_tensors)]
364
+ G_inputs = [
365
+ builder.get_value(node.input[2 + num_tensors + i]) for i in range(num_tensors)
366
+ ]
367
+ H_inputs = [
368
+ builder.get_value(node.input[2 + 2 * num_tensors + i])
369
+ for i in range(num_tensors)
370
+ ]
371
+
372
+ # Process each tensor pair and collect results
373
+ results = []
374
+ for i in range(num_tensors):
375
+ result = builder.call_function(
376
+ _adagrad_update,
377
+ args=(
378
+ R,
379
+ T,
380
+ X_inputs[i],
381
+ G_inputs[i],
382
+ H_inputs[i],
383
+ norm_coefficient,
384
+ decay_factor,
385
+ epsilon,
386
+ ),
387
+ )
388
+ results.append(result)
389
+
390
+ # If only one tensor, return the tuple directly
391
+ if num_tensors == 1:
392
+ return results[0]
393
+
394
+ # For multiple tensors, we need to flatten results:
395
+ # Output format: X_1_new, X_2_new, ..., X_n_new, H_1_new, H_2_new, ..., H_n_new
396
+ # Create a helper function to extract and reorder the outputs
397
+ def _reorder_adagrad_outputs(*results):
398
+ X_news = [r[0] for r in results]
399
+ H_news = [r[1] for r in results]
400
+ return tuple(X_news + H_news)
401
+
402
+ return builder.call_function(_reorder_adagrad_outputs, args=tuple(results))
onnx2fx/py.typed ADDED
File without changes
@@ -0,0 +1,45 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Utility modules for ONNX model handling.
3
+
4
+ This package provides utilities for:
5
+ - ONNX attribute parsing
6
+ - ONNX to PyTorch data type mapping
7
+ - ONNX model analysis
8
+ - Name sanitization for valid Python identifiers
9
+ - Operator implementation helpers
10
+ """
11
+
12
+ from .dtype import DTYPE_MAP, onnx_dtype_to_torch
13
+ from .attributes import get_attribute, get_attributes
14
+ from .analyze import analyze_model, AnalysisResult
15
+ from .training import make_trainable
16
+ from .names import sanitize_name
17
+ from .op_helpers import (
18
+ get_optional_input,
19
+ get_attribute_or_input,
20
+ unary_op,
21
+ unary_op_with_kwargs,
22
+ binary_op,
23
+ compute_same_padding,
24
+ pad_list_to_onnx_pads,
25
+ apply_auto_pad,
26
+ )
27
+
28
+ __all__ = [
29
+ "DTYPE_MAP",
30
+ "onnx_dtype_to_torch",
31
+ "get_attribute",
32
+ "get_attributes",
33
+ "analyze_model",
34
+ "AnalysisResult",
35
+ "make_trainable",
36
+ "sanitize_name",
37
+ "get_optional_input",
38
+ "get_attribute_or_input",
39
+ "unary_op",
40
+ "unary_op_with_kwargs",
41
+ "binary_op",
42
+ "compute_same_padding",
43
+ "pad_list_to_onnx_pads",
44
+ "apply_auto_pad",
45
+ ]
@@ -0,0 +1,139 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """ONNX model analysis utilities for identifying operator support."""
3
+
4
+ from typing import Union, Dict, List, Set, Tuple
5
+ from dataclasses import dataclass, field
6
+
7
+ import onnx
8
+
9
+ from ..op_registry import is_supported
10
+
11
+
12
+ @dataclass
13
+ class AnalysisResult:
14
+ """Result of analyzing an ONNX model for operator support.
15
+
16
+ Attributes
17
+ ----------
18
+ total_nodes : int
19
+ Total number of nodes in the model graph.
20
+ unique_ops : Set[Tuple[str, str]]
21
+ Set of unique (op_type, domain) tuples.
22
+ supported_ops : List[Tuple[str, str]]
23
+ List of supported (op_type, domain) tuples.
24
+ unsupported_ops : List[Tuple[str, str, int]]
25
+ List of unsupported (op_type, domain, opset_version) tuples.
26
+ opset_versions : Dict[str, int]
27
+ Mapping of domain to opset version.
28
+ op_counts : Dict[Tuple[str, str], int]
29
+ Count of each (op_type, domain) in the model.
30
+ """
31
+
32
+ total_nodes: int = 0
33
+ unique_ops: Set[Tuple[str, str]] = field(default_factory=set)
34
+ supported_ops: List[Tuple[str, str]] = field(default_factory=list)
35
+ unsupported_ops: List[Tuple[str, str, int]] = field(default_factory=list)
36
+ opset_versions: Dict[str, int] = field(default_factory=dict)
37
+ op_counts: Dict[Tuple[str, str], int] = field(default_factory=dict)
38
+
39
+ def is_fully_supported(self) -> bool:
40
+ """Check if all operators in the model are supported."""
41
+ return len(self.unsupported_ops) == 0
42
+
43
+ def summary(self) -> str:
44
+ """Generate a human-readable summary of the analysis."""
45
+ lines = [
46
+ "ONNX Model Analysis Summary",
47
+ "=" * 40,
48
+ f"Total nodes: {self.total_nodes}",
49
+ f"Unique operators: {len(self.unique_ops)}",
50
+ f"Supported: {len(self.supported_ops)}",
51
+ f"Unsupported: {len(self.unsupported_ops)}",
52
+ "",
53
+ "Opset versions:",
54
+ ]
55
+ for domain, version in sorted(self.opset_versions.items()):
56
+ domain_display = domain if domain else "(default ONNX)"
57
+ lines.append(f" {domain_display}: {version}")
58
+
59
+ if self.unsupported_ops:
60
+ lines.append("")
61
+ lines.append("Unsupported operators:")
62
+ for op_type, domain, opset in self.unsupported_ops:
63
+ domain_display = domain if domain else "(default)"
64
+ lines.append(
65
+ f" - {op_type} (domain: {domain_display}, opset: {opset})"
66
+ )
67
+
68
+ if self.supported_ops:
69
+ lines.append("")
70
+ lines.append("Supported operators:")
71
+ for op_type, domain in sorted(self.supported_ops):
72
+ domain_display = domain if domain else "(default)"
73
+ count = self.op_counts.get((op_type, domain), 0)
74
+ lines.append(f" - {op_type} (domain: {domain_display}) x{count}")
75
+
76
+ return "\n".join(lines)
77
+
78
+
79
+ def analyze_model(model: Union[onnx.ModelProto, str]) -> AnalysisResult:
80
+ """Analyze an ONNX model and identify supported/unsupported operators.
81
+
82
+ This function iterates through all nodes in an ONNX model graph and
83
+ checks each operator against the onnx2fx registry to determine
84
+ which operators are supported for conversion.
85
+
86
+ Parameters
87
+ ----------
88
+ model : onnx.ModelProto or str
89
+ The ONNX model or path to the ONNX model file.
90
+
91
+ Returns
92
+ -------
93
+ AnalysisResult
94
+ Analysis result containing supported/unsupported operators,
95
+ opset versions, and operator counts.
96
+
97
+ Examples
98
+ --------
99
+ >>> import onnx
100
+ >>> from onnx2fx import analyze_model
101
+ >>> model = onnx.load("model.onnx")
102
+ >>> result = analyze_model(model)
103
+ >>> print(result.summary())
104
+ >>> if not result.is_fully_supported():
105
+ ... print("Missing operators:", result.unsupported_ops)
106
+ """
107
+ if isinstance(model, str):
108
+ model = onnx.load(model)
109
+
110
+ result = AnalysisResult()
111
+
112
+ # Extract opset versions
113
+ for opset in model.opset_import:
114
+ domain = opset.domain if opset.domain else ""
115
+ result.opset_versions[domain] = opset.version
116
+
117
+ # Analyze all nodes
118
+ for node in model.graph.node:
119
+ result.total_nodes += 1
120
+
121
+ op_type = node.op_type
122
+ domain = node.domain if node.domain else ""
123
+ opset_version = result.opset_versions.get(domain, 1)
124
+
125
+ op_key = (op_type, domain)
126
+ result.unique_ops.add(op_key)
127
+
128
+ # Count occurrences
129
+ result.op_counts[op_key] = result.op_counts.get(op_key, 0) + 1
130
+
131
+ # Check if supported (only add to supported/unsupported once per unique op)
132
+ if op_key not in [(op, dom) for op, dom in result.supported_ops]:
133
+ if op_key not in [(op, dom, _) for op, dom, _ in result.unsupported_ops]:
134
+ if is_supported(op_type, domain, opset_version):
135
+ result.supported_ops.append(op_key)
136
+ else:
137
+ result.unsupported_ops.append((op_type, domain, opset_version))
138
+
139
+ return result