onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,395 @@
1
+ Metadata-Version: 2.3
2
+ Name: onnx2fx
3
+ Version: 0.0.0
4
+ Summary: Yet another ONNX to PyTorch FX converter
5
+ Author: Masahiro Hiramori
6
+ Author-email: Masahiro Hiramori <contact@mshr-h.com>
7
+ Requires-Dist: onnx>=1.19.1
8
+ Requires-Dist: torch>=2.9.0
9
+ Requires-Python: >=3.11
10
+ Description-Content-Type: text/markdown
11
+
12
+ # onnx2fx
13
+
14
+ Yet another ONNX to PyTorch FX converter.
15
+
16
+ > **⚠️ Note:** This project is under active development. The public API may change at any time.
17
+
18
+ `onnx2fx` converts ONNX models into PyTorch FX `GraphModule`s, enabling seamless integration with PyTorch's ecosystem for optimization, analysis, and deployment.
19
+
20
+ ## Features
21
+
22
+ - **Simple API**: Convert ONNX models with a single function call
23
+ - **Extensive Operator Support**: Wide ONNX operator coverage including standard and Microsoft domain operators
24
+ - **Multi-Opset Version Support**: Automatic selection of version-specific operator handlers based on model opset
25
+ - **Custom Operator Registration**: Easily extend support for unsupported or custom ONNX operators
26
+ - **PyTorch FX Output**: Get a `torch.fx.GraphModule` for easy inspection, optimization, and compilation
27
+ - **Dynamic Shape Support**: Handle models with dynamic input dimensions
28
+ - **Quantization Support**: Support for quantized operators (QLinear*, DequantizeLinear, etc.)
29
+ - **Training Support**: Convert models to trainable modules with `make_trainable()` utility
30
+
31
+ ## Tested Models
32
+
33
+ The following models have been tested and verified to work with onnx2fx:
34
+
35
+ - **PaddleOCRv5**: Text detection and recognition models (mobile and server variants)
36
+ - PP-OCRv5_mobile_det, PP-OCRv5_mobile_rec
37
+ - PP-OCRv5_server_det, PP-OCRv5_server_rec
38
+ - **TorchVision Models**: ResNet, VGG, MobileNet, etc. (via ONNX export)
39
+ - **LFM2**: Liquid Foundation Model (LFM2-350M-ENJP-MT)
40
+ - **LFM2.5**: Liquid Foundation Model 2.5
41
+ - **TinyLlama**: TinyLlama-1.1B-Chat
42
+
43
+ ## Installation
44
+
45
+ ### Requirements
46
+
47
+ - Python >= 3.11
48
+ - PyTorch >= 2.9.0
49
+ - ONNX >= 1.19.1
50
+
51
+ ### From Source
52
+
53
+ ```bash
54
+ git clone https://github.com/mshr-h/onnx2fx.git
55
+ cd onnx2fx
56
+ uv sync
57
+ ```
58
+
59
+ ### Development Installation
60
+
61
+ ```bash
62
+ git clone https://github.com/mshr-h/onnx2fx.git
63
+ cd onnx2fx
64
+ uv sync --dev
65
+ ```
66
+
67
+ ## Quick Start
68
+
69
+ ### Basic Conversion
70
+
71
+ ```python
72
+ import torch
73
+ import onnx
74
+ from onnx2fx import convert
75
+
76
+ # Load from file path
77
+ fx_module = convert("model.onnx")
78
+
79
+ # Or from onnx.ModelProto
80
+ onnx_model = onnx.load("model.onnx")
81
+ fx_module = convert(onnx_model)
82
+
83
+ # For models with external data, you can pass base_dir.
84
+ # memmap_external_data avoids loading external data into memory.
85
+ fx_module = convert("model.onnx", base_dir="/path/to/model_dir", memmap_external_data=True)
86
+
87
+ # Run inference
88
+ input_tensor = torch.randn(1, 3, 224, 224)
89
+ output = fx_module(input_tensor)
90
+ ```
91
+
92
+ ### Inspecting the Converted Graph
93
+
94
+ ```python
95
+ from onnx2fx import convert
96
+
97
+ fx_module = convert("model.onnx")
98
+
99
+ # Print the FX graph
100
+ print(fx_module.graph)
101
+
102
+ # Get the graph code
103
+ print(fx_module.code)
104
+ ```
105
+
106
+ ### Registering Custom Operators
107
+
108
+ For unsupported or custom ONNX operators, you can register your own handlers:
109
+
110
+ ```python
111
+ import torch
112
+ from onnx2fx import convert, register_op
113
+
114
+ # Using decorator
115
+ @register_op("MyCustomOp")
116
+ def my_custom_op(builder, node):
117
+ x = builder.get_value(node.input[0])
118
+ return builder.call_function(torch.sigmoid, args=(x,))
119
+
120
+ # Or register directly
121
+ def my_handler(builder, node):
122
+ x = builder.get_value(node.input[0])
123
+ return builder.call_function(torch.tanh, args=(x,))
124
+
125
+ register_op("TanhCustom", my_handler)
126
+
127
+ # For custom domains (e.g., Microsoft operators)
128
+ @register_op("BiasGelu", domain="com.microsoft")
129
+ def bias_gelu(builder, node):
130
+ x = builder.get_value(node.input[0])
131
+ bias = builder.get_value(node.input[1])
132
+ return builder.call_function(
133
+ lambda t, b: torch.nn.functional.gelu(t + b),
134
+ args=(x, bias)
135
+ )
136
+ ```
137
+
138
+
139
+ > Note: `ai.onnx.ml` is treated as a distinct domain. If you register or query
140
+ > operators in that domain, pass `domain="ai.onnx.ml"` explicitly.
141
+
142
+ ### Multi-Opset Version Support
143
+
144
+ The library automatically selects the appropriate operator handler based on the model's opset version. For operators with version-specific behavior (e.g., `Softmax` changed default axis in opset 13), the correct implementation is used automatically:
145
+
146
+ ```python
147
+ from onnx2fx import convert
148
+
149
+ # Models with different opset versions are handled automatically
150
+ fx_module_v11 = convert("model_opset11.onnx") # Uses opset 11 semantics
151
+ fx_module_v17 = convert("model_opset17.onnx") # Uses opset 17 semantics
152
+ ```
153
+
154
+ ### Training Converted Models
155
+
156
+ By default, ONNX weights are loaded as non-trainable buffers. Use `make_trainable()` to enable training:
157
+
158
+ ```python
159
+ import torch
160
+ from onnx2fx import convert, make_trainable
161
+
162
+ # Convert and make trainable
163
+ fx_module = convert("model.onnx")
164
+ fx_module = make_trainable(fx_module) # Convert buffers to trainable parameters
165
+
166
+ # Now you can train the model
167
+ optimizer = torch.optim.Adam(fx_module.parameters(), lr=1e-4)
168
+ criterion = torch.nn.CrossEntropyLoss()
169
+
170
+ for inputs, targets in dataloader:
171
+ optimizer.zero_grad()
172
+ outputs = fx_module(inputs)
173
+ loss = criterion(outputs, targets)
174
+ loss.backward()
175
+ optimizer.step()
176
+ ```
177
+
178
+ ### Querying Supported Operators
179
+
180
+ ```python
181
+ from onnx2fx import (
182
+ get_supported_ops,
183
+ get_all_supported_ops,
184
+ get_registered_domains,
185
+ is_supported,
186
+ )
187
+
188
+ # Check if an operator is supported
189
+ print(is_supported("Conv")) # True
190
+ print(is_supported("BiasGelu", domain="com.microsoft")) # True
191
+
192
+ # Get all operators for a domain
193
+ standard_ops = get_supported_ops() # Default ONNX domain
194
+ microsoft_ops = get_supported_ops("com.microsoft")
195
+
196
+ # Get all operators across all domains
197
+ all_ops = get_all_supported_ops()
198
+
199
+ # Get registered domains
200
+ domains = get_registered_domains() # ['', 'com.microsoft']
201
+ ```
202
+
203
+ ### Analyzing Model Compatibility
204
+
205
+ Before converting, you can analyze a model to check operator support:
206
+
207
+ ```python
208
+ from onnx2fx import analyze_model
209
+
210
+ # Analyze an ONNX model
211
+ result = analyze_model("model.onnx")
212
+
213
+ # Check results
214
+ print(f"Supported operators: {result.supported_ops}")
215
+ print(f"Unsupported operators: {result.unsupported_ops}")
216
+ print(f"Is fully supported: {result.is_fully_supported()}")
217
+
218
+ # Get detailed summary
219
+ print(result.summary())
220
+ ```
221
+
222
+ ### Exception Handling
223
+
224
+ Handle conversion errors gracefully:
225
+
226
+ ```python
227
+ from onnx2fx import (
228
+ convert,
229
+ Onnx2FxError,
230
+ UnsupportedOpError,
231
+ ConversionError,
232
+ )
233
+
234
+ try:
235
+ fx_module = convert("model.onnx")
236
+ except UnsupportedOpError as e:
237
+ print(f"Unsupported operator: {e}")
238
+ except ConversionError as e:
239
+ print(f"Conversion failed: {e}")
240
+ except Onnx2FxError as e:
241
+ print(f"onnx2fx error: {e}")
242
+ ```
243
+
244
+ ## Supported Operators
245
+
246
+ ### Standard ONNX Domain
247
+
248
+ This is a short list of representative operators. For the full list, call
249
+ `get_supported_ops()` or `get_all_supported_ops()`.
250
+
251
+ - **Core tensor & shape**: Reshape, Transpose, Concat, Split, Slice, Gather, Pad, Resize, Shape, Cast
252
+ - **Math & activations**: Add, Mul, MatMul, Gemm, Relu, Gelu, SiLU, Softmax, LogSoftmax
253
+ - **Normalization & pooling**: BatchNormalization, LayerNormalization, InstanceNormalization, GroupNormalization, MaxPool, AveragePool, GlobalAveragePool
254
+ - **Reductions & indexing**: ReduceSum, ReduceMean, ArgMax, ArgMin, TopK
255
+ - **Control flow & sequence**: If, Loop, SequenceConstruct, SplitToSequence, ConcatFromSequence
256
+ - **Quantization**: QuantizeLinear, DequantizeLinear, QLinearConv, QLinearMatMul
257
+ - **Other**: Einsum, NonMaxSuppression, StringNormalizer
258
+
259
+ #### Attention & Normalization Extensions
260
+ - Attention (opset 24+)
261
+ - RotaryEmbedding (opset 23+)
262
+ - GroupQueryAttention
263
+ - EmbedLayerNormalization
264
+ - SkipLayerNormalization
265
+ - SimplifiedLayerNormalization
266
+ - SkipSimplifiedLayerNormalization
267
+
268
+ ### Microsoft Domain (`com.microsoft`)
269
+
270
+ > Note: Some operators are available in both the standard and Microsoft domains (e.g., Attention, RotaryEmbedding, SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention, SkipLayerNormalization, EmbedLayerNormalization).
271
+
272
+ - Attention
273
+ - RotaryEmbedding
274
+ - SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization
275
+ - SkipLayerNormalization, EmbedLayerNormalization
276
+ - GroupQueryAttention
277
+
278
+ ## API Reference
279
+
280
+ ### `convert(model)`
281
+
282
+ Converts an ONNX model to a PyTorch FX `GraphModule`.
283
+
284
+ **Parameters:**
285
+ - `model` (`Union[onnx.ModelProto, str]`): Either an in-memory `onnx.ModelProto` or a file path to an ONNX model.
286
+
287
+ **Returns:**
288
+ - `torch.fx.GraphModule`: A PyTorch FX Graph module.
289
+
290
+ ### `register_op(op_type, handler=None, domain="", since_version=1)`
291
+
292
+ Register a custom ONNX operator handler.
293
+
294
+ **Parameters:**
295
+ - `op_type` (`str`): The ONNX operator type name.
296
+ - `handler` (`OpHandler`, optional): The handler function. If not provided, returns a decorator.
297
+ - `domain` (`str`, optional): The ONNX domain. Default is "" (standard ONNX domain).
298
+ - `since_version` (`int`, optional): The minimum opset version for this handler. Default is 1.
299
+
300
+ ### `unregister_op(op_type, domain="", since_version=None)`
301
+
302
+ Unregister an operator handler.
303
+
304
+ **Parameters:**
305
+ - `op_type` (`str`): The ONNX operator type name.
306
+ - `domain` (`str`, optional): The ONNX domain.
307
+ - `since_version` (`int`, optional): The specific opset handler to remove. If None, removes all versions.
308
+
309
+ **Returns:**
310
+ - `bool`: True if the operator was unregistered.
311
+
312
+ ### `is_supported(op_type, domain="")`
313
+
314
+ Check if an operator is supported.
315
+
316
+ ### `get_supported_ops(domain="")`
317
+
318
+ Get list of supported ONNX operators for a domain.
319
+
320
+ ### `get_all_supported_ops()`
321
+
322
+ Get all supported operators across all domains.
323
+
324
+ ### `get_registered_domains()`
325
+
326
+ Get list of registered domains.
327
+
328
+ ### `analyze_model(model)`
329
+
330
+ Analyze an ONNX model for operator support.
331
+
332
+ **Parameters:**
333
+ - `model` (`Union[onnx.ModelProto, str]`): Either an in-memory `onnx.ModelProto` or a file path.
334
+
335
+ **Returns:**
336
+ - `AnalysisResult`: Analysis results with supported/unsupported operators.
337
+
338
+ ### `AnalysisResult`
339
+
340
+ Dataclass containing model analysis results.
341
+
342
+ **Attributes:**
343
+ - `total_nodes` (`int`): Total number of nodes in the model graph.
344
+ - `unique_ops` (`Set[Tuple[str, str]]`): Set of unique (op_type, domain) tuples.
345
+ - `supported_ops` (`List[Tuple[str, str]]`): List of supported (op_type, domain) tuples.
346
+ - `unsupported_ops` (`List[Tuple[str, str, int]]`): List of unsupported (op_type, domain, opset_version) tuples.
347
+ - `opset_versions` (`Dict[str, int]`): Mapping of domain to opset version.
348
+ - `op_counts` (`Dict[Tuple[str, str], int]`): Count of each (op_type, domain) in the model.
349
+
350
+ **Methods:**
351
+ - `is_fully_supported()`: Returns `True` if all operators are supported.
352
+ - `summary()`: Returns a human-readable summary string.
353
+
354
+ ### Exceptions
355
+
356
+ - `Onnx2FxError`: Base exception for all onnx2fx errors.
357
+ - `UnsupportedOpError`: Raised when an operator is not supported.
358
+ - `ConversionError`: Raised when conversion fails.
359
+ - `ValueNotFoundError`: Raised when a value is not found in the environment.
360
+
361
+ ## Development
362
+
363
+ ### Running Tests
364
+
365
+ ```bash
366
+ # Run all tests
367
+ pytest
368
+
369
+ # Run all tests in parallel for faster execution
370
+ pytest -n auto
371
+
372
+ # Run specific test file
373
+ pytest tests/test_activation.py
374
+
375
+ # Skip slow tests
376
+ pytest -m "not slow"
377
+ ```
378
+
379
+ ### Code Formatting
380
+
381
+ ```bash
382
+ # Format code with ruff
383
+ ruff format .
384
+
385
+ # Check linting
386
+ ruff check .
387
+ ```
388
+
389
+ ## License
390
+
391
+ This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
392
+
393
+ ## Author
394
+
395
+ Masahiro Hiramori (contact@mshr-h.com)
@@ -0,0 +1,39 @@
1
+ onnx2fx/__init__.py,sha256=37fk-9-fKkWMQ4s1a5wypMyMiWGW8H7KO0lGqWPCbuk,2671
2
+ onnx2fx/converter.py,sha256=tIajqjlAdS7KvOJJqrMGvGeyTjKkEAWt9tg4BL0dcHw,1923
3
+ onnx2fx/exceptions.py,sha256=6XoXiRj3mUcQR3p7SCUvqCc2xfENTaJuUliP0W1Zido,4475
4
+ onnx2fx/graph_builder.py,sha256=68710FjVli6XiEBRC_9aHbGv_3yba2AH-43_2zCIlqg,22206
5
+ onnx2fx/op_registry.py,sha256=505TptHDdQheoxl5xYe1pKLOD8TCS8xMmkGoVKmJkjA,10474
6
+ onnx2fx/ops/__init__.py,sha256=ef-rzpJpzetH4q5hB9vbNIldYma8ibxP5Wzy2WQN1qc,2312
7
+ onnx2fx/ops/activation.py,sha256=pgSDFSlRe8AiAQTk4fGQPcBZPWO-prR8SIknm2HJSvU,8900
8
+ onnx2fx/ops/arithmetic.py,sha256=d3yvhjjeUiv7Szm3ez4IQ29yC5rqgCTsO00H12eCy90,10537
9
+ onnx2fx/ops/attention.py,sha256=yjmdDNResnHem4r50kZOYYIAtYXubdNXoo1655crh_U,39641
10
+ onnx2fx/ops/attention_msft.py,sha256=oRdOOR_k2O-sDK94hdMUZbJDgMAiW6fuYeoHNg-JKaw,24927
11
+ onnx2fx/ops/control_flow.py,sha256=dVMqLKmlTLzW-i172Bqk1Gy_efrGEE0OTGTnydPRrSg,34259
12
+ onnx2fx/ops/convolution.py,sha256=12jSzfxtTEOEg1D3lTwPs-JxY3GDqQdPMHl1UmK1rIk,14714
13
+ onnx2fx/ops/image.py,sha256=6C0wGI7E_wrO_zCHWEAXuQzdS0UTAzjLNUM-PzfHIbA,27027
14
+ onnx2fx/ops/linalg.py,sha256=wkPwthk-m38-1Wx6AZuD4XPY8X2bBILfgU4qlUIY-Y0,999
15
+ onnx2fx/ops/loss.py,sha256=9Ay0SmvUAqgOOlqdu1BeIvtLlfqMhSkaRbXwU0zp0HA,1718
16
+ onnx2fx/ops/nn.py,sha256=qmi2rsRei0wnro8ahKG7_jaCowNKxcmlhjpBUNdnYGo,3273
17
+ onnx2fx/ops/normalization.py,sha256=2DxrGNRimu2sFiNyJki_SazsKeZDIFOtBrjYpZdz7x4,9664
18
+ onnx2fx/ops/pooling.py,sha256=nw6KB_Fy94fk15U2EIEn2aGepZ2DrkQGAKJAVA4L0aE,34075
19
+ onnx2fx/ops/quantization.py,sha256=E6X2EYNQe7HZBlZmfFeconjJKCwQ_XX1ga8ZHfmx7tM,16227
20
+ onnx2fx/ops/random.py,sha256=_Cl1fX_rAwzb3F_7fXJnepgmpHLqGrYLMYjN5XasUlE,3470
21
+ onnx2fx/ops/recurrent.py,sha256=jXwYLVc0prOOHwEnplfzTPaTXu4UMZfklq6x1-WsTTc,23228
22
+ onnx2fx/ops/reduction.py,sha256=LfVd4CREg-aLctTW-j-h7vKf5zo2d_GcZOI1s2bAHns,19297
23
+ onnx2fx/ops/sequence.py,sha256=upmOKmaj1ijznVhZx22ZG4uHDdX8oNcxBs4JJUaQJr8,10965
24
+ onnx2fx/ops/signal.py,sha256=1RXlyWi0P_h7c3YE11rMRZ9xNJttGUjJv3A-odstEls,15684
25
+ onnx2fx/ops/string.py,sha256=b7TOA5F0_HxVnCIl9CS_EHeMT5NoCUN2P5Mo_dHwqrM,4310
26
+ onnx2fx/ops/tensor.py,sha256=Cnb-kxEn_Z4akYuynWYXhuf-rAhCozHoGteVYxZI2j8,41046
27
+ onnx2fx/ops/training.py,sha256=H4_S8UcuKB6T9JQuSeOcaoMX4ptNF4TFdwJAB8H6ftE,13505
28
+ onnx2fx/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
+ onnx2fx/utils/__init__.py,sha256=clsTEZCqSVqvCAuCgPspVeSsOy8drh6qEmB35IaSHFI,1108
30
+ onnx2fx/utils/analyze.py,sha256=9zgEnShu1NLjKM34kzIA89Dt-p3FOeJLkvD9MybUS_g,5012
31
+ onnx2fx/utils/attributes.py,sha256=wJxLzQ7Mw2ydOVCPDTkzX4l2H8iX6szCxXtJHz4vy88,4095
32
+ onnx2fx/utils/dtype.py,sha256=kr5B0q2rNBQDcSKWn7li2QPvPTHTij5Nv3rWcfWp22Q,2910
33
+ onnx2fx/utils/external_data.py,sha256=13Bfe34Pta79Hsf_ANilWL2N1WIXJoklTZelALNCTsA,7192
34
+ onnx2fx/utils/names.py,sha256=PYga5tQf6GqUuG78LVJoKvZgum-HYQUyAVjVfdNgcw0,1193
35
+ onnx2fx/utils/op_helpers.py,sha256=eMyxXWePTmT-MP61KIK4OwHulPpA3hpFmUmEJUazCzU,10950
36
+ onnx2fx/utils/training.py,sha256=tli5epji8O5ezDVxNwG-VB31LoK-IRuwlaGCEtFbJ98,1953
37
+ onnx2fx-0.0.0.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
38
+ onnx2fx-0.0.0.dist-info/METADATA,sha256=hUC0akg8XlIld3T1vkqcQfnlR1Zi40zwZsiONkyS5IM,11397
39
+ onnx2fx-0.0.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: uv 0.8.24
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any