onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,535 @@
1
+ import os
2
+ import textwrap
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Set, Tuple
5
+
6
+ try:
7
+ from typing import Self
8
+ except ImportError:
9
+ # python <= 3.10
10
+ Self = "Self" # type: ignore[assignment]
11
+ import onnx
12
+ import numpy as np
13
+ import torch
14
+ from ..helpers.onnx_helper import (
15
+ extract_subset_of_nodes,
16
+ make_submodel,
17
+ from_array_extended,
18
+ select_model_inputs_outputs,
19
+ )
20
+ from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
21
+
22
+
23
+ def make_torch_inputs(
24
+ input_names: List[str],
25
+ onnx_name_to_ep_name: Dict[str, str],
26
+ onnx_results: Dict[str, torch.Tensor],
27
+ torch_results: Dict[str, torch.Tensor],
28
+ submodel: Optional[onnx.ModelProto],
29
+ ) -> Tuple[Dict[str, torch.Tensor], Set[str]]:
30
+ """
31
+ Gathers torch tensors instead of onnx tensors (tensors produced by the onnx model)
32
+
33
+ :param input_names: tensors to gather
34
+ :param onnx_name_to_ep_name: mapping between onnx name to names in the exported program
35
+ :param onnx_results: all onnx results (produced by the onnx model)
36
+ :param torch_results: all tensors produced by the exported program
37
+ :param submodel: onnx model, any tensor missing in `torch_results` is
38
+ add as an initializer to this model
39
+ :return: the list of tensors, the set of inputs for which there was no tensor coming
40
+ from the exported program
41
+ """
42
+ torch_inputs = {}
43
+ removed_inputs = set()
44
+ for n in input_names:
45
+ if n in onnx_name_to_ep_name:
46
+ torch_inputs[n] = torch_results[onnx_name_to_ep_name[n]]
47
+ else:
48
+ removed_inputs.add(n)
49
+ if submodel is not None:
50
+ # We add that input as an initializer because it is probably a constant.
51
+ submodel.graph.initializer.append(from_array_extended(onnx_results[n], name=n))
52
+ else:
53
+ torch_inputs[n] = onnx_results[n]
54
+ return torch_inputs, removed_inputs
55
+
56
+
57
+ @dataclass
58
+ class ReplayConfiguration:
59
+ """
60
+ Configuration specifying how to replay or dump pieces of
61
+ onnx graph in order to replay them later and investigate
62
+ later possible sources of discrepancies.
63
+
64
+ :param dump_folder: where to dump the onnx model corresponding to the
65
+ pieces to investigate
66
+ :param selected_names: list of results names to dump
67
+ :param selected_op_types: list of onnx operators to dump
68
+ :param threshold: only keep those whose discrepancies is greater than that threshold
69
+ :param dump_prefix_model: after dumping the smallest model able to replicate
70
+ one given output, if also dumps the models producing the inputs
71
+ and the outputs truncated from the big one
72
+ """
73
+
74
+ dump_folder: str
75
+ selected_names: Optional[Set[str]] = None
76
+ selected_op_types: Optional[Set[str]] = None
77
+ threshold: float = 0.1
78
+ dump_prefix_model: bool = False
79
+
80
+ def __post_init__(self):
81
+ assert self.dump_folder, "dump_folder is empty and this is not allowed for the replay"
82
+
83
+ def select(
84
+ self,
85
+ name: Optional[str] = None,
86
+ op_type: Optional[str] = None,
87
+ err_abs: Optional[float] = None,
88
+ ) -> bool:
89
+ """
90
+ Returns true or false whether or not a piece of the onnx model should be dumped,
91
+ around a particular node. The results is True if one of the condition is true:
92
+
93
+ * ``name in self.selected_names``
94
+ * ``op_type in self.selected_op_types``
95
+ * ``err_abs >= self.threshold``
96
+
97
+ :param name: result name
98
+ :param op_type: operator type
99
+ :param err_abs: measured discrepancy
100
+ :return: True if this should be dumped
101
+ """
102
+ if name and self.selected_names and name in self.selected_names:
103
+ return True
104
+ if op_type and self.selected_op_types and op_type in self.selected_op_types:
105
+ return True
106
+ if err_abs is not None and self.threshold is not None and err_abs >= self.threshold:
107
+ return True
108
+ return False
109
+
110
+ def get_replay_code(self) -> str:
111
+ """
112
+ Returns a code letting the user replay the onnx model.
113
+ It looks like the following. It may have to be adapted.
114
+
115
+ .. runpython::
116
+ :showcode:
117
+
118
+ from onnx_diagnostic.torch_onnx.sbs_dataclasses import ReplayConfiguration
119
+
120
+ rc = ReplayConfiguration(dump_folder="unused")
121
+ print(rc.get_replay_code())
122
+ """
123
+ return textwrap.dedent(
124
+ """
125
+ import onnx
126
+ import torch
127
+ from onnx_diagnostic.helpers import max_diff, string_diff, string_type
128
+ from onnx_diagnostic.helpers.torch_helper import study_discrepancies
129
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
130
+ from onnx_diagnostic.reference import OnnxruntimeEvaluator
131
+
132
+ skws = dict(with_shape=True, with_device=True)
133
+
134
+ torch_inputs = torch.load("torch_inputs.pt")
135
+ onnx_inputs = torch.load("onnx_inputs.pt")
136
+ expected_outputs_and_mapping = torch.load("torch_outputs_and_mapping.pt")
137
+ expected = expected_outputs_and_mapping["expected"]
138
+ mapping = expected_outputs_and_mapping["mapping"]
139
+
140
+ print(f"-- torch_inputs={string_type(torch_inputs, **skws)}")
141
+ print(f"-- onnx_inputs={string_type(onnx_inputs, **skws)}")
142
+ print(f"-- expected={string_type(expected, **skws)}")
143
+ print(f"-- mapping={mapping}")
144
+
145
+ print()
146
+ print("-- model.onnx")
147
+ print()
148
+
149
+ model = onnx.load("model.onnx")
150
+ print(pretty_onnx(model))
151
+
152
+ print()
153
+ print("-- range of inputs --")
154
+ print()
155
+
156
+ for k, v in onnx_inputs.items():
157
+ print(f"-- {k}: {string_type(v, **skws, with_min_max=True)}")
158
+
159
+ print()
160
+ print("-- discrepancies of inputs --")
161
+ print()
162
+
163
+ ep_feeds = {}
164
+ for k, v in onnx_inputs.items():
165
+ tk = mapping.get(k, k)
166
+ tkv = torch_inputs[k] if k in torch_inputs else torch_inputs[tk]
167
+ ep_feeds[k] = tkv
168
+ diff = max_diff(v, tkv)
169
+ print(
170
+ f"-- {k} -> {tk} ep:{string_type(tkv, **skws)} "
171
+ f"nx:{string_type(v, **skws)} / diff {string_diff(diff)}"
172
+ )
173
+
174
+ print()
175
+ print("-- SVD --")
176
+ print()
177
+
178
+ for k, v in onnx_inputs.items():
179
+ if len(v.shape) == 2:
180
+ U, S, Vt = torch.linalg.svd(v.to(torch.float32))
181
+ print(f" -- {k}: {S[:5]}")
182
+
183
+ print()
184
+ print("-- run with onnx_inputs --")
185
+ print()
186
+
187
+ sess = OnnxruntimeEvaluator(model, whole=True)
188
+ feeds = onnx_inputs
189
+ obtained = sess.run(None, feeds)
190
+ print(f"-- obtained={string_type(obtained, **skws)}")
191
+ diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
192
+ print(f"-- diff: {string_diff(diff)}")
193
+ print()
194
+ print("-- plots --")
195
+
196
+ for i in range(len(expected)):
197
+ study_discrepancies(
198
+ expected[i],
199
+ obtained[i],
200
+ title=f"study output {i}",
201
+ name=f"disc{i}.png",
202
+ bins=50,
203
+ )
204
+
205
+ print()
206
+ print("-- run with torch_inputs --")
207
+ print()
208
+
209
+ obtained = sess.run(None, ep_feeds)
210
+ print(f"-- obtained={string_type(obtained, **skws)}")
211
+ diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01])
212
+ print(f"-- diff: {string_diff(diff)}")
213
+
214
+ print()
215
+ print("-- end --")
216
+ print()
217
+
218
+ if False:
219
+ # CUDA profiling
220
+ with torch.profiler.profile(
221
+ activities=[torch.profiler.ProfilerActivity.CUDA],
222
+ record_shapes=True,
223
+ with_stack=True,
224
+ ) as prof:
225
+ sess.run(None, ep_feeds)
226
+ obj = prof.key_averages()
227
+ print(obj.table())
228
+ """
229
+ )
230
+
231
+ def dump(
232
+ self,
233
+ name: str,
234
+ onnx_id_node: int,
235
+ model: onnx.ModelProto,
236
+ onnx_results: Dict[str, Any],
237
+ torch_results: Dict[str, torch.Tensor],
238
+ onnx_name_to_ep_name: Dict[str, str],
239
+ verbose: int = 0,
240
+ ) -> Optional[str]:
241
+ """
242
+ Dumps the minimal graph which can be replayed outside the model.
243
+
244
+ :param name: name of the result to look into
245
+ :param onnx_id_node: index of the node which produces it model `model`
246
+ :param model: onnx model
247
+ :param onnx_results: all known onnx results
248
+ :param torch_results: all known torch results
249
+ :param onnx_name_to_ep_name: correspondence between onnx_node name
250
+ and exported program name
251
+ :param verbose: verbosity level
252
+ :return: the folder created to dump everything
253
+ """
254
+ if verbose:
255
+ print(
256
+ f"[ReplayConfiguration.dump] extract subset of nodes for "
257
+ f"{name!r} (onnx_id_node={onnx_id_node})"
258
+ )
259
+ if verbose >= 10:
260
+ print(f"[ReplayConfiguration.dump] onnx_results={sorted(onnx_results)}")
261
+ print(f"[ReplayConfiguration.dump] torch_results={sorted(torch_results)}")
262
+ print(
263
+ f"[ReplayConfiguration.dump] onnx_name_to_ep_name="
264
+ f"{sorted(onnx_name_to_ep_name)}"
265
+ )
266
+ nodes = extract_subset_of_nodes(
267
+ model=model,
268
+ name=name,
269
+ node_index=onnx_id_node,
270
+ cut_points=set(onnx_name_to_ep_name),
271
+ )
272
+ if not nodes:
273
+ if verbose:
274
+ print(
275
+ f"[ReplayConfiguration.dump] could not extract subset of "
276
+ f"nodes for {name!r}"
277
+ )
278
+ return None
279
+ if verbose:
280
+ print(f"[ReplayConfiguration.dump] make model with {len(nodes)} nodes")
281
+ submodel = make_submodel(
282
+ nodes,
283
+ ir_version=model.ir_version,
284
+ opset_imports=model.opset_import,
285
+ output_names=[name],
286
+ type_rank_fn=lambda name: (
287
+ torch_dtype_to_onnx_dtype(onnx_results[name].dtype),
288
+ len(onnx_results[name].shape),
289
+ ),
290
+ )
291
+ input_names = [n.name for n in submodel.graph.input]
292
+ if verbose:
293
+ print(f"[ReplayConfiguration.dump] model inputs {input_names}")
294
+ folder = os.path.join(self.dump_folder, name.replace(":", "_").replace("/", "_"))
295
+ os.makedirs(folder, exist_ok=True)
296
+ if verbose:
297
+ print(f"[ReplayConfiguration.dump] dumps into folder {folder!r}")
298
+
299
+ torch_inputs, removed_inputs = make_torch_inputs(
300
+ input_names, onnx_name_to_ep_name, onnx_results, torch_results, submodel
301
+ )
302
+
303
+ if removed_inputs:
304
+ input_names = [i for i in input_names if i not in removed_inputs]
305
+ new_inputs = [i for i in submodel.graph.input if i.name not in removed_inputs]
306
+ del submodel.graph.input[:]
307
+ submodel.graph.input.extend(new_inputs)
308
+ if verbose:
309
+ print(f"[ReplayConfiguration.dump] removed inputs {removed_inputs}")
310
+ print(f"[ReplayConfiguration.dump] final model inputs {input_names}")
311
+
312
+ onnx.save(submodel, os.path.join(folder, "model.onnx"))
313
+ onnx_inputs = {n: onnx_results[n] for n in input_names}
314
+ assert (
315
+ name in onnx_name_to_ep_name
316
+ ), f"Unable to find {name!r} in {onnx_name_to_ep_name}"
317
+ expected_outputs_and_mapping = dict(
318
+ expected=(torch_results[onnx_name_to_ep_name[name]],),
319
+ mapping={
320
+ k: onnx_name_to_ep_name[k] for k in input_names if k in onnx_name_to_ep_name
321
+ },
322
+ )
323
+ torch.save(torch_inputs, os.path.join(folder, "torch_inputs.pt"))
324
+ torch.save(onnx_inputs, os.path.join(folder, "onnx_inputs.pt"))
325
+ torch.save(
326
+ expected_outputs_and_mapping, os.path.join(folder, "torch_outputs_and_mapping.pt")
327
+ )
328
+ with open(os.path.join(folder, "replay.py"), "w") as f:
329
+ f.write(self.get_replay_code())
330
+
331
+ if self.dump_prefix_model:
332
+ main_inputs = {
333
+ i.name: onnx_inputs.get(i.name, torch_inputs.get(i.name, None))
334
+ for i in model.graph.input
335
+ }
336
+ # only saving onnx inputs, torch should be the same
337
+ torch.save(main_inputs, os.path.join(folder, "onnx_main_inputs.pt"))
338
+
339
+ model_inputs_file = os.path.join(folder, "model.inputs.onnx")
340
+ exclude = {i.name for i in model.graph.input} | {
341
+ i.name for i in model.graph.initializer
342
+ }
343
+ model_inputs = select_model_inputs_outputs(
344
+ model, outputs=[i.name for i in submodel.graph.input if i.name not in exclude]
345
+ )
346
+ onnx.save(model_inputs, model_inputs_file)
347
+
348
+ model_outputs_file = os.path.join(folder, "model.outputs.onnx")
349
+ model_outputs = select_model_inputs_outputs(
350
+ model, outputs=[i.name for i in submodel.graph.output]
351
+ )
352
+ onnx.save(model_outputs, model_outputs_file)
353
+
354
+ if verbose:
355
+ print(f"[ReplayConfiguration.dump] done {folder!r}")
356
+ return folder
357
+
358
+
359
+ @dataclass
360
+ class RunAlignedRecord:
361
+ """
362
+ The side-by-side ran by function :func:`run_aligned
363
+ <onnx_diagnostic.torch_onnx.sbs.run_aligned>`
364
+ yields instances of this type. If both `ep_name`
365
+ and `onnx_name` are specified, then both results
366
+ appear in the exported program (torch) and the onnx model.
367
+
368
+ :param ep_id_node: node index in the exported program
369
+ :param onnx_id_node: node index in the onnx model, -1 for an initializer
370
+ :param ep_name: result name in the exported program
371
+ :param onnx_name: result name in the onnx model, usually same as `ep_name`
372
+ except for initializer
373
+ :param ep_target: target name in the exported program producing the result
374
+ :param onnx_op_type: operator type in the onnx model producing the result
375
+ :param onnx_id_output: usually 0 unless this node has multiple output,
376
+ in that case, it is the output index
377
+ :param ep_shape_type: shape and type of the results in the exported program
378
+ :param onnx_shape_type: shape and type of the results in the onnx mode,
379
+ it should be the same as `ep_shape_type`, anything different probably
380
+ means a bug
381
+ :param err_abs: maximum absolute error for the considered result
382
+ between the exported program and the onnx model
383
+ :param err_rel: maximum relative error
384
+ :param err_dev: 0 if the device is the same, 1 if not
385
+ :param err_nan: number of nan values disagreeing
386
+ :param err_h01: number of values for which the discrepancy is above 0.1
387
+ :param err_h001: number of values for which the discrepancy is above 0.01
388
+ :param ep_time_run: execution time for the exported program
389
+ :param onnx_time_run: execution time for the onnx model, that includes
390
+ the creation of the onnx model so that's probably not very usable
391
+ :param err_abs2: same as `err_abs` if onnx kernel is run with torch results
392
+ :param err_rel2: same as `err_rel` if onnx kernel is run with torch results
393
+ :param err_dev2: same as `err_dev` if onnx kernel is run with torch results
394
+ :param err_nan2: same as `err_nan` if onnx kernel is run with torch results
395
+ :param err_h012: same as `err_h01` if onnx kernel is run with torch results
396
+ :param err_h0012: same as `err_h001` if onnx kernel is run with torch results
397
+ :param comment: any additional information
398
+ """
399
+
400
+ ep_id_node: Optional[int] = None
401
+ onnx_id_node: Optional[int] = None
402
+ ep_name: Optional[str] = None
403
+ onnx_name: Optional[str] = None
404
+ ep_target: Optional[str] = None
405
+ onnx_op_type: Optional[str] = None
406
+ onnx_id_output: Optional[int] = None
407
+ ep_shape_type: Optional[str] = None
408
+ onnx_shape_type: Optional[str] = None
409
+ err_abs: Optional[float] = None
410
+ err_rel: Optional[float] = None
411
+ err_dev: Optional[float] = None
412
+ err_nan: Optional[float] = None
413
+ err_h01: Optional[float] = None
414
+ err_h001: Optional[float] = None
415
+ ep_time_run: Optional[float] = None
416
+ onnx_time_run: Optional[float] = None
417
+ err_abs2: Optional[float] = None
418
+ err_rel2: Optional[float] = None
419
+ err_dev2: Optional[float] = None
420
+ err_nan2: Optional[float] = None
421
+ err_h012: Optional[float] = None
422
+ err_h0012: Optional[float] = None
423
+ comment: Optional[str] = None
424
+
425
+ def __post_init__(self):
426
+ "Validation."
427
+ assert self.ep_id_node is None or self.ep_id_node >= 0, (
428
+ f"Node id are always positive in the exported program but "
429
+ f"ep_id_node={self.ep_id_node}"
430
+ )
431
+
432
+ def set_diff(self, diff: Dict[str, Any]) -> Self:
433
+ """Sets error."""
434
+ if diff is None:
435
+ return
436
+ if "abs" in diff:
437
+ self.err_abs = diff["abs"]
438
+ if "rel" in diff:
439
+ self.err_rel = diff["rel"]
440
+ if "dev" in diff:
441
+ self.err_dev = diff["dev"]
442
+ if "nan" in diff:
443
+ self.err_nan = diff["nan"]
444
+ if "rep" in diff:
445
+ self.err_h01 = diff["rep"][">0.1"]
446
+ self.err_h001 = diff["rep"][">0.01"]
447
+ return self
448
+
449
+ def set_diff2(self, diff: Dict[str, Any]) -> Self:
450
+ """Sets error."""
451
+ if diff is None:
452
+ return
453
+ if "abs" in diff:
454
+ self.err_abs2 = diff["abs"]
455
+ if "rel" in diff:
456
+ self.err_rel2 = diff["rel"]
457
+ if "dev" in diff:
458
+ self.err_dev2 = diff["dev"]
459
+ if "nan" in diff:
460
+ self.err_nan2 = diff["nan"]
461
+ if "rep" in diff:
462
+ self.err_h012 = diff["rep"][">0.1"]
463
+ self.err_h0012 = diff["rep"][">0.01"]
464
+ return self
465
+
466
+ @property
467
+ def key(
468
+ self,
469
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]]:
470
+ "Creates a unique identifier."
471
+ return (
472
+ self.ep_id_node,
473
+ self.onnx_id_node,
474
+ self.onnx_id_output,
475
+ self.ep_name,
476
+ self.onnx_name,
477
+ )
478
+
479
+ def check(
480
+ self,
481
+ already_yielded: Dict[
482
+ Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]],
483
+ int,
484
+ ],
485
+ ) -> Self:
486
+ "Checks a record was not already yielded."
487
+ if self.onnx_op_type == "reset":
488
+ # no record for this one
489
+ return self
490
+ key = self.key
491
+ assert key not in already_yielded, (
492
+ f"Record with key={key} was already yielded, "
493
+ f"number of records={len(already_yielded)} and previous "
494
+ f"record at position {already_yielded[key]} (self={self})"
495
+ )
496
+ already_yielded[key] = len(already_yielded)
497
+ return self
498
+
499
+
500
+ @dataclass
501
+ class StatusRunAligned:
502
+ """
503
+ Information to display while running the side-by-side
504
+
505
+ :param max_abs: maximum absolute seen so far
506
+ :param n_inf: number of infinite values seen so far
507
+ :param n_nan: number of nan values seen so for
508
+ :param yielded_nodes: number of yielded pair of nodes seen so far
509
+ :param last_replay: last result dumped on disk for later replay
510
+ """
511
+
512
+ max_abs: float = 0.0
513
+ n_inf: int = 0
514
+ n_nan: int = 0
515
+ yielded_nodes: int = 0
516
+ last_replay: str = ""
517
+
518
+ def to_str(self) -> str:
519
+ "Nice display."
520
+ s = (
521
+ f"yielded={self.yielded_nodes} maxabs={self.max_abs:1.3f} "
522
+ f"#inf={self.n_inf} #nan={self.n_nan}"
523
+ )
524
+ if self.last_replay:
525
+ return f"{s} -PLAY({self.last_replay})"
526
+ return s
527
+
528
+ def update(self, err_abs: float):
529
+ "Updates all attributes with the latest measure."
530
+ if np.isinf(err_abs) or np.isnan(err_abs):
531
+ self.n_inf += 1
532
+ elif err_abs > 1e6:
533
+ self.n_nan += 1
534
+ else:
535
+ self.max_abs = max(self.max_abs, err_abs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.8.2
3
+ Version: 0.8.4
4
4
  Summary: Tools to help converting pytorch models into ONNX.
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré