onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
2
2
|
import onnx
|
|
3
3
|
import numpy as np
|
|
4
|
-
import numpy.typing as npt
|
|
5
4
|
import torch
|
|
6
5
|
from torch._C import _from_dlpack
|
|
7
6
|
import onnxruntime
|
|
@@ -16,6 +15,7 @@ from .torch_helper import torch_dtype_to_onnx_dtype
|
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
|
|
18
|
+
TensorLike = Union[np.ndarray, torch.Tensor]
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class _InferenceSession:
|
|
@@ -108,7 +108,10 @@ class _InferenceSession:
|
|
|
108
108
|
session_options,
|
|
109
109
|
providers=providers,
|
|
110
110
|
)
|
|
111
|
-
except
|
|
111
|
+
except (
|
|
112
|
+
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
|
|
113
|
+
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
|
|
114
|
+
) as e:
|
|
112
115
|
if isinstance(sess, onnx.ModelProto):
|
|
113
116
|
debug_path = "_debug_InferenceSession_last_failure.onnx"
|
|
114
117
|
onnx.save(
|
|
@@ -134,7 +137,13 @@ class _InferenceSession:
|
|
|
134
137
|
|
|
135
138
|
self.sess = sess
|
|
136
139
|
self.input_names = [i.name for i in sess.get_inputs()]
|
|
140
|
+
assert (
|
|
141
|
+
"" not in self.input_names
|
|
142
|
+
), f"Input name cannot be empty but input_names={self.input_names}"
|
|
137
143
|
self.output_names = [i.name for i in sess.get_outputs()]
|
|
144
|
+
assert (
|
|
145
|
+
"" not in self.input_names
|
|
146
|
+
), f"Output name cannot be empty but output_names={self.output_names}"
|
|
138
147
|
self.input_shapes = [i.shape for i in sess.get_inputs()]
|
|
139
148
|
self.output_shapes = [i.shape for i in sess.get_outputs()]
|
|
140
149
|
self.input_types = [i.type for i in sess.get_inputs()]
|
|
@@ -234,16 +243,16 @@ class InferenceSessionForNumpy(_InferenceSession):
|
|
|
234
243
|
)
|
|
235
244
|
|
|
236
245
|
def run(
|
|
237
|
-
self, output_names: Optional[List[str]], feeds: Dict[str,
|
|
238
|
-
) -> List[Optional[
|
|
246
|
+
self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
|
|
247
|
+
) -> List[Optional[TensorLike]]:
|
|
239
248
|
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
|
|
240
249
|
# sess.run does not support blfoat16
|
|
241
250
|
# res = self.sess.run(output_names, feeds)
|
|
242
251
|
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
|
|
243
252
|
|
|
244
253
|
def run_dlpack(
|
|
245
|
-
self, output_names: Optional[List[str]], feeds: Dict[str,
|
|
246
|
-
) -> Tuple[Optional[
|
|
254
|
+
self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
|
|
255
|
+
) -> Tuple[Optional[TensorLike], ...]:
|
|
247
256
|
"""
|
|
248
257
|
Same as :meth:`onnxruntime.InferenceSession.run` except that
|
|
249
258
|
feeds is a dictionary of :class:`np.ndarray`.
|
|
@@ -280,13 +289,13 @@ class InferenceSessionForNumpy(_InferenceSession):
|
|
|
280
289
|
def _ortvalues_to_numpy_tensor(
|
|
281
290
|
self,
|
|
282
291
|
ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
|
|
283
|
-
) -> Tuple[Optional[
|
|
292
|
+
) -> Tuple[Optional[TensorLike], ...]:
|
|
284
293
|
if len(ortvalues) == 0:
|
|
285
294
|
return tuple()
|
|
286
295
|
|
|
287
296
|
if self.nvtx:
|
|
288
297
|
self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
|
|
289
|
-
res: List[Optional[
|
|
298
|
+
res: List[Optional[TensorLike]] = [] # noqa: F823
|
|
290
299
|
for i in range(len(ortvalues)):
|
|
291
300
|
if not ortvalues[i].has_value():
|
|
292
301
|
res.append(None)
|
|
@@ -338,6 +347,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
338
347
|
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
|
|
339
348
|
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
|
|
340
349
|
:param use_training_api: use onnxruntime-traning API
|
|
350
|
+
:param cpu_output: if True, force the outputs to be on CPU
|
|
341
351
|
"""
|
|
342
352
|
|
|
343
353
|
def __init__(
|
|
@@ -353,6 +363,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
353
363
|
optimized_model_filepath: Optional[str] = None,
|
|
354
364
|
disable_aot_function_inlining: Optional[bool] = None,
|
|
355
365
|
use_training_api: Optional[bool] = None,
|
|
366
|
+
cpu_outputs: bool = False,
|
|
356
367
|
):
|
|
357
368
|
super().__init__(
|
|
358
369
|
sess,
|
|
@@ -367,6 +378,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
367
378
|
disable_aot_function_inlining=disable_aot_function_inlining,
|
|
368
379
|
use_training_api=use_training_api,
|
|
369
380
|
)
|
|
381
|
+
self.cpu_outputs = cpu_outputs
|
|
370
382
|
|
|
371
383
|
def _get_ortvalues_from_torch_tensors(
|
|
372
384
|
self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
|
|
@@ -490,23 +502,37 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
490
502
|
feeds is a dictionary of :class:`torch.Tensor`.
|
|
491
503
|
The output device is CPU even if the outputs are on CUDA.
|
|
492
504
|
"""
|
|
493
|
-
|
|
505
|
+
input_names = []
|
|
506
|
+
values = ORTC.OrtValueVector()
|
|
507
|
+
device = -1
|
|
494
508
|
for k, v in feeds.items():
|
|
509
|
+
assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
|
|
510
|
+
device = max(device, v.get_device())
|
|
495
511
|
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
|
|
496
512
|
if not v.is_contiguous():
|
|
497
513
|
v = v.contiguous()
|
|
498
514
|
if v.dtype == torch.bool:
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
|
|
502
|
-
v.detach().numpy(), onnx.TensorProto.BOOL
|
|
503
|
-
)
|
|
515
|
+
v = v.to(torch.uint8)
|
|
516
|
+
v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True)
|
|
504
517
|
else:
|
|
505
|
-
|
|
518
|
+
v = ORTC.OrtValue.from_dlpack(v.detach().__dlpack__(), False)
|
|
519
|
+
input_names.append(k)
|
|
520
|
+
values.push_back(v)
|
|
506
521
|
if self.nvtx:
|
|
507
|
-
self.torch.cuda.nvtx.range_push("
|
|
508
|
-
|
|
509
|
-
|
|
522
|
+
self.torch.cuda.nvtx.range_push("run_with_ortvaluevector")
|
|
523
|
+
|
|
524
|
+
# ort_outputs = self.sess._sess.run_with_ort_values(
|
|
525
|
+
# new_feeds, output_names or self.output_names, self.run_options
|
|
526
|
+
# )
|
|
527
|
+
ort_outputs = ORTC.OrtValueVector()
|
|
528
|
+
out_names = output_names or self.output_names
|
|
529
|
+
self.sess._sess.run_with_ortvaluevector(
|
|
530
|
+
self.run_options,
|
|
531
|
+
input_names,
|
|
532
|
+
values,
|
|
533
|
+
out_names,
|
|
534
|
+
ort_outputs,
|
|
535
|
+
[DEVICES[-1 if self.cpu_outputs else device] for o in out_names],
|
|
510
536
|
)
|
|
511
537
|
if self.nvtx:
|
|
512
538
|
self.torch.cuda.nvtx.range_pop()
|
|
@@ -530,7 +556,7 @@ def investigate_onnxruntime_issue(
|
|
|
530
556
|
Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]]
|
|
531
557
|
] = None,
|
|
532
558
|
# if model needs to be run.
|
|
533
|
-
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str,
|
|
559
|
+
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, TensorLike]]] = None,
|
|
534
560
|
verbose: int = 0,
|
|
535
561
|
dump_filename: Optional[str] = None,
|
|
536
562
|
infer_shapes: bool = True,
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from .helper import string_type
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def validate_fx_tensor(
|
|
7
|
+
node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
|
|
8
|
+
) -> None:
|
|
9
|
+
"""
|
|
10
|
+
Validates the shape of tensor is expected.
|
|
11
|
+
|
|
12
|
+
:param node: node
|
|
13
|
+
:param tensor: tensor
|
|
14
|
+
:param expected_shape: expected shape
|
|
15
|
+
"""
|
|
16
|
+
assert len(tensor.shape) == len(expected_shape), (
|
|
17
|
+
f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
|
|
18
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
19
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
20
|
+
f"node.meta={node.meta}"
|
|
21
|
+
)
|
|
22
|
+
for a, b in zip(tensor.shape, expected_shape):
|
|
23
|
+
assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
|
|
24
|
+
f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
|
|
25
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
26
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
27
|
+
f"node.meta={node.meta}"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
|
|
32
|
+
"""
|
|
33
|
+
Validates the outputs of a node using metadata stored in the node.
|
|
34
|
+
|
|
35
|
+
:param node: node
|
|
36
|
+
:param outputs: outputs
|
|
37
|
+
"""
|
|
38
|
+
if "val" not in node.meta:
|
|
39
|
+
return
|
|
40
|
+
if isinstance(outputs, torch.Tensor):
|
|
41
|
+
validate_fx_tensor(node, outputs, node.meta["val"].shape)
|
|
42
|
+
return
|
|
43
|
+
if isinstance(outputs, (tuple, list)):
|
|
44
|
+
assert isinstance(node.meta["val"], (list, tuple)), (
|
|
45
|
+
f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
|
|
46
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
47
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
48
|
+
f"node.meta={node.meta}"
|
|
49
|
+
)
|
|
50
|
+
assert len(outputs) == len(node.meta["val"]), (
|
|
51
|
+
f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
|
|
52
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
53
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
54
|
+
f"node.meta={node.meta}"
|
|
55
|
+
)
|
|
56
|
+
for a, b in zip(outputs, node.meta["val"]):
|
|
57
|
+
validate_fx_tensor(node, a, b.shape)
|
|
58
|
+
return
|
|
59
|
+
if isinstance(outputs, int):
|
|
60
|
+
assert (
|
|
61
|
+
isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
|
|
62
|
+
or outputs == node.meta["val"]
|
|
63
|
+
), (
|
|
64
|
+
f"Int mismatch, got {outputs} expected {node.meta['val']}, "
|
|
65
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
66
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
67
|
+
f"node.meta={node.meta}"
|
|
68
|
+
)
|
|
69
|
+
return
|
|
70
|
+
if outputs is None:
|
|
71
|
+
assert node.meta["val"] is None, (
|
|
72
|
+
f"None mismatch, got {outputs} expected {node.meta['val']}, "
|
|
73
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
74
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
75
|
+
f"node.meta={node.meta}"
|
|
76
|
+
)
|
|
77
|
+
return
|
|
78
|
+
raise NotImplementedError(
|
|
79
|
+
f"Validation for output type {type(outputs)} is not implemented, "
|
|
80
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
81
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
82
|
+
f"node.meta={node.meta}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def run_fx_node(
|
|
87
|
+
node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
|
|
88
|
+
) -> Tuple[Any, ...]:
|
|
89
|
+
"""
|
|
90
|
+
Executes a node
|
|
91
|
+
|
|
92
|
+
:param node: runs a node
|
|
93
|
+
:param args: unnamed inputs to the node
|
|
94
|
+
:param kwargs: named inputs to the node
|
|
95
|
+
:return: results
|
|
96
|
+
"""
|
|
97
|
+
if node.op == "output":
|
|
98
|
+
assert len(args) == 1 and not kwargs, (
|
|
99
|
+
f"Unexpected inputs: args={string_type(args, limit=20)} "
|
|
100
|
+
f"kwargs={string_type(kwargs, limit=20)}"
|
|
101
|
+
)
|
|
102
|
+
return args
|
|
103
|
+
if node.op == "call_function":
|
|
104
|
+
assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
|
|
105
|
+
for a, ea in zip(args, node.args):
|
|
106
|
+
if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta:
|
|
107
|
+
ta = ea.meta["val"]
|
|
108
|
+
assert (
|
|
109
|
+
isinstance(ta, torch.Tensor)
|
|
110
|
+
and len(a.shape) == len(ta.shape)
|
|
111
|
+
and a.dtype == ta.dtype
|
|
112
|
+
), (
|
|
113
|
+
f"Unable to run node {node!r}, target={node.target!r}, "
|
|
114
|
+
f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, "
|
|
115
|
+
f"args={string_type(args, with_shape=True, with_device=True)}, "
|
|
116
|
+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
|
|
117
|
+
)
|
|
118
|
+
try:
|
|
119
|
+
outputs = node.target(*args, **(kwargs or {}))
|
|
120
|
+
except RuntimeError as e:
|
|
121
|
+
raise RuntimeError(
|
|
122
|
+
f"Unable to run node {node!r}, target={node.target!r}, "
|
|
123
|
+
f"args={string_type(args, with_shape=True, with_device=True)}, "
|
|
124
|
+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
|
|
125
|
+
) from e
|
|
126
|
+
validate_fx_outputs(node, outputs)
|
|
127
|
+
return outputs
|
|
128
|
+
raise NotImplementedError(
|
|
129
|
+
f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
|
|
134
|
+
"See :func:`prepare_args_kwargs`."
|
|
135
|
+
if isinstance(ref, torch.fx.Node):
|
|
136
|
+
return torch_results[ref.name]
|
|
137
|
+
if isinstance(ref, list):
|
|
138
|
+
return [_pick_result(torch_results, n) for n in ref]
|
|
139
|
+
if isinstance(ref, tuple):
|
|
140
|
+
return tuple(_pick_result(torch_results, n) for n in ref)
|
|
141
|
+
if isinstance(ref, dict):
|
|
142
|
+
return {k: _pick_result(torch_results, v) for k, v in ref.items()}
|
|
143
|
+
if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
|
|
144
|
+
return ref
|
|
145
|
+
if ref is None:
|
|
146
|
+
return None
|
|
147
|
+
if isinstance(ref, torch.layout):
|
|
148
|
+
return ref
|
|
149
|
+
raise NotImplementedError(f"Unable to process args type {type(ref)}")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def prepare_args_kwargs(
|
|
153
|
+
torch_results: Dict[str, Any], node: torch.fx.Node
|
|
154
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
|
155
|
+
"""
|
|
156
|
+
Prepares args and kwargs before executing a fx node.
|
|
157
|
+
|
|
158
|
+
:param torch_results: existing results
|
|
159
|
+
:param node: node to execute
|
|
160
|
+
:return: new args and kwargs
|
|
161
|
+
"""
|
|
162
|
+
new_args = _pick_result(torch_results, node.args)
|
|
163
|
+
new_kwargs = _pick_result(torch_results, node.kwargs)
|
|
164
|
+
return new_args, new_kwargs
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import ctypes
|
|
3
3
|
import inspect
|
|
4
|
+
import math
|
|
4
5
|
import os
|
|
5
6
|
import sys
|
|
6
7
|
import warnings
|
|
@@ -30,9 +31,7 @@ from .onnx_helper import (
|
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def proto_from_tensor(
|
|
33
|
-
arr:
|
|
34
|
-
name: Optional[str] = None,
|
|
35
|
-
verbose: int = 0,
|
|
34
|
+
arr: torch.Tensor, name: Optional[str] = None, verbose: int = 0
|
|
36
35
|
) -> onnx.TensorProto:
|
|
37
36
|
"""
|
|
38
37
|
Converts a torch Tensor into a TensorProto.
|
|
@@ -98,7 +97,7 @@ def proto_from_tensor(
|
|
|
98
97
|
return tensor
|
|
99
98
|
|
|
100
99
|
|
|
101
|
-
def onnx_dtype_to_torch_dtype(itype: int) ->
|
|
100
|
+
def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype:
|
|
102
101
|
"""
|
|
103
102
|
Converts an onnx type into a torch dtype.
|
|
104
103
|
|
|
@@ -140,7 +139,16 @@ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
|
|
|
140
139
|
)
|
|
141
140
|
|
|
142
141
|
|
|
143
|
-
|
|
142
|
+
_TYPENAME = dict(
|
|
143
|
+
FLOAT=onnx.TensorProto.FLOAT,
|
|
144
|
+
INT64=onnx.TensorProto.INT64,
|
|
145
|
+
INT32=onnx.TensorProto.INT32,
|
|
146
|
+
FLOAT16=onnx.TensorProto.FLOAT16,
|
|
147
|
+
BFLOAT16=onnx.TensorProto.BFLOAT16,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
|
|
144
152
|
"""
|
|
145
153
|
Converts a torch dtype into a onnx element type.
|
|
146
154
|
|
|
@@ -183,7 +191,13 @@ def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
|
|
|
183
191
|
return onnx.TensorProto.COMPLEX64
|
|
184
192
|
if to == torch.complex128:
|
|
185
193
|
return onnx.TensorProto.COMPLEX128
|
|
186
|
-
|
|
194
|
+
# SymbolicTensor
|
|
195
|
+
sto = str(to)
|
|
196
|
+
if sto in _TYPENAME:
|
|
197
|
+
return _TYPENAME[sto]
|
|
198
|
+
raise NotImplementedError(
|
|
199
|
+
f"Unable to convert torch dtype {to!r} ({type(to)}) to onnx dtype."
|
|
200
|
+
)
|
|
187
201
|
|
|
188
202
|
|
|
189
203
|
def _forward_(
|
|
@@ -483,7 +497,7 @@ def is_torchdynamo_exporting() -> bool:
|
|
|
483
497
|
return False
|
|
484
498
|
|
|
485
499
|
|
|
486
|
-
def to_numpy(tensor:
|
|
500
|
+
def to_numpy(tensor: torch.Tensor) -> np.ndarray:
|
|
487
501
|
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
|
|
488
502
|
try:
|
|
489
503
|
return tensor.detach().cpu().numpy()
|
|
@@ -498,6 +512,21 @@ def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
|
|
|
498
512
|
return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
|
|
499
513
|
|
|
500
514
|
|
|
515
|
+
def from_numpy(tensor: np.ndarray) -> torch.Tensor:
|
|
516
|
+
"""Converts a :class:`numpy.ndarray` to :class:`torch.Tensor`."""
|
|
517
|
+
try:
|
|
518
|
+
return torch.from_numpy(tensor)
|
|
519
|
+
except TypeError:
|
|
520
|
+
# We try with ml_dtypes
|
|
521
|
+
pass
|
|
522
|
+
|
|
523
|
+
import ml_dtypes
|
|
524
|
+
|
|
525
|
+
conv = {ml_dtypes.bfloat16: torch.bfloat16}
|
|
526
|
+
assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
|
|
527
|
+
return torch.from_numpy(tensor.astype(torch.float32)).to(conv[tensor.dtype])
|
|
528
|
+
|
|
529
|
+
|
|
501
530
|
def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
|
|
502
531
|
"""Replaces strings by ``torch.export.Dim.DYNAMIC``."""
|
|
503
532
|
import torch
|
|
@@ -797,7 +826,8 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
797
826
|
if isinstance(value, tuple):
|
|
798
827
|
return tuple(torch_deepcopy(v) for v in value)
|
|
799
828
|
if isinstance(value, list):
|
|
800
|
-
|
|
829
|
+
if type(value) is list:
|
|
830
|
+
return [torch_deepcopy(v) for v in value]
|
|
801
831
|
if isinstance(value, set):
|
|
802
832
|
return {torch_deepcopy(v) for v in value}
|
|
803
833
|
if isinstance(value, dict):
|
|
@@ -990,3 +1020,96 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
|
|
|
990
1020
|
counts[dt] += 1
|
|
991
1021
|
final = max(list(counts.items()))
|
|
992
1022
|
return final[0]
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
def closest_factor_pair(n: int):
|
|
1026
|
+
"""Tries to find ``a, b`` such as ``n == a * b``."""
|
|
1027
|
+
assert n > 0, f"n={n} must be a positive integer"
|
|
1028
|
+
start = math.isqrt(n)
|
|
1029
|
+
for a in range(start, 0, -1):
|
|
1030
|
+
if n % a == 0:
|
|
1031
|
+
b = n // a
|
|
1032
|
+
return a, b
|
|
1033
|
+
return 1, n
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
def study_discrepancies(
|
|
1037
|
+
t1: torch.Tensor,
|
|
1038
|
+
t2: torch.Tensor,
|
|
1039
|
+
bins: int = 50,
|
|
1040
|
+
figsize: Optional[Tuple[int, int]] = (15, 15),
|
|
1041
|
+
title: Optional[str] = None,
|
|
1042
|
+
name: Optional[str] = None,
|
|
1043
|
+
) -> "matplotlib.axes.Axes": # noqa: F821
|
|
1044
|
+
"""
|
|
1045
|
+
Computes different metrics for the discrepancies.
|
|
1046
|
+
Returns graphs.
|
|
1047
|
+
|
|
1048
|
+
.. plot::
|
|
1049
|
+
:include-source:
|
|
1050
|
+
|
|
1051
|
+
import torch
|
|
1052
|
+
from onnx_diagnostic.helpers.torch_helper import study_discrepancies
|
|
1053
|
+
|
|
1054
|
+
t1 = torch.randn((512, 1024)) * 10
|
|
1055
|
+
t2 = t1 + torch.randn((512, 1024))
|
|
1056
|
+
study_discrepancies(t1, t2, title="Random noise")
|
|
1057
|
+
"""
|
|
1058
|
+
assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}"
|
|
1059
|
+
assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"
|
|
1060
|
+
d1, d2 = (
|
|
1061
|
+
(t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32))
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
d1 = d1.squeeze()
|
|
1065
|
+
d2 = d2.squeeze()
|
|
1066
|
+
if len(d1.shape) == 1:
|
|
1067
|
+
new_shape = closest_factor_pair(d1.shape[0])
|
|
1068
|
+
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
|
|
1069
|
+
elif len(d1.shape) > 2:
|
|
1070
|
+
new_shape = (-1, max(d1.shape))
|
|
1071
|
+
d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
|
|
1072
|
+
|
|
1073
|
+
import matplotlib.pyplot as plt
|
|
1074
|
+
|
|
1075
|
+
fig, ax = plt.subplots(3, 2, figsize=figsize)
|
|
1076
|
+
vmin, vmax = d1.min().item(), d1.max().item()
|
|
1077
|
+
ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
|
|
1078
|
+
ax[0, 0].set_title(
|
|
1079
|
+
f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}"
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
diff = d2 - d1
|
|
1083
|
+
vmin, vmax = diff.min().item(), diff.max().item()
|
|
1084
|
+
ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax)
|
|
1085
|
+
ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]")
|
|
1086
|
+
|
|
1087
|
+
ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins)
|
|
1088
|
+
ax[1, 0].set_title("Distribution of the first tensor")
|
|
1089
|
+
|
|
1090
|
+
ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins)
|
|
1091
|
+
ax[1, 1].set_title("Distribution of the differences")
|
|
1092
|
+
|
|
1093
|
+
tf1 = d1.ravel()
|
|
1094
|
+
td1 = diff.ravel()
|
|
1095
|
+
ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".")
|
|
1096
|
+
ax[2, 1].set_title("Graph XY")
|
|
1097
|
+
ax[2, 1].set_xlabel("First tensor values")
|
|
1098
|
+
ax[2, 1].set_ylabel("Difference values")
|
|
1099
|
+
|
|
1100
|
+
if title:
|
|
1101
|
+
fig.suptitle(title)
|
|
1102
|
+
fig.tight_layout()
|
|
1103
|
+
if name:
|
|
1104
|
+
fig.savefig(name)
|
|
1105
|
+
return ax
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def int_device_to_torch_device(device_id: int) -> torch.device:
|
|
1109
|
+
"""
|
|
1110
|
+
Converts a device defined as an integer (coming from :meth:`torch.Tensor.get_device`)
|
|
1111
|
+
into a ``torch.device``.
|
|
1112
|
+
"""
|
|
1113
|
+
if device_id < 0:
|
|
1114
|
+
return torch.device("cpu")
|
|
1115
|
+
return torch.device("cuda", device_id)
|