onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import textwrap
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from typing import Self
|
|
8
|
+
except ImportError:
|
|
9
|
+
# python <= 3.10
|
|
10
|
+
Self = "Self" # type: ignore[assignment]
|
|
11
|
+
import onnx
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from ..helpers.onnx_helper import (
|
|
15
|
+
extract_subset_of_nodes,
|
|
16
|
+
make_submodel,
|
|
17
|
+
from_array_extended,
|
|
18
|
+
select_model_inputs_outputs,
|
|
19
|
+
)
|
|
20
|
+
from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def make_torch_inputs(
|
|
24
|
+
input_names: List[str],
|
|
25
|
+
onnx_name_to_ep_name: Dict[str, str],
|
|
26
|
+
onnx_results: Dict[str, torch.Tensor],
|
|
27
|
+
torch_results: Dict[str, torch.Tensor],
|
|
28
|
+
submodel: Optional[onnx.ModelProto],
|
|
29
|
+
) -> Tuple[Dict[str, torch.Tensor], Set[str]]:
|
|
30
|
+
"""
|
|
31
|
+
Gathers torch tensors instead of onnx tensors (tensors produced by the onnx model)
|
|
32
|
+
|
|
33
|
+
:param input_names: tensors to gather
|
|
34
|
+
:param onnx_name_to_ep_name: mapping between onnx name to names in the exported program
|
|
35
|
+
:param onnx_results: all onnx results (produced by the onnx model)
|
|
36
|
+
:param torch_results: all tensors produced by the exported program
|
|
37
|
+
:param submodel: onnx model, any tensor missing in `torch_results` is
|
|
38
|
+
add as an initializer to this model
|
|
39
|
+
:return: the list of tensors, the set of inputs for which there was no tensor coming
|
|
40
|
+
from the exported program
|
|
41
|
+
"""
|
|
42
|
+
torch_inputs = {}
|
|
43
|
+
removed_inputs = set()
|
|
44
|
+
for n in input_names:
|
|
45
|
+
if n in onnx_name_to_ep_name:
|
|
46
|
+
torch_inputs[n] = torch_results[onnx_name_to_ep_name[n]]
|
|
47
|
+
else:
|
|
48
|
+
removed_inputs.add(n)
|
|
49
|
+
if submodel is not None:
|
|
50
|
+
# We add that input as an initializer because it is probably a constant.
|
|
51
|
+
submodel.graph.initializer.append(from_array_extended(onnx_results[n], name=n))
|
|
52
|
+
else:
|
|
53
|
+
torch_inputs[n] = onnx_results[n]
|
|
54
|
+
return torch_inputs, removed_inputs
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class ReplayConfiguration:
|
|
59
|
+
"""
|
|
60
|
+
Configuration specifying how to replay or dump pieces of
|
|
61
|
+
onnx graph in order to replay them later and investigate
|
|
62
|
+
later possible sources of discrepancies.
|
|
63
|
+
|
|
64
|
+
:param dump_folder: where to dump the onnx model corresponding to the
|
|
65
|
+
pieces to investigate
|
|
66
|
+
:param selected_names: list of results names to dump
|
|
67
|
+
:param selected_op_types: list of onnx operators to dump
|
|
68
|
+
:param threshold: only keep those whose discrepancies is greater than that threshold
|
|
69
|
+
:param dump_prefix_model: after dumping the smallest model able to replicate
|
|
70
|
+
one given output, if also dumps the models producing the inputs
|
|
71
|
+
and the outputs truncated from the big one
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
dump_folder: str
|
|
75
|
+
selected_names: Optional[Set[str]] = None
|
|
76
|
+
selected_op_types: Optional[Set[str]] = None
|
|
77
|
+
threshold: float = 0.1
|
|
78
|
+
dump_prefix_model: bool = False
|
|
79
|
+
|
|
80
|
+
def __post_init__(self):
|
|
81
|
+
assert self.dump_folder, "dump_folder is empty and this is not allowed for the replay"
|
|
82
|
+
|
|
83
|
+
def select(
|
|
84
|
+
self,
|
|
85
|
+
name: Optional[str] = None,
|
|
86
|
+
op_type: Optional[str] = None,
|
|
87
|
+
err_abs: Optional[float] = None,
|
|
88
|
+
) -> bool:
|
|
89
|
+
"""
|
|
90
|
+
Returns true or false whether or not a piece of the onnx model should be dumped,
|
|
91
|
+
around a particular node. The results is True if one of the condition is true:
|
|
92
|
+
|
|
93
|
+
* ``name in self.selected_names``
|
|
94
|
+
* ``op_type in self.selected_op_types``
|
|
95
|
+
* ``err_abs >= self.threshold``
|
|
96
|
+
|
|
97
|
+
:param name: result name
|
|
98
|
+
:param op_type: operator type
|
|
99
|
+
:param err_abs: measured discrepancy
|
|
100
|
+
:return: True if this should be dumped
|
|
101
|
+
"""
|
|
102
|
+
if name and self.selected_names and name in self.selected_names:
|
|
103
|
+
return True
|
|
104
|
+
if op_type and self.selected_op_types and op_type in self.selected_op_types:
|
|
105
|
+
return True
|
|
106
|
+
if err_abs is not None and self.threshold is not None and err_abs >= self.threshold:
|
|
107
|
+
return True
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
def get_replay_code(self) -> str:
|
|
111
|
+
"""
|
|
112
|
+
Returns a code letting the user replay the onnx model.
|
|
113
|
+
It looks like the following. It may have to be adapted.
|
|
114
|
+
|
|
115
|
+
.. runpython::
|
|
116
|
+
:showcode:
|
|
117
|
+
|
|
118
|
+
from onnx_diagnostic.torch_onnx.sbs_dataclasses import ReplayConfiguration
|
|
119
|
+
|
|
120
|
+
rc = ReplayConfiguration(dump_folder="unused")
|
|
121
|
+
print(rc.get_replay_code())
|
|
122
|
+
"""
|
|
123
|
+
return textwrap.dedent(
|
|
124
|
+
"""
|
|
125
|
+
import onnx
|
|
126
|
+
import torch
|
|
127
|
+
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
|
|
128
|
+
from onnx_diagnostic.helpers.torch_helper import study_discrepancies
|
|
129
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
130
|
+
from onnx_diagnostic.reference import OnnxruntimeEvaluator
|
|
131
|
+
|
|
132
|
+
skws = dict(with_shape=True, with_device=True)
|
|
133
|
+
|
|
134
|
+
torch_inputs = torch.load("torch_inputs.pt")
|
|
135
|
+
onnx_inputs = torch.load("onnx_inputs.pt")
|
|
136
|
+
expected_outputs_and_mapping = torch.load("torch_outputs_and_mapping.pt")
|
|
137
|
+
expected = expected_outputs_and_mapping["expected"]
|
|
138
|
+
mapping = expected_outputs_and_mapping["mapping"]
|
|
139
|
+
|
|
140
|
+
print(f"-- torch_inputs={string_type(torch_inputs, **skws)}")
|
|
141
|
+
print(f"-- onnx_inputs={string_type(onnx_inputs, **skws)}")
|
|
142
|
+
print(f"-- expected={string_type(expected, **skws)}")
|
|
143
|
+
print(f"-- mapping={mapping}")
|
|
144
|
+
|
|
145
|
+
print()
|
|
146
|
+
print("-- model.onnx")
|
|
147
|
+
print()
|
|
148
|
+
|
|
149
|
+
model = onnx.load("model.onnx")
|
|
150
|
+
print(pretty_onnx(model))
|
|
151
|
+
|
|
152
|
+
print()
|
|
153
|
+
print("-- range of inputs --")
|
|
154
|
+
print()
|
|
155
|
+
|
|
156
|
+
for k, v in onnx_inputs.items():
|
|
157
|
+
print(f"-- {k}: {string_type(v, **skws, with_min_max=True)}")
|
|
158
|
+
|
|
159
|
+
print()
|
|
160
|
+
print("-- discrepancies of inputs --")
|
|
161
|
+
print()
|
|
162
|
+
|
|
163
|
+
ep_feeds = {}
|
|
164
|
+
for k, v in onnx_inputs.items():
|
|
165
|
+
tk = mapping.get(k, k)
|
|
166
|
+
tkv = torch_inputs[k] if k in torch_inputs else torch_inputs[tk]
|
|
167
|
+
ep_feeds[k] = tkv
|
|
168
|
+
diff = max_diff(v, tkv)
|
|
169
|
+
print(
|
|
170
|
+
f"-- {k} -> {tk} ep:{string_type(tkv, **skws)} "
|
|
171
|
+
f"nx:{string_type(v, **skws)} / diff {string_diff(diff)}"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
print()
|
|
175
|
+
print("-- SVD --")
|
|
176
|
+
print()
|
|
177
|
+
|
|
178
|
+
for k, v in onnx_inputs.items():
|
|
179
|
+
if len(v.shape) == 2:
|
|
180
|
+
U, S, Vt = torch.linalg.svd(v.to(torch.float32))
|
|
181
|
+
print(f" -- {k}: {S[:5]}")
|
|
182
|
+
|
|
183
|
+
print()
|
|
184
|
+
print("-- run with onnx_inputs --")
|
|
185
|
+
print()
|
|
186
|
+
|
|
187
|
+
sess = OnnxruntimeEvaluator(model, whole=True)
|
|
188
|
+
feeds = onnx_inputs
|
|
189
|
+
obtained = sess.run(None, feeds)
|
|
190
|
+
print(f"-- obtained={string_type(obtained, **skws)}")
|
|
191
|
+
diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
|
|
192
|
+
print(f"-- diff: {string_diff(diff)}")
|
|
193
|
+
print()
|
|
194
|
+
print("-- plots --")
|
|
195
|
+
|
|
196
|
+
for i in range(len(expected)):
|
|
197
|
+
study_discrepancies(
|
|
198
|
+
expected[i],
|
|
199
|
+
obtained[i],
|
|
200
|
+
title=f"study output {i}",
|
|
201
|
+
name=f"disc{i}.png",
|
|
202
|
+
bins=50,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
print()
|
|
206
|
+
print("-- run with torch_inputs --")
|
|
207
|
+
print()
|
|
208
|
+
|
|
209
|
+
obtained = sess.run(None, ep_feeds)
|
|
210
|
+
print(f"-- obtained={string_type(obtained, **skws)}")
|
|
211
|
+
diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
|
|
212
|
+
print(f"-- diff: {string_diff(diff)}")
|
|
213
|
+
|
|
214
|
+
print()
|
|
215
|
+
print("-- end --")
|
|
216
|
+
print()
|
|
217
|
+
|
|
218
|
+
if False:
|
|
219
|
+
# CUDA profiling
|
|
220
|
+
with torch.profiler.profile(
|
|
221
|
+
activities=[torch.profiler.ProfilerActivity.CUDA],
|
|
222
|
+
record_shapes=True,
|
|
223
|
+
with_stack=True,
|
|
224
|
+
) as prof:
|
|
225
|
+
sess.run(None, ep_feeds)
|
|
226
|
+
obj = prof.key_averages()
|
|
227
|
+
print(obj.table())
|
|
228
|
+
"""
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def dump(
|
|
232
|
+
self,
|
|
233
|
+
name: str,
|
|
234
|
+
onnx_id_node: int,
|
|
235
|
+
model: onnx.ModelProto,
|
|
236
|
+
onnx_results: Dict[str, Any],
|
|
237
|
+
torch_results: Dict[str, torch.Tensor],
|
|
238
|
+
onnx_name_to_ep_name: Dict[str, str],
|
|
239
|
+
verbose: int = 0,
|
|
240
|
+
) -> Optional[str]:
|
|
241
|
+
"""
|
|
242
|
+
Dumps the minimal graph which can be replayed outside the model.
|
|
243
|
+
|
|
244
|
+
:param name: name of the result to look into
|
|
245
|
+
:param onnx_id_node: index of the node which produces it model `model`
|
|
246
|
+
:param model: onnx model
|
|
247
|
+
:param onnx_results: all known onnx results
|
|
248
|
+
:param torch_results: all known torch results
|
|
249
|
+
:param onnx_name_to_ep_name: correspondence between onnx_node name
|
|
250
|
+
and exported program name
|
|
251
|
+
:param verbose: verbosity level
|
|
252
|
+
:return: the folder created to dump everything
|
|
253
|
+
"""
|
|
254
|
+
if verbose:
|
|
255
|
+
print(
|
|
256
|
+
f"[ReplayConfiguration.dump] extract subset of nodes for "
|
|
257
|
+
f"{name!r} (onnx_id_node={onnx_id_node})"
|
|
258
|
+
)
|
|
259
|
+
if verbose >= 10:
|
|
260
|
+
print(f"[ReplayConfiguration.dump] onnx_results={sorted(onnx_results)}")
|
|
261
|
+
print(f"[ReplayConfiguration.dump] torch_results={sorted(torch_results)}")
|
|
262
|
+
print(
|
|
263
|
+
f"[ReplayConfiguration.dump] onnx_name_to_ep_name="
|
|
264
|
+
f"{sorted(onnx_name_to_ep_name)}"
|
|
265
|
+
)
|
|
266
|
+
nodes = extract_subset_of_nodes(
|
|
267
|
+
model=model,
|
|
268
|
+
name=name,
|
|
269
|
+
node_index=onnx_id_node,
|
|
270
|
+
cut_points=set(onnx_name_to_ep_name),
|
|
271
|
+
)
|
|
272
|
+
if not nodes:
|
|
273
|
+
if verbose:
|
|
274
|
+
print(
|
|
275
|
+
f"[ReplayConfiguration.dump] could not extract subset of "
|
|
276
|
+
f"nodes for {name!r}"
|
|
277
|
+
)
|
|
278
|
+
return None
|
|
279
|
+
if verbose:
|
|
280
|
+
print(f"[ReplayConfiguration.dump] make model with {len(nodes)} nodes")
|
|
281
|
+
submodel = make_submodel(
|
|
282
|
+
nodes,
|
|
283
|
+
ir_version=model.ir_version,
|
|
284
|
+
opset_imports=model.opset_import,
|
|
285
|
+
output_names=[name],
|
|
286
|
+
type_rank_fn=lambda name: (
|
|
287
|
+
torch_dtype_to_onnx_dtype(onnx_results[name].dtype),
|
|
288
|
+
len(onnx_results[name].shape),
|
|
289
|
+
),
|
|
290
|
+
)
|
|
291
|
+
input_names = [n.name for n in submodel.graph.input]
|
|
292
|
+
if verbose:
|
|
293
|
+
print(f"[ReplayConfiguration.dump] model inputs {input_names}")
|
|
294
|
+
folder = os.path.join(self.dump_folder, name.replace(":", "_").replace("/", "_"))
|
|
295
|
+
os.makedirs(folder, exist_ok=True)
|
|
296
|
+
if verbose:
|
|
297
|
+
print(f"[ReplayConfiguration.dump] dumps into folder {folder!r}")
|
|
298
|
+
|
|
299
|
+
torch_inputs, removed_inputs = make_torch_inputs(
|
|
300
|
+
input_names, onnx_name_to_ep_name, onnx_results, torch_results, submodel
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if removed_inputs:
|
|
304
|
+
input_names = [i for i in input_names if i not in removed_inputs]
|
|
305
|
+
new_inputs = [i for i in submodel.graph.input if i.name not in removed_inputs]
|
|
306
|
+
del submodel.graph.input[:]
|
|
307
|
+
submodel.graph.input.extend(new_inputs)
|
|
308
|
+
if verbose:
|
|
309
|
+
print(f"[ReplayConfiguration.dump] removed inputs {removed_inputs}")
|
|
310
|
+
print(f"[ReplayConfiguration.dump] final model inputs {input_names}")
|
|
311
|
+
|
|
312
|
+
onnx.save(submodel, os.path.join(folder, "model.onnx"))
|
|
313
|
+
onnx_inputs = {n: onnx_results[n] for n in input_names}
|
|
314
|
+
assert (
|
|
315
|
+
name in onnx_name_to_ep_name
|
|
316
|
+
), f"Unable to find {name!r} in {onnx_name_to_ep_name}"
|
|
317
|
+
expected_outputs_and_mapping = dict(
|
|
318
|
+
expected=(torch_results[onnx_name_to_ep_name[name]],),
|
|
319
|
+
mapping={
|
|
320
|
+
k: onnx_name_to_ep_name[k] for k in input_names if k in onnx_name_to_ep_name
|
|
321
|
+
},
|
|
322
|
+
)
|
|
323
|
+
torch.save(torch_inputs, os.path.join(folder, "torch_inputs.pt"))
|
|
324
|
+
torch.save(onnx_inputs, os.path.join(folder, "onnx_inputs.pt"))
|
|
325
|
+
torch.save(
|
|
326
|
+
expected_outputs_and_mapping, os.path.join(folder, "torch_outputs_and_mapping.pt")
|
|
327
|
+
)
|
|
328
|
+
with open(os.path.join(folder, "replay.py"), "w") as f:
|
|
329
|
+
f.write(self.get_replay_code())
|
|
330
|
+
|
|
331
|
+
if self.dump_prefix_model:
|
|
332
|
+
main_inputs = {
|
|
333
|
+
i.name: onnx_inputs.get(i.name, torch_inputs.get(i.name, None))
|
|
334
|
+
for i in model.graph.input
|
|
335
|
+
}
|
|
336
|
+
# only saving onnx inputs, torch should be the same
|
|
337
|
+
torch.save(main_inputs, os.path.join(folder, "onnx_main_inputs.pt"))
|
|
338
|
+
|
|
339
|
+
model_inputs_file = os.path.join(folder, "model.inputs.onnx")
|
|
340
|
+
exclude = {i.name for i in model.graph.input} | {
|
|
341
|
+
i.name for i in model.graph.initializer
|
|
342
|
+
}
|
|
343
|
+
model_inputs = select_model_inputs_outputs(
|
|
344
|
+
model, outputs=[i.name for i in submodel.graph.input if i.name not in exclude]
|
|
345
|
+
)
|
|
346
|
+
onnx.save(model_inputs, model_inputs_file)
|
|
347
|
+
|
|
348
|
+
model_outputs_file = os.path.join(folder, "model.outputs.onnx")
|
|
349
|
+
model_outputs = select_model_inputs_outputs(
|
|
350
|
+
model, outputs=[i.name for i in submodel.graph.output]
|
|
351
|
+
)
|
|
352
|
+
onnx.save(model_outputs, model_outputs_file)
|
|
353
|
+
|
|
354
|
+
if verbose:
|
|
355
|
+
print(f"[ReplayConfiguration.dump] done {folder!r}")
|
|
356
|
+
return folder
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
@dataclass
|
|
360
|
+
class RunAlignedRecord:
|
|
361
|
+
"""
|
|
362
|
+
The side-by-side ran by function :func:`run_aligned
|
|
363
|
+
<onnx_diagnostic.torch_onnx.sbs.run_aligned>`
|
|
364
|
+
yields instances of this type. If both `ep_name`
|
|
365
|
+
and `onnx_name` are specified, then both results
|
|
366
|
+
appear in the exported program (torch) and the onnx model.
|
|
367
|
+
|
|
368
|
+
:param ep_id_node: node index in the exported program
|
|
369
|
+
:param onnx_id_node: node index in the onnx model, -1 for an initializer
|
|
370
|
+
:param ep_name: result name in the exported program
|
|
371
|
+
:param onnx_name: result name in the onnx model, usually same as `ep_name`
|
|
372
|
+
except for initializer
|
|
373
|
+
:param ep_target: target name in the exported program producing the result
|
|
374
|
+
:param onnx_op_type: operator type in the onnx model producing the result
|
|
375
|
+
:param onnx_id_output: usually 0 unless this node has multiple output,
|
|
376
|
+
in that case, it is the output index
|
|
377
|
+
:param ep_shape_type: shape and type of the results in the exported program
|
|
378
|
+
:param onnx_shape_type: shape and type of the results in the onnx mode,
|
|
379
|
+
it should be the same as `ep_shape_type`, anything different probably
|
|
380
|
+
means a bug
|
|
381
|
+
:param err_abs: maximum absolute error for the considered result
|
|
382
|
+
between the exported program and the onnx model
|
|
383
|
+
:param err_rel: maximum relative error
|
|
384
|
+
:param err_dev: 0 if the device is the same, 1 if not
|
|
385
|
+
:param err_nan: number of nan values disagreeing
|
|
386
|
+
:param err_h01: number of values for which the discrepancy is above 0.1
|
|
387
|
+
:param err_h001: number of values for which the discrepancy is above 0.01
|
|
388
|
+
:param ep_time_run: execution time for the exported program
|
|
389
|
+
:param onnx_time_run: execution time for the onnx model, that includes
|
|
390
|
+
the creation of the onnx model so that's probably not very usable
|
|
391
|
+
:param err_abs2: same as `err_abs` if onnx kernel is run with torch results
|
|
392
|
+
:param err_rel2: same as `err_rel` if onnx kernel is run with torch results
|
|
393
|
+
:param err_dev2: same as `err_dev` if onnx kernel is run with torch results
|
|
394
|
+
:param err_nan2: same as `err_nan` if onnx kernel is run with torch results
|
|
395
|
+
:param err_h012: same as `err_h01` if onnx kernel is run with torch results
|
|
396
|
+
:param err_h0012: same as `err_h001` if onnx kernel is run with torch results
|
|
397
|
+
:param comment: any additional information
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
ep_id_node: Optional[int] = None
|
|
401
|
+
onnx_id_node: Optional[int] = None
|
|
402
|
+
ep_name: Optional[str] = None
|
|
403
|
+
onnx_name: Optional[str] = None
|
|
404
|
+
ep_target: Optional[str] = None
|
|
405
|
+
onnx_op_type: Optional[str] = None
|
|
406
|
+
onnx_id_output: Optional[int] = None
|
|
407
|
+
ep_shape_type: Optional[str] = None
|
|
408
|
+
onnx_shape_type: Optional[str] = None
|
|
409
|
+
err_abs: Optional[float] = None
|
|
410
|
+
err_rel: Optional[float] = None
|
|
411
|
+
err_dev: Optional[float] = None
|
|
412
|
+
err_nan: Optional[float] = None
|
|
413
|
+
err_h01: Optional[float] = None
|
|
414
|
+
err_h001: Optional[float] = None
|
|
415
|
+
ep_time_run: Optional[float] = None
|
|
416
|
+
onnx_time_run: Optional[float] = None
|
|
417
|
+
err_abs2: Optional[float] = None
|
|
418
|
+
err_rel2: Optional[float] = None
|
|
419
|
+
err_dev2: Optional[float] = None
|
|
420
|
+
err_nan2: Optional[float] = None
|
|
421
|
+
err_h012: Optional[float] = None
|
|
422
|
+
err_h0012: Optional[float] = None
|
|
423
|
+
comment: Optional[str] = None
|
|
424
|
+
|
|
425
|
+
def __post_init__(self):
|
|
426
|
+
"Validation."
|
|
427
|
+
assert self.ep_id_node is None or self.ep_id_node >= 0, (
|
|
428
|
+
f"Node id are always positive in the exported program but "
|
|
429
|
+
f"ep_id_node={self.ep_id_node}"
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
def set_diff(self, diff: Dict[str, Any]) -> Self:
|
|
433
|
+
"""Sets error."""
|
|
434
|
+
if diff is None:
|
|
435
|
+
return
|
|
436
|
+
if "abs" in diff:
|
|
437
|
+
self.err_abs = diff["abs"]
|
|
438
|
+
if "rel" in diff:
|
|
439
|
+
self.err_rel = diff["rel"]
|
|
440
|
+
if "dev" in diff:
|
|
441
|
+
self.err_dev = diff["dev"]
|
|
442
|
+
if "nan" in diff:
|
|
443
|
+
self.err_nan = diff["nan"]
|
|
444
|
+
if "rep" in diff:
|
|
445
|
+
self.err_h01 = diff["rep"][">0.1"]
|
|
446
|
+
self.err_h001 = diff["rep"][">0.01"]
|
|
447
|
+
return self
|
|
448
|
+
|
|
449
|
+
def set_diff2(self, diff: Dict[str, Any]) -> Self:
|
|
450
|
+
"""Sets error."""
|
|
451
|
+
if diff is None:
|
|
452
|
+
return
|
|
453
|
+
if "abs" in diff:
|
|
454
|
+
self.err_abs2 = diff["abs"]
|
|
455
|
+
if "rel" in diff:
|
|
456
|
+
self.err_rel2 = diff["rel"]
|
|
457
|
+
if "dev" in diff:
|
|
458
|
+
self.err_dev2 = diff["dev"]
|
|
459
|
+
if "nan" in diff:
|
|
460
|
+
self.err_nan2 = diff["nan"]
|
|
461
|
+
if "rep" in diff:
|
|
462
|
+
self.err_h012 = diff["rep"][">0.1"]
|
|
463
|
+
self.err_h0012 = diff["rep"][">0.01"]
|
|
464
|
+
return self
|
|
465
|
+
|
|
466
|
+
@property
|
|
467
|
+
def key(
|
|
468
|
+
self,
|
|
469
|
+
) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]]:
|
|
470
|
+
"Creates a unique identifier."
|
|
471
|
+
return (
|
|
472
|
+
self.ep_id_node,
|
|
473
|
+
self.onnx_id_node,
|
|
474
|
+
self.onnx_id_output,
|
|
475
|
+
self.ep_name,
|
|
476
|
+
self.onnx_name,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
def check(
|
|
480
|
+
self,
|
|
481
|
+
already_yielded: Dict[
|
|
482
|
+
Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]],
|
|
483
|
+
int,
|
|
484
|
+
],
|
|
485
|
+
) -> Self:
|
|
486
|
+
"Checks a record was not already yielded."
|
|
487
|
+
if self.onnx_op_type == "reset":
|
|
488
|
+
# no record for this one
|
|
489
|
+
return self
|
|
490
|
+
key = self.key
|
|
491
|
+
assert key not in already_yielded, (
|
|
492
|
+
f"Record with key={key} was already yielded, "
|
|
493
|
+
f"number of records={len(already_yielded)} and previous "
|
|
494
|
+
f"record at position {already_yielded[key]} (self={self})"
|
|
495
|
+
)
|
|
496
|
+
already_yielded[key] = len(already_yielded)
|
|
497
|
+
return self
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
@dataclass
|
|
501
|
+
class StatusRunAligned:
|
|
502
|
+
"""
|
|
503
|
+
Information to display while running the side-by-side
|
|
504
|
+
|
|
505
|
+
:param max_abs: maximum absolute seen so far
|
|
506
|
+
:param n_inf: number of infinite values seen so far
|
|
507
|
+
:param n_nan: number of nan values seen so for
|
|
508
|
+
:param yielded_nodes: number of yielded pair of nodes seen so far
|
|
509
|
+
:param last_replay: last result dumped on disk for later replay
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
max_abs: float = 0.0
|
|
513
|
+
n_inf: int = 0
|
|
514
|
+
n_nan: int = 0
|
|
515
|
+
yielded_nodes: int = 0
|
|
516
|
+
last_replay: str = ""
|
|
517
|
+
|
|
518
|
+
def to_str(self) -> str:
|
|
519
|
+
"Nice display."
|
|
520
|
+
s = (
|
|
521
|
+
f"yielded={self.yielded_nodes} maxabs={self.max_abs:1.3f} "
|
|
522
|
+
f"#inf={self.n_inf} #nan={self.n_nan}"
|
|
523
|
+
)
|
|
524
|
+
if self.last_replay:
|
|
525
|
+
return f"{s} -PLAY({self.last_replay})"
|
|
526
|
+
return s
|
|
527
|
+
|
|
528
|
+
def update(self, err_abs: float):
|
|
529
|
+
"Updates all attributes with the latest measure."
|
|
530
|
+
if np.isinf(err_abs) or np.isnan(err_abs):
|
|
531
|
+
self.n_inf += 1
|
|
532
|
+
elif err_abs > 1e6:
|
|
533
|
+
self.n_nan += 1
|
|
534
|
+
else:
|
|
535
|
+
self.max_abs = max(self.max_abs, err_abs)
|