onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +91 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +3 -3
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +92 -23
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +90 -26
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +103 -1
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +103 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -108,7 +108,10 @@ class _InferenceSession:
|
|
|
108
108
|
session_options,
|
|
109
109
|
providers=providers,
|
|
110
110
|
)
|
|
111
|
-
except
|
|
111
|
+
except (
|
|
112
|
+
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
|
|
113
|
+
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
|
|
114
|
+
) as e:
|
|
112
115
|
if isinstance(sess, onnx.ModelProto):
|
|
113
116
|
debug_path = "_debug_InferenceSession_last_failure.onnx"
|
|
114
117
|
onnx.save(
|
|
@@ -134,7 +137,13 @@ class _InferenceSession:
|
|
|
134
137
|
|
|
135
138
|
self.sess = sess
|
|
136
139
|
self.input_names = [i.name for i in sess.get_inputs()]
|
|
140
|
+
assert (
|
|
141
|
+
"" not in self.input_names
|
|
142
|
+
), f"Input name cannot be empty but input_names={self.input_names}"
|
|
137
143
|
self.output_names = [i.name for i in sess.get_outputs()]
|
|
144
|
+
assert (
|
|
145
|
+
"" not in self.input_names
|
|
146
|
+
), f"Output name cannot be empty but output_names={self.output_names}"
|
|
138
147
|
self.input_shapes = [i.shape for i in sess.get_inputs()]
|
|
139
148
|
self.output_shapes = [i.shape for i in sess.get_outputs()]
|
|
140
149
|
self.input_types = [i.type for i in sess.get_inputs()]
|
|
@@ -338,6 +347,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
338
347
|
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
|
|
339
348
|
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
|
|
340
349
|
:param use_training_api: use onnxruntime-traning API
|
|
350
|
+
:param cpu_output: if True, force the outputs to be on CPU
|
|
341
351
|
"""
|
|
342
352
|
|
|
343
353
|
def __init__(
|
|
@@ -353,6 +363,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
353
363
|
optimized_model_filepath: Optional[str] = None,
|
|
354
364
|
disable_aot_function_inlining: Optional[bool] = None,
|
|
355
365
|
use_training_api: Optional[bool] = None,
|
|
366
|
+
cpu_outputs: bool = False,
|
|
356
367
|
):
|
|
357
368
|
super().__init__(
|
|
358
369
|
sess,
|
|
@@ -367,6 +378,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
367
378
|
disable_aot_function_inlining=disable_aot_function_inlining,
|
|
368
379
|
use_training_api=use_training_api,
|
|
369
380
|
)
|
|
381
|
+
self.cpu_outputs = cpu_outputs
|
|
370
382
|
|
|
371
383
|
def _get_ortvalues_from_torch_tensors(
|
|
372
384
|
self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
|
|
@@ -490,23 +502,37 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
490
502
|
feeds is a dictionary of :class:`torch.Tensor`.
|
|
491
503
|
The output device is CPU even if the outputs are on CUDA.
|
|
492
504
|
"""
|
|
493
|
-
|
|
505
|
+
input_names = []
|
|
506
|
+
values = ORTC.OrtValueVector()
|
|
507
|
+
device = -1
|
|
494
508
|
for k, v in feeds.items():
|
|
509
|
+
assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
|
|
510
|
+
device = max(device, v.get_device())
|
|
495
511
|
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
|
|
496
512
|
if not v.is_contiguous():
|
|
497
513
|
v = v.contiguous()
|
|
498
514
|
if v.dtype == torch.bool:
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
|
|
502
|
-
v.detach().numpy(), onnx.TensorProto.BOOL
|
|
503
|
-
)
|
|
515
|
+
v = v.to(torch.uint8)
|
|
516
|
+
v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True)
|
|
504
517
|
else:
|
|
505
|
-
|
|
518
|
+
v = ORTC.OrtValue.from_dlpack(v.detach().__dlpack__(), False)
|
|
519
|
+
input_names.append(k)
|
|
520
|
+
values.push_back(v)
|
|
506
521
|
if self.nvtx:
|
|
507
|
-
self.torch.cuda.nvtx.range_push("
|
|
508
|
-
|
|
509
|
-
|
|
522
|
+
self.torch.cuda.nvtx.range_push("run_with_ortvaluevector")
|
|
523
|
+
|
|
524
|
+
# ort_outputs = self.sess._sess.run_with_ort_values(
|
|
525
|
+
# new_feeds, output_names or self.output_names, self.run_options
|
|
526
|
+
# )
|
|
527
|
+
ort_outputs = ORTC.OrtValueVector()
|
|
528
|
+
out_names = output_names or self.output_names
|
|
529
|
+
self.sess._sess.run_with_ortvaluevector(
|
|
530
|
+
self.run_options,
|
|
531
|
+
input_names,
|
|
532
|
+
values,
|
|
533
|
+
out_names,
|
|
534
|
+
ort_outputs,
|
|
535
|
+
[DEVICES[-1 if self.cpu_outputs else device] for o in out_names],
|
|
510
536
|
)
|
|
511
537
|
if self.nvtx:
|
|
512
538
|
self.torch.cuda.nvtx.range_pop()
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from .helper import string_type
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def validate_fx_tensor(
|
|
7
|
+
node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
|
|
8
|
+
) -> None:
|
|
9
|
+
"""
|
|
10
|
+
Validates the shape of tensor is expected.
|
|
11
|
+
|
|
12
|
+
:param node: node
|
|
13
|
+
:param tensor: tensor
|
|
14
|
+
:param expected_shape: expected shape
|
|
15
|
+
"""
|
|
16
|
+
assert len(tensor.shape) == len(expected_shape), (
|
|
17
|
+
f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
|
|
18
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
19
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
20
|
+
f"node.meta={node.meta}"
|
|
21
|
+
)
|
|
22
|
+
for a, b in zip(tensor.shape, expected_shape):
|
|
23
|
+
assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
|
|
24
|
+
f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
|
|
25
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
26
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
27
|
+
f"node.meta={node.meta}"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
|
|
32
|
+
"""
|
|
33
|
+
Validates the outputs of a node using metadata stored in the node.
|
|
34
|
+
|
|
35
|
+
:param node: node
|
|
36
|
+
:param outputs: outputs
|
|
37
|
+
"""
|
|
38
|
+
if "val" not in node.meta:
|
|
39
|
+
return
|
|
40
|
+
if isinstance(outputs, torch.Tensor):
|
|
41
|
+
validate_fx_tensor(node, outputs, node.meta["val"].shape)
|
|
42
|
+
return
|
|
43
|
+
if isinstance(outputs, (tuple, list)):
|
|
44
|
+
assert isinstance(node.meta["val"], (list, tuple)), (
|
|
45
|
+
f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
|
|
46
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
47
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
48
|
+
f"node.meta={node.meta}"
|
|
49
|
+
)
|
|
50
|
+
assert len(outputs) == len(node.meta["val"]), (
|
|
51
|
+
f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
|
|
52
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
53
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
54
|
+
f"node.meta={node.meta}"
|
|
55
|
+
)
|
|
56
|
+
for a, b in zip(outputs, node.meta["val"]):
|
|
57
|
+
validate_fx_tensor(node, a, b.shape)
|
|
58
|
+
return
|
|
59
|
+
if isinstance(outputs, int):
|
|
60
|
+
assert (
|
|
61
|
+
isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
|
|
62
|
+
or outputs == node.meta["val"]
|
|
63
|
+
), (
|
|
64
|
+
f"Int mismatch, got {outputs} expected {node.meta['val']}, "
|
|
65
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
66
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
67
|
+
f"node.meta={node.meta}"
|
|
68
|
+
)
|
|
69
|
+
return
|
|
70
|
+
if outputs is None:
|
|
71
|
+
assert node.meta["val"] is None, (
|
|
72
|
+
f"None mismatch, got {outputs} expected {node.meta['val']}, "
|
|
73
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
74
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
75
|
+
f"node.meta={node.meta}"
|
|
76
|
+
)
|
|
77
|
+
return
|
|
78
|
+
raise NotImplementedError(
|
|
79
|
+
f"Validation for output type {type(outputs)} is not implemented, "
|
|
80
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
81
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
82
|
+
f"node.meta={node.meta}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def run_fx_node(
|
|
87
|
+
node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
|
|
88
|
+
) -> Tuple[Any, ...]:
|
|
89
|
+
"""
|
|
90
|
+
Executes a node
|
|
91
|
+
|
|
92
|
+
:param node: runs a node
|
|
93
|
+
:param args: unnamed inputs to the node
|
|
94
|
+
:param kwargs: named inputs to the node
|
|
95
|
+
:return: results
|
|
96
|
+
"""
|
|
97
|
+
if node.op == "output":
|
|
98
|
+
assert len(args) == 1 and not kwargs, (
|
|
99
|
+
f"Unexpected inputs: args={string_type(args, limit=20)} "
|
|
100
|
+
f"kwargs={string_type(kwargs, limit=20)}"
|
|
101
|
+
)
|
|
102
|
+
return args
|
|
103
|
+
if node.op == "call_function":
|
|
104
|
+
assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
|
|
105
|
+
for a, ea in zip(args, node.args):
|
|
106
|
+
if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta:
|
|
107
|
+
ta = ea.meta["val"]
|
|
108
|
+
assert (
|
|
109
|
+
isinstance(ta, torch.Tensor)
|
|
110
|
+
and len(a.shape) == len(ta.shape)
|
|
111
|
+
and a.dtype == ta.dtype
|
|
112
|
+
), (
|
|
113
|
+
f"Unable to run node {node!r}, target={node.target!r}, "
|
|
114
|
+
f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, "
|
|
115
|
+
f"args={string_type(args, with_shape=True, with_device=True)}, "
|
|
116
|
+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
|
|
117
|
+
)
|
|
118
|
+
try:
|
|
119
|
+
outputs = node.target(*args, **(kwargs or {}))
|
|
120
|
+
except RuntimeError as e:
|
|
121
|
+
raise RuntimeError(
|
|
122
|
+
f"Unable to run node {node!r}, target={node.target!r}, "
|
|
123
|
+
f"args={string_type(args, with_shape=True, with_device=True)}, "
|
|
124
|
+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
|
|
125
|
+
) from e
|
|
126
|
+
validate_fx_outputs(node, outputs)
|
|
127
|
+
return outputs
|
|
128
|
+
raise NotImplementedError(
|
|
129
|
+
f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
|
|
134
|
+
"See :func:`prepare_args_kwargs`."
|
|
135
|
+
if isinstance(ref, torch.fx.Node):
|
|
136
|
+
return torch_results[ref.name]
|
|
137
|
+
if isinstance(ref, list):
|
|
138
|
+
return [_pick_result(torch_results, n) for n in ref]
|
|
139
|
+
if isinstance(ref, tuple):
|
|
140
|
+
return tuple(_pick_result(torch_results, n) for n in ref)
|
|
141
|
+
if isinstance(ref, dict):
|
|
142
|
+
return {k: _pick_result(torch_results, v) for k, v in ref.items()}
|
|
143
|
+
if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
|
|
144
|
+
return ref
|
|
145
|
+
if ref is None:
|
|
146
|
+
return None
|
|
147
|
+
if isinstance(ref, torch.layout):
|
|
148
|
+
return ref
|
|
149
|
+
raise NotImplementedError(f"Unable to process args type {type(ref)}")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def prepare_args_kwargs(
|
|
153
|
+
torch_results: Dict[str, Any], node: torch.fx.Node
|
|
154
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
|
155
|
+
"""
|
|
156
|
+
Prepares args and kwargs before executing a fx node.
|
|
157
|
+
|
|
158
|
+
:param torch_results: existing results
|
|
159
|
+
:param node: node to execute
|
|
160
|
+
:return: new args and kwargs
|
|
161
|
+
"""
|
|
162
|
+
new_args = _pick_result(torch_results, node.args)
|
|
163
|
+
new_kwargs = _pick_result(torch_results, node.kwargs)
|
|
164
|
+
return new_args, new_kwargs
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import ctypes
|
|
3
3
|
import inspect
|
|
4
|
+
import math
|
|
4
5
|
import os
|
|
5
6
|
import sys
|
|
6
7
|
import warnings
|
|
@@ -30,9 +31,7 @@ from .onnx_helper import (
|
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def proto_from_tensor(
|
|
33
|
-
arr:
|
|
34
|
-
name: Optional[str] = None,
|
|
35
|
-
verbose: int = 0,
|
|
34
|
+
arr: torch.Tensor, name: Optional[str] = None, verbose: int = 0
|
|
36
35
|
) -> onnx.TensorProto:
|
|
37
36
|
"""
|
|
38
37
|
Converts a torch Tensor into a TensorProto.
|
|
@@ -98,7 +97,7 @@ def proto_from_tensor(
|
|
|
98
97
|
return tensor
|
|
99
98
|
|
|
100
99
|
|
|
101
|
-
def onnx_dtype_to_torch_dtype(itype: int) ->
|
|
100
|
+
def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype:
|
|
102
101
|
"""
|
|
103
102
|
Converts an onnx type into a torch dtype.
|
|
104
103
|
|
|
@@ -140,7 +139,7 @@ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
|
|
|
140
139
|
)
|
|
141
140
|
|
|
142
141
|
|
|
143
|
-
def torch_dtype_to_onnx_dtype(to:
|
|
142
|
+
def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
|
|
144
143
|
"""
|
|
145
144
|
Converts a torch dtype into a onnx element type.
|
|
146
145
|
|
|
@@ -483,7 +482,7 @@ def is_torchdynamo_exporting() -> bool:
|
|
|
483
482
|
return False
|
|
484
483
|
|
|
485
484
|
|
|
486
|
-
def to_numpy(tensor:
|
|
485
|
+
def to_numpy(tensor: torch.Tensor) -> np.ndarray:
|
|
487
486
|
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
|
|
488
487
|
try:
|
|
489
488
|
return tensor.detach().cpu().numpy()
|
|
@@ -498,6 +497,21 @@ def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
|
|
|
498
497
|
return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
|
|
499
498
|
|
|
500
499
|
|
|
500
|
+
def from_numpy(tensor: np.ndarray) -> torch.Tensor:
|
|
501
|
+
"""Converts a :class:`numpy.ndarray` to :class:`torch.Tensor`."""
|
|
502
|
+
try:
|
|
503
|
+
return torch.from_numpy(tensor)
|
|
504
|
+
except TypeError:
|
|
505
|
+
# We try with ml_dtypes
|
|
506
|
+
pass
|
|
507
|
+
|
|
508
|
+
import ml_dtypes
|
|
509
|
+
|
|
510
|
+
conv = {ml_dtypes.bfloat16: torch.bfloat16}
|
|
511
|
+
assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
|
|
512
|
+
return torch.from_numpy(tensor.astype(torch.float32)).to(conv[tensor.dtype])
|
|
513
|
+
|
|
514
|
+
|
|
501
515
|
def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
|
|
502
516
|
"""Replaces strings by ``torch.export.Dim.DYNAMIC``."""
|
|
503
517
|
import torch
|
|
@@ -990,3 +1004,86 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
|
|
|
990
1004
|
counts[dt] += 1
|
|
991
1005
|
final = max(list(counts.items()))
|
|
992
1006
|
return final[0]
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
def closest_factor_pair(n: int):
|
|
1010
|
+
"""Tries to find ``a, b`` such as ``n == a * b``."""
|
|
1011
|
+
assert n > 0, f"n={n} must be a positive integer"
|
|
1012
|
+
start = math.isqrt(n)
|
|
1013
|
+
for a in range(start, 0, -1):
|
|
1014
|
+
if n % a == 0:
|
|
1015
|
+
b = n // a
|
|
1016
|
+
return a, b
|
|
1017
|
+
return 1, n
|
|
1018
|
+
|
|
1019
|
+
|
|
1020
|
+
def study_discrepancies(
|
|
1021
|
+
t1: torch.Tensor,
|
|
1022
|
+
t2: torch.Tensor,
|
|
1023
|
+
bins: int = 50,
|
|
1024
|
+
figsize: Optional[Tuple[int, int]] = (15, 15),
|
|
1025
|
+
title: Optional[str] = None,
|
|
1026
|
+
name: Optional[str] = None,
|
|
1027
|
+
) -> "matplotlib.axes.Axes": # noqa: F821
|
|
1028
|
+
"""
|
|
1029
|
+
Computes different metrics for the discrepancies.
|
|
1030
|
+
Returns graphs.
|
|
1031
|
+
|
|
1032
|
+
.. plot::
|
|
1033
|
+
:include-source:
|
|
1034
|
+
|
|
1035
|
+
import torch
|
|
1036
|
+
from onnx_diagnostic.helpers.torch_helper import study_discrepancies
|
|
1037
|
+
|
|
1038
|
+
t1 = torch.randn((512, 1024)) * 10
|
|
1039
|
+
t2 = t1 + torch.randn((512, 1024))
|
|
1040
|
+
study_discrepancies(t1, t2, title="Random noise")
|
|
1041
|
+
"""
|
|
1042
|
+
assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}"
|
|
1043
|
+
assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"
|
|
1044
|
+
d1, d2 = (
|
|
1045
|
+
(t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32))
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
d1 = d1.squeeze()
|
|
1049
|
+
d2 = d2.squeeze()
|
|
1050
|
+
if len(d1.shape) == 1:
|
|
1051
|
+
new_shape = closest_factor_pair(d1.shape[0])
|
|
1052
|
+
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
|
|
1053
|
+
elif len(d1.shape) > 2:
|
|
1054
|
+
new_shape = (-1, max(d1.shape))
|
|
1055
|
+
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
|
|
1056
|
+
|
|
1057
|
+
import matplotlib.pyplot as plt
|
|
1058
|
+
|
|
1059
|
+
fig, ax = plt.subplots(3, 2, figsize=figsize)
|
|
1060
|
+
vmin, vmax = d1.min().item(), d1.max().item()
|
|
1061
|
+
ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
|
|
1062
|
+
ax[0, 0].set_title(
|
|
1063
|
+
f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}"
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
diff = d2 - d1
|
|
1067
|
+
vmin, vmax = diff.min().item(), diff.max().item()
|
|
1068
|
+
ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax)
|
|
1069
|
+
ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]")
|
|
1070
|
+
|
|
1071
|
+
ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins)
|
|
1072
|
+
ax[1, 0].set_title("Distribution of the first tensor")
|
|
1073
|
+
|
|
1074
|
+
ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins)
|
|
1075
|
+
ax[1, 1].set_title("Distribution of the differences")
|
|
1076
|
+
|
|
1077
|
+
tf1 = d1.ravel()
|
|
1078
|
+
td1 = diff.ravel()
|
|
1079
|
+
ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".")
|
|
1080
|
+
ax[2, 1].set_title("Graph XY")
|
|
1081
|
+
ax[2, 1].set_xlabel("First tensor values")
|
|
1082
|
+
ax[2, 1].set_ylabel("Difference values")
|
|
1083
|
+
|
|
1084
|
+
if title:
|
|
1085
|
+
fig.suptitle(title)
|
|
1086
|
+
fig.tight_layout()
|
|
1087
|
+
if name:
|
|
1088
|
+
fig.savefig(name)
|
|
1089
|
+
return ax
|