onnx-diagnostic 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +24 -3
  3. onnx_diagnostic/doc.py +46 -0
  4. onnx_diagnostic/helpers/doc_helper.py +163 -0
  5. onnx_diagnostic/helpers/model_builder_helper.py +3 -0
  6. onnx_diagnostic/helpers/onnx_helper.py +291 -7
  7. onnx_diagnostic/reference/torch_evaluator.py +141 -11
  8. onnx_diagnostic/reference/torch_ops/__init__.py +1 -1
  9. onnx_diagnostic/reference/torch_ops/_op_run.py +14 -5
  10. onnx_diagnostic/reference/torch_ops/access_ops.py +18 -8
  11. onnx_diagnostic/reference/torch_ops/binary_ops.py +2 -2
  12. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +7 -4
  13. onnx_diagnostic/reference/torch_ops/generator_ops.py +4 -3
  14. onnx_diagnostic/reference/torch_ops/nn_ops.py +34 -14
  15. onnx_diagnostic/reference/torch_ops/other_ops.py +19 -19
  16. onnx_diagnostic/reference/torch_ops/reduce_ops.py +6 -6
  17. onnx_diagnostic/reference/torch_ops/sequence_ops.py +6 -6
  18. onnx_diagnostic/reference/torch_ops/shape_ops.py +16 -15
  19. onnx_diagnostic/reference/torch_ops/unary_ops.py +13 -13
  20. onnx_diagnostic/torch_export_patches/patch_module_helper.py +1 -0
  21. onnx_diagnostic/torch_models/test_helper.py +34 -12
  22. {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/METADATA +1 -1
  23. {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/RECORD +26 -25
  24. {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/WHEEL +0 -0
  25. {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/licenses/LICENSE.txt +0 -0
  26. {onnx_diagnostic-0.6.1.dist-info → onnx_diagnostic-0.6.3.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.6.1"
6
+ __version__ = "0.6.3"
7
7
  __author__ = "Xavier Dupré"
@@ -191,24 +191,45 @@ def get_parser_find() -> ArgumentParser:
191
191
  "--names",
192
192
  type=str,
193
193
  required=False,
194
- help="names to look at comma separated values",
194
+ help="names to look at comma separated values, if 'SHADOW', "
195
+ "search for shadowing names",
195
196
  )
196
197
  parser.add_argument(
197
198
  "-v",
198
199
  "--verbose",
199
200
  default=0,
201
+ type=int,
200
202
  required=False,
201
203
  help="verbosity",
202
204
  )
205
+ parser.add_argument(
206
+ "--v2",
207
+ default=False,
208
+ action=BooleanOptionalAction,
209
+ help="use enumerate_results instead of onnx_find",
210
+ )
203
211
  return parser
204
212
 
205
213
 
206
214
  def _cmd_find(argv: List[Any]):
207
- from .helpers.onnx_helper import onnx_find
215
+ from .helpers.onnx_helper import onnx_find, enumerate_results, shadowing_names
208
216
 
209
217
  parser = get_parser_find()
210
218
  args = parser.parse_args(argv[1:])
211
- onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
219
+ if args.names == "SHADOW":
220
+ onx = onnx.load(args.input, load_external_data=False)
221
+ s, ps = shadowing_names(onx)[:2]
222
+ print(f"shadowing names: {s}")
223
+ print(f"post-shadowing names: {ps}")
224
+ elif args.v2:
225
+ onx = onnx.load(args.input, load_external_data=False)
226
+ res = list(
227
+ enumerate_results(onx, name=set(args.names.split(",")), verbose=args.verbose)
228
+ )
229
+ if not args.verbose:
230
+ print("\n".join(map(str, res)))
231
+ else:
232
+ onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
212
233
 
213
234
 
214
235
  def get_parser_config() -> ArgumentParser:
onnx_diagnostic/doc.py CHANGED
@@ -1,3 +1,7 @@
1
+ from typing import Optional
2
+ import numpy as np
3
+
4
+
1
5
  def reset_torch_transformers(gallery_conf, fname):
2
6
  "Resets torch dynamo for :epkg:`sphinx-gallery`."
3
7
  import matplotlib.pyplot as plt
@@ -30,3 +34,45 @@ def plot_legend(
30
34
  ax.grid(False)
31
35
  ax.set_axis_off()
32
36
  return ax
37
+
38
+
39
+ def rotate_align(ax, angle=15, align="right"):
40
+ """Rotates x-label and align them to thr right. Returns ax."""
41
+ for label in ax.get_xticklabels():
42
+ label.set_rotation(angle)
43
+ label.set_horizontalalignment(align)
44
+ return ax
45
+
46
+
47
+ def save_fig(ax, name: str):
48
+ """Applies ``tight_layout`` and saves the figures. Returns ax."""
49
+ import matplotlib.pyplot as plt
50
+
51
+ plt.tight_layout()
52
+ fig = ax.get_figure()
53
+ fig.savefig(name)
54
+ return ax
55
+
56
+
57
+ def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821
58
+ "Adds a title to axes and returns them."
59
+ ax.set_title(title)
60
+ return ax
61
+
62
+
63
+ def plot_histogram(
64
+ tensor: np.ndarray,
65
+ ax: Optional["plt.axes"] = None, # noqa: F821
66
+ bins: int = 30,
67
+ color: str = "orange",
68
+ alpha: float = 0.7,
69
+ ) -> "plt.axes": # noqa: F821
70
+ "Computes the distribution for a tensor."
71
+ if ax is None:
72
+ import matplotlib.pyplot as plt
73
+
74
+ ax = plt.gca()
75
+ ax.cla()
76
+ ax.hist(tensor, bins=30, color="orange", alpha=0.7)
77
+ ax.set_yscale("log")
78
+ return ax
@@ -0,0 +1,163 @@
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple
3
+ import onnx
4
+ import onnx.helper as oh
5
+ import torch
6
+ from ..reference.torch_ops import OpRunKernel, OpRunTensor
7
+ from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
8
+ from .ort_session import InferenceSessionForTorch
9
+
10
+ _SAVED: List[str] = []
11
+ _SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
12
+
13
+
14
+ def _get_model_name(op_name: str, provider: str) -> Optional[str]:
15
+ if _SAVE_OPTIMIZED_MODEL_:
16
+ name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
17
+ _SAVED.append(name)
18
+ return name
19
+ return None
20
+
21
+
22
+ class LayerNormalizationOrt(OpRunKernel):
23
+ "LayerNormalization with onnxruntime"
24
+
25
+ @classmethod
26
+ def device_dependent(cls) -> bool:
27
+ "Needs device."
28
+ return True
29
+
30
+ def __init__(
31
+ self,
32
+ node: onnx.NodeProto,
33
+ version=None,
34
+ device: Optional[torch.device] = None,
35
+ verbose: int = 0,
36
+ ):
37
+ super().__init__(node, version, verbose=verbose)
38
+ self.axis = self.get_attribute_int(node, "axis", -1)
39
+ self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
40
+ self.device = device
41
+ self.stash_type = onnx_dtype_to_torch_dtype(
42
+ self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
43
+ )
44
+ self.compute_std = len(node.output) > 1
45
+ assert not self.compute_std, (
46
+ f"This kernel implementation only work when only one output "
47
+ f"is required but {node.output} were."
48
+ )
49
+ self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
50
+ self.is_cpu = torch.device("cpu") == self.device
51
+
52
+ def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
53
+ shape = [*["d{i}" for i in range(rank - 1)], "last"]
54
+ layer_model = oh.make_model(
55
+ oh.make_graph(
56
+ [
57
+ oh.make_node(
58
+ "LayerNormalization",
59
+ ["X", "W", "B"] if has_bias else ["X", "W"],
60
+ ["Z"],
61
+ axis=self.axis,
62
+ epsilon=self.epsilon,
63
+ )
64
+ ],
65
+ "dummy",
66
+ (
67
+ [
68
+ oh.make_tensor_value_info("X", itype, shape),
69
+ oh.make_tensor_value_info("W", itype, ["last"]),
70
+ oh.make_tensor_value_info("B", itype, ["last"]),
71
+ ]
72
+ if has_bias
73
+ else [
74
+ oh.make_tensor_value_info("X", itype, shape),
75
+ oh.make_tensor_value_info("W", itype, ["last"]),
76
+ ]
77
+ ),
78
+ [oh.make_tensor_value_info("Z", itype, shape)],
79
+ ),
80
+ ir_version=9,
81
+ opset_imports=[oh.make_opsetid("", 18)],
82
+ )
83
+ provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
84
+ self._provider = provider
85
+ return InferenceSessionForTorch(
86
+ layer_model,
87
+ optimized_model_filepath=_get_model_name("layer_norm", provider),
88
+ providers=[provider],
89
+ )
90
+
91
+ def run(self, x, scale, bias=None):
92
+ itype = torch_dtype_to_onnx_dtype(x.dtype)
93
+ rank = len(x.shape)
94
+ key = itype, rank
95
+ if key not in self._cache:
96
+ self._cache[key] = self._make_model(itype, rank, bias is not None)
97
+ sess = self._cache[key]
98
+ if self.verbose:
99
+ print(f"[LayerNormalizationOrt] running on {self._provider!r}")
100
+ feeds = dict(X=x.tensor, W=scale.tensor)
101
+ if bias is not None:
102
+ feeds["B"] = bias.tensor
103
+ got = sess.run(None, feeds)[0]
104
+ return OpRunTensor(got)
105
+
106
+
107
+ class MatMulOrt(OpRunKernel):
108
+ "MatMul with onnxruntime"
109
+
110
+ @classmethod
111
+ def device_dependent(cls) -> bool:
112
+ "Needs device."
113
+ return True
114
+
115
+ def __init__(
116
+ self,
117
+ node: onnx.NodeProto,
118
+ version=None,
119
+ device: Optional[torch.device] = None,
120
+ verbose: int = 0,
121
+ ):
122
+ super().__init__(node, version, verbose=verbose)
123
+ self.device = device
124
+ self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
125
+ self.is_cpu = torch.device("cpu") == self.device
126
+
127
+ def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
128
+ shapea = ["a{i}" for i in range(ranka)]
129
+ shapeb = ["b{i}" for i in range(rankb)]
130
+ shapec = ["c{i}" for i in range(max(ranka, rankb))]
131
+ model = oh.make_model(
132
+ oh.make_graph(
133
+ [oh.make_node("MatMul", ["A", "B"], ["C"])],
134
+ "dummy",
135
+ [
136
+ oh.make_tensor_value_info("A", itype, shapea),
137
+ oh.make_tensor_value_info("B", itype, shapeb),
138
+ ],
139
+ [oh.make_tensor_value_info("C", itype, shapec)],
140
+ ),
141
+ ir_version=9,
142
+ opset_imports=[oh.make_opsetid("", 18)],
143
+ )
144
+ provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
145
+ self._provider = provider
146
+ return InferenceSessionForTorch(
147
+ model,
148
+ optimized_model_filepath=_get_model_name("matmul", provider),
149
+ providers=[provider],
150
+ )
151
+
152
+ def run(self, a, b):
153
+ itype = torch_dtype_to_onnx_dtype(a.dtype)
154
+ ranka, rankb = len(a.shape), len(b.shape)
155
+ key = itype, ranka, rankb
156
+ if key not in self._cache:
157
+ self._cache[key] = self._make_model(itype, ranka, rankb)
158
+ sess = self._cache[key]
159
+ if self.verbose:
160
+ print(f"[MatMulOrt] running on {self._provider!r}")
161
+ feeds = dict(A=a.tensor, B=b.tensor)
162
+ got = sess.run(None, feeds)[0]
163
+ return OpRunTensor(got)
@@ -220,6 +220,9 @@ def create_model_builder(
220
220
  """
221
221
  assert cache_dir, "create_model_builder does not work without cache_dir."
222
222
  assert os.path.exists(cache_dir), f"cache_dir={cache_dir!r} does not exists"
223
+ precision = {"float32": "fp32", "float16": "fp16", "bfloat16": "bfp16"}.get(
224
+ precision, precision
225
+ )
223
226
  download_model_builder_to_cache()
224
227
  builder = import_model_builder()
225
228
  io_dtype = builder.set_io_dtype(precision, execution_provider, extra_options)
@@ -316,7 +316,7 @@ def check_model_ort(
316
316
 
317
317
 
318
318
  @functools.cache
319
- def onnx_dtype_name(itype: int) -> str:
319
+ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
320
320
  """
321
321
  Returns the ONNX name for a specific element type.
322
322
 
@@ -335,7 +335,11 @@ def onnx_dtype_name(itype: int) -> str:
335
335
  v = getattr(TensorProto, k)
336
336
  if v == itype:
337
337
  return k
338
- raise ValueError(f"Unexpected value itype: {itype}")
338
+ if exc:
339
+ raise ValueError(f"Unexpected value itype: {itype}")
340
+ if itype == 0:
341
+ return "UNDEFINED"
342
+ return "UNEXPECTED"
339
343
 
340
344
 
341
345
  def pretty_onnx(
@@ -365,7 +369,7 @@ def pretty_onnx(
365
369
  itype = onx.type.tensor_type.elem_type
366
370
  shape = tuple((d.dim_param or d.dim_value) for d in onx.type.tensor_type.shape.dim)
367
371
  shape_str = ",".join(map(str, shape))
368
- return f"{onnx_dtype_name(itype)}[{shape_str}] {name}"
372
+ return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
369
373
 
370
374
  if isinstance(onx, AttributeProto):
371
375
  att = onx
@@ -767,7 +771,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
767
771
 
768
772
 
769
773
  def iterator_initializer_constant(
770
- model: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
774
+ model: Union[FunctionProto, GraphProto, ModelProto],
771
775
  use_numpy: bool = True,
772
776
  prefix: str = "",
773
777
  ) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821
@@ -779,8 +783,8 @@ def iterator_initializer_constant(
779
783
  :param prefix: for subgraph
780
784
  :return: iterator
781
785
  """
782
- if not isinstance(model, onnx.FunctionProto):
783
- graph = model if isinstance(model, onnx.GraphProto) else model.graph
786
+ if not isinstance(model, FunctionProto):
787
+ graph = model if isinstance(model, GraphProto) else model.graph
784
788
  if not use_numpy:
785
789
  from .torch_helper import to_tensor
786
790
  if prefix:
@@ -791,7 +795,7 @@ def iterator_initializer_constant(
791
795
  )
792
796
  nodes = graph.node
793
797
  name = graph.name
794
- if isinstance(model, onnx.ModelProto):
798
+ if isinstance(model, ModelProto):
795
799
  for f in model.functions:
796
800
  yield from iterator_initializer_constant(
797
801
  f, use_numpy=use_numpy, prefix=f"{prefix}{f.name}"
@@ -908,3 +912,283 @@ def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union
908
912
  qu = np.quantile(tensor, ii)
909
913
  stat.update({f"q{i}": float(q) for i, q in zip(ii, qu)})
910
914
  return stat
915
+
916
+
917
+ class NodeCoordinates:
918
+ """
919
+ A way to localize a node,
920
+ path is a tuple of three information, node index, node type, node name.
921
+ """
922
+
923
+ __slots__ = ("node", "path")
924
+
925
+ def __init__(
926
+ self,
927
+ node: Union[onnx.TensorProto, NodeProto, str],
928
+ path: Tuple[Tuple[int, str, str], ...],
929
+ ):
930
+ assert isinstance(path, tuple), f"Unexpected type {type(path)} for path"
931
+ assert all(isinstance(t, tuple) for t in path), f"Unexpected type in path={path}"
932
+ self.node = node
933
+ self.path = path
934
+
935
+ def __str__(self) -> str:
936
+ "usual"
937
+ if isinstance(self.node, str):
938
+ return f"{self.path_to_str()} :: {self.node!r}"
939
+ return f"{self.path_to_str()} :: {pretty_onnx(self.node)}"
940
+
941
+ def path_to_str(self) -> str:
942
+ "Strings representing coordinates."
943
+ return "x".join(f"({':'.join(map(str, t))})" for t in self.path)
944
+
945
+
946
+ class ResultFound:
947
+ """
948
+ Class returned by :func:`enumerate_results`.
949
+ """
950
+
951
+ __slots__ = ("consumer", "name", "producer")
952
+
953
+ def __init__(
954
+ self,
955
+ name: str,
956
+ producer: Optional[NodeCoordinates],
957
+ consumer: Optional[NodeCoordinates],
958
+ ):
959
+ assert isinstance(name, str), f"unexpected type {type(name)} for name"
960
+ self.name = name
961
+ self.producer = producer
962
+ self.consumer = consumer
963
+
964
+ def __str__(self) -> str:
965
+ "usuals"
966
+ return (
967
+ f"<< {self.name} - {self.consumer}"
968
+ if self.producer is None
969
+ else f">> {self.name} - {self.producer}"
970
+ )
971
+
972
+
973
+ def enumerate_results(
974
+ proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
975
+ name: Union[Set[str], str],
976
+ verbose: int = 0,
977
+ coordinates: Optional[List[Tuple[int, str, str]]] = None,
978
+ ) -> Iterator[ResultFound]:
979
+ """
980
+ Iterates on all nodes, attributes to find where a name is used.
981
+
982
+ :param proto: a proto
983
+ :param name: name or names to find
984
+ :param verbose: verbosity
985
+ :param coordinates: coordinates of a node
986
+ :return: iterator on :class:`ResultFound`
987
+ """
988
+ if not isinstance(name, set):
989
+ name = {name}
990
+ coordinates = coordinates or []
991
+ assert all(
992
+ isinstance(c, tuple) for c in coordinates
993
+ ), f"Unexpected type in coordinates={coordinates}"
994
+ indent = " " * len(coordinates)
995
+ if isinstance(proto, ModelProto):
996
+ if verbose:
997
+ print(f"[enumerate_results] {indent}searching for {name!r} into ModelProto...")
998
+ yield from enumerate_results(proto.graph, name, verbose=verbose)
999
+ elif isinstance(proto, FunctionProto):
1000
+ if verbose:
1001
+ print(f"[enumerate_results] {indent}searching for {name!r} into FunctionProto...")
1002
+ for i in proto.input:
1003
+ if i in name:
1004
+ r = ResultFound(
1005
+ i,
1006
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
1007
+ None,
1008
+ )
1009
+ if verbose > 1:
1010
+ print(f"[enumerate_results] {indent}-- {r}")
1011
+ yield r
1012
+ yield from enumerate_results(proto.node, name, verbose=verbose)
1013
+ for i in proto.output:
1014
+ if i in name:
1015
+ r = ResultFound(
1016
+ i,
1017
+ None,
1018
+ NodeCoordinates(
1019
+ i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
1020
+ ),
1021
+ )
1022
+ if verbose > 1:
1023
+ print(f"[enumerate_results] {indent}-- {r}")
1024
+ yield r
1025
+ elif isinstance(proto, GraphProto):
1026
+ if verbose:
1027
+ print(f"[enumerate_results] {indent}searching for {name!r} into GraphProto...")
1028
+ for i in proto.initializer:
1029
+ if i.name in name:
1030
+ r = ResultFound(
1031
+ i.name,
1032
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
1033
+ None,
1034
+ )
1035
+ if verbose > 1:
1036
+ print(f"[enumerate_results] {indent}-- {r}")
1037
+ yield r
1038
+ for i in proto.sparse_initializer:
1039
+ if i.name in name:
1040
+ r = ResultFound(
1041
+ i.name,
1042
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
1043
+ None,
1044
+ )
1045
+ if verbose > 1:
1046
+ print(f"[enumerate_results] {indent}-- {r}")
1047
+ yield r
1048
+ for i in proto.input:
1049
+ if i.name in name:
1050
+ r = ResultFound(
1051
+ i.name,
1052
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
1053
+ None,
1054
+ )
1055
+ if verbose > 1:
1056
+ print(f"[enumerate_results] {indent}-- {r}")
1057
+ yield r
1058
+ yield from enumerate_results(
1059
+ proto.node, name, verbose=verbose, coordinates=coordinates
1060
+ )
1061
+ for i in proto.output:
1062
+ if i.name in name:
1063
+ r = ResultFound(
1064
+ i.name,
1065
+ None,
1066
+ NodeCoordinates(
1067
+ i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
1068
+ ),
1069
+ )
1070
+ if verbose > 1:
1071
+ print(f"[enumerate_results] {indent}-- {r}")
1072
+ yield r
1073
+ else:
1074
+ if verbose:
1075
+ print(
1076
+ f"[enumerate_results] {indent}searching for {name!r} into List[NodeProto]..."
1077
+ )
1078
+ for node_i, node in enumerate(proto):
1079
+ if set(node.input) & name:
1080
+ for n in node.input:
1081
+ if n in name:
1082
+ r = ResultFound(
1083
+ n,
1084
+ NodeCoordinates(
1085
+ node,
1086
+ tuple( # noqa: C409
1087
+ [*coordinates, (node_i, node.op_type, node.name)]
1088
+ ),
1089
+ ),
1090
+ None,
1091
+ )
1092
+ if verbose > 1:
1093
+ print(f"[enumerate_results] {indent}-- {r}")
1094
+ yield r
1095
+ if node.op_type in {"If", "Scan", "Loop", "SequenceMap"}:
1096
+ for att in node.attribute:
1097
+ if att.type == onnx.AttributeProto.GRAPH:
1098
+ yield from enumerate_results(
1099
+ att.g,
1100
+ name,
1101
+ verbose=verbose,
1102
+ coordinates=[*coordinates, (node_i, node.op_type, node.name)],
1103
+ )
1104
+ if set(node.output) & name:
1105
+ for n in node.output:
1106
+ if n in name:
1107
+ r = ResultFound(
1108
+ n,
1109
+ None,
1110
+ NodeCoordinates(
1111
+ node,
1112
+ tuple( # noqa: C409
1113
+ [*coordinates, (node_i, node.op_type, node.name)]
1114
+ ),
1115
+ ),
1116
+ )
1117
+ if verbose > 1:
1118
+ print(f"[enumerate_results] {indent}-- {r}")
1119
+ yield r
1120
+ if verbose:
1121
+ print(f"[enumerate_results] {indent}done")
1122
+
1123
+
1124
+ def shadowing_names(
1125
+ proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
1126
+ verbose: int = 0,
1127
+ existing: Optional[Set[str]] = None,
1128
+ shadow_context: Optional[Set[str]] = None,
1129
+ post_shadow_context: Optional[Set[str]] = None,
1130
+ ) -> Tuple[Set[str], Set[str], Set[str]]:
1131
+ """
1132
+ Returns the shadowing names, the names created in the main graph
1133
+ after they were created in a subgraphs and the names created by the nodes.
1134
+ """
1135
+ if isinstance(proto, ModelProto):
1136
+ return shadowing_names(proto.graph)
1137
+ if isinstance(proto, GraphProto):
1138
+ assert (
1139
+ existing is None and shadow_context is None
1140
+ ), "existing must be None if nodes is None"
1141
+ return shadowing_names(
1142
+ proto.node,
1143
+ verbose=verbose,
1144
+ existing=set(i.name for i in proto.initializer)
1145
+ | set(i.name for i in proto.sparse_initializer)
1146
+ | set(i.name for i in proto.input if i.name),
1147
+ shadow_context=set(),
1148
+ post_shadow_context=set(),
1149
+ )
1150
+ if isinstance(proto, FunctionProto):
1151
+ assert (
1152
+ existing is None and shadow_context is None
1153
+ ), "existing must be None if nodes is None"
1154
+ return shadowing_names(
1155
+ proto.node,
1156
+ verbose=verbose,
1157
+ existing=set(i for i in proto.input if i),
1158
+ shadow_context=set(),
1159
+ post_shadow_context=set(),
1160
+ )
1161
+
1162
+ assert (
1163
+ existing is not None and shadow_context is not None
1164
+ ), "existing must not be None if nodes is not None"
1165
+ shadow = set()
1166
+ shadow_context = shadow_context.copy()
1167
+ existing = existing.copy()
1168
+ created = set()
1169
+ post_shadow = set()
1170
+ for node in proto:
1171
+ not_empty = set(n for n in node.input if n)
1172
+ intersection = not_empty & existing
1173
+ assert len(intersection) == len(not_empty), (
1174
+ f"One input in {not_empty}, node={pretty_onnx(node)} "
1175
+ f"was not found in {existing}"
1176
+ )
1177
+ for att in node.attribute:
1178
+ if att.type == AttributeProto.GRAPH:
1179
+ g = att.g
1180
+ shadow |= set(i.name for i in g.input) & shadow_context
1181
+ shadow |= set(i.name for i in g.initializer) & shadow_context
1182
+ shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
1183
+ s, ps, c = shadowing_names(
1184
+ g.node, verbose=verbose, existing=existing, shadow_context=existing
1185
+ )
1186
+ shadow |= s
1187
+ created |= c
1188
+
1189
+ not_empty = set(n for n in node.output if n)
1190
+ post_shadow |= not_empty & created
1191
+ shadow |= not_empty & shadow_context
1192
+ existing |= not_empty
1193
+ created |= not_empty
1194
+ return shadow, post_shadow, created