onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Utilities for parsing ONNX node attributes."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
|
|
8
|
+
from ..exceptions import UnsupportedDTypeError
|
|
9
|
+
from .dtype import DTYPE_MAP
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_attribute(
|
|
16
|
+
node: onnx.NodeProto,
|
|
17
|
+
name: str,
|
|
18
|
+
default: Optional[Any] = None,
|
|
19
|
+
*,
|
|
20
|
+
tensor_loader: Optional[Callable[[onnx.TensorProto], "torch.Tensor"]] = None,
|
|
21
|
+
) -> Any:
|
|
22
|
+
"""Get a single attribute value from an ONNX node.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
node : onnx.NodeProto
|
|
27
|
+
The ONNX node.
|
|
28
|
+
name : str
|
|
29
|
+
The attribute name.
|
|
30
|
+
default : Optional[Any]
|
|
31
|
+
Default value if attribute is not found.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
Any
|
|
36
|
+
The attribute value, or default if not found.
|
|
37
|
+
"""
|
|
38
|
+
for attr in node.attribute:
|
|
39
|
+
if attr.name == name:
|
|
40
|
+
return _parse_attribute_value(attr, tensor_loader=tensor_loader)
|
|
41
|
+
return default
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_attributes(node: onnx.NodeProto) -> Dict[str, Any]:
|
|
45
|
+
"""Get all attributes from an ONNX node as a dictionary.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
node : onnx.NodeProto
|
|
50
|
+
The ONNX node.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
Dict[str, Any]
|
|
55
|
+
Dictionary mapping attribute names to their values.
|
|
56
|
+
"""
|
|
57
|
+
return {attr.name: _parse_attribute_value(attr) for attr in node.attribute}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _parse_attribute_value(
|
|
61
|
+
attr: onnx.AttributeProto,
|
|
62
|
+
*,
|
|
63
|
+
tensor_loader: Optional[Callable[[onnx.TensorProto], "torch.Tensor"]] = None,
|
|
64
|
+
) -> Any:
|
|
65
|
+
"""Parse an ONNX attribute into a Python value.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
attr : onnx.AttributeProto
|
|
70
|
+
The ONNX attribute.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
Any
|
|
75
|
+
The parsed Python value.
|
|
76
|
+
"""
|
|
77
|
+
match attr.type:
|
|
78
|
+
case onnx.AttributeProto.FLOAT:
|
|
79
|
+
return attr.f
|
|
80
|
+
case onnx.AttributeProto.INT:
|
|
81
|
+
return attr.i
|
|
82
|
+
case onnx.AttributeProto.STRING:
|
|
83
|
+
return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
|
|
84
|
+
case onnx.AttributeProto.TENSOR:
|
|
85
|
+
return _parse_tensor(attr.t, tensor_loader=tensor_loader)
|
|
86
|
+
case onnx.AttributeProto.GRAPH:
|
|
87
|
+
return attr.g
|
|
88
|
+
case onnx.AttributeProto.FLOATS:
|
|
89
|
+
return list(attr.floats)
|
|
90
|
+
case onnx.AttributeProto.INTS:
|
|
91
|
+
return list(attr.ints)
|
|
92
|
+
case onnx.AttributeProto.STRINGS:
|
|
93
|
+
return [
|
|
94
|
+
s.decode("utf-8") if isinstance(s, bytes) else s for s in attr.strings
|
|
95
|
+
]
|
|
96
|
+
case onnx.AttributeProto.TENSORS:
|
|
97
|
+
return [_parse_tensor(t, tensor_loader=tensor_loader) for t in attr.tensors]
|
|
98
|
+
case onnx.AttributeProto.GRAPHS:
|
|
99
|
+
return list(attr.graphs)
|
|
100
|
+
case onnx.AttributeProto.SPARSE_TENSOR:
|
|
101
|
+
return attr.sparse_tensor
|
|
102
|
+
case onnx.AttributeProto.SPARSE_TENSORS:
|
|
103
|
+
return list(attr.sparse_tensors)
|
|
104
|
+
case onnx.AttributeProto.TYPE_PROTO:
|
|
105
|
+
return attr.tp
|
|
106
|
+
case onnx.AttributeProto.TYPE_PROTOS:
|
|
107
|
+
return list(attr.type_protos)
|
|
108
|
+
case _:
|
|
109
|
+
raise ValueError(f"Unsupported attribute type: {attr.type}")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _parse_tensor(
|
|
113
|
+
tensor: onnx.TensorProto,
|
|
114
|
+
*,
|
|
115
|
+
tensor_loader: Optional[Callable[[onnx.TensorProto], "torch.Tensor"]] = None,
|
|
116
|
+
) -> "torch.Tensor":
|
|
117
|
+
"""Convert an ONNX TensorProto to a PyTorch tensor.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
tensor : onnx.TensorProto
|
|
122
|
+
The ONNX tensor.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
torch.Tensor
|
|
127
|
+
The converted PyTorch tensor.
|
|
128
|
+
"""
|
|
129
|
+
if tensor_loader is not None:
|
|
130
|
+
return tensor_loader(tensor)
|
|
131
|
+
|
|
132
|
+
import torch
|
|
133
|
+
from onnx import numpy_helper
|
|
134
|
+
|
|
135
|
+
onnx_dtype = tensor.data_type
|
|
136
|
+
if DTYPE_MAP.get(onnx_dtype) is None:
|
|
137
|
+
raise UnsupportedDTypeError(
|
|
138
|
+
onnx_dtype=onnx_dtype,
|
|
139
|
+
tensor_name=tensor.name or "<unnamed>",
|
|
140
|
+
details="attribute tensor dtype not supported",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
np_array = numpy_helper.to_array(tensor)
|
|
144
|
+
return torch.from_numpy(np_array.copy())
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
__all__ = [
|
|
148
|
+
"get_attribute",
|
|
149
|
+
"get_attributes",
|
|
150
|
+
]
|
onnx2fx/utils/dtype.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Data type mapping between ONNX and PyTorch."""
|
|
3
|
+
|
|
4
|
+
from typing import Dict, Optional
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
# ONNX TensorProto data type to PyTorch dtype mapping
|
|
10
|
+
DTYPE_MAP: Dict[int, Optional[torch.dtype]] = {
|
|
11
|
+
onnx.TensorProto.FLOAT: torch.float32,
|
|
12
|
+
onnx.TensorProto.FLOAT16: torch.float16,
|
|
13
|
+
onnx.TensorProto.BFLOAT16: torch.bfloat16,
|
|
14
|
+
onnx.TensorProto.DOUBLE: torch.float64,
|
|
15
|
+
onnx.TensorProto.INT8: torch.int8,
|
|
16
|
+
onnx.TensorProto.INT16: torch.int16,
|
|
17
|
+
onnx.TensorProto.INT32: torch.int32,
|
|
18
|
+
onnx.TensorProto.INT64: torch.int64,
|
|
19
|
+
onnx.TensorProto.UINT8: torch.uint8,
|
|
20
|
+
onnx.TensorProto.UINT16: torch.uint16,
|
|
21
|
+
onnx.TensorProto.UINT32: torch.uint32,
|
|
22
|
+
onnx.TensorProto.UINT64: torch.uint64,
|
|
23
|
+
onnx.TensorProto.BOOL: torch.bool,
|
|
24
|
+
onnx.TensorProto.COMPLEX64: torch.complex64,
|
|
25
|
+
onnx.TensorProto.COMPLEX128: torch.complex128,
|
|
26
|
+
onnx.TensorProto.STRING: None,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
# Reverse mapping: PyTorch dtype to ONNX TensorProto data type
|
|
30
|
+
TORCH_TO_ONNX_DTYPE: Dict[torch.dtype, int] = {
|
|
31
|
+
torch_dtype: onnx_dtype
|
|
32
|
+
for onnx_dtype, torch_dtype in DTYPE_MAP.items()
|
|
33
|
+
if torch_dtype is not None
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def onnx_dtype_to_torch(onnx_dtype: int) -> Optional[torch.dtype]:
|
|
38
|
+
"""Convert ONNX TensorProto data type to PyTorch dtype.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
onnx_dtype : int
|
|
43
|
+
ONNX TensorProto data type enum value.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
Optional[torch.dtype]
|
|
48
|
+
Corresponding PyTorch dtype, or None if not supported.
|
|
49
|
+
"""
|
|
50
|
+
return DTYPE_MAP.get(onnx_dtype)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def torch_dtype_to_onnx(torch_dtype: torch.dtype) -> Optional[int]:
|
|
54
|
+
"""Convert PyTorch dtype to ONNX TensorProto data type.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
torch_dtype : torch.dtype
|
|
59
|
+
PyTorch dtype.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
Optional[int]
|
|
64
|
+
Corresponding ONNX TensorProto data type, or None if not supported.
|
|
65
|
+
"""
|
|
66
|
+
return TORCH_TO_ONNX_DTYPE.get(torch_dtype)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# ONNX stash_type to PyTorch dtype mapping (used in normalization ops)
|
|
70
|
+
_STASH_TYPE_MAP: Dict[int, torch.dtype] = {
|
|
71
|
+
1: torch.float32,
|
|
72
|
+
10: torch.float16,
|
|
73
|
+
11: torch.float64,
|
|
74
|
+
16: torch.bfloat16,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def stash_type_to_torch_dtype(stash_type: int) -> torch.dtype:
|
|
79
|
+
"""Convert ONNX stash_type attribute to PyTorch dtype.
|
|
80
|
+
|
|
81
|
+
The stash_type attribute specifies the floating-point precision
|
|
82
|
+
for intermediate computations in normalization operations.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
stash_type : int
|
|
87
|
+
ONNX stash_type value:
|
|
88
|
+
- 1: float32
|
|
89
|
+
- 10: float16
|
|
90
|
+
- 11: float64
|
|
91
|
+
- 16: bfloat16
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
torch.dtype
|
|
96
|
+
Corresponding PyTorch dtype. Defaults to float32 for unknown values.
|
|
97
|
+
"""
|
|
98
|
+
return _STASH_TYPE_MAP.get(stash_type, torch.float32)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
__all__ = [
|
|
102
|
+
"DTYPE_MAP",
|
|
103
|
+
"TORCH_TO_ONNX_DTYPE",
|
|
104
|
+
"onnx_dtype_to_torch",
|
|
105
|
+
"stash_type_to_torch_dtype",
|
|
106
|
+
"torch_dtype_to_onnx",
|
|
107
|
+
]
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Helpers for ONNX external data tensors."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
import importlib
|
|
8
|
+
import math
|
|
9
|
+
import os
|
|
10
|
+
from typing import Any, Dict, Iterable, Tuple
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import onnx
|
|
14
|
+
|
|
15
|
+
from ..exceptions import ExternalDataError, UnsupportedDTypeError
|
|
16
|
+
from .dtype import DTYPE_MAP
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _load_onnx_mapping() -> Any: # pragma: no cover - version-dependent
|
|
20
|
+
try:
|
|
21
|
+
return importlib.import_module("onnx.mapping")
|
|
22
|
+
except Exception:
|
|
23
|
+
return importlib.import_module("onnx._mapping")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
onnx_mapping: Any = _load_onnx_mapping()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class ExternalDataInfo:
|
|
31
|
+
"""Resolved external data metadata."""
|
|
32
|
+
|
|
33
|
+
path: str
|
|
34
|
+
offset: int
|
|
35
|
+
length: int
|
|
36
|
+
shape: Tuple[int, ...]
|
|
37
|
+
numpy_dtype: np.dtype
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _tensor_name(tensor: onnx.TensorProto) -> str:
|
|
41
|
+
return tensor.name or "<unnamed>"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _get_numpy_dtype(onnx_dtype: int, *, tensor_name: str) -> np.dtype:
|
|
45
|
+
if DTYPE_MAP.get(onnx_dtype) is None:
|
|
46
|
+
raise UnsupportedDTypeError(
|
|
47
|
+
onnx_dtype=onnx_dtype,
|
|
48
|
+
tensor_name=tensor_name,
|
|
49
|
+
details="dtype not supported by onnx2fx",
|
|
50
|
+
)
|
|
51
|
+
np_type = onnx_mapping.TENSOR_TYPE_TO_NP_TYPE.get(onnx_dtype)
|
|
52
|
+
if np_type is None:
|
|
53
|
+
raise UnsupportedDTypeError(
|
|
54
|
+
onnx_dtype=onnx_dtype,
|
|
55
|
+
tensor_name=tensor_name,
|
|
56
|
+
details="dtype has no NumPy mapping",
|
|
57
|
+
)
|
|
58
|
+
try:
|
|
59
|
+
return np.dtype(np_type)
|
|
60
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
61
|
+
raise UnsupportedDTypeError(
|
|
62
|
+
onnx_dtype=onnx_dtype,
|
|
63
|
+
tensor_name=tensor_name,
|
|
64
|
+
details=f"dtype not supported by NumPy memmap ({exc})",
|
|
65
|
+
) from exc
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _parse_external_data_kv(
|
|
69
|
+
tensor: onnx.TensorProto, *, tensor_name: str
|
|
70
|
+
) -> Dict[str, str]:
|
|
71
|
+
data = {entry.key: entry.value for entry in tensor.external_data}
|
|
72
|
+
if not data:
|
|
73
|
+
raise ExternalDataError(
|
|
74
|
+
tensor_name=tensor_name,
|
|
75
|
+
message="external data metadata is empty",
|
|
76
|
+
)
|
|
77
|
+
return data
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _require_int_field(
|
|
81
|
+
data: Dict[str, str],
|
|
82
|
+
*,
|
|
83
|
+
field: str,
|
|
84
|
+
tensor_name: str,
|
|
85
|
+
) -> int:
|
|
86
|
+
if field not in data:
|
|
87
|
+
raise ExternalDataError(
|
|
88
|
+
tensor_name=tensor_name,
|
|
89
|
+
message=f"missing external data field '{field}'",
|
|
90
|
+
)
|
|
91
|
+
try:
|
|
92
|
+
return int(data[field])
|
|
93
|
+
except ValueError as exc:
|
|
94
|
+
raise ExternalDataError(
|
|
95
|
+
tensor_name=tensor_name,
|
|
96
|
+
message=f"invalid external data field '{field}': {data[field]}",
|
|
97
|
+
) from exc
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _resolve_external_path(
|
|
101
|
+
location: str, base_dir: str | None, *, tensor_name: str
|
|
102
|
+
) -> str:
|
|
103
|
+
if os.path.isabs(location):
|
|
104
|
+
return location
|
|
105
|
+
if not base_dir:
|
|
106
|
+
raise ExternalDataError(
|
|
107
|
+
tensor_name=tensor_name,
|
|
108
|
+
message="base_dir is required for relative external data paths",
|
|
109
|
+
)
|
|
110
|
+
return os.path.normpath(os.path.join(base_dir, location))
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _expected_nbytes(shape: Iterable[int], np_dtype: np.dtype) -> int:
|
|
114
|
+
dims = list(shape)
|
|
115
|
+
if len(dims) == 0:
|
|
116
|
+
element_count = 1
|
|
117
|
+
else:
|
|
118
|
+
element_count = math.prod(dims)
|
|
119
|
+
return int(element_count) * int(np_dtype.itemsize)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def resolve_external_data(
|
|
123
|
+
tensor: onnx.TensorProto,
|
|
124
|
+
*,
|
|
125
|
+
base_dir: str | None,
|
|
126
|
+
strict: bool = True,
|
|
127
|
+
) -> ExternalDataInfo:
|
|
128
|
+
tensor_name = _tensor_name(tensor)
|
|
129
|
+
data = _parse_external_data_kv(tensor, tensor_name=tensor_name)
|
|
130
|
+
if "location" not in data:
|
|
131
|
+
raise ExternalDataError(
|
|
132
|
+
tensor_name=tensor_name,
|
|
133
|
+
message="missing external data field 'location'",
|
|
134
|
+
)
|
|
135
|
+
location = data["location"]
|
|
136
|
+
offset = _require_int_field(data, field="offset", tensor_name=tensor_name)
|
|
137
|
+
length = _require_int_field(data, field="length", tensor_name=tensor_name)
|
|
138
|
+
if offset < 0 or length < 0:
|
|
139
|
+
raise ExternalDataError(
|
|
140
|
+
tensor_name=tensor_name,
|
|
141
|
+
message="external data offset/length must be non-negative",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
path = _resolve_external_path(location, base_dir, tensor_name=tensor_name)
|
|
145
|
+
if not os.path.isfile(path):
|
|
146
|
+
raise ExternalDataError(
|
|
147
|
+
tensor_name=tensor_name,
|
|
148
|
+
message=f"external data file not found: {path}",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
np_dtype = _get_numpy_dtype(tensor.data_type, tensor_name=tensor_name)
|
|
152
|
+
expected = _expected_nbytes(tensor.dims, np_dtype)
|
|
153
|
+
file_size = os.path.getsize(path)
|
|
154
|
+
|
|
155
|
+
if offset + length > file_size:
|
|
156
|
+
raise ExternalDataError(
|
|
157
|
+
tensor_name=tensor_name,
|
|
158
|
+
message=(
|
|
159
|
+
"external data range exceeds file size "
|
|
160
|
+
f"(offset={offset}, length={length}, file_size={file_size})"
|
|
161
|
+
),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if strict and length != expected:
|
|
165
|
+
raise ExternalDataError(
|
|
166
|
+
tensor_name=tensor_name,
|
|
167
|
+
message=(
|
|
168
|
+
f"external data length mismatch (length={length}, expected={expected})"
|
|
169
|
+
),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return ExternalDataInfo(
|
|
173
|
+
path=path,
|
|
174
|
+
offset=offset,
|
|
175
|
+
length=length,
|
|
176
|
+
shape=tuple(int(dim) for dim in tensor.dims),
|
|
177
|
+
numpy_dtype=np_dtype,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def iter_all_graphs(model: onnx.ModelProto) -> Iterable[onnx.GraphProto]:
|
|
182
|
+
"""Yield all graphs, including nested subgraphs."""
|
|
183
|
+
graphs = [model.graph]
|
|
184
|
+
while graphs:
|
|
185
|
+
graph = graphs.pop()
|
|
186
|
+
yield graph
|
|
187
|
+
for node in graph.node:
|
|
188
|
+
for attr in node.attribute:
|
|
189
|
+
if attr.type == onnx.AttributeProto.GRAPH:
|
|
190
|
+
graphs.append(attr.g)
|
|
191
|
+
elif attr.type == onnx.AttributeProto.GRAPHS:
|
|
192
|
+
graphs.extend(list(attr.graphs))
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def validate_external_data_model(
|
|
196
|
+
model: onnx.ModelProto,
|
|
197
|
+
*,
|
|
198
|
+
base_dir: str | None,
|
|
199
|
+
strict: bool = True,
|
|
200
|
+
) -> None:
|
|
201
|
+
"""Validate external data metadata for all tensors in a model."""
|
|
202
|
+
for graph in iter_all_graphs(model):
|
|
203
|
+
for tensor in graph.initializer:
|
|
204
|
+
if (
|
|
205
|
+
tensor.data_location == onnx.TensorProto.EXTERNAL
|
|
206
|
+
or tensor.external_data
|
|
207
|
+
):
|
|
208
|
+
resolve_external_data(tensor, base_dir=base_dir, strict=strict)
|
|
209
|
+
for node in graph.node:
|
|
210
|
+
for attr in node.attribute:
|
|
211
|
+
if attr.type == onnx.AttributeProto.TENSOR:
|
|
212
|
+
tensor = attr.t
|
|
213
|
+
if (
|
|
214
|
+
tensor.data_location == onnx.TensorProto.EXTERNAL
|
|
215
|
+
or tensor.external_data
|
|
216
|
+
):
|
|
217
|
+
resolve_external_data(tensor, base_dir=base_dir, strict=strict)
|
|
218
|
+
elif attr.type == onnx.AttributeProto.TENSORS:
|
|
219
|
+
for tensor in attr.tensors:
|
|
220
|
+
if (
|
|
221
|
+
tensor.data_location == onnx.TensorProto.EXTERNAL
|
|
222
|
+
or tensor.external_data
|
|
223
|
+
):
|
|
224
|
+
resolve_external_data(
|
|
225
|
+
tensor, base_dir=base_dir, strict=strict
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
__all__ = [
|
|
230
|
+
"ExternalDataInfo",
|
|
231
|
+
"resolve_external_data",
|
|
232
|
+
"validate_external_data_model",
|
|
233
|
+
]
|
onnx2fx/utils/names.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Utilities for sanitizing ONNX names to valid Python identifiers."""
|
|
3
|
+
|
|
4
|
+
import keyword
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def sanitize_name(name: str) -> str:
|
|
9
|
+
"""Sanitize a name to be a valid Python identifier.
|
|
10
|
+
|
|
11
|
+
ONNX tensor names can contain characters that are not valid in Python
|
|
12
|
+
identifiers (e.g., '.', '/', '-') or may start with digits. This function
|
|
13
|
+
converts such names to valid Python identifiers.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
name : str
|
|
18
|
+
The ONNX tensor name to sanitize.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
str
|
|
23
|
+
A valid Python identifier.
|
|
24
|
+
"""
|
|
25
|
+
# Replace common invalid characters
|
|
26
|
+
safe_name = name.replace(".", "_").replace("/", "_").replace("-", "_")
|
|
27
|
+
|
|
28
|
+
# Replace any remaining non-alphanumeric characters (except underscore)
|
|
29
|
+
safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", safe_name)
|
|
30
|
+
|
|
31
|
+
# If name starts with a digit, prefix with underscore
|
|
32
|
+
if safe_name and safe_name[0].isdigit():
|
|
33
|
+
safe_name = "_" + safe_name
|
|
34
|
+
|
|
35
|
+
# Handle empty names
|
|
36
|
+
if not safe_name:
|
|
37
|
+
safe_name = "_unnamed"
|
|
38
|
+
|
|
39
|
+
# Handle Python keywords
|
|
40
|
+
if keyword.iskeyword(safe_name):
|
|
41
|
+
safe_name = safe_name + "_"
|
|
42
|
+
|
|
43
|
+
return safe_name
|