onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/training.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Training operators from ai.onnx.preview.training domain."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, Tuple
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
import torch.fx
|
|
9
|
+
|
|
10
|
+
from ..op_registry import register
|
|
11
|
+
from ..utils.attributes import get_attribute
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..graph_builder import GraphBuilder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# =============================================================================
|
|
18
|
+
# Gradient operator
|
|
19
|
+
# =============================================================================
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@register("Gradient", domain="ai.onnx.preview.training")
|
|
23
|
+
def gradient(builder: "GraphBuilder", node: onnx.NodeProto) -> list:
|
|
24
|
+
"""Compute gradients of y with respect to xs using PyTorch autograd.
|
|
25
|
+
|
|
26
|
+
This operator computes symbolic gradients as specified in the ONNX
|
|
27
|
+
ai.onnx.preview.training domain. It recomputes the forward pass with
|
|
28
|
+
requires_grad=True to enable autograd.
|
|
29
|
+
"""
|
|
30
|
+
# Get the xs (inputs to differentiate with respect to) and y (output to differentiate)
|
|
31
|
+
xs_names = get_attribute(node, "xs", [])
|
|
32
|
+
y_name = get_attribute(node, "y", "")
|
|
33
|
+
|
|
34
|
+
# Get the xs nodes - these are the graph inputs we differentiate with respect to
|
|
35
|
+
xs_nodes = [builder.get_value(name) for name in xs_names]
|
|
36
|
+
|
|
37
|
+
# Find the path from xs to y by analyzing the ONNX graph
|
|
38
|
+
# We need to collect all nodes that contribute to y
|
|
39
|
+
graph = builder.model.graph
|
|
40
|
+
node_outputs = {} # output_name -> node
|
|
41
|
+
for n in graph.node:
|
|
42
|
+
for out in n.output:
|
|
43
|
+
node_outputs[out] = n
|
|
44
|
+
|
|
45
|
+
# Trace back from y to find all operations needed
|
|
46
|
+
def trace_ops(target_name, collected_nodes):
|
|
47
|
+
"""Recursively collect all nodes needed to compute target_name."""
|
|
48
|
+
if target_name in xs_names or target_name not in node_outputs:
|
|
49
|
+
return
|
|
50
|
+
n = node_outputs[target_name]
|
|
51
|
+
if n not in collected_nodes:
|
|
52
|
+
collected_nodes.append(n)
|
|
53
|
+
for inp in n.input:
|
|
54
|
+
trace_ops(inp, collected_nodes)
|
|
55
|
+
|
|
56
|
+
ops_to_y = []
|
|
57
|
+
trace_ops(y_name, ops_to_y)
|
|
58
|
+
ops_to_y.reverse() # Order from inputs to output
|
|
59
|
+
|
|
60
|
+
# Create a function that recomputes y from xs with autograd enabled
|
|
61
|
+
ops_info = [(n.op_type, list(n.input), list(n.output)) for n in ops_to_y]
|
|
62
|
+
|
|
63
|
+
def _compute_gradient_symbolic(ops_info, xs_names, y_name, *xs_values):
|
|
64
|
+
"""Recompute forward pass and compute gradients."""
|
|
65
|
+
# Create tensors with requires_grad
|
|
66
|
+
env = {}
|
|
67
|
+
for name, val in zip(xs_names, xs_values):
|
|
68
|
+
if isinstance(val, torch.Tensor):
|
|
69
|
+
env[name] = val.detach().clone().requires_grad_(True)
|
|
70
|
+
else:
|
|
71
|
+
env[name] = torch.tensor(val, dtype=torch.float32, requires_grad=True)
|
|
72
|
+
|
|
73
|
+
# Replay operations
|
|
74
|
+
for op_type, inputs, outputs in ops_info:
|
|
75
|
+
input_tensors = [env[inp] for inp in inputs if inp in env]
|
|
76
|
+
if op_type == "Add":
|
|
77
|
+
result = input_tensors[0] + input_tensors[1]
|
|
78
|
+
elif op_type == "Mul":
|
|
79
|
+
result = input_tensors[0] * input_tensors[1]
|
|
80
|
+
elif op_type == "Sub":
|
|
81
|
+
result = input_tensors[0] - input_tensors[1]
|
|
82
|
+
elif op_type == "Div":
|
|
83
|
+
result = input_tensors[0] / input_tensors[1]
|
|
84
|
+
else:
|
|
85
|
+
# Fallback for unsupported ops
|
|
86
|
+
result = input_tensors[0] if input_tensors else torch.tensor(0.0)
|
|
87
|
+
for out in outputs:
|
|
88
|
+
env[out] = result
|
|
89
|
+
|
|
90
|
+
# Get y and compute gradients
|
|
91
|
+
y = env.get(y_name)
|
|
92
|
+
if y is None:
|
|
93
|
+
return tuple(torch.zeros_like(x) for x in xs_values)
|
|
94
|
+
|
|
95
|
+
xs_tensors = [env[name] for name in xs_names]
|
|
96
|
+
grads = torch.autograd.grad(
|
|
97
|
+
outputs=y,
|
|
98
|
+
inputs=xs_tensors,
|
|
99
|
+
grad_outputs=torch.ones_like(y),
|
|
100
|
+
create_graph=False,
|
|
101
|
+
allow_unused=True,
|
|
102
|
+
)
|
|
103
|
+
return grads
|
|
104
|
+
|
|
105
|
+
# Create the gradient computation node
|
|
106
|
+
result = builder.call_function(
|
|
107
|
+
_compute_gradient_symbolic,
|
|
108
|
+
args=(ops_info, xs_names, y_name, *xs_nodes),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Return list of gradient outputs (one per xs input)
|
|
112
|
+
outputs = []
|
|
113
|
+
for i in range(len(xs_names)):
|
|
114
|
+
|
|
115
|
+
def _get_grad_i(grads, idx=i):
|
|
116
|
+
return grads[idx]
|
|
117
|
+
|
|
118
|
+
outputs.append(builder.call_function(_get_grad_i, args=(result,)))
|
|
119
|
+
|
|
120
|
+
return outputs
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# =============================================================================
|
|
124
|
+
# Optimizer operators
|
|
125
|
+
# =============================================================================
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _momentum_update(
|
|
129
|
+
R: torch.Tensor,
|
|
130
|
+
T: torch.Tensor,
|
|
131
|
+
X: torch.Tensor,
|
|
132
|
+
G: torch.Tensor,
|
|
133
|
+
V: torch.Tensor,
|
|
134
|
+
alpha: float,
|
|
135
|
+
beta: float,
|
|
136
|
+
norm_coefficient: float,
|
|
137
|
+
mode: str,
|
|
138
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
139
|
+
"""Compute one iteration of stochastic gradient update with momentum.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
R : torch.Tensor
|
|
144
|
+
Learning rate (scalar).
|
|
145
|
+
T : torch.Tensor
|
|
146
|
+
Training iteration count (scalar, int64).
|
|
147
|
+
X : torch.Tensor
|
|
148
|
+
Parameter tensor to optimize.
|
|
149
|
+
G : torch.Tensor
|
|
150
|
+
Gradient of X.
|
|
151
|
+
V : torch.Tensor
|
|
152
|
+
Accumulated momentum of X.
|
|
153
|
+
alpha : float
|
|
154
|
+
Decay coefficient of previous accumulated gradient (momentum).
|
|
155
|
+
beta : float
|
|
156
|
+
Scaling coefficient of current gradient.
|
|
157
|
+
norm_coefficient : float
|
|
158
|
+
L2-norm regularization coefficient.
|
|
159
|
+
mode : str
|
|
160
|
+
Either "standard" or "nesterov".
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
Tuple[torch.Tensor, torch.Tensor]
|
|
165
|
+
(X_new, V_new) - updated parameter and momentum.
|
|
166
|
+
"""
|
|
167
|
+
# Add L2 regularization gradient: gradient of 0.5 * norm_coefficient * ||X||^2
|
|
168
|
+
G_regularized = norm_coefficient * X + G
|
|
169
|
+
|
|
170
|
+
# In the first training iteration (T == 0), beta should always be 1
|
|
171
|
+
beta_adjusted = beta if T.item() > 0 else 1.0
|
|
172
|
+
|
|
173
|
+
# Compute new momentum
|
|
174
|
+
V_new = alpha * V + beta_adjusted * G_regularized
|
|
175
|
+
|
|
176
|
+
if mode == "nesterov":
|
|
177
|
+
# Nesterov momentum: X_new = X - R * (G_regularized + alpha * V_new)
|
|
178
|
+
X_new = X - R * (G_regularized + alpha * V_new)
|
|
179
|
+
else:
|
|
180
|
+
# Standard momentum: X_new = X - R * V_new
|
|
181
|
+
X_new = X - R * V_new
|
|
182
|
+
|
|
183
|
+
return X_new, V_new
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@register("Momentum", domain="ai.onnx.preview.training")
|
|
187
|
+
def momentum(
|
|
188
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
189
|
+
) -> Tuple[torch.fx.Node, ...]:
|
|
190
|
+
"""Momentum optimizer operator.
|
|
191
|
+
|
|
192
|
+
Compute one iteration of stochastic gradient update with momentum.
|
|
193
|
+
This operator can conduct the optimization of multiple tensor variables.
|
|
194
|
+
|
|
195
|
+
Inputs:
|
|
196
|
+
R: Learning rate (scalar).
|
|
197
|
+
T: Training iteration count (scalar, int64).
|
|
198
|
+
inputs (variadic): X_1, X_2, ..., X_n (parameters), G_1, G_2, ..., G_n (gradients),
|
|
199
|
+
V_1, V_2, ..., V_n (momentums).
|
|
200
|
+
|
|
201
|
+
Outputs:
|
|
202
|
+
X_1_new, X_2_new, ..., X_n_new, V_1_new, V_2_new, ..., V_n_new.
|
|
203
|
+
|
|
204
|
+
Attributes:
|
|
205
|
+
alpha: Decay coefficient of previous momentum.
|
|
206
|
+
beta: Scaling coefficient of current gradient.
|
|
207
|
+
mode: "standard" or "nesterov".
|
|
208
|
+
norm_coefficient: L2-norm regularization coefficient.
|
|
209
|
+
"""
|
|
210
|
+
# Get attributes
|
|
211
|
+
alpha = get_attribute(node, "alpha", 0.9)
|
|
212
|
+
beta = get_attribute(node, "beta", 1.0)
|
|
213
|
+
mode = get_attribute(node, "mode", "standard")
|
|
214
|
+
norm_coefficient = get_attribute(node, "norm_coefficient", 0.0)
|
|
215
|
+
|
|
216
|
+
# Decode mode if bytes
|
|
217
|
+
if isinstance(mode, bytes):
|
|
218
|
+
mode = mode.decode("utf-8")
|
|
219
|
+
|
|
220
|
+
# Get inputs: R, T, then groups of (X, G, V)
|
|
221
|
+
# Input format: R, T, X_1, X_2, ..., X_n, G_1, G_2, ..., G_n, V_1, V_2, ..., V_n
|
|
222
|
+
# Number of tensors to optimize: (num_inputs - 2) / 3
|
|
223
|
+
num_inputs = len(node.input)
|
|
224
|
+
num_tensors = (num_inputs - 2) // 3
|
|
225
|
+
|
|
226
|
+
R = builder.get_value(node.input[0])
|
|
227
|
+
T = builder.get_value(node.input[1])
|
|
228
|
+
|
|
229
|
+
# Collect X, G, V tensors
|
|
230
|
+
X_inputs = [builder.get_value(node.input[2 + i]) for i in range(num_tensors)]
|
|
231
|
+
G_inputs = [
|
|
232
|
+
builder.get_value(node.input[2 + num_tensors + i]) for i in range(num_tensors)
|
|
233
|
+
]
|
|
234
|
+
V_inputs = [
|
|
235
|
+
builder.get_value(node.input[2 + 2 * num_tensors + i])
|
|
236
|
+
for i in range(num_tensors)
|
|
237
|
+
]
|
|
238
|
+
|
|
239
|
+
# Process each tensor pair and collect results
|
|
240
|
+
results = []
|
|
241
|
+
for i in range(num_tensors):
|
|
242
|
+
result = builder.call_function(
|
|
243
|
+
_momentum_update,
|
|
244
|
+
args=(
|
|
245
|
+
R,
|
|
246
|
+
T,
|
|
247
|
+
X_inputs[i],
|
|
248
|
+
G_inputs[i],
|
|
249
|
+
V_inputs[i],
|
|
250
|
+
alpha,
|
|
251
|
+
beta,
|
|
252
|
+
norm_coefficient,
|
|
253
|
+
mode,
|
|
254
|
+
),
|
|
255
|
+
)
|
|
256
|
+
results.append(result)
|
|
257
|
+
|
|
258
|
+
# If only one tensor, return the tuple directly
|
|
259
|
+
if num_tensors == 1:
|
|
260
|
+
return results[0]
|
|
261
|
+
|
|
262
|
+
# For multiple tensors, we need to flatten results:
|
|
263
|
+
# Output format: X_1_new, X_2_new, ..., X_n_new, V_1_new, V_2_new, ..., V_n_new
|
|
264
|
+
# Create a helper function to extract and reorder the outputs
|
|
265
|
+
def _reorder_momentum_outputs(*results):
|
|
266
|
+
X_news = [r[0] for r in results]
|
|
267
|
+
V_news = [r[1] for r in results]
|
|
268
|
+
return tuple(X_news + V_news)
|
|
269
|
+
|
|
270
|
+
return builder.call_function(_reorder_momentum_outputs, args=tuple(results))
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _adagrad_update(
|
|
274
|
+
R: torch.Tensor,
|
|
275
|
+
T: torch.Tensor,
|
|
276
|
+
X: torch.Tensor,
|
|
277
|
+
G: torch.Tensor,
|
|
278
|
+
H: torch.Tensor,
|
|
279
|
+
norm_coefficient: float,
|
|
280
|
+
decay_factor: float,
|
|
281
|
+
epsilon: float,
|
|
282
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
283
|
+
"""Compute one iteration of ADAGRAD update.
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
R : torch.Tensor
|
|
288
|
+
Initial learning rate (scalar).
|
|
289
|
+
T : torch.Tensor
|
|
290
|
+
Update count (scalar, int64).
|
|
291
|
+
X : torch.Tensor
|
|
292
|
+
Parameter tensor to optimize.
|
|
293
|
+
G : torch.Tensor
|
|
294
|
+
Gradient of X.
|
|
295
|
+
H : torch.Tensor
|
|
296
|
+
Accumulated squared gradient of X.
|
|
297
|
+
norm_coefficient : float
|
|
298
|
+
L2-norm regularization coefficient.
|
|
299
|
+
decay_factor : float
|
|
300
|
+
Learning rate decay factor.
|
|
301
|
+
epsilon : float
|
|
302
|
+
Small constant to avoid dividing by zero.
|
|
303
|
+
|
|
304
|
+
Returns
|
|
305
|
+
-------
|
|
306
|
+
Tuple[torch.Tensor, torch.Tensor]
|
|
307
|
+
(X_new, H_new) - updated parameter and accumulated squared gradient.
|
|
308
|
+
"""
|
|
309
|
+
# Compute decayed learning rate: r = R / (1 + T * decay_factor)
|
|
310
|
+
r = R / (1 + T * decay_factor)
|
|
311
|
+
|
|
312
|
+
# Add L2 regularization gradient: gradient of 0.5 * norm_coefficient * ||X||^2
|
|
313
|
+
G_regularized = norm_coefficient * X + G
|
|
314
|
+
|
|
315
|
+
# Compute new accumulated squared gradient
|
|
316
|
+
H_new = H + G_regularized * G_regularized
|
|
317
|
+
|
|
318
|
+
# Compute the adaptive part of per-coordinate learning rate
|
|
319
|
+
H_adaptive = torch.sqrt(H_new) + epsilon
|
|
320
|
+
|
|
321
|
+
# Compute the new value of X
|
|
322
|
+
X_new = X - r * G_regularized / H_adaptive
|
|
323
|
+
|
|
324
|
+
return X_new, H_new
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
@register("Adagrad", domain="ai.onnx.preview.training")
|
|
328
|
+
def adagrad(builder: "GraphBuilder", node: onnx.NodeProto) -> Tuple[torch.fx.Node, ...]:
|
|
329
|
+
"""Adagrad optimizer operator.
|
|
330
|
+
|
|
331
|
+
Compute one iteration of ADAGRAD, a stochastic gradient based optimization
|
|
332
|
+
algorithm. This operator can conduct the optimization of multiple tensor variables.
|
|
333
|
+
|
|
334
|
+
Inputs:
|
|
335
|
+
R: Initial learning rate (scalar).
|
|
336
|
+
T: Update count (scalar, int64).
|
|
337
|
+
inputs (variadic): X_1, X_2, ..., X_n (parameters), G_1, G_2, ..., G_n (gradients),
|
|
338
|
+
H_1, H_2, ..., H_n (accumulated squared gradients).
|
|
339
|
+
|
|
340
|
+
Outputs:
|
|
341
|
+
X_1_new, X_2_new, ..., X_n_new, H_1_new, H_2_new, ..., H_n_new.
|
|
342
|
+
|
|
343
|
+
Attributes:
|
|
344
|
+
norm_coefficient: L2-norm regularization coefficient (default: 0.0).
|
|
345
|
+
decay_factor: Learning rate decay factor (default: 0.0).
|
|
346
|
+
epsilon: Small constant to avoid dividing by zero (default: 1e-6).
|
|
347
|
+
"""
|
|
348
|
+
# Get attributes
|
|
349
|
+
norm_coefficient = get_attribute(node, "norm_coefficient", 0.0)
|
|
350
|
+
decay_factor = get_attribute(node, "decay_factor", 0.0)
|
|
351
|
+
epsilon = get_attribute(node, "epsilon", 1e-6)
|
|
352
|
+
|
|
353
|
+
# Get inputs: R, T, then groups of (X, G, H)
|
|
354
|
+
# Input format: R, T, X_1, X_2, ..., X_n, G_1, G_2, ..., G_n, H_1, H_2, ..., H_n
|
|
355
|
+
# Number of tensors to optimize: (num_inputs - 2) / 3
|
|
356
|
+
num_inputs = len(node.input)
|
|
357
|
+
num_tensors = (num_inputs - 2) // 3
|
|
358
|
+
|
|
359
|
+
R = builder.get_value(node.input[0])
|
|
360
|
+
T = builder.get_value(node.input[1])
|
|
361
|
+
|
|
362
|
+
# Collect X, G, H tensors
|
|
363
|
+
X_inputs = [builder.get_value(node.input[2 + i]) for i in range(num_tensors)]
|
|
364
|
+
G_inputs = [
|
|
365
|
+
builder.get_value(node.input[2 + num_tensors + i]) for i in range(num_tensors)
|
|
366
|
+
]
|
|
367
|
+
H_inputs = [
|
|
368
|
+
builder.get_value(node.input[2 + 2 * num_tensors + i])
|
|
369
|
+
for i in range(num_tensors)
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
# Process each tensor pair and collect results
|
|
373
|
+
results = []
|
|
374
|
+
for i in range(num_tensors):
|
|
375
|
+
result = builder.call_function(
|
|
376
|
+
_adagrad_update,
|
|
377
|
+
args=(
|
|
378
|
+
R,
|
|
379
|
+
T,
|
|
380
|
+
X_inputs[i],
|
|
381
|
+
G_inputs[i],
|
|
382
|
+
H_inputs[i],
|
|
383
|
+
norm_coefficient,
|
|
384
|
+
decay_factor,
|
|
385
|
+
epsilon,
|
|
386
|
+
),
|
|
387
|
+
)
|
|
388
|
+
results.append(result)
|
|
389
|
+
|
|
390
|
+
# If only one tensor, return the tuple directly
|
|
391
|
+
if num_tensors == 1:
|
|
392
|
+
return results[0]
|
|
393
|
+
|
|
394
|
+
# For multiple tensors, we need to flatten results:
|
|
395
|
+
# Output format: X_1_new, X_2_new, ..., X_n_new, H_1_new, H_2_new, ..., H_n_new
|
|
396
|
+
# Create a helper function to extract and reorder the outputs
|
|
397
|
+
def _reorder_adagrad_outputs(*results):
|
|
398
|
+
X_news = [r[0] for r in results]
|
|
399
|
+
H_news = [r[1] for r in results]
|
|
400
|
+
return tuple(X_news + H_news)
|
|
401
|
+
|
|
402
|
+
return builder.call_function(_reorder_adagrad_outputs, args=tuple(results))
|
onnx2fx/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Utility modules for ONNX model handling.
|
|
3
|
+
|
|
4
|
+
This package provides utilities for:
|
|
5
|
+
- ONNX attribute parsing
|
|
6
|
+
- ONNX to PyTorch data type mapping
|
|
7
|
+
- ONNX model analysis
|
|
8
|
+
- Name sanitization for valid Python identifiers
|
|
9
|
+
- Operator implementation helpers
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .dtype import DTYPE_MAP, onnx_dtype_to_torch
|
|
13
|
+
from .attributes import get_attribute, get_attributes
|
|
14
|
+
from .analyze import analyze_model, AnalysisResult
|
|
15
|
+
from .training import make_trainable
|
|
16
|
+
from .names import sanitize_name
|
|
17
|
+
from .op_helpers import (
|
|
18
|
+
get_optional_input,
|
|
19
|
+
get_attribute_or_input,
|
|
20
|
+
unary_op,
|
|
21
|
+
unary_op_with_kwargs,
|
|
22
|
+
binary_op,
|
|
23
|
+
compute_same_padding,
|
|
24
|
+
pad_list_to_onnx_pads,
|
|
25
|
+
apply_auto_pad,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"DTYPE_MAP",
|
|
30
|
+
"onnx_dtype_to_torch",
|
|
31
|
+
"get_attribute",
|
|
32
|
+
"get_attributes",
|
|
33
|
+
"analyze_model",
|
|
34
|
+
"AnalysisResult",
|
|
35
|
+
"make_trainable",
|
|
36
|
+
"sanitize_name",
|
|
37
|
+
"get_optional_input",
|
|
38
|
+
"get_attribute_or_input",
|
|
39
|
+
"unary_op",
|
|
40
|
+
"unary_op_with_kwargs",
|
|
41
|
+
"binary_op",
|
|
42
|
+
"compute_same_padding",
|
|
43
|
+
"pad_list_to_onnx_pads",
|
|
44
|
+
"apply_auto_pad",
|
|
45
|
+
]
|
onnx2fx/utils/analyze.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""ONNX model analysis utilities for identifying operator support."""
|
|
3
|
+
|
|
4
|
+
from typing import Union, Dict, List, Set, Tuple
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
import onnx
|
|
8
|
+
|
|
9
|
+
from ..op_registry import is_supported
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class AnalysisResult:
|
|
14
|
+
"""Result of analyzing an ONNX model for operator support.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
total_nodes : int
|
|
19
|
+
Total number of nodes in the model graph.
|
|
20
|
+
unique_ops : Set[Tuple[str, str]]
|
|
21
|
+
Set of unique (op_type, domain) tuples.
|
|
22
|
+
supported_ops : List[Tuple[str, str]]
|
|
23
|
+
List of supported (op_type, domain) tuples.
|
|
24
|
+
unsupported_ops : List[Tuple[str, str, int]]
|
|
25
|
+
List of unsupported (op_type, domain, opset_version) tuples.
|
|
26
|
+
opset_versions : Dict[str, int]
|
|
27
|
+
Mapping of domain to opset version.
|
|
28
|
+
op_counts : Dict[Tuple[str, str], int]
|
|
29
|
+
Count of each (op_type, domain) in the model.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
total_nodes: int = 0
|
|
33
|
+
unique_ops: Set[Tuple[str, str]] = field(default_factory=set)
|
|
34
|
+
supported_ops: List[Tuple[str, str]] = field(default_factory=list)
|
|
35
|
+
unsupported_ops: List[Tuple[str, str, int]] = field(default_factory=list)
|
|
36
|
+
opset_versions: Dict[str, int] = field(default_factory=dict)
|
|
37
|
+
op_counts: Dict[Tuple[str, str], int] = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
def is_fully_supported(self) -> bool:
|
|
40
|
+
"""Check if all operators in the model are supported."""
|
|
41
|
+
return len(self.unsupported_ops) == 0
|
|
42
|
+
|
|
43
|
+
def summary(self) -> str:
|
|
44
|
+
"""Generate a human-readable summary of the analysis."""
|
|
45
|
+
lines = [
|
|
46
|
+
"ONNX Model Analysis Summary",
|
|
47
|
+
"=" * 40,
|
|
48
|
+
f"Total nodes: {self.total_nodes}",
|
|
49
|
+
f"Unique operators: {len(self.unique_ops)}",
|
|
50
|
+
f"Supported: {len(self.supported_ops)}",
|
|
51
|
+
f"Unsupported: {len(self.unsupported_ops)}",
|
|
52
|
+
"",
|
|
53
|
+
"Opset versions:",
|
|
54
|
+
]
|
|
55
|
+
for domain, version in sorted(self.opset_versions.items()):
|
|
56
|
+
domain_display = domain if domain else "(default ONNX)"
|
|
57
|
+
lines.append(f" {domain_display}: {version}")
|
|
58
|
+
|
|
59
|
+
if self.unsupported_ops:
|
|
60
|
+
lines.append("")
|
|
61
|
+
lines.append("Unsupported operators:")
|
|
62
|
+
for op_type, domain, opset in self.unsupported_ops:
|
|
63
|
+
domain_display = domain if domain else "(default)"
|
|
64
|
+
lines.append(
|
|
65
|
+
f" - {op_type} (domain: {domain_display}, opset: {opset})"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if self.supported_ops:
|
|
69
|
+
lines.append("")
|
|
70
|
+
lines.append("Supported operators:")
|
|
71
|
+
for op_type, domain in sorted(self.supported_ops):
|
|
72
|
+
domain_display = domain if domain else "(default)"
|
|
73
|
+
count = self.op_counts.get((op_type, domain), 0)
|
|
74
|
+
lines.append(f" - {op_type} (domain: {domain_display}) x{count}")
|
|
75
|
+
|
|
76
|
+
return "\n".join(lines)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def analyze_model(model: Union[onnx.ModelProto, str]) -> AnalysisResult:
|
|
80
|
+
"""Analyze an ONNX model and identify supported/unsupported operators.
|
|
81
|
+
|
|
82
|
+
This function iterates through all nodes in an ONNX model graph and
|
|
83
|
+
checks each operator against the onnx2fx registry to determine
|
|
84
|
+
which operators are supported for conversion.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
model : onnx.ModelProto or str
|
|
89
|
+
The ONNX model or path to the ONNX model file.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
AnalysisResult
|
|
94
|
+
Analysis result containing supported/unsupported operators,
|
|
95
|
+
opset versions, and operator counts.
|
|
96
|
+
|
|
97
|
+
Examples
|
|
98
|
+
--------
|
|
99
|
+
>>> import onnx
|
|
100
|
+
>>> from onnx2fx import analyze_model
|
|
101
|
+
>>> model = onnx.load("model.onnx")
|
|
102
|
+
>>> result = analyze_model(model)
|
|
103
|
+
>>> print(result.summary())
|
|
104
|
+
>>> if not result.is_fully_supported():
|
|
105
|
+
... print("Missing operators:", result.unsupported_ops)
|
|
106
|
+
"""
|
|
107
|
+
if isinstance(model, str):
|
|
108
|
+
model = onnx.load(model)
|
|
109
|
+
|
|
110
|
+
result = AnalysisResult()
|
|
111
|
+
|
|
112
|
+
# Extract opset versions
|
|
113
|
+
for opset in model.opset_import:
|
|
114
|
+
domain = opset.domain if opset.domain else ""
|
|
115
|
+
result.opset_versions[domain] = opset.version
|
|
116
|
+
|
|
117
|
+
# Analyze all nodes
|
|
118
|
+
for node in model.graph.node:
|
|
119
|
+
result.total_nodes += 1
|
|
120
|
+
|
|
121
|
+
op_type = node.op_type
|
|
122
|
+
domain = node.domain if node.domain else ""
|
|
123
|
+
opset_version = result.opset_versions.get(domain, 1)
|
|
124
|
+
|
|
125
|
+
op_key = (op_type, domain)
|
|
126
|
+
result.unique_ops.add(op_key)
|
|
127
|
+
|
|
128
|
+
# Count occurrences
|
|
129
|
+
result.op_counts[op_key] = result.op_counts.get(op_key, 0) + 1
|
|
130
|
+
|
|
131
|
+
# Check if supported (only add to supported/unsupported once per unique op)
|
|
132
|
+
if op_key not in [(op, dom) for op, dom in result.supported_ops]:
|
|
133
|
+
if op_key not in [(op, dom, _) for op, dom, _ in result.unsupported_ops]:
|
|
134
|
+
if is_supported(op_type, domain, opset_version):
|
|
135
|
+
result.supported_ops.append(op_key)
|
|
136
|
+
else:
|
|
137
|
+
result.unsupported_ops.append((op_type, domain, opset_version))
|
|
138
|
+
|
|
139
|
+
return result
|