onnx-diagnostic 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +24 -3
- onnx_diagnostic/doc.py +46 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/model_builder_helper.py +3 -0
- onnx_diagnostic/helpers/onnx_helper.py +291 -7
- onnx_diagnostic/reference/torch_evaluator.py +141 -11
- onnx_diagnostic/reference/torch_ops/__init__.py +1 -1
- onnx_diagnostic/reference/torch_ops/_op_run.py +14 -5
- onnx_diagnostic/reference/torch_ops/access_ops.py +18 -8
- onnx_diagnostic/reference/torch_ops/binary_ops.py +2 -2
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +7 -4
- onnx_diagnostic/reference/torch_ops/generator_ops.py +4 -3
- onnx_diagnostic/reference/torch_ops/nn_ops.py +34 -14
- onnx_diagnostic/reference/torch_ops/other_ops.py +19 -19
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +6 -6
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +6 -6
- onnx_diagnostic/reference/torch_ops/shape_ops.py +16 -15
- onnx_diagnostic/reference/torch_ops/unary_ops.py +13 -13
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +1 -0
- onnx_diagnostic/torch_models/test_helper.py +34 -12
- {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/RECORD +26 -25
- {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -191,24 +191,45 @@ def get_parser_find() -> ArgumentParser:
|
|
|
191
191
|
"--names",
|
|
192
192
|
type=str,
|
|
193
193
|
required=False,
|
|
194
|
-
help="names to look at comma separated values"
|
|
194
|
+
help="names to look at comma separated values, if 'SHADOW', "
|
|
195
|
+
"search for shadowing names",
|
|
195
196
|
)
|
|
196
197
|
parser.add_argument(
|
|
197
198
|
"-v",
|
|
198
199
|
"--verbose",
|
|
199
200
|
default=0,
|
|
201
|
+
type=int,
|
|
200
202
|
required=False,
|
|
201
203
|
help="verbosity",
|
|
202
204
|
)
|
|
205
|
+
parser.add_argument(
|
|
206
|
+
"--v2",
|
|
207
|
+
default=False,
|
|
208
|
+
action=BooleanOptionalAction,
|
|
209
|
+
help="use enumerate_results instead of onnx_find",
|
|
210
|
+
)
|
|
203
211
|
return parser
|
|
204
212
|
|
|
205
213
|
|
|
206
214
|
def _cmd_find(argv: List[Any]):
|
|
207
|
-
from .helpers.onnx_helper import onnx_find
|
|
215
|
+
from .helpers.onnx_helper import onnx_find, enumerate_results, shadowing_names
|
|
208
216
|
|
|
209
217
|
parser = get_parser_find()
|
|
210
218
|
args = parser.parse_args(argv[1:])
|
|
211
|
-
|
|
219
|
+
if args.names == "SHADOW":
|
|
220
|
+
onx = onnx.load(args.input, load_external_data=False)
|
|
221
|
+
s, ps = shadowing_names(onx)[:2]
|
|
222
|
+
print(f"shadowing names: {s}")
|
|
223
|
+
print(f"post-shadowing names: {ps}")
|
|
224
|
+
elif args.v2:
|
|
225
|
+
onx = onnx.load(args.input, load_external_data=False)
|
|
226
|
+
res = list(
|
|
227
|
+
enumerate_results(onx, name=set(args.names.split(",")), verbose=args.verbose)
|
|
228
|
+
)
|
|
229
|
+
if not args.verbose:
|
|
230
|
+
print("\n".join(map(str, res)))
|
|
231
|
+
else:
|
|
232
|
+
onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
|
|
212
233
|
|
|
213
234
|
|
|
214
235
|
def get_parser_config() -> ArgumentParser:
|
onnx_diagnostic/doc.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
1
5
|
def reset_torch_transformers(gallery_conf, fname):
|
|
2
6
|
"Resets torch dynamo for :epkg:`sphinx-gallery`."
|
|
3
7
|
import matplotlib.pyplot as plt
|
|
@@ -30,3 +34,45 @@ def plot_legend(
|
|
|
30
34
|
ax.grid(False)
|
|
31
35
|
ax.set_axis_off()
|
|
32
36
|
return ax
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def rotate_align(ax, angle=15, align="right"):
|
|
40
|
+
"""Rotates x-label and align them to thr right. Returns ax."""
|
|
41
|
+
for label in ax.get_xticklabels():
|
|
42
|
+
label.set_rotation(angle)
|
|
43
|
+
label.set_horizontalalignment(align)
|
|
44
|
+
return ax
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def save_fig(ax, name: str):
|
|
48
|
+
"""Applies ``tight_layout`` and saves the figures. Returns ax."""
|
|
49
|
+
import matplotlib.pyplot as plt
|
|
50
|
+
|
|
51
|
+
plt.tight_layout()
|
|
52
|
+
fig = ax.get_figure()
|
|
53
|
+
fig.savefig(name)
|
|
54
|
+
return ax
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821
|
|
58
|
+
"Adds a title to axes and returns them."
|
|
59
|
+
ax.set_title(title)
|
|
60
|
+
return ax
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def plot_histogram(
|
|
64
|
+
tensor: np.ndarray,
|
|
65
|
+
ax: Optional["plt.axes"] = None, # noqa: F821
|
|
66
|
+
bins: int = 30,
|
|
67
|
+
color: str = "orange",
|
|
68
|
+
alpha: float = 0.7,
|
|
69
|
+
) -> "plt.axes": # noqa: F821
|
|
70
|
+
"Computes the distribution for a tensor."
|
|
71
|
+
if ax is None:
|
|
72
|
+
import matplotlib.pyplot as plt
|
|
73
|
+
|
|
74
|
+
ax = plt.gca()
|
|
75
|
+
ax.cla()
|
|
76
|
+
ax.hist(tensor, bins=30, color="orange", alpha=0.7)
|
|
77
|
+
ax.set_yscale("log")
|
|
78
|
+
return ax
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, List, Optional, Tuple
|
|
3
|
+
import onnx
|
|
4
|
+
import onnx.helper as oh
|
|
5
|
+
import torch
|
|
6
|
+
from ..reference.torch_ops import OpRunKernel, OpRunTensor
|
|
7
|
+
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
|
|
8
|
+
from .ort_session import InferenceSessionForTorch
|
|
9
|
+
|
|
10
|
+
_SAVED: List[str] = []
|
|
11
|
+
_SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_model_name(op_name: str, provider: str) -> Optional[str]:
|
|
15
|
+
if _SAVE_OPTIMIZED_MODEL_:
|
|
16
|
+
name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
|
|
17
|
+
_SAVED.append(name)
|
|
18
|
+
return name
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LayerNormalizationOrt(OpRunKernel):
|
|
23
|
+
"LayerNormalization with onnxruntime"
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def device_dependent(cls) -> bool:
|
|
27
|
+
"Needs device."
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
node: onnx.NodeProto,
|
|
33
|
+
version=None,
|
|
34
|
+
device: Optional[torch.device] = None,
|
|
35
|
+
verbose: int = 0,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(node, version, verbose=verbose)
|
|
38
|
+
self.axis = self.get_attribute_int(node, "axis", -1)
|
|
39
|
+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
|
|
40
|
+
self.device = device
|
|
41
|
+
self.stash_type = onnx_dtype_to_torch_dtype(
|
|
42
|
+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
|
|
43
|
+
)
|
|
44
|
+
self.compute_std = len(node.output) > 1
|
|
45
|
+
assert not self.compute_std, (
|
|
46
|
+
f"This kernel implementation only work when only one output "
|
|
47
|
+
f"is required but {node.output} were."
|
|
48
|
+
)
|
|
49
|
+
self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
|
|
50
|
+
self.is_cpu = torch.device("cpu") == self.device
|
|
51
|
+
|
|
52
|
+
def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
|
|
53
|
+
shape = [*["d{i}" for i in range(rank - 1)], "last"]
|
|
54
|
+
layer_model = oh.make_model(
|
|
55
|
+
oh.make_graph(
|
|
56
|
+
[
|
|
57
|
+
oh.make_node(
|
|
58
|
+
"LayerNormalization",
|
|
59
|
+
["X", "W", "B"] if has_bias else ["X", "W"],
|
|
60
|
+
["Z"],
|
|
61
|
+
axis=self.axis,
|
|
62
|
+
epsilon=self.epsilon,
|
|
63
|
+
)
|
|
64
|
+
],
|
|
65
|
+
"dummy",
|
|
66
|
+
(
|
|
67
|
+
[
|
|
68
|
+
oh.make_tensor_value_info("X", itype, shape),
|
|
69
|
+
oh.make_tensor_value_info("W", itype, ["last"]),
|
|
70
|
+
oh.make_tensor_value_info("B", itype, ["last"]),
|
|
71
|
+
]
|
|
72
|
+
if has_bias
|
|
73
|
+
else [
|
|
74
|
+
oh.make_tensor_value_info("X", itype, shape),
|
|
75
|
+
oh.make_tensor_value_info("W", itype, ["last"]),
|
|
76
|
+
]
|
|
77
|
+
),
|
|
78
|
+
[oh.make_tensor_value_info("Z", itype, shape)],
|
|
79
|
+
),
|
|
80
|
+
ir_version=9,
|
|
81
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
82
|
+
)
|
|
83
|
+
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
|
|
84
|
+
self._provider = provider
|
|
85
|
+
return InferenceSessionForTorch(
|
|
86
|
+
layer_model,
|
|
87
|
+
optimized_model_filepath=_get_model_name("layer_norm", provider),
|
|
88
|
+
providers=[provider],
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def run(self, x, scale, bias=None):
|
|
92
|
+
itype = torch_dtype_to_onnx_dtype(x.dtype)
|
|
93
|
+
rank = len(x.shape)
|
|
94
|
+
key = itype, rank
|
|
95
|
+
if key not in self._cache:
|
|
96
|
+
self._cache[key] = self._make_model(itype, rank, bias is not None)
|
|
97
|
+
sess = self._cache[key]
|
|
98
|
+
if self.verbose:
|
|
99
|
+
print(f"[LayerNormalizationOrt] running on {self._provider!r}")
|
|
100
|
+
feeds = dict(X=x.tensor, W=scale.tensor)
|
|
101
|
+
if bias is not None:
|
|
102
|
+
feeds["B"] = bias.tensor
|
|
103
|
+
got = sess.run(None, feeds)[0]
|
|
104
|
+
return OpRunTensor(got)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class MatMulOrt(OpRunKernel):
|
|
108
|
+
"MatMul with onnxruntime"
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def device_dependent(cls) -> bool:
|
|
112
|
+
"Needs device."
|
|
113
|
+
return True
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
node: onnx.NodeProto,
|
|
118
|
+
version=None,
|
|
119
|
+
device: Optional[torch.device] = None,
|
|
120
|
+
verbose: int = 0,
|
|
121
|
+
):
|
|
122
|
+
super().__init__(node, version, verbose=verbose)
|
|
123
|
+
self.device = device
|
|
124
|
+
self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
|
|
125
|
+
self.is_cpu = torch.device("cpu") == self.device
|
|
126
|
+
|
|
127
|
+
def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
|
|
128
|
+
shapea = ["a{i}" for i in range(ranka)]
|
|
129
|
+
shapeb = ["b{i}" for i in range(rankb)]
|
|
130
|
+
shapec = ["c{i}" for i in range(max(ranka, rankb))]
|
|
131
|
+
model = oh.make_model(
|
|
132
|
+
oh.make_graph(
|
|
133
|
+
[oh.make_node("MatMul", ["A", "B"], ["C"])],
|
|
134
|
+
"dummy",
|
|
135
|
+
[
|
|
136
|
+
oh.make_tensor_value_info("A", itype, shapea),
|
|
137
|
+
oh.make_tensor_value_info("B", itype, shapeb),
|
|
138
|
+
],
|
|
139
|
+
[oh.make_tensor_value_info("C", itype, shapec)],
|
|
140
|
+
),
|
|
141
|
+
ir_version=9,
|
|
142
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
143
|
+
)
|
|
144
|
+
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
|
|
145
|
+
self._provider = provider
|
|
146
|
+
return InferenceSessionForTorch(
|
|
147
|
+
model,
|
|
148
|
+
optimized_model_filepath=_get_model_name("matmul", provider),
|
|
149
|
+
providers=[provider],
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def run(self, a, b):
|
|
153
|
+
itype = torch_dtype_to_onnx_dtype(a.dtype)
|
|
154
|
+
ranka, rankb = len(a.shape), len(b.shape)
|
|
155
|
+
key = itype, ranka, rankb
|
|
156
|
+
if key not in self._cache:
|
|
157
|
+
self._cache[key] = self._make_model(itype, ranka, rankb)
|
|
158
|
+
sess = self._cache[key]
|
|
159
|
+
if self.verbose:
|
|
160
|
+
print(f"[MatMulOrt] running on {self._provider!r}")
|
|
161
|
+
feeds = dict(A=a.tensor, B=b.tensor)
|
|
162
|
+
got = sess.run(None, feeds)[0]
|
|
163
|
+
return OpRunTensor(got)
|
|
@@ -220,6 +220,9 @@ def create_model_builder(
|
|
|
220
220
|
"""
|
|
221
221
|
assert cache_dir, "create_model_builder does not work without cache_dir."
|
|
222
222
|
assert os.path.exists(cache_dir), f"cache_dir={cache_dir!r} does not exists"
|
|
223
|
+
precision = {"float32": "fp32", "float16": "fp16", "bfloat16": "bfp16"}.get(
|
|
224
|
+
precision, precision
|
|
225
|
+
)
|
|
223
226
|
download_model_builder_to_cache()
|
|
224
227
|
builder = import_model_builder()
|
|
225
228
|
io_dtype = builder.set_io_dtype(precision, execution_provider, extra_options)
|
|
@@ -316,7 +316,7 @@ def check_model_ort(
|
|
|
316
316
|
|
|
317
317
|
|
|
318
318
|
@functools.cache
|
|
319
|
-
def onnx_dtype_name(itype: int) -> str:
|
|
319
|
+
def onnx_dtype_name(itype: int, exc: bool = True) -> str:
|
|
320
320
|
"""
|
|
321
321
|
Returns the ONNX name for a specific element type.
|
|
322
322
|
|
|
@@ -335,7 +335,11 @@ def onnx_dtype_name(itype: int) -> str:
|
|
|
335
335
|
v = getattr(TensorProto, k)
|
|
336
336
|
if v == itype:
|
|
337
337
|
return k
|
|
338
|
-
|
|
338
|
+
if exc:
|
|
339
|
+
raise ValueError(f"Unexpected value itype: {itype}")
|
|
340
|
+
if itype == 0:
|
|
341
|
+
return "UNDEFINED"
|
|
342
|
+
return "UNEXPECTED"
|
|
339
343
|
|
|
340
344
|
|
|
341
345
|
def pretty_onnx(
|
|
@@ -365,7 +369,7 @@ def pretty_onnx(
|
|
|
365
369
|
itype = onx.type.tensor_type.elem_type
|
|
366
370
|
shape = tuple((d.dim_param or d.dim_value) for d in onx.type.tensor_type.shape.dim)
|
|
367
371
|
shape_str = ",".join(map(str, shape))
|
|
368
|
-
return f"{onnx_dtype_name(itype)}[{shape_str}] {name}"
|
|
372
|
+
return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
|
|
369
373
|
|
|
370
374
|
if isinstance(onx, AttributeProto):
|
|
371
375
|
att = onx
|
|
@@ -767,7 +771,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
|
|
|
767
771
|
|
|
768
772
|
|
|
769
773
|
def iterator_initializer_constant(
|
|
770
|
-
model: Union[
|
|
774
|
+
model: Union[FunctionProto, GraphProto, ModelProto],
|
|
771
775
|
use_numpy: bool = True,
|
|
772
776
|
prefix: str = "",
|
|
773
777
|
) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821
|
|
@@ -779,8 +783,8 @@ def iterator_initializer_constant(
|
|
|
779
783
|
:param prefix: for subgraph
|
|
780
784
|
:return: iterator
|
|
781
785
|
"""
|
|
782
|
-
if not isinstance(model,
|
|
783
|
-
graph = model if isinstance(model,
|
|
786
|
+
if not isinstance(model, FunctionProto):
|
|
787
|
+
graph = model if isinstance(model, GraphProto) else model.graph
|
|
784
788
|
if not use_numpy:
|
|
785
789
|
from .torch_helper import to_tensor
|
|
786
790
|
if prefix:
|
|
@@ -791,7 +795,7 @@ def iterator_initializer_constant(
|
|
|
791
795
|
)
|
|
792
796
|
nodes = graph.node
|
|
793
797
|
name = graph.name
|
|
794
|
-
if isinstance(model,
|
|
798
|
+
if isinstance(model, ModelProto):
|
|
795
799
|
for f in model.functions:
|
|
796
800
|
yield from iterator_initializer_constant(
|
|
797
801
|
f, use_numpy=use_numpy, prefix=f"{prefix}{f.name}"
|
|
@@ -908,3 +912,283 @@ def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union
|
|
|
908
912
|
qu = np.quantile(tensor, ii)
|
|
909
913
|
stat.update({f"q{i}": float(q) for i, q in zip(ii, qu)})
|
|
910
914
|
return stat
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
class NodeCoordinates:
|
|
918
|
+
"""
|
|
919
|
+
A way to localize a node,
|
|
920
|
+
path is a tuple of three information, node index, node type, node name.
|
|
921
|
+
"""
|
|
922
|
+
|
|
923
|
+
__slots__ = ("node", "path")
|
|
924
|
+
|
|
925
|
+
def __init__(
|
|
926
|
+
self,
|
|
927
|
+
node: Union[onnx.TensorProto, NodeProto, str],
|
|
928
|
+
path: Tuple[Tuple[int, str, str], ...],
|
|
929
|
+
):
|
|
930
|
+
assert isinstance(path, tuple), f"Unexpected type {type(path)} for path"
|
|
931
|
+
assert all(isinstance(t, tuple) for t in path), f"Unexpected type in path={path}"
|
|
932
|
+
self.node = node
|
|
933
|
+
self.path = path
|
|
934
|
+
|
|
935
|
+
def __str__(self) -> str:
|
|
936
|
+
"usual"
|
|
937
|
+
if isinstance(self.node, str):
|
|
938
|
+
return f"{self.path_to_str()} :: {self.node!r}"
|
|
939
|
+
return f"{self.path_to_str()} :: {pretty_onnx(self.node)}"
|
|
940
|
+
|
|
941
|
+
def path_to_str(self) -> str:
|
|
942
|
+
"Strings representing coordinates."
|
|
943
|
+
return "x".join(f"({':'.join(map(str, t))})" for t in self.path)
|
|
944
|
+
|
|
945
|
+
|
|
946
|
+
class ResultFound:
|
|
947
|
+
"""
|
|
948
|
+
Class returned by :func:`enumerate_results`.
|
|
949
|
+
"""
|
|
950
|
+
|
|
951
|
+
__slots__ = ("consumer", "name", "producer")
|
|
952
|
+
|
|
953
|
+
def __init__(
|
|
954
|
+
self,
|
|
955
|
+
name: str,
|
|
956
|
+
producer: Optional[NodeCoordinates],
|
|
957
|
+
consumer: Optional[NodeCoordinates],
|
|
958
|
+
):
|
|
959
|
+
assert isinstance(name, str), f"unexpected type {type(name)} for name"
|
|
960
|
+
self.name = name
|
|
961
|
+
self.producer = producer
|
|
962
|
+
self.consumer = consumer
|
|
963
|
+
|
|
964
|
+
def __str__(self) -> str:
|
|
965
|
+
"usuals"
|
|
966
|
+
return (
|
|
967
|
+
f"<< {self.name} - {self.consumer}"
|
|
968
|
+
if self.producer is None
|
|
969
|
+
else f">> {self.name} - {self.producer}"
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
def enumerate_results(
|
|
974
|
+
proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
|
|
975
|
+
name: Union[Set[str], str],
|
|
976
|
+
verbose: int = 0,
|
|
977
|
+
coordinates: Optional[List[Tuple[int, str, str]]] = None,
|
|
978
|
+
) -> Iterator[ResultFound]:
|
|
979
|
+
"""
|
|
980
|
+
Iterates on all nodes, attributes to find where a name is used.
|
|
981
|
+
|
|
982
|
+
:param proto: a proto
|
|
983
|
+
:param name: name or names to find
|
|
984
|
+
:param verbose: verbosity
|
|
985
|
+
:param coordinates: coordinates of a node
|
|
986
|
+
:return: iterator on :class:`ResultFound`
|
|
987
|
+
"""
|
|
988
|
+
if not isinstance(name, set):
|
|
989
|
+
name = {name}
|
|
990
|
+
coordinates = coordinates or []
|
|
991
|
+
assert all(
|
|
992
|
+
isinstance(c, tuple) for c in coordinates
|
|
993
|
+
), f"Unexpected type in coordinates={coordinates}"
|
|
994
|
+
indent = " " * len(coordinates)
|
|
995
|
+
if isinstance(proto, ModelProto):
|
|
996
|
+
if verbose:
|
|
997
|
+
print(f"[enumerate_results] {indent}searching for {name!r} into ModelProto...")
|
|
998
|
+
yield from enumerate_results(proto.graph, name, verbose=verbose)
|
|
999
|
+
elif isinstance(proto, FunctionProto):
|
|
1000
|
+
if verbose:
|
|
1001
|
+
print(f"[enumerate_results] {indent}searching for {name!r} into FunctionProto...")
|
|
1002
|
+
for i in proto.input:
|
|
1003
|
+
if i in name:
|
|
1004
|
+
r = ResultFound(
|
|
1005
|
+
i,
|
|
1006
|
+
NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
|
|
1007
|
+
None,
|
|
1008
|
+
)
|
|
1009
|
+
if verbose > 1:
|
|
1010
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1011
|
+
yield r
|
|
1012
|
+
yield from enumerate_results(proto.node, name, verbose=verbose)
|
|
1013
|
+
for i in proto.output:
|
|
1014
|
+
if i in name:
|
|
1015
|
+
r = ResultFound(
|
|
1016
|
+
i,
|
|
1017
|
+
None,
|
|
1018
|
+
NodeCoordinates(
|
|
1019
|
+
i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
|
|
1020
|
+
),
|
|
1021
|
+
)
|
|
1022
|
+
if verbose > 1:
|
|
1023
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1024
|
+
yield r
|
|
1025
|
+
elif isinstance(proto, GraphProto):
|
|
1026
|
+
if verbose:
|
|
1027
|
+
print(f"[enumerate_results] {indent}searching for {name!r} into GraphProto...")
|
|
1028
|
+
for i in proto.initializer:
|
|
1029
|
+
if i.name in name:
|
|
1030
|
+
r = ResultFound(
|
|
1031
|
+
i.name,
|
|
1032
|
+
NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
|
|
1033
|
+
None,
|
|
1034
|
+
)
|
|
1035
|
+
if verbose > 1:
|
|
1036
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1037
|
+
yield r
|
|
1038
|
+
for i in proto.sparse_initializer:
|
|
1039
|
+
if i.name in name:
|
|
1040
|
+
r = ResultFound(
|
|
1041
|
+
i.name,
|
|
1042
|
+
NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
|
|
1043
|
+
None,
|
|
1044
|
+
)
|
|
1045
|
+
if verbose > 1:
|
|
1046
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1047
|
+
yield r
|
|
1048
|
+
for i in proto.input:
|
|
1049
|
+
if i.name in name:
|
|
1050
|
+
r = ResultFound(
|
|
1051
|
+
i.name,
|
|
1052
|
+
NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
|
|
1053
|
+
None,
|
|
1054
|
+
)
|
|
1055
|
+
if verbose > 1:
|
|
1056
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1057
|
+
yield r
|
|
1058
|
+
yield from enumerate_results(
|
|
1059
|
+
proto.node, name, verbose=verbose, coordinates=coordinates
|
|
1060
|
+
)
|
|
1061
|
+
for i in proto.output:
|
|
1062
|
+
if i.name in name:
|
|
1063
|
+
r = ResultFound(
|
|
1064
|
+
i.name,
|
|
1065
|
+
None,
|
|
1066
|
+
NodeCoordinates(
|
|
1067
|
+
i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
|
|
1068
|
+
),
|
|
1069
|
+
)
|
|
1070
|
+
if verbose > 1:
|
|
1071
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1072
|
+
yield r
|
|
1073
|
+
else:
|
|
1074
|
+
if verbose:
|
|
1075
|
+
print(
|
|
1076
|
+
f"[enumerate_results] {indent}searching for {name!r} into List[NodeProto]..."
|
|
1077
|
+
)
|
|
1078
|
+
for node_i, node in enumerate(proto):
|
|
1079
|
+
if set(node.input) & name:
|
|
1080
|
+
for n in node.input:
|
|
1081
|
+
if n in name:
|
|
1082
|
+
r = ResultFound(
|
|
1083
|
+
n,
|
|
1084
|
+
NodeCoordinates(
|
|
1085
|
+
node,
|
|
1086
|
+
tuple( # noqa: C409
|
|
1087
|
+
[*coordinates, (node_i, node.op_type, node.name)]
|
|
1088
|
+
),
|
|
1089
|
+
),
|
|
1090
|
+
None,
|
|
1091
|
+
)
|
|
1092
|
+
if verbose > 1:
|
|
1093
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1094
|
+
yield r
|
|
1095
|
+
if node.op_type in {"If", "Scan", "Loop", "SequenceMap"}:
|
|
1096
|
+
for att in node.attribute:
|
|
1097
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1098
|
+
yield from enumerate_results(
|
|
1099
|
+
att.g,
|
|
1100
|
+
name,
|
|
1101
|
+
verbose=verbose,
|
|
1102
|
+
coordinates=[*coordinates, (node_i, node.op_type, node.name)],
|
|
1103
|
+
)
|
|
1104
|
+
if set(node.output) & name:
|
|
1105
|
+
for n in node.output:
|
|
1106
|
+
if n in name:
|
|
1107
|
+
r = ResultFound(
|
|
1108
|
+
n,
|
|
1109
|
+
None,
|
|
1110
|
+
NodeCoordinates(
|
|
1111
|
+
node,
|
|
1112
|
+
tuple( # noqa: C409
|
|
1113
|
+
[*coordinates, (node_i, node.op_type, node.name)]
|
|
1114
|
+
),
|
|
1115
|
+
),
|
|
1116
|
+
)
|
|
1117
|
+
if verbose > 1:
|
|
1118
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1119
|
+
yield r
|
|
1120
|
+
if verbose:
|
|
1121
|
+
print(f"[enumerate_results] {indent}done")
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
def shadowing_names(
|
|
1125
|
+
proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
|
|
1126
|
+
verbose: int = 0,
|
|
1127
|
+
existing: Optional[Set[str]] = None,
|
|
1128
|
+
shadow_context: Optional[Set[str]] = None,
|
|
1129
|
+
post_shadow_context: Optional[Set[str]] = None,
|
|
1130
|
+
) -> Tuple[Set[str], Set[str], Set[str]]:
|
|
1131
|
+
"""
|
|
1132
|
+
Returns the shadowing names, the names created in the main graph
|
|
1133
|
+
after they were created in a subgraphs and the names created by the nodes.
|
|
1134
|
+
"""
|
|
1135
|
+
if isinstance(proto, ModelProto):
|
|
1136
|
+
return shadowing_names(proto.graph)
|
|
1137
|
+
if isinstance(proto, GraphProto):
|
|
1138
|
+
assert (
|
|
1139
|
+
existing is None and shadow_context is None
|
|
1140
|
+
), "existing must be None if nodes is None"
|
|
1141
|
+
return shadowing_names(
|
|
1142
|
+
proto.node,
|
|
1143
|
+
verbose=verbose,
|
|
1144
|
+
existing=set(i.name for i in proto.initializer)
|
|
1145
|
+
| set(i.name for i in proto.sparse_initializer)
|
|
1146
|
+
| set(i.name for i in proto.input if i.name),
|
|
1147
|
+
shadow_context=set(),
|
|
1148
|
+
post_shadow_context=set(),
|
|
1149
|
+
)
|
|
1150
|
+
if isinstance(proto, FunctionProto):
|
|
1151
|
+
assert (
|
|
1152
|
+
existing is None and shadow_context is None
|
|
1153
|
+
), "existing must be None if nodes is None"
|
|
1154
|
+
return shadowing_names(
|
|
1155
|
+
proto.node,
|
|
1156
|
+
verbose=verbose,
|
|
1157
|
+
existing=set(i for i in proto.input if i),
|
|
1158
|
+
shadow_context=set(),
|
|
1159
|
+
post_shadow_context=set(),
|
|
1160
|
+
)
|
|
1161
|
+
|
|
1162
|
+
assert (
|
|
1163
|
+
existing is not None and shadow_context is not None
|
|
1164
|
+
), "existing must not be None if nodes is not None"
|
|
1165
|
+
shadow = set()
|
|
1166
|
+
shadow_context = shadow_context.copy()
|
|
1167
|
+
existing = existing.copy()
|
|
1168
|
+
created = set()
|
|
1169
|
+
post_shadow = set()
|
|
1170
|
+
for node in proto:
|
|
1171
|
+
not_empty = set(n for n in node.input if n)
|
|
1172
|
+
intersection = not_empty & existing
|
|
1173
|
+
assert len(intersection) == len(not_empty), (
|
|
1174
|
+
f"One input in {not_empty}, node={pretty_onnx(node)} "
|
|
1175
|
+
f"was not found in {existing}"
|
|
1176
|
+
)
|
|
1177
|
+
for att in node.attribute:
|
|
1178
|
+
if att.type == AttributeProto.GRAPH:
|
|
1179
|
+
g = att.g
|
|
1180
|
+
shadow |= set(i.name for i in g.input) & shadow_context
|
|
1181
|
+
shadow |= set(i.name for i in g.initializer) & shadow_context
|
|
1182
|
+
shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
|
|
1183
|
+
s, ps, c = shadowing_names(
|
|
1184
|
+
g.node, verbose=verbose, existing=existing, shadow_context=existing
|
|
1185
|
+
)
|
|
1186
|
+
shadow |= s
|
|
1187
|
+
created |= c
|
|
1188
|
+
|
|
1189
|
+
not_empty = set(n for n in node.output if n)
|
|
1190
|
+
post_shadow |= not_empty & created
|
|
1191
|
+
shadow |= not_empty & shadow_context
|
|
1192
|
+
existing |= not_empty
|
|
1193
|
+
created |= not_empty
|
|
1194
|
+
return shadow, post_shadow, created
|