onnx2fx 0.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx-0.0.0/PKG-INFO +395 -0
- onnx2fx-0.0.0/README.md +384 -0
- onnx2fx-0.0.0/pyproject.toml +50 -0
- onnx2fx-0.0.0/src/onnx2fx/__init__.py +96 -0
- onnx2fx-0.0.0/src/onnx2fx/converter.py +62 -0
- onnx2fx-0.0.0/src/onnx2fx/exceptions.py +155 -0
- onnx2fx-0.0.0/src/onnx2fx/graph_builder.py +634 -0
- onnx2fx-0.0.0/src/onnx2fx/op_registry.py +345 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/__init__.py +74 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/activation.py +282 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/attention.py +1055 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/control_flow.py +947 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/convolution.py +406 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/image.py +748 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/linalg.py +33 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/loss.py +56 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/nn.py +96 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/normalization.py +289 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/pooling.py +897 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/quantization.py +524 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/random.py +102 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/recurrent.py +647 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/reduction.py +534 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/sequence.py +304 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/signal.py +444 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/string.py +126 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/tensor.py +1161 -0
- onnx2fx-0.0.0/src/onnx2fx/ops/training.py +402 -0
- onnx2fx-0.0.0/src/onnx2fx/py.typed +0 -0
- onnx2fx-0.0.0/src/onnx2fx/utils/__init__.py +45 -0
- onnx2fx-0.0.0/src/onnx2fx/utils/analyze.py +139 -0
- onnx2fx-0.0.0/src/onnx2fx/utils/attributes.py +150 -0
- onnx2fx-0.0.0/src/onnx2fx/utils/dtype.py +107 -0
- onnx2fx-0.0.0/src/onnx2fx/utils/external_data.py +233 -0
- onnx2fx-0.0.0/src/onnx2fx/utils/names.py +43 -0
- onnx2fx-0.0.0/src/onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx-0.0.0/src/onnx2fx/utils/training.py +54 -0
onnx2fx-0.0.0/PKG-INFO
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: onnx2fx
|
|
3
|
+
Version: 0.0.0
|
|
4
|
+
Summary: Yet another ONNX to PyTorch FX converter
|
|
5
|
+
Author: Masahiro Hiramori
|
|
6
|
+
Author-email: Masahiro Hiramori <contact@mshr-h.com>
|
|
7
|
+
Requires-Dist: onnx>=1.19.1
|
|
8
|
+
Requires-Dist: torch>=2.9.0
|
|
9
|
+
Requires-Python: >=3.11
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# onnx2fx
|
|
13
|
+
|
|
14
|
+
Yet another ONNX to PyTorch FX converter.
|
|
15
|
+
|
|
16
|
+
> **⚠️ Note:** This project is under active development. The public API may change at any time.
|
|
17
|
+
|
|
18
|
+
`onnx2fx` converts ONNX models into PyTorch FX `GraphModule`s, enabling seamless integration with PyTorch's ecosystem for optimization, analysis, and deployment.
|
|
19
|
+
|
|
20
|
+
## Features
|
|
21
|
+
|
|
22
|
+
- **Simple API**: Convert ONNX models with a single function call
|
|
23
|
+
- **Extensive Operator Support**: Wide ONNX operator coverage including standard and Microsoft domain operators
|
|
24
|
+
- **Multi-Opset Version Support**: Automatic selection of version-specific operator handlers based on model opset
|
|
25
|
+
- **Custom Operator Registration**: Easily extend support for unsupported or custom ONNX operators
|
|
26
|
+
- **PyTorch FX Output**: Get a `torch.fx.GraphModule` for easy inspection, optimization, and compilation
|
|
27
|
+
- **Dynamic Shape Support**: Handle models with dynamic input dimensions
|
|
28
|
+
- **Quantization Support**: Support for quantized operators (QLinear*, DequantizeLinear, etc.)
|
|
29
|
+
- **Training Support**: Convert models to trainable modules with `make_trainable()` utility
|
|
30
|
+
|
|
31
|
+
## Tested Models
|
|
32
|
+
|
|
33
|
+
The following models have been tested and verified to work with onnx2fx:
|
|
34
|
+
|
|
35
|
+
- **PaddleOCRv5**: Text detection and recognition models (mobile and server variants)
|
|
36
|
+
- PP-OCRv5_mobile_det, PP-OCRv5_mobile_rec
|
|
37
|
+
- PP-OCRv5_server_det, PP-OCRv5_server_rec
|
|
38
|
+
- **TorchVision Models**: ResNet, VGG, MobileNet, etc. (via ONNX export)
|
|
39
|
+
- **LFM2**: Liquid Foundation Model (LFM2-350M-ENJP-MT)
|
|
40
|
+
- **LFM2.5**: Liquid Foundation Model 2.5
|
|
41
|
+
- **TinyLlama**: TinyLlama-1.1B-Chat
|
|
42
|
+
|
|
43
|
+
## Installation
|
|
44
|
+
|
|
45
|
+
### Requirements
|
|
46
|
+
|
|
47
|
+
- Python >= 3.11
|
|
48
|
+
- PyTorch >= 2.9.0
|
|
49
|
+
- ONNX >= 1.19.1
|
|
50
|
+
|
|
51
|
+
### From Source
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
git clone https://github.com/mshr-h/onnx2fx.git
|
|
55
|
+
cd onnx2fx
|
|
56
|
+
uv sync
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
### Development Installation
|
|
60
|
+
|
|
61
|
+
```bash
|
|
62
|
+
git clone https://github.com/mshr-h/onnx2fx.git
|
|
63
|
+
cd onnx2fx
|
|
64
|
+
uv sync --dev
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Quick Start
|
|
68
|
+
|
|
69
|
+
### Basic Conversion
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
import torch
|
|
73
|
+
import onnx
|
|
74
|
+
from onnx2fx import convert
|
|
75
|
+
|
|
76
|
+
# Load from file path
|
|
77
|
+
fx_module = convert("model.onnx")
|
|
78
|
+
|
|
79
|
+
# Or from onnx.ModelProto
|
|
80
|
+
onnx_model = onnx.load("model.onnx")
|
|
81
|
+
fx_module = convert(onnx_model)
|
|
82
|
+
|
|
83
|
+
# For models with external data, you can pass base_dir.
|
|
84
|
+
# memmap_external_data avoids loading external data into memory.
|
|
85
|
+
fx_module = convert("model.onnx", base_dir="/path/to/model_dir", memmap_external_data=True)
|
|
86
|
+
|
|
87
|
+
# Run inference
|
|
88
|
+
input_tensor = torch.randn(1, 3, 224, 224)
|
|
89
|
+
output = fx_module(input_tensor)
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### Inspecting the Converted Graph
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
from onnx2fx import convert
|
|
96
|
+
|
|
97
|
+
fx_module = convert("model.onnx")
|
|
98
|
+
|
|
99
|
+
# Print the FX graph
|
|
100
|
+
print(fx_module.graph)
|
|
101
|
+
|
|
102
|
+
# Get the graph code
|
|
103
|
+
print(fx_module.code)
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
### Registering Custom Operators
|
|
107
|
+
|
|
108
|
+
For unsupported or custom ONNX operators, you can register your own handlers:
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
import torch
|
|
112
|
+
from onnx2fx import convert, register_op
|
|
113
|
+
|
|
114
|
+
# Using decorator
|
|
115
|
+
@register_op("MyCustomOp")
|
|
116
|
+
def my_custom_op(builder, node):
|
|
117
|
+
x = builder.get_value(node.input[0])
|
|
118
|
+
return builder.call_function(torch.sigmoid, args=(x,))
|
|
119
|
+
|
|
120
|
+
# Or register directly
|
|
121
|
+
def my_handler(builder, node):
|
|
122
|
+
x = builder.get_value(node.input[0])
|
|
123
|
+
return builder.call_function(torch.tanh, args=(x,))
|
|
124
|
+
|
|
125
|
+
register_op("TanhCustom", my_handler)
|
|
126
|
+
|
|
127
|
+
# For custom domains (e.g., Microsoft operators)
|
|
128
|
+
@register_op("BiasGelu", domain="com.microsoft")
|
|
129
|
+
def bias_gelu(builder, node):
|
|
130
|
+
x = builder.get_value(node.input[0])
|
|
131
|
+
bias = builder.get_value(node.input[1])
|
|
132
|
+
return builder.call_function(
|
|
133
|
+
lambda t, b: torch.nn.functional.gelu(t + b),
|
|
134
|
+
args=(x, bias)
|
|
135
|
+
)
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
> Note: `ai.onnx.ml` is treated as a distinct domain. If you register or query
|
|
140
|
+
> operators in that domain, pass `domain="ai.onnx.ml"` explicitly.
|
|
141
|
+
|
|
142
|
+
### Multi-Opset Version Support
|
|
143
|
+
|
|
144
|
+
The library automatically selects the appropriate operator handler based on the model's opset version. For operators with version-specific behavior (e.g., `Softmax` changed default axis in opset 13), the correct implementation is used automatically:
|
|
145
|
+
|
|
146
|
+
```python
|
|
147
|
+
from onnx2fx import convert
|
|
148
|
+
|
|
149
|
+
# Models with different opset versions are handled automatically
|
|
150
|
+
fx_module_v11 = convert("model_opset11.onnx") # Uses opset 11 semantics
|
|
151
|
+
fx_module_v17 = convert("model_opset17.onnx") # Uses opset 17 semantics
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
### Training Converted Models
|
|
155
|
+
|
|
156
|
+
By default, ONNX weights are loaded as non-trainable buffers. Use `make_trainable()` to enable training:
|
|
157
|
+
|
|
158
|
+
```python
|
|
159
|
+
import torch
|
|
160
|
+
from onnx2fx import convert, make_trainable
|
|
161
|
+
|
|
162
|
+
# Convert and make trainable
|
|
163
|
+
fx_module = convert("model.onnx")
|
|
164
|
+
fx_module = make_trainable(fx_module) # Convert buffers to trainable parameters
|
|
165
|
+
|
|
166
|
+
# Now you can train the model
|
|
167
|
+
optimizer = torch.optim.Adam(fx_module.parameters(), lr=1e-4)
|
|
168
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
169
|
+
|
|
170
|
+
for inputs, targets in dataloader:
|
|
171
|
+
optimizer.zero_grad()
|
|
172
|
+
outputs = fx_module(inputs)
|
|
173
|
+
loss = criterion(outputs, targets)
|
|
174
|
+
loss.backward()
|
|
175
|
+
optimizer.step()
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
### Querying Supported Operators
|
|
179
|
+
|
|
180
|
+
```python
|
|
181
|
+
from onnx2fx import (
|
|
182
|
+
get_supported_ops,
|
|
183
|
+
get_all_supported_ops,
|
|
184
|
+
get_registered_domains,
|
|
185
|
+
is_supported,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Check if an operator is supported
|
|
189
|
+
print(is_supported("Conv")) # True
|
|
190
|
+
print(is_supported("BiasGelu", domain="com.microsoft")) # True
|
|
191
|
+
|
|
192
|
+
# Get all operators for a domain
|
|
193
|
+
standard_ops = get_supported_ops() # Default ONNX domain
|
|
194
|
+
microsoft_ops = get_supported_ops("com.microsoft")
|
|
195
|
+
|
|
196
|
+
# Get all operators across all domains
|
|
197
|
+
all_ops = get_all_supported_ops()
|
|
198
|
+
|
|
199
|
+
# Get registered domains
|
|
200
|
+
domains = get_registered_domains() # ['', 'com.microsoft']
|
|
201
|
+
```
|
|
202
|
+
|
|
203
|
+
### Analyzing Model Compatibility
|
|
204
|
+
|
|
205
|
+
Before converting, you can analyze a model to check operator support:
|
|
206
|
+
|
|
207
|
+
```python
|
|
208
|
+
from onnx2fx import analyze_model
|
|
209
|
+
|
|
210
|
+
# Analyze an ONNX model
|
|
211
|
+
result = analyze_model("model.onnx")
|
|
212
|
+
|
|
213
|
+
# Check results
|
|
214
|
+
print(f"Supported operators: {result.supported_ops}")
|
|
215
|
+
print(f"Unsupported operators: {result.unsupported_ops}")
|
|
216
|
+
print(f"Is fully supported: {result.is_fully_supported()}")
|
|
217
|
+
|
|
218
|
+
# Get detailed summary
|
|
219
|
+
print(result.summary())
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
### Exception Handling
|
|
223
|
+
|
|
224
|
+
Handle conversion errors gracefully:
|
|
225
|
+
|
|
226
|
+
```python
|
|
227
|
+
from onnx2fx import (
|
|
228
|
+
convert,
|
|
229
|
+
Onnx2FxError,
|
|
230
|
+
UnsupportedOpError,
|
|
231
|
+
ConversionError,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
fx_module = convert("model.onnx")
|
|
236
|
+
except UnsupportedOpError as e:
|
|
237
|
+
print(f"Unsupported operator: {e}")
|
|
238
|
+
except ConversionError as e:
|
|
239
|
+
print(f"Conversion failed: {e}")
|
|
240
|
+
except Onnx2FxError as e:
|
|
241
|
+
print(f"onnx2fx error: {e}")
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
## Supported Operators
|
|
245
|
+
|
|
246
|
+
### Standard ONNX Domain
|
|
247
|
+
|
|
248
|
+
This is a short list of representative operators. For the full list, call
|
|
249
|
+
`get_supported_ops()` or `get_all_supported_ops()`.
|
|
250
|
+
|
|
251
|
+
- **Core tensor & shape**: Reshape, Transpose, Concat, Split, Slice, Gather, Pad, Resize, Shape, Cast
|
|
252
|
+
- **Math & activations**: Add, Mul, MatMul, Gemm, Relu, Gelu, SiLU, Softmax, LogSoftmax
|
|
253
|
+
- **Normalization & pooling**: BatchNormalization, LayerNormalization, InstanceNormalization, GroupNormalization, MaxPool, AveragePool, GlobalAveragePool
|
|
254
|
+
- **Reductions & indexing**: ReduceSum, ReduceMean, ArgMax, ArgMin, TopK
|
|
255
|
+
- **Control flow & sequence**: If, Loop, SequenceConstruct, SplitToSequence, ConcatFromSequence
|
|
256
|
+
- **Quantization**: QuantizeLinear, DequantizeLinear, QLinearConv, QLinearMatMul
|
|
257
|
+
- **Other**: Einsum, NonMaxSuppression, StringNormalizer
|
|
258
|
+
|
|
259
|
+
#### Attention & Normalization Extensions
|
|
260
|
+
- Attention (opset 24+)
|
|
261
|
+
- RotaryEmbedding (opset 23+)
|
|
262
|
+
- GroupQueryAttention
|
|
263
|
+
- EmbedLayerNormalization
|
|
264
|
+
- SkipLayerNormalization
|
|
265
|
+
- SimplifiedLayerNormalization
|
|
266
|
+
- SkipSimplifiedLayerNormalization
|
|
267
|
+
|
|
268
|
+
### Microsoft Domain (`com.microsoft`)
|
|
269
|
+
|
|
270
|
+
> Note: Some operators are available in both the standard and Microsoft domains (e.g., Attention, RotaryEmbedding, SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention, SkipLayerNormalization, EmbedLayerNormalization).
|
|
271
|
+
|
|
272
|
+
- Attention
|
|
273
|
+
- RotaryEmbedding
|
|
274
|
+
- SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization
|
|
275
|
+
- SkipLayerNormalization, EmbedLayerNormalization
|
|
276
|
+
- GroupQueryAttention
|
|
277
|
+
|
|
278
|
+
## API Reference
|
|
279
|
+
|
|
280
|
+
### `convert(model)`
|
|
281
|
+
|
|
282
|
+
Converts an ONNX model to a PyTorch FX `GraphModule`.
|
|
283
|
+
|
|
284
|
+
**Parameters:**
|
|
285
|
+
- `model` (`Union[onnx.ModelProto, str]`): Either an in-memory `onnx.ModelProto` or a file path to an ONNX model.
|
|
286
|
+
|
|
287
|
+
**Returns:**
|
|
288
|
+
- `torch.fx.GraphModule`: A PyTorch FX Graph module.
|
|
289
|
+
|
|
290
|
+
### `register_op(op_type, handler=None, domain="", since_version=1)`
|
|
291
|
+
|
|
292
|
+
Register a custom ONNX operator handler.
|
|
293
|
+
|
|
294
|
+
**Parameters:**
|
|
295
|
+
- `op_type` (`str`): The ONNX operator type name.
|
|
296
|
+
- `handler` (`OpHandler`, optional): The handler function. If not provided, returns a decorator.
|
|
297
|
+
- `domain` (`str`, optional): The ONNX domain. Default is "" (standard ONNX domain).
|
|
298
|
+
- `since_version` (`int`, optional): The minimum opset version for this handler. Default is 1.
|
|
299
|
+
|
|
300
|
+
### `unregister_op(op_type, domain="", since_version=None)`
|
|
301
|
+
|
|
302
|
+
Unregister an operator handler.
|
|
303
|
+
|
|
304
|
+
**Parameters:**
|
|
305
|
+
- `op_type` (`str`): The ONNX operator type name.
|
|
306
|
+
- `domain` (`str`, optional): The ONNX domain.
|
|
307
|
+
- `since_version` (`int`, optional): The specific opset handler to remove. If None, removes all versions.
|
|
308
|
+
|
|
309
|
+
**Returns:**
|
|
310
|
+
- `bool`: True if the operator was unregistered.
|
|
311
|
+
|
|
312
|
+
### `is_supported(op_type, domain="")`
|
|
313
|
+
|
|
314
|
+
Check if an operator is supported.
|
|
315
|
+
|
|
316
|
+
### `get_supported_ops(domain="")`
|
|
317
|
+
|
|
318
|
+
Get list of supported ONNX operators for a domain.
|
|
319
|
+
|
|
320
|
+
### `get_all_supported_ops()`
|
|
321
|
+
|
|
322
|
+
Get all supported operators across all domains.
|
|
323
|
+
|
|
324
|
+
### `get_registered_domains()`
|
|
325
|
+
|
|
326
|
+
Get list of registered domains.
|
|
327
|
+
|
|
328
|
+
### `analyze_model(model)`
|
|
329
|
+
|
|
330
|
+
Analyze an ONNX model for operator support.
|
|
331
|
+
|
|
332
|
+
**Parameters:**
|
|
333
|
+
- `model` (`Union[onnx.ModelProto, str]`): Either an in-memory `onnx.ModelProto` or a file path.
|
|
334
|
+
|
|
335
|
+
**Returns:**
|
|
336
|
+
- `AnalysisResult`: Analysis results with supported/unsupported operators.
|
|
337
|
+
|
|
338
|
+
### `AnalysisResult`
|
|
339
|
+
|
|
340
|
+
Dataclass containing model analysis results.
|
|
341
|
+
|
|
342
|
+
**Attributes:**
|
|
343
|
+
- `total_nodes` (`int`): Total number of nodes in the model graph.
|
|
344
|
+
- `unique_ops` (`Set[Tuple[str, str]]`): Set of unique (op_type, domain) tuples.
|
|
345
|
+
- `supported_ops` (`List[Tuple[str, str]]`): List of supported (op_type, domain) tuples.
|
|
346
|
+
- `unsupported_ops` (`List[Tuple[str, str, int]]`): List of unsupported (op_type, domain, opset_version) tuples.
|
|
347
|
+
- `opset_versions` (`Dict[str, int]`): Mapping of domain to opset version.
|
|
348
|
+
- `op_counts` (`Dict[Tuple[str, str], int]`): Count of each (op_type, domain) in the model.
|
|
349
|
+
|
|
350
|
+
**Methods:**
|
|
351
|
+
- `is_fully_supported()`: Returns `True` if all operators are supported.
|
|
352
|
+
- `summary()`: Returns a human-readable summary string.
|
|
353
|
+
|
|
354
|
+
### Exceptions
|
|
355
|
+
|
|
356
|
+
- `Onnx2FxError`: Base exception for all onnx2fx errors.
|
|
357
|
+
- `UnsupportedOpError`: Raised when an operator is not supported.
|
|
358
|
+
- `ConversionError`: Raised when conversion fails.
|
|
359
|
+
- `ValueNotFoundError`: Raised when a value is not found in the environment.
|
|
360
|
+
|
|
361
|
+
## Development
|
|
362
|
+
|
|
363
|
+
### Running Tests
|
|
364
|
+
|
|
365
|
+
```bash
|
|
366
|
+
# Run all tests
|
|
367
|
+
pytest
|
|
368
|
+
|
|
369
|
+
# Run all tests in parallel for faster execution
|
|
370
|
+
pytest -n auto
|
|
371
|
+
|
|
372
|
+
# Run specific test file
|
|
373
|
+
pytest tests/test_activation.py
|
|
374
|
+
|
|
375
|
+
# Skip slow tests
|
|
376
|
+
pytest -m "not slow"
|
|
377
|
+
```
|
|
378
|
+
|
|
379
|
+
### Code Formatting
|
|
380
|
+
|
|
381
|
+
```bash
|
|
382
|
+
# Format code with ruff
|
|
383
|
+
ruff format .
|
|
384
|
+
|
|
385
|
+
# Check linting
|
|
386
|
+
ruff check .
|
|
387
|
+
```
|
|
388
|
+
|
|
389
|
+
## License
|
|
390
|
+
|
|
391
|
+
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
|
|
392
|
+
|
|
393
|
+
## Author
|
|
394
|
+
|
|
395
|
+
Masahiro Hiramori (contact@mshr-h.com)
|