onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,150 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Utilities for parsing ONNX node attributes."""
3
+
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
5
+
6
+ import onnx
7
+
8
+ from ..exceptions import UnsupportedDTypeError
9
+ from .dtype import DTYPE_MAP
10
+
11
+ if TYPE_CHECKING:
12
+ import torch
13
+
14
+
15
+ def get_attribute(
16
+ node: onnx.NodeProto,
17
+ name: str,
18
+ default: Optional[Any] = None,
19
+ *,
20
+ tensor_loader: Optional[Callable[[onnx.TensorProto], "torch.Tensor"]] = None,
21
+ ) -> Any:
22
+ """Get a single attribute value from an ONNX node.
23
+
24
+ Parameters
25
+ ----------
26
+ node : onnx.NodeProto
27
+ The ONNX node.
28
+ name : str
29
+ The attribute name.
30
+ default : Optional[Any]
31
+ Default value if attribute is not found.
32
+
33
+ Returns
34
+ -------
35
+ Any
36
+ The attribute value, or default if not found.
37
+ """
38
+ for attr in node.attribute:
39
+ if attr.name == name:
40
+ return _parse_attribute_value(attr, tensor_loader=tensor_loader)
41
+ return default
42
+
43
+
44
+ def get_attributes(node: onnx.NodeProto) -> Dict[str, Any]:
45
+ """Get all attributes from an ONNX node as a dictionary.
46
+
47
+ Parameters
48
+ ----------
49
+ node : onnx.NodeProto
50
+ The ONNX node.
51
+
52
+ Returns
53
+ -------
54
+ Dict[str, Any]
55
+ Dictionary mapping attribute names to their values.
56
+ """
57
+ return {attr.name: _parse_attribute_value(attr) for attr in node.attribute}
58
+
59
+
60
+ def _parse_attribute_value(
61
+ attr: onnx.AttributeProto,
62
+ *,
63
+ tensor_loader: Optional[Callable[[onnx.TensorProto], "torch.Tensor"]] = None,
64
+ ) -> Any:
65
+ """Parse an ONNX attribute into a Python value.
66
+
67
+ Parameters
68
+ ----------
69
+ attr : onnx.AttributeProto
70
+ The ONNX attribute.
71
+
72
+ Returns
73
+ -------
74
+ Any
75
+ The parsed Python value.
76
+ """
77
+ match attr.type:
78
+ case onnx.AttributeProto.FLOAT:
79
+ return attr.f
80
+ case onnx.AttributeProto.INT:
81
+ return attr.i
82
+ case onnx.AttributeProto.STRING:
83
+ return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
84
+ case onnx.AttributeProto.TENSOR:
85
+ return _parse_tensor(attr.t, tensor_loader=tensor_loader)
86
+ case onnx.AttributeProto.GRAPH:
87
+ return attr.g
88
+ case onnx.AttributeProto.FLOATS:
89
+ return list(attr.floats)
90
+ case onnx.AttributeProto.INTS:
91
+ return list(attr.ints)
92
+ case onnx.AttributeProto.STRINGS:
93
+ return [
94
+ s.decode("utf-8") if isinstance(s, bytes) else s for s in attr.strings
95
+ ]
96
+ case onnx.AttributeProto.TENSORS:
97
+ return [_parse_tensor(t, tensor_loader=tensor_loader) for t in attr.tensors]
98
+ case onnx.AttributeProto.GRAPHS:
99
+ return list(attr.graphs)
100
+ case onnx.AttributeProto.SPARSE_TENSOR:
101
+ return attr.sparse_tensor
102
+ case onnx.AttributeProto.SPARSE_TENSORS:
103
+ return list(attr.sparse_tensors)
104
+ case onnx.AttributeProto.TYPE_PROTO:
105
+ return attr.tp
106
+ case onnx.AttributeProto.TYPE_PROTOS:
107
+ return list(attr.type_protos)
108
+ case _:
109
+ raise ValueError(f"Unsupported attribute type: {attr.type}")
110
+
111
+
112
+ def _parse_tensor(
113
+ tensor: onnx.TensorProto,
114
+ *,
115
+ tensor_loader: Optional[Callable[[onnx.TensorProto], "torch.Tensor"]] = None,
116
+ ) -> "torch.Tensor":
117
+ """Convert an ONNX TensorProto to a PyTorch tensor.
118
+
119
+ Parameters
120
+ ----------
121
+ tensor : onnx.TensorProto
122
+ The ONNX tensor.
123
+
124
+ Returns
125
+ -------
126
+ torch.Tensor
127
+ The converted PyTorch tensor.
128
+ """
129
+ if tensor_loader is not None:
130
+ return tensor_loader(tensor)
131
+
132
+ import torch
133
+ from onnx import numpy_helper
134
+
135
+ onnx_dtype = tensor.data_type
136
+ if DTYPE_MAP.get(onnx_dtype) is None:
137
+ raise UnsupportedDTypeError(
138
+ onnx_dtype=onnx_dtype,
139
+ tensor_name=tensor.name or "<unnamed>",
140
+ details="attribute tensor dtype not supported",
141
+ )
142
+
143
+ np_array = numpy_helper.to_array(tensor)
144
+ return torch.from_numpy(np_array.copy())
145
+
146
+
147
+ __all__ = [
148
+ "get_attribute",
149
+ "get_attributes",
150
+ ]
onnx2fx/utils/dtype.py ADDED
@@ -0,0 +1,107 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Data type mapping between ONNX and PyTorch."""
3
+
4
+ from typing import Dict, Optional
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ # ONNX TensorProto data type to PyTorch dtype mapping
10
+ DTYPE_MAP: Dict[int, Optional[torch.dtype]] = {
11
+ onnx.TensorProto.FLOAT: torch.float32,
12
+ onnx.TensorProto.FLOAT16: torch.float16,
13
+ onnx.TensorProto.BFLOAT16: torch.bfloat16,
14
+ onnx.TensorProto.DOUBLE: torch.float64,
15
+ onnx.TensorProto.INT8: torch.int8,
16
+ onnx.TensorProto.INT16: torch.int16,
17
+ onnx.TensorProto.INT32: torch.int32,
18
+ onnx.TensorProto.INT64: torch.int64,
19
+ onnx.TensorProto.UINT8: torch.uint8,
20
+ onnx.TensorProto.UINT16: torch.uint16,
21
+ onnx.TensorProto.UINT32: torch.uint32,
22
+ onnx.TensorProto.UINT64: torch.uint64,
23
+ onnx.TensorProto.BOOL: torch.bool,
24
+ onnx.TensorProto.COMPLEX64: torch.complex64,
25
+ onnx.TensorProto.COMPLEX128: torch.complex128,
26
+ onnx.TensorProto.STRING: None,
27
+ }
28
+
29
+ # Reverse mapping: PyTorch dtype to ONNX TensorProto data type
30
+ TORCH_TO_ONNX_DTYPE: Dict[torch.dtype, int] = {
31
+ torch_dtype: onnx_dtype
32
+ for onnx_dtype, torch_dtype in DTYPE_MAP.items()
33
+ if torch_dtype is not None
34
+ }
35
+
36
+
37
+ def onnx_dtype_to_torch(onnx_dtype: int) -> Optional[torch.dtype]:
38
+ """Convert ONNX TensorProto data type to PyTorch dtype.
39
+
40
+ Parameters
41
+ ----------
42
+ onnx_dtype : int
43
+ ONNX TensorProto data type enum value.
44
+
45
+ Returns
46
+ -------
47
+ Optional[torch.dtype]
48
+ Corresponding PyTorch dtype, or None if not supported.
49
+ """
50
+ return DTYPE_MAP.get(onnx_dtype)
51
+
52
+
53
+ def torch_dtype_to_onnx(torch_dtype: torch.dtype) -> Optional[int]:
54
+ """Convert PyTorch dtype to ONNX TensorProto data type.
55
+
56
+ Parameters
57
+ ----------
58
+ torch_dtype : torch.dtype
59
+ PyTorch dtype.
60
+
61
+ Returns
62
+ -------
63
+ Optional[int]
64
+ Corresponding ONNX TensorProto data type, or None if not supported.
65
+ """
66
+ return TORCH_TO_ONNX_DTYPE.get(torch_dtype)
67
+
68
+
69
+ # ONNX stash_type to PyTorch dtype mapping (used in normalization ops)
70
+ _STASH_TYPE_MAP: Dict[int, torch.dtype] = {
71
+ 1: torch.float32,
72
+ 10: torch.float16,
73
+ 11: torch.float64,
74
+ 16: torch.bfloat16,
75
+ }
76
+
77
+
78
+ def stash_type_to_torch_dtype(stash_type: int) -> torch.dtype:
79
+ """Convert ONNX stash_type attribute to PyTorch dtype.
80
+
81
+ The stash_type attribute specifies the floating-point precision
82
+ for intermediate computations in normalization operations.
83
+
84
+ Parameters
85
+ ----------
86
+ stash_type : int
87
+ ONNX stash_type value:
88
+ - 1: float32
89
+ - 10: float16
90
+ - 11: float64
91
+ - 16: bfloat16
92
+
93
+ Returns
94
+ -------
95
+ torch.dtype
96
+ Corresponding PyTorch dtype. Defaults to float32 for unknown values.
97
+ """
98
+ return _STASH_TYPE_MAP.get(stash_type, torch.float32)
99
+
100
+
101
+ __all__ = [
102
+ "DTYPE_MAP",
103
+ "TORCH_TO_ONNX_DTYPE",
104
+ "onnx_dtype_to_torch",
105
+ "stash_type_to_torch_dtype",
106
+ "torch_dtype_to_onnx",
107
+ ]
@@ -0,0 +1,233 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Helpers for ONNX external data tensors."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ import importlib
8
+ import math
9
+ import os
10
+ from typing import Any, Dict, Iterable, Tuple
11
+
12
+ import numpy as np
13
+ import onnx
14
+
15
+ from ..exceptions import ExternalDataError, UnsupportedDTypeError
16
+ from .dtype import DTYPE_MAP
17
+
18
+
19
+ def _load_onnx_mapping() -> Any: # pragma: no cover - version-dependent
20
+ try:
21
+ return importlib.import_module("onnx.mapping")
22
+ except Exception:
23
+ return importlib.import_module("onnx._mapping")
24
+
25
+
26
+ onnx_mapping: Any = _load_onnx_mapping()
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class ExternalDataInfo:
31
+ """Resolved external data metadata."""
32
+
33
+ path: str
34
+ offset: int
35
+ length: int
36
+ shape: Tuple[int, ...]
37
+ numpy_dtype: np.dtype
38
+
39
+
40
+ def _tensor_name(tensor: onnx.TensorProto) -> str:
41
+ return tensor.name or "<unnamed>"
42
+
43
+
44
+ def _get_numpy_dtype(onnx_dtype: int, *, tensor_name: str) -> np.dtype:
45
+ if DTYPE_MAP.get(onnx_dtype) is None:
46
+ raise UnsupportedDTypeError(
47
+ onnx_dtype=onnx_dtype,
48
+ tensor_name=tensor_name,
49
+ details="dtype not supported by onnx2fx",
50
+ )
51
+ np_type = onnx_mapping.TENSOR_TYPE_TO_NP_TYPE.get(onnx_dtype)
52
+ if np_type is None:
53
+ raise UnsupportedDTypeError(
54
+ onnx_dtype=onnx_dtype,
55
+ tensor_name=tensor_name,
56
+ details="dtype has no NumPy mapping",
57
+ )
58
+ try:
59
+ return np.dtype(np_type)
60
+ except Exception as exc: # pragma: no cover - defensive
61
+ raise UnsupportedDTypeError(
62
+ onnx_dtype=onnx_dtype,
63
+ tensor_name=tensor_name,
64
+ details=f"dtype not supported by NumPy memmap ({exc})",
65
+ ) from exc
66
+
67
+
68
+ def _parse_external_data_kv(
69
+ tensor: onnx.TensorProto, *, tensor_name: str
70
+ ) -> Dict[str, str]:
71
+ data = {entry.key: entry.value for entry in tensor.external_data}
72
+ if not data:
73
+ raise ExternalDataError(
74
+ tensor_name=tensor_name,
75
+ message="external data metadata is empty",
76
+ )
77
+ return data
78
+
79
+
80
+ def _require_int_field(
81
+ data: Dict[str, str],
82
+ *,
83
+ field: str,
84
+ tensor_name: str,
85
+ ) -> int:
86
+ if field not in data:
87
+ raise ExternalDataError(
88
+ tensor_name=tensor_name,
89
+ message=f"missing external data field '{field}'",
90
+ )
91
+ try:
92
+ return int(data[field])
93
+ except ValueError as exc:
94
+ raise ExternalDataError(
95
+ tensor_name=tensor_name,
96
+ message=f"invalid external data field '{field}': {data[field]}",
97
+ ) from exc
98
+
99
+
100
+ def _resolve_external_path(
101
+ location: str, base_dir: str | None, *, tensor_name: str
102
+ ) -> str:
103
+ if os.path.isabs(location):
104
+ return location
105
+ if not base_dir:
106
+ raise ExternalDataError(
107
+ tensor_name=tensor_name,
108
+ message="base_dir is required for relative external data paths",
109
+ )
110
+ return os.path.normpath(os.path.join(base_dir, location))
111
+
112
+
113
+ def _expected_nbytes(shape: Iterable[int], np_dtype: np.dtype) -> int:
114
+ dims = list(shape)
115
+ if len(dims) == 0:
116
+ element_count = 1
117
+ else:
118
+ element_count = math.prod(dims)
119
+ return int(element_count) * int(np_dtype.itemsize)
120
+
121
+
122
+ def resolve_external_data(
123
+ tensor: onnx.TensorProto,
124
+ *,
125
+ base_dir: str | None,
126
+ strict: bool = True,
127
+ ) -> ExternalDataInfo:
128
+ tensor_name = _tensor_name(tensor)
129
+ data = _parse_external_data_kv(tensor, tensor_name=tensor_name)
130
+ if "location" not in data:
131
+ raise ExternalDataError(
132
+ tensor_name=tensor_name,
133
+ message="missing external data field 'location'",
134
+ )
135
+ location = data["location"]
136
+ offset = _require_int_field(data, field="offset", tensor_name=tensor_name)
137
+ length = _require_int_field(data, field="length", tensor_name=tensor_name)
138
+ if offset < 0 or length < 0:
139
+ raise ExternalDataError(
140
+ tensor_name=tensor_name,
141
+ message="external data offset/length must be non-negative",
142
+ )
143
+
144
+ path = _resolve_external_path(location, base_dir, tensor_name=tensor_name)
145
+ if not os.path.isfile(path):
146
+ raise ExternalDataError(
147
+ tensor_name=tensor_name,
148
+ message=f"external data file not found: {path}",
149
+ )
150
+
151
+ np_dtype = _get_numpy_dtype(tensor.data_type, tensor_name=tensor_name)
152
+ expected = _expected_nbytes(tensor.dims, np_dtype)
153
+ file_size = os.path.getsize(path)
154
+
155
+ if offset + length > file_size:
156
+ raise ExternalDataError(
157
+ tensor_name=tensor_name,
158
+ message=(
159
+ "external data range exceeds file size "
160
+ f"(offset={offset}, length={length}, file_size={file_size})"
161
+ ),
162
+ )
163
+
164
+ if strict and length != expected:
165
+ raise ExternalDataError(
166
+ tensor_name=tensor_name,
167
+ message=(
168
+ f"external data length mismatch (length={length}, expected={expected})"
169
+ ),
170
+ )
171
+
172
+ return ExternalDataInfo(
173
+ path=path,
174
+ offset=offset,
175
+ length=length,
176
+ shape=tuple(int(dim) for dim in tensor.dims),
177
+ numpy_dtype=np_dtype,
178
+ )
179
+
180
+
181
+ def iter_all_graphs(model: onnx.ModelProto) -> Iterable[onnx.GraphProto]:
182
+ """Yield all graphs, including nested subgraphs."""
183
+ graphs = [model.graph]
184
+ while graphs:
185
+ graph = graphs.pop()
186
+ yield graph
187
+ for node in graph.node:
188
+ for attr in node.attribute:
189
+ if attr.type == onnx.AttributeProto.GRAPH:
190
+ graphs.append(attr.g)
191
+ elif attr.type == onnx.AttributeProto.GRAPHS:
192
+ graphs.extend(list(attr.graphs))
193
+
194
+
195
+ def validate_external_data_model(
196
+ model: onnx.ModelProto,
197
+ *,
198
+ base_dir: str | None,
199
+ strict: bool = True,
200
+ ) -> None:
201
+ """Validate external data metadata for all tensors in a model."""
202
+ for graph in iter_all_graphs(model):
203
+ for tensor in graph.initializer:
204
+ if (
205
+ tensor.data_location == onnx.TensorProto.EXTERNAL
206
+ or tensor.external_data
207
+ ):
208
+ resolve_external_data(tensor, base_dir=base_dir, strict=strict)
209
+ for node in graph.node:
210
+ for attr in node.attribute:
211
+ if attr.type == onnx.AttributeProto.TENSOR:
212
+ tensor = attr.t
213
+ if (
214
+ tensor.data_location == onnx.TensorProto.EXTERNAL
215
+ or tensor.external_data
216
+ ):
217
+ resolve_external_data(tensor, base_dir=base_dir, strict=strict)
218
+ elif attr.type == onnx.AttributeProto.TENSORS:
219
+ for tensor in attr.tensors:
220
+ if (
221
+ tensor.data_location == onnx.TensorProto.EXTERNAL
222
+ or tensor.external_data
223
+ ):
224
+ resolve_external_data(
225
+ tensor, base_dir=base_dir, strict=strict
226
+ )
227
+
228
+
229
+ __all__ = [
230
+ "ExternalDataInfo",
231
+ "resolve_external_data",
232
+ "validate_external_data_model",
233
+ ]
onnx2fx/utils/names.py ADDED
@@ -0,0 +1,43 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Utilities for sanitizing ONNX names to valid Python identifiers."""
3
+
4
+ import keyword
5
+ import re
6
+
7
+
8
+ def sanitize_name(name: str) -> str:
9
+ """Sanitize a name to be a valid Python identifier.
10
+
11
+ ONNX tensor names can contain characters that are not valid in Python
12
+ identifiers (e.g., '.', '/', '-') or may start with digits. This function
13
+ converts such names to valid Python identifiers.
14
+
15
+ Parameters
16
+ ----------
17
+ name : str
18
+ The ONNX tensor name to sanitize.
19
+
20
+ Returns
21
+ -------
22
+ str
23
+ A valid Python identifier.
24
+ """
25
+ # Replace common invalid characters
26
+ safe_name = name.replace(".", "_").replace("/", "_").replace("-", "_")
27
+
28
+ # Replace any remaining non-alphanumeric characters (except underscore)
29
+ safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", safe_name)
30
+
31
+ # If name starts with a digit, prefix with underscore
32
+ if safe_name and safe_name[0].isdigit():
33
+ safe_name = "_" + safe_name
34
+
35
+ # Handle empty names
36
+ if not safe_name:
37
+ safe_name = "_unnamed"
38
+
39
+ # Handle Python keywords
40
+ if keyword.iskeyword(safe_name):
41
+ safe_name = safe_name + "_"
42
+
43
+ return safe_name