onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1200 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import warnings
|
|
6
|
+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
7
|
+
import numpy as np
|
|
8
|
+
import numpy.typing as npt
|
|
9
|
+
import onnx
|
|
10
|
+
import onnx.helper as oh
|
|
11
|
+
import onnx.numpy_helper as onh
|
|
12
|
+
from onnx import (
|
|
13
|
+
AttributeProto,
|
|
14
|
+
FunctionProto,
|
|
15
|
+
GraphProto,
|
|
16
|
+
ModelProto,
|
|
17
|
+
NodeProto,
|
|
18
|
+
TensorProto,
|
|
19
|
+
ValueInfoProto,
|
|
20
|
+
load as onnx_load,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _make_stat(init: TensorProto) -> Dict[str, float]:
|
|
25
|
+
"""
|
|
26
|
+
Produces statistics.
|
|
27
|
+
|
|
28
|
+
:param init: tensor
|
|
29
|
+
:return statistics
|
|
30
|
+
"""
|
|
31
|
+
ar = onh.to_array(init)
|
|
32
|
+
return dict(
|
|
33
|
+
mean=float(ar.mean()),
|
|
34
|
+
std=float(ar.std()),
|
|
35
|
+
shape=ar.shape,
|
|
36
|
+
itype=np_dtype_to_tensor_dtype(ar.dtype),
|
|
37
|
+
min=float(ar.min()),
|
|
38
|
+
max=float(ar.max()),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def onnx_lighten(
|
|
43
|
+
onx: Union[str, ModelProto],
|
|
44
|
+
verbose: int = 0,
|
|
45
|
+
) -> Tuple[ModelProto, Dict[str, Dict[str, float]]]:
|
|
46
|
+
"""
|
|
47
|
+
Creates a model without big initializers but stores statistics
|
|
48
|
+
into dictionaries. The function can be reversed with
|
|
49
|
+
:func:`onnx_diagnostic.helpers.onnx_helper.onnx_unlighten`.
|
|
50
|
+
The model is modified inplace.
|
|
51
|
+
|
|
52
|
+
:param onx: model
|
|
53
|
+
:param verbose: verbosity
|
|
54
|
+
:return: new model, statistics
|
|
55
|
+
"""
|
|
56
|
+
if isinstance(onx, str):
|
|
57
|
+
if verbose:
|
|
58
|
+
print(f"[onnx_lighten] load {onx!r}")
|
|
59
|
+
model = onnx.load(onx)
|
|
60
|
+
else:
|
|
61
|
+
assert isinstance(onx, ModelProto), f"Unexpected type {type(onx)}"
|
|
62
|
+
model = onx
|
|
63
|
+
|
|
64
|
+
keep = []
|
|
65
|
+
stats = []
|
|
66
|
+
for init in model.graph.initializer:
|
|
67
|
+
shape = init.dims
|
|
68
|
+
size = np.prod(shape)
|
|
69
|
+
if size > 2**12:
|
|
70
|
+
stat = _make_stat(init)
|
|
71
|
+
stats.append((init.name, stat))
|
|
72
|
+
if verbose:
|
|
73
|
+
print(f"[onnx_lighten] remove initializer {init.name!r} stat={stat}")
|
|
74
|
+
else:
|
|
75
|
+
keep.append(init)
|
|
76
|
+
|
|
77
|
+
del model.graph.initializer[:]
|
|
78
|
+
model.graph.initializer.extend(keep)
|
|
79
|
+
return model, dict(stats)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _get_tensor(min=None, max=None, mean=None, std=None, shape=None, itype=None):
|
|
83
|
+
assert itype is not None, "itype must be specified."
|
|
84
|
+
assert shape is not None, "shape must be specified."
|
|
85
|
+
dtype = tensor_dtype_to_np_dtype(itype)
|
|
86
|
+
if (mean is None or std is None) or (
|
|
87
|
+
min is not None and max is not None and abs(max - min - 1) < 0.01
|
|
88
|
+
):
|
|
89
|
+
if min is None:
|
|
90
|
+
min = 0
|
|
91
|
+
if max is None:
|
|
92
|
+
max = 0
|
|
93
|
+
return (np.random.random(shape) * (max - min) + min).astype(dtype)
|
|
94
|
+
assert std is not None and mean is not None, f"mean={mean} or std={std} is None"
|
|
95
|
+
t = np.random.randn(*shape).astype(dtype)
|
|
96
|
+
return t
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def onnx_unlighten(
|
|
100
|
+
onx: Union[str, ModelProto],
|
|
101
|
+
stats: Optional[Dict[str, Dict[str, float]]] = None,
|
|
102
|
+
verbose: int = 0,
|
|
103
|
+
) -> ModelProto:
|
|
104
|
+
"""
|
|
105
|
+
Function fixing the model produced by function
|
|
106
|
+
:func:`onnx_diagnostic.helpers.onnx_helper.onnx_lighten`.
|
|
107
|
+
The model is modified inplace.
|
|
108
|
+
|
|
109
|
+
:param onx: model
|
|
110
|
+
:param stats: statistics, can be None if onx is a file,
|
|
111
|
+
then it loads the file ``<filename>.stats``,
|
|
112
|
+
it assumes it is json format
|
|
113
|
+
:param verbose: verbosity
|
|
114
|
+
:return: new model, statistics
|
|
115
|
+
"""
|
|
116
|
+
if isinstance(onx, str):
|
|
117
|
+
if stats is None:
|
|
118
|
+
fstats = f"{onx}.stats"
|
|
119
|
+
assert os.path.exists(fstats), f"File {fstats!r} is missing."
|
|
120
|
+
if verbose:
|
|
121
|
+
print(f"[onnx_unlighten] load {fstats!r}")
|
|
122
|
+
with open(fstats, "r") as f:
|
|
123
|
+
stats = json.load(f)
|
|
124
|
+
if verbose:
|
|
125
|
+
print(f"[onnx_unlighten] load {onx!r}")
|
|
126
|
+
model = onnx.load(onx)
|
|
127
|
+
else:
|
|
128
|
+
assert isinstance(onx, ModelProto), f"Unexpected type {type(onx)}"
|
|
129
|
+
model = onx
|
|
130
|
+
assert stats is not None, "stats is missing"
|
|
131
|
+
|
|
132
|
+
keep = []
|
|
133
|
+
for name, stat in stats.items():
|
|
134
|
+
t = _get_tensor(**stat)
|
|
135
|
+
init = from_array_extended(t, name=name)
|
|
136
|
+
keep.append(init)
|
|
137
|
+
|
|
138
|
+
model.graph.initializer.extend(keep)
|
|
139
|
+
return model
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _validate_graph(
|
|
143
|
+
g: GraphProto,
|
|
144
|
+
existing: Set[str],
|
|
145
|
+
verbose: int = 0,
|
|
146
|
+
watch: Optional[Set[str]] = None,
|
|
147
|
+
path: Optional[Sequence[str]] = None,
|
|
148
|
+
):
|
|
149
|
+
found = []
|
|
150
|
+
path = path or ["root"]
|
|
151
|
+
set_init = set(i.name for i in g.initializer)
|
|
152
|
+
set_input = set(i.name for i in g.input)
|
|
153
|
+
existing |= set_init | set_input
|
|
154
|
+
if watch and set_init & watch:
|
|
155
|
+
if verbose:
|
|
156
|
+
print(f"-- found init {set_init & watch} in {path}")
|
|
157
|
+
found.extend([i for i in g.initializer if i.name in set_init & watch])
|
|
158
|
+
if watch and set_input & watch:
|
|
159
|
+
if verbose:
|
|
160
|
+
print(f"-- found input {set_input & watch} in {path}")
|
|
161
|
+
found.extend([i for i in g.input if i.name in set_input & watch])
|
|
162
|
+
try:
|
|
163
|
+
import tqdm
|
|
164
|
+
|
|
165
|
+
loop = tqdm.tqdm(g.node) if verbose else g.node
|
|
166
|
+
except ImportError:
|
|
167
|
+
loop = g.node
|
|
168
|
+
|
|
169
|
+
for node in loop:
|
|
170
|
+
ins = set(node.input) & existing
|
|
171
|
+
if ins != set(node.input):
|
|
172
|
+
raise AssertionError(
|
|
173
|
+
f"One input is missing from node.input={node.input}, "
|
|
174
|
+
f"existing={ins}, path={'/'.join(path)}, "
|
|
175
|
+
f"node: {node.op_type}[{node.name}]"
|
|
176
|
+
)
|
|
177
|
+
if watch and ins & watch:
|
|
178
|
+
if verbose:
|
|
179
|
+
print(
|
|
180
|
+
f"-- found input {ins & watch} in "
|
|
181
|
+
f"{'/'.join(path)}/{node.op_type}[{node.name}]"
|
|
182
|
+
)
|
|
183
|
+
found.append(node)
|
|
184
|
+
for att in node.attribute:
|
|
185
|
+
if att.type == AttributeProto.GRAPH:
|
|
186
|
+
found.extend(
|
|
187
|
+
_validate_graph(
|
|
188
|
+
att.g,
|
|
189
|
+
existing.copy(),
|
|
190
|
+
watch=watch,
|
|
191
|
+
path=[*path, f"{node.op_type}[{node.name}]"],
|
|
192
|
+
verbose=verbose,
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
existing |= set(node.output)
|
|
196
|
+
if watch and set(node.output) & watch:
|
|
197
|
+
if verbose:
|
|
198
|
+
print(
|
|
199
|
+
f"-- found output {set(node.output) & watch} "
|
|
200
|
+
f"in {'/'.join(path)}/{node.op_type}[{node.name}]"
|
|
201
|
+
)
|
|
202
|
+
found.append(node)
|
|
203
|
+
out = set(o.name for o in g.output)
|
|
204
|
+
ins = out & existing
|
|
205
|
+
if ins != out:
|
|
206
|
+
raise AssertionError(
|
|
207
|
+
f"One output is missing, out={node.input}, existing={ins}, path={path}"
|
|
208
|
+
)
|
|
209
|
+
return found
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[str]] = None):
|
|
213
|
+
existing = set(g.input)
|
|
214
|
+
found = []
|
|
215
|
+
for node in g.node:
|
|
216
|
+
ins = set(node.input) & existing
|
|
217
|
+
if ins != set(node.input):
|
|
218
|
+
raise AssertionError(
|
|
219
|
+
f"One input is missing from node.input={node.input}, existing={ins}"
|
|
220
|
+
)
|
|
221
|
+
if watch and ins & watch:
|
|
222
|
+
if verbose:
|
|
223
|
+
print(f"-- found input {ins & watch} in {node.op_type}[{node.name}]")
|
|
224
|
+
found.append(node)
|
|
225
|
+
for att in node.attribute:
|
|
226
|
+
if att.type == AttributeProto.GRAPH:
|
|
227
|
+
found.extend(
|
|
228
|
+
_validate_graph(g, existing.copy(), path=[g.name], verbose=verbose)
|
|
229
|
+
)
|
|
230
|
+
existing |= set(node.output)
|
|
231
|
+
if watch and set(node.output) & watch:
|
|
232
|
+
if verbose:
|
|
233
|
+
print(
|
|
234
|
+
f"-- found output {set(node.output) & watch} "
|
|
235
|
+
f"in {node.op_type}[{node.name}]"
|
|
236
|
+
)
|
|
237
|
+
out = set(g.output)
|
|
238
|
+
ins = out & existing
|
|
239
|
+
if ins != out:
|
|
240
|
+
raise AssertionError(
|
|
241
|
+
f"One output is missing, out={node.input}, existing={ins}, path={g.name}"
|
|
242
|
+
)
|
|
243
|
+
return found
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def onnx_find(
|
|
247
|
+
onx: Union[str, ModelProto], verbose: int = 0, watch: Optional[Set[str]] = None
|
|
248
|
+
) -> List[Union[NodeProto, TensorProto]]:
|
|
249
|
+
"""
|
|
250
|
+
Looks for node producing or consuming some results.
|
|
251
|
+
|
|
252
|
+
:param onx: model
|
|
253
|
+
:param verbose: verbosity
|
|
254
|
+
:param watch: names to search for
|
|
255
|
+
:return: list of nodes
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
if isinstance(onx, str):
|
|
259
|
+
onx = onnx.load(onx, load_external_data=False)
|
|
260
|
+
found = []
|
|
261
|
+
found.extend(_validate_graph(onx.graph, set(), verbose=verbose, watch=watch))
|
|
262
|
+
for f in onx.functions:
|
|
263
|
+
found.extend(_validate_function(f, watch=watch, verbose=verbose))
|
|
264
|
+
if verbose and found:
|
|
265
|
+
print(f"-- found {len(found)} nodes")
|
|
266
|
+
return found
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def check_model_ort(
|
|
270
|
+
onx: ModelProto,
|
|
271
|
+
providers: Optional[Union[str, List[Any]]] = None,
|
|
272
|
+
dump_file: Optional[str] = None,
|
|
273
|
+
) -> "onnxruntime.InferenceSession": # noqa: F821
|
|
274
|
+
"""
|
|
275
|
+
Loads a model with onnxruntime.
|
|
276
|
+
|
|
277
|
+
:param onx: ModelProto
|
|
278
|
+
:param providers: list of providers, None fur CPU, cpu for CPU, cuda for CUDA
|
|
279
|
+
:param dump_file: if not empty, dumps the model into this file if
|
|
280
|
+
an error happened
|
|
281
|
+
:return: InferenceSession
|
|
282
|
+
"""
|
|
283
|
+
from onnxruntime import InferenceSession
|
|
284
|
+
|
|
285
|
+
if providers is None or providers == "cpu":
|
|
286
|
+
providers = ["CPUExecutionProvider"]
|
|
287
|
+
elif not isinstance(providers, list) and providers.startswith("cuda"):
|
|
288
|
+
device_id = 0 if ":" not in providers else int(providers.split(":")[1])
|
|
289
|
+
providers = [
|
|
290
|
+
("CUDAExecutionProvider", {"device_id": device_id}),
|
|
291
|
+
("CPUExecutionProvider", {}),
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
if isinstance(onx, str):
|
|
295
|
+
try:
|
|
296
|
+
return InferenceSession(onx, providers=providers)
|
|
297
|
+
except Exception as e:
|
|
298
|
+
import onnx
|
|
299
|
+
|
|
300
|
+
if dump_file:
|
|
301
|
+
onnx.save(onx, dump_file)
|
|
302
|
+
|
|
303
|
+
raise AssertionError( # noqa: B904
|
|
304
|
+
f"onnxruntime cannot load the model "
|
|
305
|
+
f"due to {e}\n{pretty_onnx(onnx.load(onx))}"
|
|
306
|
+
)
|
|
307
|
+
return
|
|
308
|
+
try:
|
|
309
|
+
return InferenceSession(onx.SerializeToString(), providers=providers)
|
|
310
|
+
except Exception as e:
|
|
311
|
+
if dump_file:
|
|
312
|
+
onnx.save(onx, dump_file)
|
|
313
|
+
raise AssertionError( # noqa: B904
|
|
314
|
+
f"onnxruntime cannot load the modeldue to {e}\n{pretty_onnx(onx)}"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@functools.cache
|
|
319
|
+
def onnx_dtype_name(itype: int, exc: bool = True) -> str:
|
|
320
|
+
"""
|
|
321
|
+
Returns the ONNX name for a specific element type.
|
|
322
|
+
|
|
323
|
+
.. runpython::
|
|
324
|
+
:showcode:
|
|
325
|
+
|
|
326
|
+
import onnx
|
|
327
|
+
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name
|
|
328
|
+
|
|
329
|
+
itype = onnx.TensorProto.BFLOAT16
|
|
330
|
+
print(onnx_dtype_name(itype))
|
|
331
|
+
print(onnx_dtype_name(7))
|
|
332
|
+
"""
|
|
333
|
+
for k in dir(TensorProto):
|
|
334
|
+
if "FLOAT" in k or "INT" in k or "TEXT" in k or "BOOL" in k:
|
|
335
|
+
v = getattr(TensorProto, k)
|
|
336
|
+
if v == itype:
|
|
337
|
+
return k
|
|
338
|
+
if exc:
|
|
339
|
+
raise ValueError(f"Unexpected value itype: {itype}")
|
|
340
|
+
if itype == 0:
|
|
341
|
+
return "UNDEFINED"
|
|
342
|
+
return "UNEXPECTED"
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def pretty_onnx(
|
|
346
|
+
onx: Union[FunctionProto, GraphProto, ModelProto, ValueInfoProto, str],
|
|
347
|
+
with_attributes: bool = False,
|
|
348
|
+
highlight: Optional[Set[str]] = None,
|
|
349
|
+
shape_inference: bool = False,
|
|
350
|
+
) -> str:
|
|
351
|
+
"""
|
|
352
|
+
Displays an onnx prot in a better way.
|
|
353
|
+
|
|
354
|
+
:param with_attributes: displays attributes as well, if only a node is printed
|
|
355
|
+
:param highlight: to highlight some names
|
|
356
|
+
:param shape_inference: run shape inference before printing the model
|
|
357
|
+
:return: text
|
|
358
|
+
"""
|
|
359
|
+
assert onx is not None, "onx cannot be None"
|
|
360
|
+
if isinstance(onx, str):
|
|
361
|
+
onx = onnx_load(onx, load_external_data=False)
|
|
362
|
+
assert onx is not None, "onx cannot be None"
|
|
363
|
+
|
|
364
|
+
if shape_inference:
|
|
365
|
+
onx = onnx.shape_inference.infer_shapes(onx)
|
|
366
|
+
|
|
367
|
+
if isinstance(onx, ValueInfoProto):
|
|
368
|
+
name = onx.name
|
|
369
|
+
itype = onx.type.tensor_type.elem_type
|
|
370
|
+
shape = tuple((d.dim_param or d.dim_value) for d in onx.type.tensor_type.shape.dim)
|
|
371
|
+
shape_str = ",".join(map(str, shape))
|
|
372
|
+
return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
|
|
373
|
+
|
|
374
|
+
if isinstance(onx, AttributeProto):
|
|
375
|
+
att = onx
|
|
376
|
+
if att.type == AttributeProto.INT:
|
|
377
|
+
return f"{att.name}={att.i}"
|
|
378
|
+
if att.type == AttributeProto.INTS:
|
|
379
|
+
return f"{att.name}={att.ints}"
|
|
380
|
+
if att.type == AttributeProto.FLOAT:
|
|
381
|
+
return f"{att.name}={att.f}"
|
|
382
|
+
if att.type == AttributeProto.FLOATS:
|
|
383
|
+
return f"{att.name}={att.floats}"
|
|
384
|
+
if att.type == AttributeProto.STRING:
|
|
385
|
+
return f"{att.name}={att.s!r}"
|
|
386
|
+
if att.type == AttributeProto.TENSOR:
|
|
387
|
+
v = to_array_extended(att.t)
|
|
388
|
+
assert hasattr(v, "reshape"), f"not a tensor {type(v)}"
|
|
389
|
+
assert hasattr(v, "shape"), f"not a tensor {type(v)}"
|
|
390
|
+
vf = v.reshape((-1,))
|
|
391
|
+
if vf.size < 10:
|
|
392
|
+
tt = f"[{', '.join(map(str, vf))}]"
|
|
393
|
+
else:
|
|
394
|
+
tt = f"[{', '.join(map(str, vf[:10]))}, ...]"
|
|
395
|
+
if len(v.shape) != 1:
|
|
396
|
+
return f"{att.name}=tensor({tt}, dtype={v.dtype}).reshape({v.shape})"
|
|
397
|
+
return f"{att.name}=tensor({tt}, dtype={v.dtype})"
|
|
398
|
+
raise NotImplementedError(
|
|
399
|
+
f"pretty_onnx not implemented yet for AttributeProto={att!r}"
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
if isinstance(onx, NodeProto):
|
|
403
|
+
|
|
404
|
+
def _high(n):
|
|
405
|
+
if highlight and n in highlight:
|
|
406
|
+
return f"**{n}**"
|
|
407
|
+
return n
|
|
408
|
+
|
|
409
|
+
text = (
|
|
410
|
+
f"{onx.op_type}({', '.join(map(_high, onx.input))})"
|
|
411
|
+
f" -> {', '.join(map(_high, onx.output))}"
|
|
412
|
+
)
|
|
413
|
+
if onx.domain:
|
|
414
|
+
text = f"{onx.domain}.{text}"
|
|
415
|
+
if not with_attributes or not onx.attribute:
|
|
416
|
+
return text
|
|
417
|
+
rows = []
|
|
418
|
+
for att in onx.attribute:
|
|
419
|
+
rows.append(pretty_onnx(att))
|
|
420
|
+
if len(rows) > 1:
|
|
421
|
+
suffix = "\n".join(f" {s}" for s in rows)
|
|
422
|
+
return f"{text}\n{suffix}"
|
|
423
|
+
return f"{text} --- {rows[0]}"
|
|
424
|
+
|
|
425
|
+
if isinstance(onx, TensorProto):
|
|
426
|
+
shape = "x".join(map(str, onx.dims))
|
|
427
|
+
return f"TensorProto:{onx.data_type}:{shape}:{onx.name}"
|
|
428
|
+
|
|
429
|
+
try:
|
|
430
|
+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
|
|
431
|
+
|
|
432
|
+
if isinstance(onx, FunctionProto):
|
|
433
|
+
return (
|
|
434
|
+
f"function: {onx.name}[{onx.domain}]\n"
|
|
435
|
+
f"{onnx_simple_text_plot(onx, recursive=True)}"
|
|
436
|
+
)
|
|
437
|
+
return onnx_simple_text_plot(onx, recursive=True)
|
|
438
|
+
except ImportError:
|
|
439
|
+
from onnx.printer import to_text
|
|
440
|
+
|
|
441
|
+
return to_text(onx)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def get_onnx_signature(model: ModelProto) -> Tuple[Tuple[str, Any], ...]:
|
|
445
|
+
"""
|
|
446
|
+
Produces a tuple of tuples corresponding to the signatures.
|
|
447
|
+
|
|
448
|
+
:param model: model
|
|
449
|
+
:return: signature
|
|
450
|
+
"""
|
|
451
|
+
sig: List[Any] = []
|
|
452
|
+
for i in model.graph.input:
|
|
453
|
+
dt = i.type
|
|
454
|
+
if dt.HasField("sequence_type"):
|
|
455
|
+
dst = dt.sequence_type.elem_type
|
|
456
|
+
tdt = dst.tensor_type
|
|
457
|
+
el = tdt.elem_type
|
|
458
|
+
shape = tuple(d.dim_param or d.dim_value for d in tdt.shape.dim)
|
|
459
|
+
sig.append((i.name, [(i.name, el, shape)]))
|
|
460
|
+
elif dt.HasField("tensor_type"):
|
|
461
|
+
el = dt.tensor_type.elem_type
|
|
462
|
+
shape = tuple(d.dim_param or d.dim_value for d in dt.tensor_type.shape.dim)
|
|
463
|
+
sig.append((i.name, el, shape))
|
|
464
|
+
else:
|
|
465
|
+
raise AssertionError(f"Unable to interpret dt={dt!r} in {i!r}")
|
|
466
|
+
return tuple(sig)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def convert_endian(tensor: TensorProto) -> None:
|
|
470
|
+
"""Call to convert endianness of raw data in tensor.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
tensor: TensorProto to be converted.
|
|
474
|
+
"""
|
|
475
|
+
tensor_dtype = tensor.data_type
|
|
476
|
+
np_dtype = tensor_dtype_to_np_dtype(tensor_dtype)
|
|
477
|
+
tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
|
|
481
|
+
"""
|
|
482
|
+
Converts a numpy array to a tensor def assuming the dtype
|
|
483
|
+
is defined in ml_dtypes.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
arr: a numpy array.
|
|
487
|
+
name: (optional) the name of the tensor.
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
TensorProto: the converted tensor def.
|
|
491
|
+
"""
|
|
492
|
+
import ml_dtypes
|
|
493
|
+
|
|
494
|
+
assert isinstance(arr, np.ndarray), f"arr must be of type numpy.ndarray, got {type(arr)}"
|
|
495
|
+
|
|
496
|
+
tensor = TensorProto()
|
|
497
|
+
tensor.dims.extend(arr.shape)
|
|
498
|
+
if name:
|
|
499
|
+
tensor.name = name
|
|
500
|
+
|
|
501
|
+
if arr.dtype == ml_dtypes.bfloat16:
|
|
502
|
+
dtype = TensorProto.BFLOAT16
|
|
503
|
+
elif arr.dtype == ml_dtypes.float8_e4m3fn:
|
|
504
|
+
dtype = TensorProto.FLOAT8E4M3FN
|
|
505
|
+
elif arr.dtype == ml_dtypes.float8_e4m3fnuz:
|
|
506
|
+
dtype = TensorProto.FLOAT8E4M3FNUZ
|
|
507
|
+
elif arr.dtype == ml_dtypes.float8_e5m2:
|
|
508
|
+
dtype = TensorProto.FLOAT8E5M2
|
|
509
|
+
elif arr.dtype == ml_dtypes.float8_e5m2fnuz:
|
|
510
|
+
dtype = TensorProto.FLOAT8E5M2FNUZ
|
|
511
|
+
else:
|
|
512
|
+
raise NotImplementedError(f"No conversion from {arr.dtype}")
|
|
513
|
+
tensor.data_type = dtype
|
|
514
|
+
tensor.raw_data = arr.tobytes() # note: tobytes() is only after 1.9.
|
|
515
|
+
if sys.byteorder == "big":
|
|
516
|
+
convert_endian(tensor)
|
|
517
|
+
return tensor
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
_STORAGE_TYPE = {
|
|
521
|
+
TensorProto.FLOAT16: np.int16,
|
|
522
|
+
TensorProto.BFLOAT16: np.int16,
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
|
|
527
|
+
"""
|
|
528
|
+
Converts an array into a :class:`onnx.TensorProto`.
|
|
529
|
+
|
|
530
|
+
:param tensor: numpy array or torch tensor
|
|
531
|
+
:param name: name
|
|
532
|
+
:return: TensorProto
|
|
533
|
+
"""
|
|
534
|
+
if not isinstance(tensor, np.ndarray):
|
|
535
|
+
import torch
|
|
536
|
+
from .torch_helper import proto_from_tensor
|
|
537
|
+
|
|
538
|
+
assert isinstance(
|
|
539
|
+
tensor, torch.Tensor
|
|
540
|
+
), f"Unable to convert type {type(tensor)} into TensorProto."
|
|
541
|
+
return proto_from_tensor(tensor, name=name)
|
|
542
|
+
|
|
543
|
+
try:
|
|
544
|
+
from onnx.reference.ops.op_cast import (
|
|
545
|
+
bfloat16,
|
|
546
|
+
float8e4m3fn,
|
|
547
|
+
float8e4m3fnuz,
|
|
548
|
+
float8e5m2,
|
|
549
|
+
float8e5m2fnuz,
|
|
550
|
+
)
|
|
551
|
+
except ImportError:
|
|
552
|
+
bfloat16 = None
|
|
553
|
+
|
|
554
|
+
if bfloat16 is None:
|
|
555
|
+
return onh.from_array(tensor, name)
|
|
556
|
+
|
|
557
|
+
dt = tensor.dtype
|
|
558
|
+
if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
|
|
559
|
+
to = TensorProto.FLOAT8E4M3FN
|
|
560
|
+
dt_to = np.uint8
|
|
561
|
+
elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
|
|
562
|
+
to = TensorProto.FLOAT8E4M3FNUZ
|
|
563
|
+
dt_to = np.uint8
|
|
564
|
+
elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
|
|
565
|
+
to = TensorProto.FLOAT8E5M2
|
|
566
|
+
dt_to = np.uint8
|
|
567
|
+
elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
|
|
568
|
+
to = TensorProto.FLOAT8E5M2FNUZ
|
|
569
|
+
dt_to = np.uint8
|
|
570
|
+
elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
|
|
571
|
+
to = TensorProto.BFLOAT16
|
|
572
|
+
dt_to = np.uint16
|
|
573
|
+
else:
|
|
574
|
+
try:
|
|
575
|
+
import ml_dtypes
|
|
576
|
+
except ImportError:
|
|
577
|
+
ml_dtypes = None
|
|
578
|
+
if ml_dtypes is not None and (
|
|
579
|
+
tensor.dtype == ml_dtypes.bfloat16
|
|
580
|
+
or tensor.dtype == ml_dtypes.float8_e4m3fn
|
|
581
|
+
or tensor.dtype == ml_dtypes.float8_e4m3fnuz
|
|
582
|
+
or tensor.dtype == ml_dtypes.float8_e5m2
|
|
583
|
+
or tensor.dtype == ml_dtypes.float8_e5m2fnuz
|
|
584
|
+
):
|
|
585
|
+
return from_array_ml_dtypes(tensor, name)
|
|
586
|
+
return onh.from_array(tensor, name)
|
|
587
|
+
|
|
588
|
+
t = onh.from_array(tensor.astype(dt_to), name)
|
|
589
|
+
t.data_type = to
|
|
590
|
+
return t
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def to_array_extended(proto: TensorProto) -> npt.ArrayLike:
|
|
594
|
+
"""Converts :class:`onnx.TensorProto` into a numpy array."""
|
|
595
|
+
arr = onh.to_array(proto)
|
|
596
|
+
if proto.data_type >= onnx.TensorProto.BFLOAT16:
|
|
597
|
+
# Types not supported by numpy
|
|
598
|
+
ml_dtypes = onnx_dtype_to_np_dtype(proto.data_type)
|
|
599
|
+
return arr.view(ml_dtypes)
|
|
600
|
+
return arr
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def onnx_dtype_to_np_dtype(itype: int) -> Any:
|
|
604
|
+
"""
|
|
605
|
+
Converts an onnx type into a to numpy dtype.
|
|
606
|
+
That includes :epkg:`ml_dtypes` dtypes.
|
|
607
|
+
|
|
608
|
+
:param to: onnx dtype
|
|
609
|
+
:return: numpy dtype
|
|
610
|
+
"""
|
|
611
|
+
if itype == TensorProto.FLOAT:
|
|
612
|
+
return np.float32
|
|
613
|
+
if itype == TensorProto.FLOAT16:
|
|
614
|
+
return np.float16
|
|
615
|
+
if itype == TensorProto.BFLOAT16:
|
|
616
|
+
import ml_dtypes
|
|
617
|
+
|
|
618
|
+
return ml_dtypes.bfloat16
|
|
619
|
+
if itype == TensorProto.DOUBLE:
|
|
620
|
+
return np.float64
|
|
621
|
+
if itype == TensorProto.INT32:
|
|
622
|
+
return np.int32
|
|
623
|
+
if itype == TensorProto.INT64:
|
|
624
|
+
return np.int64
|
|
625
|
+
if itype == TensorProto.UINT32:
|
|
626
|
+
return np.uint32
|
|
627
|
+
if itype == TensorProto.UINT64:
|
|
628
|
+
return np.uint64
|
|
629
|
+
if itype == TensorProto.BOOL:
|
|
630
|
+
return np.bool
|
|
631
|
+
if itype == TensorProto.INT16:
|
|
632
|
+
return np.int16
|
|
633
|
+
if itype == TensorProto.UINT16:
|
|
634
|
+
return np.uint16
|
|
635
|
+
if itype == TensorProto.INT8:
|
|
636
|
+
return np.int16
|
|
637
|
+
if itype == TensorProto.UINT8:
|
|
638
|
+
return np.uint16
|
|
639
|
+
if itype == TensorProto.COMPLEX64:
|
|
640
|
+
return np.complex64
|
|
641
|
+
if itype == TensorProto.COMPLEX128:
|
|
642
|
+
return np.complex128
|
|
643
|
+
raise NotImplementedError(
|
|
644
|
+
f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F821
|
|
649
|
+
"""
|
|
650
|
+
Converts a torch dtype or numpy dtype into a onnx element type.
|
|
651
|
+
|
|
652
|
+
:param to: dtype
|
|
653
|
+
:return: onnx type
|
|
654
|
+
"""
|
|
655
|
+
try:
|
|
656
|
+
return np_dtype_to_tensor_dtype(dt)
|
|
657
|
+
except (KeyError, TypeError, ValueError):
|
|
658
|
+
pass
|
|
659
|
+
from .torch_helper import torch_dtype_to_onnx_dtype
|
|
660
|
+
|
|
661
|
+
return torch_dtype_to_onnx_dtype(dt)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821
|
|
665
|
+
"""
|
|
666
|
+
Converts a numpy dtype into a onnx element type.
|
|
667
|
+
|
|
668
|
+
:param to: dtype
|
|
669
|
+
:return: onnx type
|
|
670
|
+
"""
|
|
671
|
+
try:
|
|
672
|
+
return oh.np_dtype_to_tensor_dtype(dt)
|
|
673
|
+
except ValueError:
|
|
674
|
+
try:
|
|
675
|
+
import ml_dtypes
|
|
676
|
+
except ImportError:
|
|
677
|
+
ml_dtypes = None # type: ignore
|
|
678
|
+
if ml_dtypes is not None:
|
|
679
|
+
if dt == ml_dtypes.bfloat16:
|
|
680
|
+
return TensorProto.BFLOAT16
|
|
681
|
+
if dt == ml_dtypes.float8_e4m3fn:
|
|
682
|
+
return TensorProto.FLOAT8E4M3FN
|
|
683
|
+
if dt == ml_dtypes.float8_e4m3fnuz:
|
|
684
|
+
return TensorProto.FLOAT8E4M3FNUZ
|
|
685
|
+
if dt == ml_dtypes.float8_e5m2:
|
|
686
|
+
return TensorProto.FLOAT8E5M2
|
|
687
|
+
if dt == ml_dtypes.float8_e5m2fnuz:
|
|
688
|
+
return TensorProto.FLOAT8E5M2FNUZ
|
|
689
|
+
if dt == np.float32:
|
|
690
|
+
return TensorProto.FLOAT
|
|
691
|
+
if dt == np.float16:
|
|
692
|
+
return TensorProto.FLOAT16
|
|
693
|
+
if dt == np.float64:
|
|
694
|
+
return TensorProto.DOUBLE
|
|
695
|
+
if dt == np.int64:
|
|
696
|
+
return TensorProto.INT64
|
|
697
|
+
if dt == np.uint64:
|
|
698
|
+
return TensorProto.UINT64
|
|
699
|
+
if dt == np.int16:
|
|
700
|
+
return TensorProto.INT16
|
|
701
|
+
if dt == np.uint16:
|
|
702
|
+
return TensorProto.UINT16
|
|
703
|
+
if dt == np.int32:
|
|
704
|
+
return TensorProto.INT32
|
|
705
|
+
if dt == np.int8:
|
|
706
|
+
return TensorProto.INT8
|
|
707
|
+
if dt == np.uint8:
|
|
708
|
+
return TensorProto.UINT8
|
|
709
|
+
if dt == np.uint32:
|
|
710
|
+
return TensorProto.UINT32
|
|
711
|
+
if dt == np.bool:
|
|
712
|
+
return TensorProto.BOOL
|
|
713
|
+
if dt == np.complex64:
|
|
714
|
+
return TensorProto.COMPLEX64
|
|
715
|
+
if dt == np.complex128:
|
|
716
|
+
return TensorProto.COMPLEX128
|
|
717
|
+
raise ValueError(f"Unable to convert type {dt}")
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def type_info(itype: int, att: str):
|
|
721
|
+
"""
|
|
722
|
+
Returns the minimum or maximum value for a type.
|
|
723
|
+
|
|
724
|
+
:param itype: onnx type
|
|
725
|
+
:param att: 'min' or 'max'
|
|
726
|
+
:return: value
|
|
727
|
+
"""
|
|
728
|
+
if itype in {TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE}:
|
|
729
|
+
dtype = tensor_dtype_to_np_dtype(itype)
|
|
730
|
+
fi = np.finfo(dtype)
|
|
731
|
+
elif itype == TensorProto.BFLOAT16:
|
|
732
|
+
import ml_dtypes
|
|
733
|
+
|
|
734
|
+
dtype = tensor_dtype_to_np_dtype(itype)
|
|
735
|
+
fi = ml_dtypes.finfo(dtype) # type: ignore
|
|
736
|
+
else:
|
|
737
|
+
dtype = tensor_dtype_to_np_dtype(itype)
|
|
738
|
+
fi = np.iinfo(dtype) # type: ignore
|
|
739
|
+
if att == "min":
|
|
740
|
+
return fi.min
|
|
741
|
+
if att == "max":
|
|
742
|
+
return fi.max
|
|
743
|
+
raise ValueError(f"Unexpected value {att!r}")
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
|
|
747
|
+
"""
|
|
748
|
+
Converts a TensorProto's data_type to corresponding numpy dtype.
|
|
749
|
+
It can be used while making tensor.
|
|
750
|
+
|
|
751
|
+
:param tensor_dtype: TensorProto's data_type
|
|
752
|
+
:return: numpy's data_type
|
|
753
|
+
"""
|
|
754
|
+
if tensor_dtype >= 16:
|
|
755
|
+
try:
|
|
756
|
+
import ml_dtypes # noqa: F401
|
|
757
|
+
except ImportError as e:
|
|
758
|
+
raise ValueError(
|
|
759
|
+
f"Unsupported value for tensor_dtype, "
|
|
760
|
+
f"numpy does not support onnx type {tensor_dtype}. "
|
|
761
|
+
f"ml_dtypes can be used."
|
|
762
|
+
) from e
|
|
763
|
+
|
|
764
|
+
mapping: Dict[int, np.dtype] = {
|
|
765
|
+
TensorProto.BFLOAT16: ml_dtypes.bfloat16,
|
|
766
|
+
TensorProto.FLOAT8E4M3FN: ml_dtypes.float8_e4m3fn,
|
|
767
|
+
TensorProto.FLOAT8E4M3FNUZ: ml_dtypes.float8_e4m3fnuz,
|
|
768
|
+
TensorProto.FLOAT8E5M2: ml_dtypes.float8_e5m2,
|
|
769
|
+
TensorProto.FLOAT8E5M2FNUZ: ml_dtypes.float8_e5m2fnuz,
|
|
770
|
+
}
|
|
771
|
+
assert (
|
|
772
|
+
tensor_dtype in mapping
|
|
773
|
+
), f"Unable to find tensor_dtype={tensor_dtype!r} in mapping={mapping}"
|
|
774
|
+
return mapping[tensor_dtype]
|
|
775
|
+
|
|
776
|
+
return oh.tensor_dtype_to_np_dtype(tensor_dtype)
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def iterator_initializer_constant(
|
|
780
|
+
model: Union[FunctionProto, GraphProto, ModelProto],
|
|
781
|
+
use_numpy: bool = True,
|
|
782
|
+
prefix: str = "",
|
|
783
|
+
) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821
|
|
784
|
+
"""
|
|
785
|
+
Iterates on iniatialiers and constant in an onnx model.
|
|
786
|
+
|
|
787
|
+
:param model: model
|
|
788
|
+
:param use_numpy: use numpy or pytorch
|
|
789
|
+
:param prefix: for subgraph
|
|
790
|
+
:return: iterator
|
|
791
|
+
"""
|
|
792
|
+
if not isinstance(model, FunctionProto):
|
|
793
|
+
graph = model if isinstance(model, GraphProto) else model.graph
|
|
794
|
+
if not use_numpy:
|
|
795
|
+
from .torch_helper import to_tensor
|
|
796
|
+
if prefix:
|
|
797
|
+
prefix += "."
|
|
798
|
+
for init in graph.initializer:
|
|
799
|
+
yield f"{prefix}{init.name}", (
|
|
800
|
+
to_array_extended(init) if use_numpy else to_tensor(init)
|
|
801
|
+
)
|
|
802
|
+
nodes = graph.node
|
|
803
|
+
name = graph.name
|
|
804
|
+
if isinstance(model, ModelProto):
|
|
805
|
+
for f in model.functions:
|
|
806
|
+
yield from iterator_initializer_constant(
|
|
807
|
+
f, use_numpy=use_numpy, prefix=f"{prefix}{f.name}"
|
|
808
|
+
)
|
|
809
|
+
else:
|
|
810
|
+
nodes = model.node
|
|
811
|
+
name = model.name
|
|
812
|
+
for node in nodes:
|
|
813
|
+
if node.op_type == "Constant" and node.domain == "":
|
|
814
|
+
from ..reference import ExtendedReferenceEvaluator as Inference
|
|
815
|
+
|
|
816
|
+
if not use_numpy:
|
|
817
|
+
import torch
|
|
818
|
+
sess = Inference(node)
|
|
819
|
+
value = sess.run(None, {})[0]
|
|
820
|
+
yield f"{prefix}{node.output[0]}", (
|
|
821
|
+
value if use_numpy else torch.from_numpy(value)
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
if node.op_type in {"Loop", "Body", "Scan"}:
|
|
825
|
+
for att in node.attribute:
|
|
826
|
+
assert (
|
|
827
|
+
att.type != onnx.AttributeProto.GRAPHS
|
|
828
|
+
), "Not implemented for type AttributeProto.GRAPHS."
|
|
829
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
830
|
+
yield from iterator_initializer_constant(
|
|
831
|
+
att.g, use_numpy=use_numpy, prefix=f"{prefix}{name}"
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union[float, str]]:
|
|
836
|
+
"""
|
|
837
|
+
Produces statistics on a tensor.
|
|
838
|
+
|
|
839
|
+
:param tensor: tensor
|
|
840
|
+
:return: statistics
|
|
841
|
+
|
|
842
|
+
.. runpython::
|
|
843
|
+
:showcode:
|
|
844
|
+
|
|
845
|
+
import pprint
|
|
846
|
+
import numpy as np
|
|
847
|
+
from onnx_diagnostic.helpers.onnx_helper import tensor_statistics
|
|
848
|
+
|
|
849
|
+
t = np.random.rand(40, 50).astype(np.float16)
|
|
850
|
+
pprint.pprint(tensor_statistics(t))
|
|
851
|
+
"""
|
|
852
|
+
from .helper import size_type
|
|
853
|
+
|
|
854
|
+
if isinstance(tensor, TensorProto):
|
|
855
|
+
tensor = to_array_extended(tensor)
|
|
856
|
+
itype = np_dtype_to_tensor_dtype(tensor.dtype)
|
|
857
|
+
stat = dict(
|
|
858
|
+
mean=float(tensor.mean()),
|
|
859
|
+
std=float(tensor.std()),
|
|
860
|
+
shape="x".join(map(str, tensor.shape)),
|
|
861
|
+
numel=tensor.size,
|
|
862
|
+
size=tensor.size * size_type(tensor.dtype),
|
|
863
|
+
itype=itype,
|
|
864
|
+
stype=onnx_dtype_name(itype),
|
|
865
|
+
min=float(tensor.min()),
|
|
866
|
+
max=float(tensor.max()),
|
|
867
|
+
nnan=float(np.isnan(tensor).sum()),
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
if tensor.size < 8:
|
|
871
|
+
return stat
|
|
872
|
+
|
|
873
|
+
with warnings.catch_warnings():
|
|
874
|
+
warnings.simplefilter("ignore")
|
|
875
|
+
try:
|
|
876
|
+
hist = np.array(
|
|
877
|
+
[
|
|
878
|
+
0,
|
|
879
|
+
1e-10,
|
|
880
|
+
1e-8,
|
|
881
|
+
1e-7,
|
|
882
|
+
1e-6,
|
|
883
|
+
1e-5,
|
|
884
|
+
0.0001,
|
|
885
|
+
0.001,
|
|
886
|
+
0.01,
|
|
887
|
+
0.1,
|
|
888
|
+
0.5,
|
|
889
|
+
1,
|
|
890
|
+
1.96,
|
|
891
|
+
10,
|
|
892
|
+
1e2,
|
|
893
|
+
1e3,
|
|
894
|
+
1e4,
|
|
895
|
+
1e5,
|
|
896
|
+
1e6,
|
|
897
|
+
1e7,
|
|
898
|
+
1e8,
|
|
899
|
+
1e10,
|
|
900
|
+
1e50,
|
|
901
|
+
],
|
|
902
|
+
dtype=tensor.dtype,
|
|
903
|
+
)
|
|
904
|
+
except OverflowError as e:
|
|
905
|
+
from .helper import string_type
|
|
906
|
+
|
|
907
|
+
raise ValueError(
|
|
908
|
+
f"Unable to convert one value into {tensor.dtype}, "
|
|
909
|
+
f"tensor={string_type(tensor, with_shape=True)}"
|
|
910
|
+
) from e
|
|
911
|
+
hist = np.array(sorted(set(hist[~np.isinf(hist)])), dtype=tensor.dtype)
|
|
912
|
+
ind = np.digitize(np.abs(tensor).reshape((-1,)), hist, right=True)
|
|
913
|
+
cou = np.bincount(ind, minlength=ind.shape[0] + 1)
|
|
914
|
+
stat.update(
|
|
915
|
+
dict(zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))]))
|
|
916
|
+
)
|
|
917
|
+
ii = (np.arange(9) + 1) / 10
|
|
918
|
+
qu = np.quantile(tensor, ii)
|
|
919
|
+
stat.update({f"q{i}": float(q) for i, q in zip(ii, qu)})
|
|
920
|
+
return stat
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
class NodeCoordinates:
|
|
924
|
+
"""
|
|
925
|
+
A way to localize a node,
|
|
926
|
+
path is a tuple of three information, node index, node type, node name.
|
|
927
|
+
"""
|
|
928
|
+
|
|
929
|
+
__slots__ = ("node", "path")
|
|
930
|
+
|
|
931
|
+
def __init__(
|
|
932
|
+
self,
|
|
933
|
+
node: Union[onnx.TensorProto, NodeProto, str],
|
|
934
|
+
path: Tuple[Tuple[int, str, str], ...],
|
|
935
|
+
):
|
|
936
|
+
assert isinstance(path, tuple), f"Unexpected type {type(path)} for path"
|
|
937
|
+
assert all(isinstance(t, tuple) for t in path), f"Unexpected type in path={path}"
|
|
938
|
+
self.node = node
|
|
939
|
+
self.path = path
|
|
940
|
+
|
|
941
|
+
def __str__(self) -> str:
|
|
942
|
+
"usual"
|
|
943
|
+
if isinstance(self.node, str):
|
|
944
|
+
return f"{self.path_to_str()} :: {self.node!r}"
|
|
945
|
+
return f"{self.path_to_str()} :: {pretty_onnx(self.node)}"
|
|
946
|
+
|
|
947
|
+
def path_to_str(self) -> str:
|
|
948
|
+
"Strings representing coordinates."
|
|
949
|
+
return "x".join(f"({':'.join(map(str, t))})" for t in self.path)
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
class ResultFound:
|
|
953
|
+
"""
|
|
954
|
+
Class returned by :func:`enumerate_results`.
|
|
955
|
+
"""
|
|
956
|
+
|
|
957
|
+
__slots__ = ("consumer", "name", "producer")
|
|
958
|
+
|
|
959
|
+
def __init__(
|
|
960
|
+
self,
|
|
961
|
+
name: str,
|
|
962
|
+
producer: Optional[NodeCoordinates],
|
|
963
|
+
consumer: Optional[NodeCoordinates],
|
|
964
|
+
):
|
|
965
|
+
assert isinstance(name, str), f"unexpected type {type(name)} for name"
|
|
966
|
+
self.name = name
|
|
967
|
+
self.producer = producer
|
|
968
|
+
self.consumer = consumer
|
|
969
|
+
|
|
970
|
+
def __str__(self) -> str:
|
|
971
|
+
"usuals"
|
|
972
|
+
return (
|
|
973
|
+
f"<< {self.name} - {self.consumer}"
|
|
974
|
+
if self.producer is None
|
|
975
|
+
else f">> {self.name} - {self.producer}"
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
def enumerate_results(
|
|
980
|
+
proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
|
|
981
|
+
name: Union[Set[str], str],
|
|
982
|
+
verbose: int = 0,
|
|
983
|
+
coordinates: Optional[List[Tuple[int, str, str]]] = None,
|
|
984
|
+
) -> Iterator[ResultFound]:
|
|
985
|
+
"""
|
|
986
|
+
Iterates on all nodes, attributes to find where a name is used.
|
|
987
|
+
|
|
988
|
+
:param proto: a proto
|
|
989
|
+
:param name: name or names to find
|
|
990
|
+
:param verbose: verbosity
|
|
991
|
+
:param coordinates: coordinates of a node
|
|
992
|
+
:return: iterator on :class:`ResultFound`
|
|
993
|
+
"""
|
|
994
|
+
if not isinstance(name, set):
|
|
995
|
+
name = {name}
|
|
996
|
+
coordinates = coordinates or []
|
|
997
|
+
assert all(
|
|
998
|
+
isinstance(c, tuple) for c in coordinates
|
|
999
|
+
), f"Unexpected type in coordinates={coordinates}"
|
|
1000
|
+
indent = " " * len(coordinates)
|
|
1001
|
+
if isinstance(proto, ModelProto):
|
|
1002
|
+
if verbose:
|
|
1003
|
+
print(f"[enumerate_results] {indent}searching for {name!r} into ModelProto...")
|
|
1004
|
+
yield from enumerate_results(proto.graph, name, verbose=verbose)
|
|
1005
|
+
elif isinstance(proto, FunctionProto):
|
|
1006
|
+
if verbose:
|
|
1007
|
+
print(f"[enumerate_results] {indent}searching for {name!r} into FunctionProto...")
|
|
1008
|
+
for i in proto.input:
|
|
1009
|
+
if i in name:
|
|
1010
|
+
r = ResultFound(
|
|
1011
|
+
i,
|
|
1012
|
+
NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
|
|
1013
|
+
None,
|
|
1014
|
+
)
|
|
1015
|
+
if verbose > 1:
|
|
1016
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1017
|
+
yield r
|
|
1018
|
+
yield from enumerate_results(proto.node, name, verbose=verbose)
|
|
1019
|
+
for i in proto.output:
|
|
1020
|
+
if i in name:
|
|
1021
|
+
r = ResultFound(
|
|
1022
|
+
i,
|
|
1023
|
+
None,
|
|
1024
|
+
NodeCoordinates(
|
|
1025
|
+
i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
|
|
1026
|
+
),
|
|
1027
|
+
)
|
|
1028
|
+
if verbose > 1:
|
|
1029
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1030
|
+
yield r
|
|
1031
|
+
elif isinstance(proto, GraphProto):
|
|
1032
|
+
if verbose:
|
|
1033
|
+
print(f"[enumerate_results] {indent}searching for {name!r} into GraphProto...")
|
|
1034
|
+
for i in proto.initializer:
|
|
1035
|
+
if i.name in name:
|
|
1036
|
+
r = ResultFound(
|
|
1037
|
+
i.name,
|
|
1038
|
+
NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
|
|
1039
|
+
None,
|
|
1040
|
+
)
|
|
1041
|
+
if verbose > 1:
|
|
1042
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1043
|
+
yield r
|
|
1044
|
+
for i in proto.sparse_initializer:
|
|
1045
|
+
if i.name in name:
|
|
1046
|
+
r = ResultFound(
|
|
1047
|
+
i.name,
|
|
1048
|
+
NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
|
|
1049
|
+
None,
|
|
1050
|
+
)
|
|
1051
|
+
if verbose > 1:
|
|
1052
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1053
|
+
yield r
|
|
1054
|
+
for i in proto.input:
|
|
1055
|
+
if i.name in name:
|
|
1056
|
+
r = ResultFound(
|
|
1057
|
+
i.name,
|
|
1058
|
+
NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
|
|
1059
|
+
None,
|
|
1060
|
+
)
|
|
1061
|
+
if verbose > 1:
|
|
1062
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1063
|
+
yield r
|
|
1064
|
+
yield from enumerate_results(
|
|
1065
|
+
proto.node, name, verbose=verbose, coordinates=coordinates
|
|
1066
|
+
)
|
|
1067
|
+
for i in proto.output:
|
|
1068
|
+
if i.name in name:
|
|
1069
|
+
r = ResultFound(
|
|
1070
|
+
i.name,
|
|
1071
|
+
None,
|
|
1072
|
+
NodeCoordinates(
|
|
1073
|
+
i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
|
|
1074
|
+
),
|
|
1075
|
+
)
|
|
1076
|
+
if verbose > 1:
|
|
1077
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1078
|
+
yield r
|
|
1079
|
+
else:
|
|
1080
|
+
if verbose:
|
|
1081
|
+
print(
|
|
1082
|
+
f"[enumerate_results] {indent}searching for {name!r} into List[NodeProto]..."
|
|
1083
|
+
)
|
|
1084
|
+
for node_i, node in enumerate(proto):
|
|
1085
|
+
if set(node.input) & name:
|
|
1086
|
+
for n in node.input:
|
|
1087
|
+
if n in name:
|
|
1088
|
+
r = ResultFound(
|
|
1089
|
+
n,
|
|
1090
|
+
NodeCoordinates(
|
|
1091
|
+
node,
|
|
1092
|
+
tuple( # noqa: C409
|
|
1093
|
+
[*coordinates, (node_i, node.op_type, node.name)]
|
|
1094
|
+
),
|
|
1095
|
+
),
|
|
1096
|
+
None,
|
|
1097
|
+
)
|
|
1098
|
+
if verbose > 1:
|
|
1099
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1100
|
+
yield r
|
|
1101
|
+
if node.op_type in {"If", "Scan", "Loop", "SequenceMap"}:
|
|
1102
|
+
for att in node.attribute:
|
|
1103
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1104
|
+
yield from enumerate_results(
|
|
1105
|
+
att.g,
|
|
1106
|
+
name,
|
|
1107
|
+
verbose=verbose,
|
|
1108
|
+
coordinates=[*coordinates, (node_i, node.op_type, node.name)],
|
|
1109
|
+
)
|
|
1110
|
+
if set(node.output) & name:
|
|
1111
|
+
for n in node.output:
|
|
1112
|
+
if n in name:
|
|
1113
|
+
r = ResultFound(
|
|
1114
|
+
n,
|
|
1115
|
+
None,
|
|
1116
|
+
NodeCoordinates(
|
|
1117
|
+
node,
|
|
1118
|
+
tuple( # noqa: C409
|
|
1119
|
+
[*coordinates, (node_i, node.op_type, node.name)]
|
|
1120
|
+
),
|
|
1121
|
+
),
|
|
1122
|
+
)
|
|
1123
|
+
if verbose > 1:
|
|
1124
|
+
print(f"[enumerate_results] {indent}-- {r}")
|
|
1125
|
+
yield r
|
|
1126
|
+
if verbose:
|
|
1127
|
+
print(f"[enumerate_results] {indent}done")
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
def shadowing_names(
|
|
1131
|
+
proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
|
|
1132
|
+
verbose: int = 0,
|
|
1133
|
+
existing: Optional[Set[str]] = None,
|
|
1134
|
+
shadow_context: Optional[Set[str]] = None,
|
|
1135
|
+
post_shadow_context: Optional[Set[str]] = None,
|
|
1136
|
+
) -> Tuple[Set[str], Set[str], Set[str]]:
|
|
1137
|
+
"""
|
|
1138
|
+
Returns the shadowing names, the names created in the main graph
|
|
1139
|
+
after they were created in a subgraphs and the names created by the nodes.
|
|
1140
|
+
"""
|
|
1141
|
+
if isinstance(proto, ModelProto):
|
|
1142
|
+
return shadowing_names(proto.graph)
|
|
1143
|
+
if isinstance(proto, GraphProto):
|
|
1144
|
+
assert (
|
|
1145
|
+
existing is None and shadow_context is None
|
|
1146
|
+
), "existing must be None if nodes is None"
|
|
1147
|
+
return shadowing_names(
|
|
1148
|
+
proto.node,
|
|
1149
|
+
verbose=verbose,
|
|
1150
|
+
existing=set(i.name for i in proto.initializer)
|
|
1151
|
+
| set(i.name for i in proto.sparse_initializer)
|
|
1152
|
+
| set(i.name for i in proto.input if i.name),
|
|
1153
|
+
shadow_context=set(),
|
|
1154
|
+
post_shadow_context=set(),
|
|
1155
|
+
)
|
|
1156
|
+
if isinstance(proto, FunctionProto):
|
|
1157
|
+
assert (
|
|
1158
|
+
existing is None and shadow_context is None
|
|
1159
|
+
), "existing must be None if nodes is None"
|
|
1160
|
+
return shadowing_names(
|
|
1161
|
+
proto.node,
|
|
1162
|
+
verbose=verbose,
|
|
1163
|
+
existing=set(i for i in proto.input if i),
|
|
1164
|
+
shadow_context=set(),
|
|
1165
|
+
post_shadow_context=set(),
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
assert (
|
|
1169
|
+
existing is not None and shadow_context is not None
|
|
1170
|
+
), "existing must not be None if nodes is not None"
|
|
1171
|
+
shadow = set()
|
|
1172
|
+
shadow_context = shadow_context.copy()
|
|
1173
|
+
existing = existing.copy()
|
|
1174
|
+
created = set()
|
|
1175
|
+
post_shadow = set()
|
|
1176
|
+
for node in proto:
|
|
1177
|
+
not_empty = set(n for n in node.input if n)
|
|
1178
|
+
intersection = not_empty & existing
|
|
1179
|
+
assert len(intersection) == len(not_empty), (
|
|
1180
|
+
f"One input in {not_empty}, node={pretty_onnx(node)} "
|
|
1181
|
+
f"was not found in {existing}"
|
|
1182
|
+
)
|
|
1183
|
+
for att in node.attribute:
|
|
1184
|
+
if att.type == AttributeProto.GRAPH:
|
|
1185
|
+
g = att.g
|
|
1186
|
+
shadow |= set(i.name for i in g.input) & shadow_context
|
|
1187
|
+
shadow |= set(i.name for i in g.initializer) & shadow_context
|
|
1188
|
+
shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
|
|
1189
|
+
s, _ps, c = shadowing_names(
|
|
1190
|
+
g.node, verbose=verbose, existing=existing, shadow_context=existing
|
|
1191
|
+
)
|
|
1192
|
+
shadow |= s
|
|
1193
|
+
created |= c
|
|
1194
|
+
|
|
1195
|
+
not_empty = set(n for n in node.output if n)
|
|
1196
|
+
post_shadow |= not_empty & created
|
|
1197
|
+
shadow |= not_empty & shadow_context
|
|
1198
|
+
existing |= not_empty
|
|
1199
|
+
created |= not_empty
|
|
1200
|
+
return shadow, post_shadow, created
|