onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2fx/__init__.py ADDED
@@ -0,0 +1,96 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """onnx2fx: Convert ONNX models to PyTorch FX GraphModules.
3
+
4
+ This library provides tools for converting ONNX models into PyTorch FX
5
+ GraphModules, enabling seamless integration with PyTorch's ecosystem for
6
+ optimization, analysis, and deployment.
7
+
8
+ Core Functions
9
+ --------------
10
+ convert : Convert an ONNX model to a PyTorch FX GraphModule.
11
+ make_trainable : Convert buffers to trainable parameters for training.
12
+
13
+ Model Analysis
14
+ --------------
15
+ analyze_model : Analyze an ONNX model for operator support.
16
+ AnalysisResult : Dataclass containing analysis results.
17
+
18
+ Operator Registration
19
+ ---------------------
20
+ register_op : Register a custom operator handler.
21
+ unregister_op : Unregister an operator handler.
22
+ is_supported : Check if an operator is supported.
23
+ get_supported_ops : List supported operators for a domain.
24
+ get_all_supported_ops : Get all supported operators across all domains.
25
+ get_registered_domains : Get list of registered domains.
26
+
27
+ Exceptions
28
+ ----------
29
+ Onnx2FxError : Base exception for all onnx2fx errors.
30
+ UnsupportedOpError : Raised when an operator is not supported.
31
+ ConversionError : Raised when conversion fails.
32
+ ValueNotFoundError : Raised when a value is not found in environment.
33
+
34
+ Example
35
+ -------
36
+ >>> import onnx
37
+ >>> from onnx2fx import convert
38
+ >>> model = onnx.load("model.onnx")
39
+ >>> fx_module = convert(model)
40
+ >>> # Use fx_module like any PyTorch module
41
+ >>> output = fx_module(input_tensor)
42
+ """
43
+
44
+ from importlib.metadata import PackageNotFoundError, version as _pkg_version
45
+
46
+ from .converter import convert
47
+ from .exceptions import (
48
+ Onnx2FxError,
49
+ UnsupportedOpError,
50
+ ConversionError,
51
+ ValueNotFoundError,
52
+ UnsupportedDTypeError,
53
+ ExternalDataError,
54
+ InferenceOnlyError,
55
+ )
56
+ from .op_registry import (
57
+ register_op,
58
+ unregister_op,
59
+ get_supported_ops,
60
+ get_all_supported_ops,
61
+ get_registered_domains,
62
+ is_supported,
63
+ )
64
+ from .utils.analyze import analyze_model, AnalysisResult
65
+ from .utils.training import make_trainable
66
+
67
+ try:
68
+ __version__ = _pkg_version("onnx2fx")
69
+ except PackageNotFoundError: # pragma: no cover
70
+ __version__ = "0.0.0"
71
+
72
+ __all__ = [
73
+ "__version__",
74
+ # Core API
75
+ "convert",
76
+ # Training utilities
77
+ "make_trainable",
78
+ # Model analysis
79
+ "analyze_model",
80
+ "AnalysisResult",
81
+ # Operator registration
82
+ "register_op",
83
+ "unregister_op",
84
+ "get_supported_ops",
85
+ "get_all_supported_ops",
86
+ "get_registered_domains",
87
+ "is_supported",
88
+ # Exceptions
89
+ "Onnx2FxError",
90
+ "UnsupportedOpError",
91
+ "ConversionError",
92
+ "ValueNotFoundError",
93
+ "UnsupportedDTypeError",
94
+ "ExternalDataError",
95
+ "InferenceOnlyError",
96
+ ]
onnx2fx/converter.py ADDED
@@ -0,0 +1,62 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Main entry point for converting ONNX models to PyTorch FX.
3
+
4
+ This module provides the primary `convert` function for transforming
5
+ ONNX models into equivalent PyTorch FX GraphModules.
6
+ """
7
+
8
+ import os
9
+ from typing import Optional, Union
10
+
11
+ import onnx
12
+ import torch
13
+
14
+ from .graph_builder import GraphBuilder
15
+ from .utils.external_data import validate_external_data_model
16
+
17
+
18
+ def convert(
19
+ model: Union[onnx.ModelProto, str],
20
+ *,
21
+ base_dir: Optional[str] = None,
22
+ memmap_external_data: bool = False,
23
+ ) -> torch.fx.GraphModule:
24
+ """Convert an ONNX model into a ``torch.fx.GraphModule``.
25
+
26
+ Parameters
27
+ ----------
28
+ model : Union[onnx.ModelProto, str]
29
+ Either an in-memory ``onnx.ModelProto`` or a file path to an ONNX model.
30
+ base_dir : Optional[str], optional
31
+ Base directory for resolving external data tensors. Required when
32
+ ``memmap_external_data=True`` and a relative external data path is used.
33
+ memmap_external_data : bool, optional
34
+ If True, do not load external data into memory. Instead, keep external
35
+ data references for memmap-based loading during conversion.
36
+
37
+ Returns
38
+ -------
39
+ torch.fx.GraphModule
40
+ A PyTorch FX Graph module.
41
+ """
42
+
43
+ if isinstance(model, str):
44
+ if base_dir is None:
45
+ base_dir = os.path.dirname(os.path.abspath(model))
46
+ if memmap_external_data:
47
+ model = onnx.load_model(model, load_external_data=False)
48
+ else:
49
+ model = onnx.load_model(model)
50
+ elif isinstance(model, onnx.ModelProto):
51
+ model = model
52
+ else:
53
+ raise TypeError("model must be a path or onnx.ModelProto")
54
+
55
+ if memmap_external_data:
56
+ validate_external_data_model(model, base_dir=base_dir, strict=True)
57
+
58
+ return GraphBuilder(
59
+ model,
60
+ base_dir=base_dir,
61
+ memmap_external_data=memmap_external_data,
62
+ ).build()
onnx2fx/exceptions.py ADDED
@@ -0,0 +1,155 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Custom exceptions for onnx2fx."""
3
+
4
+
5
+ class Onnx2FxError(Exception):
6
+ """Base exception for all onnx2fx errors.
7
+
8
+ This is the parent class for all custom exceptions in onnx2fx.
9
+ Users can catch this exception to handle any onnx2fx-related error.
10
+ """
11
+
12
+ pass
13
+
14
+
15
+ class UnsupportedOpError(Onnx2FxError):
16
+ """Raised when an ONNX operator is not supported.
17
+
18
+ This exception is raised during conversion when the converter
19
+ encounters an ONNX operator that has no registered handler.
20
+
21
+ Parameters
22
+ ----------
23
+ op_type : str
24
+ The unsupported ONNX operator type.
25
+ domain : str, optional
26
+ The ONNX domain of the operator.
27
+ opset_version : int, optional
28
+ The opset version being used.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ op_type: str,
34
+ domain: str = "",
35
+ opset_version: int | None = None,
36
+ ):
37
+ self.op_type = op_type
38
+ self.domain = domain
39
+ self.opset_version = opset_version
40
+
41
+ domain_str = f" (domain: {domain})" if domain else ""
42
+ version_str = f" at opset {opset_version}" if opset_version else ""
43
+ message = f"Unsupported ONNX operator: {op_type}{domain_str}{version_str}"
44
+ super().__init__(message)
45
+
46
+
47
+ class ConversionError(Onnx2FxError):
48
+ """Raised when conversion fails due to an error in the conversion process.
49
+
50
+ This exception is raised when the conversion process encounters
51
+ an error that prevents successful completion.
52
+
53
+ Parameters
54
+ ----------
55
+ message : str
56
+ A description of the conversion error.
57
+ node_name : str, optional
58
+ The name of the ONNX node where the error occurred.
59
+ op_type : str, optional
60
+ The ONNX operator type where the error occurred.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ message: str,
66
+ node_name: str | None = None,
67
+ op_type: str | None = None,
68
+ ):
69
+ self.node_name = node_name
70
+ self.op_type = op_type
71
+
72
+ context_parts = []
73
+ if node_name:
74
+ context_parts.append(f"node: {node_name}")
75
+ if op_type:
76
+ context_parts.append(f"op: {op_type}")
77
+
78
+ if context_parts:
79
+ context = f" [{', '.join(context_parts)}]"
80
+ else:
81
+ context = ""
82
+
83
+ full_message = f"Conversion failed{context}: {message}"
84
+ super().__init__(full_message)
85
+
86
+
87
+ class ValueNotFoundError(Onnx2FxError):
88
+ """Raised when a value is not found in the environment.
89
+
90
+ This exception is raised when trying to access a tensor value
91
+ that has not been defined in the conversion environment.
92
+
93
+ Parameters
94
+ ----------
95
+ name : str
96
+ The name of the value that was not found.
97
+ available : list[str], optional
98
+ List of available value names for debugging.
99
+ """
100
+
101
+ def __init__(self, name: str, available: list[str] | None = None):
102
+ self.name = name
103
+ self.available = available
104
+
105
+ message = f"Value '{name}' not found in environment"
106
+ if available:
107
+ message += f". Available: {available}"
108
+ super().__init__(message)
109
+
110
+
111
+ class UnsupportedDTypeError(Onnx2FxError):
112
+ """Raised when an ONNX tensor dtype is not supported.
113
+
114
+ Parameters
115
+ ----------
116
+ onnx_dtype : int
117
+ ONNX TensorProto data type enum value.
118
+ tensor_name : str
119
+ Name of the tensor.
120
+ details : str, optional
121
+ Additional details about the failure.
122
+ """
123
+
124
+ def __init__(self, onnx_dtype: int, tensor_name: str, details: str | None = None):
125
+ self.onnx_dtype = onnx_dtype
126
+ self.tensor_name = tensor_name
127
+ self.details = details
128
+
129
+ dtype_name = f"{onnx_dtype}"
130
+ try:
131
+ import onnx
132
+
133
+ dtype_name = onnx.TensorProto.DataType.Name(onnx_dtype)
134
+ except Exception:
135
+ pass
136
+
137
+ message = f"Unsupported dtype for tensor '{tensor_name}': {dtype_name}"
138
+ if details:
139
+ message += f" ({details})"
140
+ super().__init__(message)
141
+
142
+
143
+ class ExternalDataError(Onnx2FxError):
144
+ """Raised when external data metadata is invalid or inaccessible."""
145
+
146
+ def __init__(self, tensor_name: str, message: str):
147
+ self.tensor_name = tensor_name
148
+ super().__init__(f"External data error for '{tensor_name}': {message}")
149
+
150
+
151
+ class InferenceOnlyError(Onnx2FxError):
152
+ """Raised when an inference-only model is used for training."""
153
+
154
+ def __init__(self, message: str):
155
+ super().__init__(message)