onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/__init__.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""onnx2fx: Convert ONNX models to PyTorch FX GraphModules.
|
|
3
|
+
|
|
4
|
+
This library provides tools for converting ONNX models into PyTorch FX
|
|
5
|
+
GraphModules, enabling seamless integration with PyTorch's ecosystem for
|
|
6
|
+
optimization, analysis, and deployment.
|
|
7
|
+
|
|
8
|
+
Core Functions
|
|
9
|
+
--------------
|
|
10
|
+
convert : Convert an ONNX model to a PyTorch FX GraphModule.
|
|
11
|
+
make_trainable : Convert buffers to trainable parameters for training.
|
|
12
|
+
|
|
13
|
+
Model Analysis
|
|
14
|
+
--------------
|
|
15
|
+
analyze_model : Analyze an ONNX model for operator support.
|
|
16
|
+
AnalysisResult : Dataclass containing analysis results.
|
|
17
|
+
|
|
18
|
+
Operator Registration
|
|
19
|
+
---------------------
|
|
20
|
+
register_op : Register a custom operator handler.
|
|
21
|
+
unregister_op : Unregister an operator handler.
|
|
22
|
+
is_supported : Check if an operator is supported.
|
|
23
|
+
get_supported_ops : List supported operators for a domain.
|
|
24
|
+
get_all_supported_ops : Get all supported operators across all domains.
|
|
25
|
+
get_registered_domains : Get list of registered domains.
|
|
26
|
+
|
|
27
|
+
Exceptions
|
|
28
|
+
----------
|
|
29
|
+
Onnx2FxError : Base exception for all onnx2fx errors.
|
|
30
|
+
UnsupportedOpError : Raised when an operator is not supported.
|
|
31
|
+
ConversionError : Raised when conversion fails.
|
|
32
|
+
ValueNotFoundError : Raised when a value is not found in environment.
|
|
33
|
+
|
|
34
|
+
Example
|
|
35
|
+
-------
|
|
36
|
+
>>> import onnx
|
|
37
|
+
>>> from onnx2fx import convert
|
|
38
|
+
>>> model = onnx.load("model.onnx")
|
|
39
|
+
>>> fx_module = convert(model)
|
|
40
|
+
>>> # Use fx_module like any PyTorch module
|
|
41
|
+
>>> output = fx_module(input_tensor)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
from importlib.metadata import PackageNotFoundError, version as _pkg_version
|
|
45
|
+
|
|
46
|
+
from .converter import convert
|
|
47
|
+
from .exceptions import (
|
|
48
|
+
Onnx2FxError,
|
|
49
|
+
UnsupportedOpError,
|
|
50
|
+
ConversionError,
|
|
51
|
+
ValueNotFoundError,
|
|
52
|
+
UnsupportedDTypeError,
|
|
53
|
+
ExternalDataError,
|
|
54
|
+
InferenceOnlyError,
|
|
55
|
+
)
|
|
56
|
+
from .op_registry import (
|
|
57
|
+
register_op,
|
|
58
|
+
unregister_op,
|
|
59
|
+
get_supported_ops,
|
|
60
|
+
get_all_supported_ops,
|
|
61
|
+
get_registered_domains,
|
|
62
|
+
is_supported,
|
|
63
|
+
)
|
|
64
|
+
from .utils.analyze import analyze_model, AnalysisResult
|
|
65
|
+
from .utils.training import make_trainable
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
__version__ = _pkg_version("onnx2fx")
|
|
69
|
+
except PackageNotFoundError: # pragma: no cover
|
|
70
|
+
__version__ = "0.0.0"
|
|
71
|
+
|
|
72
|
+
__all__ = [
|
|
73
|
+
"__version__",
|
|
74
|
+
# Core API
|
|
75
|
+
"convert",
|
|
76
|
+
# Training utilities
|
|
77
|
+
"make_trainable",
|
|
78
|
+
# Model analysis
|
|
79
|
+
"analyze_model",
|
|
80
|
+
"AnalysisResult",
|
|
81
|
+
# Operator registration
|
|
82
|
+
"register_op",
|
|
83
|
+
"unregister_op",
|
|
84
|
+
"get_supported_ops",
|
|
85
|
+
"get_all_supported_ops",
|
|
86
|
+
"get_registered_domains",
|
|
87
|
+
"is_supported",
|
|
88
|
+
# Exceptions
|
|
89
|
+
"Onnx2FxError",
|
|
90
|
+
"UnsupportedOpError",
|
|
91
|
+
"ConversionError",
|
|
92
|
+
"ValueNotFoundError",
|
|
93
|
+
"UnsupportedDTypeError",
|
|
94
|
+
"ExternalDataError",
|
|
95
|
+
"InferenceOnlyError",
|
|
96
|
+
]
|
onnx2fx/converter.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Main entry point for converting ONNX models to PyTorch FX.
|
|
3
|
+
|
|
4
|
+
This module provides the primary `convert` function for transforming
|
|
5
|
+
ONNX models into equivalent PyTorch FX GraphModules.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from typing import Optional, Union
|
|
10
|
+
|
|
11
|
+
import onnx
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from .graph_builder import GraphBuilder
|
|
15
|
+
from .utils.external_data import validate_external_data_model
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def convert(
|
|
19
|
+
model: Union[onnx.ModelProto, str],
|
|
20
|
+
*,
|
|
21
|
+
base_dir: Optional[str] = None,
|
|
22
|
+
memmap_external_data: bool = False,
|
|
23
|
+
) -> torch.fx.GraphModule:
|
|
24
|
+
"""Convert an ONNX model into a ``torch.fx.GraphModule``.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
model : Union[onnx.ModelProto, str]
|
|
29
|
+
Either an in-memory ``onnx.ModelProto`` or a file path to an ONNX model.
|
|
30
|
+
base_dir : Optional[str], optional
|
|
31
|
+
Base directory for resolving external data tensors. Required when
|
|
32
|
+
``memmap_external_data=True`` and a relative external data path is used.
|
|
33
|
+
memmap_external_data : bool, optional
|
|
34
|
+
If True, do not load external data into memory. Instead, keep external
|
|
35
|
+
data references for memmap-based loading during conversion.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
torch.fx.GraphModule
|
|
40
|
+
A PyTorch FX Graph module.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
if isinstance(model, str):
|
|
44
|
+
if base_dir is None:
|
|
45
|
+
base_dir = os.path.dirname(os.path.abspath(model))
|
|
46
|
+
if memmap_external_data:
|
|
47
|
+
model = onnx.load_model(model, load_external_data=False)
|
|
48
|
+
else:
|
|
49
|
+
model = onnx.load_model(model)
|
|
50
|
+
elif isinstance(model, onnx.ModelProto):
|
|
51
|
+
model = model
|
|
52
|
+
else:
|
|
53
|
+
raise TypeError("model must be a path or onnx.ModelProto")
|
|
54
|
+
|
|
55
|
+
if memmap_external_data:
|
|
56
|
+
validate_external_data_model(model, base_dir=base_dir, strict=True)
|
|
57
|
+
|
|
58
|
+
return GraphBuilder(
|
|
59
|
+
model,
|
|
60
|
+
base_dir=base_dir,
|
|
61
|
+
memmap_external_data=memmap_external_data,
|
|
62
|
+
).build()
|
onnx2fx/exceptions.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Custom exceptions for onnx2fx."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Onnx2FxError(Exception):
|
|
6
|
+
"""Base exception for all onnx2fx errors.
|
|
7
|
+
|
|
8
|
+
This is the parent class for all custom exceptions in onnx2fx.
|
|
9
|
+
Users can catch this exception to handle any onnx2fx-related error.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UnsupportedOpError(Onnx2FxError):
|
|
16
|
+
"""Raised when an ONNX operator is not supported.
|
|
17
|
+
|
|
18
|
+
This exception is raised during conversion when the converter
|
|
19
|
+
encounters an ONNX operator that has no registered handler.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
op_type : str
|
|
24
|
+
The unsupported ONNX operator type.
|
|
25
|
+
domain : str, optional
|
|
26
|
+
The ONNX domain of the operator.
|
|
27
|
+
opset_version : int, optional
|
|
28
|
+
The opset version being used.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
op_type: str,
|
|
34
|
+
domain: str = "",
|
|
35
|
+
opset_version: int | None = None,
|
|
36
|
+
):
|
|
37
|
+
self.op_type = op_type
|
|
38
|
+
self.domain = domain
|
|
39
|
+
self.opset_version = opset_version
|
|
40
|
+
|
|
41
|
+
domain_str = f" (domain: {domain})" if domain else ""
|
|
42
|
+
version_str = f" at opset {opset_version}" if opset_version else ""
|
|
43
|
+
message = f"Unsupported ONNX operator: {op_type}{domain_str}{version_str}"
|
|
44
|
+
super().__init__(message)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ConversionError(Onnx2FxError):
|
|
48
|
+
"""Raised when conversion fails due to an error in the conversion process.
|
|
49
|
+
|
|
50
|
+
This exception is raised when the conversion process encounters
|
|
51
|
+
an error that prevents successful completion.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
message : str
|
|
56
|
+
A description of the conversion error.
|
|
57
|
+
node_name : str, optional
|
|
58
|
+
The name of the ONNX node where the error occurred.
|
|
59
|
+
op_type : str, optional
|
|
60
|
+
The ONNX operator type where the error occurred.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
message: str,
|
|
66
|
+
node_name: str | None = None,
|
|
67
|
+
op_type: str | None = None,
|
|
68
|
+
):
|
|
69
|
+
self.node_name = node_name
|
|
70
|
+
self.op_type = op_type
|
|
71
|
+
|
|
72
|
+
context_parts = []
|
|
73
|
+
if node_name:
|
|
74
|
+
context_parts.append(f"node: {node_name}")
|
|
75
|
+
if op_type:
|
|
76
|
+
context_parts.append(f"op: {op_type}")
|
|
77
|
+
|
|
78
|
+
if context_parts:
|
|
79
|
+
context = f" [{', '.join(context_parts)}]"
|
|
80
|
+
else:
|
|
81
|
+
context = ""
|
|
82
|
+
|
|
83
|
+
full_message = f"Conversion failed{context}: {message}"
|
|
84
|
+
super().__init__(full_message)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ValueNotFoundError(Onnx2FxError):
|
|
88
|
+
"""Raised when a value is not found in the environment.
|
|
89
|
+
|
|
90
|
+
This exception is raised when trying to access a tensor value
|
|
91
|
+
that has not been defined in the conversion environment.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
name : str
|
|
96
|
+
The name of the value that was not found.
|
|
97
|
+
available : list[str], optional
|
|
98
|
+
List of available value names for debugging.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, name: str, available: list[str] | None = None):
|
|
102
|
+
self.name = name
|
|
103
|
+
self.available = available
|
|
104
|
+
|
|
105
|
+
message = f"Value '{name}' not found in environment"
|
|
106
|
+
if available:
|
|
107
|
+
message += f". Available: {available}"
|
|
108
|
+
super().__init__(message)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class UnsupportedDTypeError(Onnx2FxError):
|
|
112
|
+
"""Raised when an ONNX tensor dtype is not supported.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
onnx_dtype : int
|
|
117
|
+
ONNX TensorProto data type enum value.
|
|
118
|
+
tensor_name : str
|
|
119
|
+
Name of the tensor.
|
|
120
|
+
details : str, optional
|
|
121
|
+
Additional details about the failure.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, onnx_dtype: int, tensor_name: str, details: str | None = None):
|
|
125
|
+
self.onnx_dtype = onnx_dtype
|
|
126
|
+
self.tensor_name = tensor_name
|
|
127
|
+
self.details = details
|
|
128
|
+
|
|
129
|
+
dtype_name = f"{onnx_dtype}"
|
|
130
|
+
try:
|
|
131
|
+
import onnx
|
|
132
|
+
|
|
133
|
+
dtype_name = onnx.TensorProto.DataType.Name(onnx_dtype)
|
|
134
|
+
except Exception:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
message = f"Unsupported dtype for tensor '{tensor_name}': {dtype_name}"
|
|
138
|
+
if details:
|
|
139
|
+
message += f" ({details})"
|
|
140
|
+
super().__init__(message)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class ExternalDataError(Onnx2FxError):
|
|
144
|
+
"""Raised when external data metadata is invalid or inaccessible."""
|
|
145
|
+
|
|
146
|
+
def __init__(self, tensor_name: str, message: str):
|
|
147
|
+
self.tensor_name = tensor_name
|
|
148
|
+
super().__init__(f"External data error for '{tensor_name}': {message}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class InferenceOnlyError(Onnx2FxError):
|
|
152
|
+
"""Raised when an inference-only model is used for training."""
|
|
153
|
+
|
|
154
|
+
def __init__(self, message: str):
|
|
155
|
+
super().__init__(message)
|