onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
2
  import onnx
3
3
  import numpy as np
4
- import numpy.typing as npt
5
4
  import torch
6
5
  from torch._C import _from_dlpack
7
6
  import onnxruntime
@@ -16,6 +15,7 @@ from .torch_helper import torch_dtype_to_onnx_dtype
16
15
 
17
16
 
18
17
  DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
18
+ TensorLike = Union[np.ndarray, torch.Tensor]
19
19
 
20
20
 
21
21
  class _InferenceSession:
@@ -108,7 +108,10 @@ class _InferenceSession:
108
108
  session_options,
109
109
  providers=providers,
110
110
  )
111
- except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
111
+ except (
112
+ onnxruntime.capi.onnxruntime_pybind11_state.Fail,
113
+ onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
114
+ ) as e:
112
115
  if isinstance(sess, onnx.ModelProto):
113
116
  debug_path = "_debug_InferenceSession_last_failure.onnx"
114
117
  onnx.save(
@@ -134,7 +137,13 @@ class _InferenceSession:
134
137
 
135
138
  self.sess = sess
136
139
  self.input_names = [i.name for i in sess.get_inputs()]
140
+ assert (
141
+ "" not in self.input_names
142
+ ), f"Input name cannot be empty but input_names={self.input_names}"
137
143
  self.output_names = [i.name for i in sess.get_outputs()]
144
+ assert (
145
+ "" not in self.input_names
146
+ ), f"Output name cannot be empty but output_names={self.output_names}"
138
147
  self.input_shapes = [i.shape for i in sess.get_inputs()]
139
148
  self.output_shapes = [i.shape for i in sess.get_outputs()]
140
149
  self.input_types = [i.type for i in sess.get_inputs()]
@@ -234,16 +243,16 @@ class InferenceSessionForNumpy(_InferenceSession):
234
243
  )
235
244
 
236
245
  def run(
237
- self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
238
- ) -> List[Optional[npt.ArrayLike]]:
246
+ self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
247
+ ) -> List[Optional[TensorLike]]:
239
248
  """Calls :meth:`onnxruntime.InferenceSession.run`."""
240
249
  # sess.run does not support blfoat16
241
250
  # res = self.sess.run(output_names, feeds)
242
251
  return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
243
252
 
244
253
  def run_dlpack(
245
- self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
246
- ) -> Tuple[Optional[npt.ArrayLike], ...]:
254
+ self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
255
+ ) -> Tuple[Optional[TensorLike], ...]:
247
256
  """
248
257
  Same as :meth:`onnxruntime.InferenceSession.run` except that
249
258
  feeds is a dictionary of :class:`np.ndarray`.
@@ -280,13 +289,13 @@ class InferenceSessionForNumpy(_InferenceSession):
280
289
  def _ortvalues_to_numpy_tensor(
281
290
  self,
282
291
  ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
283
- ) -> Tuple[Optional[npt.ArrayLike], ...]:
292
+ ) -> Tuple[Optional[TensorLike], ...]:
284
293
  if len(ortvalues) == 0:
285
294
  return tuple()
286
295
 
287
296
  if self.nvtx:
288
297
  self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
289
- res: List[Optional[npt.ArrayLike]] = [] # noqa: F823
298
+ res: List[Optional[TensorLike]] = [] # noqa: F823
290
299
  for i in range(len(ortvalues)):
291
300
  if not ortvalues[i].has_value():
292
301
  res.append(None)
@@ -338,6 +347,7 @@ class InferenceSessionForTorch(_InferenceSession):
338
347
  :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
339
348
  :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
340
349
  :param use_training_api: use onnxruntime-traning API
350
+ :param cpu_output: if True, force the outputs to be on CPU
341
351
  """
342
352
 
343
353
  def __init__(
@@ -353,6 +363,7 @@ class InferenceSessionForTorch(_InferenceSession):
353
363
  optimized_model_filepath: Optional[str] = None,
354
364
  disable_aot_function_inlining: Optional[bool] = None,
355
365
  use_training_api: Optional[bool] = None,
366
+ cpu_outputs: bool = False,
356
367
  ):
357
368
  super().__init__(
358
369
  sess,
@@ -367,6 +378,7 @@ class InferenceSessionForTorch(_InferenceSession):
367
378
  disable_aot_function_inlining=disable_aot_function_inlining,
368
379
  use_training_api=use_training_api,
369
380
  )
381
+ self.cpu_outputs = cpu_outputs
370
382
 
371
383
  def _get_ortvalues_from_torch_tensors(
372
384
  self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
@@ -490,23 +502,37 @@ class InferenceSessionForTorch(_InferenceSession):
490
502
  feeds is a dictionary of :class:`torch.Tensor`.
491
503
  The output device is CPU even if the outputs are on CUDA.
492
504
  """
493
- new_feeds = {}
505
+ input_names = []
506
+ values = ORTC.OrtValueVector()
507
+ device = -1
494
508
  for k, v in feeds.items():
509
+ assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
510
+ device = max(device, v.get_device())
495
511
  assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
496
512
  if not v.is_contiguous():
497
513
  v = v.contiguous()
498
514
  if v.dtype == torch.bool:
499
- # It does not work with dlpack
500
- # unless onnxruntime updates the version it is using.
501
- new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
502
- v.detach().numpy(), onnx.TensorProto.BOOL
503
- )
515
+ v = v.to(torch.uint8)
516
+ v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True)
504
517
  else:
505
- new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
518
+ v = ORTC.OrtValue.from_dlpack(v.detach().__dlpack__(), False)
519
+ input_names.append(k)
520
+ values.push_back(v)
506
521
  if self.nvtx:
507
- self.torch.cuda.nvtx.range_push("run_with_ort_values")
508
- ort_outputs = self.sess._sess.run_with_ort_values(
509
- new_feeds, output_names or self.output_names, self.run_options
522
+ self.torch.cuda.nvtx.range_push("run_with_ortvaluevector")
523
+
524
+ # ort_outputs = self.sess._sess.run_with_ort_values(
525
+ # new_feeds, output_names or self.output_names, self.run_options
526
+ # )
527
+ ort_outputs = ORTC.OrtValueVector()
528
+ out_names = output_names or self.output_names
529
+ self.sess._sess.run_with_ortvaluevector(
530
+ self.run_options,
531
+ input_names,
532
+ values,
533
+ out_names,
534
+ ort_outputs,
535
+ [DEVICES[-1 if self.cpu_outputs else device] for o in out_names],
510
536
  )
511
537
  if self.nvtx:
512
538
  self.torch.cuda.nvtx.range_pop()
@@ -530,7 +556,7 @@ def investigate_onnxruntime_issue(
530
556
  Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]]
531
557
  ] = None,
532
558
  # if model needs to be run.
533
- feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, npt.ArrayLike]]] = None,
559
+ feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, TensorLike]]] = None,
534
560
  verbose: int = 0,
535
561
  dump_filename: Optional[str] = None,
536
562
  infer_shapes: bool = True,
@@ -0,0 +1,164 @@
1
+ from typing import Any, Dict, Optional, Tuple
2
+ import torch
3
+ from .helper import string_type
4
+
5
+
6
+ def validate_fx_tensor(
7
+ node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
8
+ ) -> None:
9
+ """
10
+ Validates the shape of tensor is expected.
11
+
12
+ :param node: node
13
+ :param tensor: tensor
14
+ :param expected_shape: expected shape
15
+ """
16
+ assert len(tensor.shape) == len(expected_shape), (
17
+ f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
18
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
19
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
20
+ f"node.meta={node.meta}"
21
+ )
22
+ for a, b in zip(tensor.shape, expected_shape):
23
+ assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
24
+ f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
25
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
26
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
27
+ f"node.meta={node.meta}"
28
+ )
29
+
30
+
31
+ def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
32
+ """
33
+ Validates the outputs of a node using metadata stored in the node.
34
+
35
+ :param node: node
36
+ :param outputs: outputs
37
+ """
38
+ if "val" not in node.meta:
39
+ return
40
+ if isinstance(outputs, torch.Tensor):
41
+ validate_fx_tensor(node, outputs, node.meta["val"].shape)
42
+ return
43
+ if isinstance(outputs, (tuple, list)):
44
+ assert isinstance(node.meta["val"], (list, tuple)), (
45
+ f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
46
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
47
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
48
+ f"node.meta={node.meta}"
49
+ )
50
+ assert len(outputs) == len(node.meta["val"]), (
51
+ f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
52
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
53
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
54
+ f"node.meta={node.meta}"
55
+ )
56
+ for a, b in zip(outputs, node.meta["val"]):
57
+ validate_fx_tensor(node, a, b.shape)
58
+ return
59
+ if isinstance(outputs, int):
60
+ assert (
61
+ isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
62
+ or outputs == node.meta["val"]
63
+ ), (
64
+ f"Int mismatch, got {outputs} expected {node.meta['val']}, "
65
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
66
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
67
+ f"node.meta={node.meta}"
68
+ )
69
+ return
70
+ if outputs is None:
71
+ assert node.meta["val"] is None, (
72
+ f"None mismatch, got {outputs} expected {node.meta['val']}, "
73
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
74
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
75
+ f"node.meta={node.meta}"
76
+ )
77
+ return
78
+ raise NotImplementedError(
79
+ f"Validation for output type {type(outputs)} is not implemented, "
80
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
81
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
82
+ f"node.meta={node.meta}"
83
+ )
84
+
85
+
86
+ def run_fx_node(
87
+ node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
88
+ ) -> Tuple[Any, ...]:
89
+ """
90
+ Executes a node
91
+
92
+ :param node: runs a node
93
+ :param args: unnamed inputs to the node
94
+ :param kwargs: named inputs to the node
95
+ :return: results
96
+ """
97
+ if node.op == "output":
98
+ assert len(args) == 1 and not kwargs, (
99
+ f"Unexpected inputs: args={string_type(args, limit=20)} "
100
+ f"kwargs={string_type(kwargs, limit=20)}"
101
+ )
102
+ return args
103
+ if node.op == "call_function":
104
+ assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
105
+ for a, ea in zip(args, node.args):
106
+ if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta:
107
+ ta = ea.meta["val"]
108
+ assert (
109
+ isinstance(ta, torch.Tensor)
110
+ and len(a.shape) == len(ta.shape)
111
+ and a.dtype == ta.dtype
112
+ ), (
113
+ f"Unable to run node {node!r}, target={node.target!r}, "
114
+ f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, "
115
+ f"args={string_type(args, with_shape=True, with_device=True)}, "
116
+ f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
117
+ )
118
+ try:
119
+ outputs = node.target(*args, **(kwargs or {}))
120
+ except RuntimeError as e:
121
+ raise RuntimeError(
122
+ f"Unable to run node {node!r}, target={node.target!r}, "
123
+ f"args={string_type(args, with_shape=True, with_device=True)}, "
124
+ f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
125
+ ) from e
126
+ validate_fx_outputs(node, outputs)
127
+ return outputs
128
+ raise NotImplementedError(
129
+ f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
130
+ )
131
+
132
+
133
+ def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
134
+ "See :func:`prepare_args_kwargs`."
135
+ if isinstance(ref, torch.fx.Node):
136
+ return torch_results[ref.name]
137
+ if isinstance(ref, list):
138
+ return [_pick_result(torch_results, n) for n in ref]
139
+ if isinstance(ref, tuple):
140
+ return tuple(_pick_result(torch_results, n) for n in ref)
141
+ if isinstance(ref, dict):
142
+ return {k: _pick_result(torch_results, v) for k, v in ref.items()}
143
+ if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
144
+ return ref
145
+ if ref is None:
146
+ return None
147
+ if isinstance(ref, torch.layout):
148
+ return ref
149
+ raise NotImplementedError(f"Unable to process args type {type(ref)}")
150
+
151
+
152
+ def prepare_args_kwargs(
153
+ torch_results: Dict[str, Any], node: torch.fx.Node
154
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
155
+ """
156
+ Prepares args and kwargs before executing a fx node.
157
+
158
+ :param torch_results: existing results
159
+ :param node: node to execute
160
+ :return: new args and kwargs
161
+ """
162
+ new_args = _pick_result(torch_results, node.args)
163
+ new_kwargs = _pick_result(torch_results, node.kwargs)
164
+ return new_args, new_kwargs
@@ -1,6 +1,7 @@
1
1
  import contextlib
2
2
  import ctypes
3
3
  import inspect
4
+ import math
4
5
  import os
5
6
  import sys
6
7
  import warnings
@@ -30,9 +31,7 @@ from .onnx_helper import (
30
31
 
31
32
 
32
33
  def proto_from_tensor(
33
- arr: "torch.Tensor", # noqa: F821
34
- name: Optional[str] = None,
35
- verbose: int = 0,
34
+ arr: torch.Tensor, name: Optional[str] = None, verbose: int = 0
36
35
  ) -> onnx.TensorProto:
37
36
  """
38
37
  Converts a torch Tensor into a TensorProto.
@@ -98,7 +97,7 @@ def proto_from_tensor(
98
97
  return tensor
99
98
 
100
99
 
101
- def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
100
+ def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype:
102
101
  """
103
102
  Converts an onnx type into a torch dtype.
104
103
 
@@ -140,7 +139,16 @@ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
140
139
  )
141
140
 
142
141
 
143
- def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
142
+ _TYPENAME = dict(
143
+ FLOAT=onnx.TensorProto.FLOAT,
144
+ INT64=onnx.TensorProto.INT64,
145
+ INT32=onnx.TensorProto.INT32,
146
+ FLOAT16=onnx.TensorProto.FLOAT16,
147
+ BFLOAT16=onnx.TensorProto.BFLOAT16,
148
+ )
149
+
150
+
151
+ def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
144
152
  """
145
153
  Converts a torch dtype into a onnx element type.
146
154
 
@@ -183,7 +191,13 @@ def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
183
191
  return onnx.TensorProto.COMPLEX64
184
192
  if to == torch.complex128:
185
193
  return onnx.TensorProto.COMPLEX128
186
- raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
194
+ # SymbolicTensor
195
+ sto = str(to)
196
+ if sto in _TYPENAME:
197
+ return _TYPENAME[sto]
198
+ raise NotImplementedError(
199
+ f"Unable to convert torch dtype {to!r} ({type(to)}) to onnx dtype."
200
+ )
187
201
 
188
202
 
189
203
  def _forward_(
@@ -483,7 +497,7 @@ def is_torchdynamo_exporting() -> bool:
483
497
  return False
484
498
 
485
499
 
486
- def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
500
+ def to_numpy(tensor: torch.Tensor) -> np.ndarray:
487
501
  """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
488
502
  try:
489
503
  return tensor.detach().cpu().numpy()
@@ -498,6 +512,21 @@ def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
498
512
  return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
499
513
 
500
514
 
515
+ def from_numpy(tensor: np.ndarray) -> torch.Tensor:
516
+ """Converts a :class:`numpy.ndarray` to :class:`torch.Tensor`."""
517
+ try:
518
+ return torch.from_numpy(tensor)
519
+ except TypeError:
520
+ # We try with ml_dtypes
521
+ pass
522
+
523
+ import ml_dtypes
524
+
525
+ conv = {ml_dtypes.bfloat16: torch.bfloat16}
526
+ assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
527
+ return torch.from_numpy(tensor.astype(torch.float32)).to(conv[tensor.dtype])
528
+
529
+
501
530
  def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
502
531
  """Replaces strings by ``torch.export.Dim.DYNAMIC``."""
503
532
  import torch
@@ -797,7 +826,8 @@ def torch_deepcopy(value: Any) -> Any:
797
826
  if isinstance(value, tuple):
798
827
  return tuple(torch_deepcopy(v) for v in value)
799
828
  if isinstance(value, list):
800
- return [torch_deepcopy(v) for v in value]
829
+ if type(value) is list:
830
+ return [torch_deepcopy(v) for v in value]
801
831
  if isinstance(value, set):
802
832
  return {torch_deepcopy(v) for v in value}
803
833
  if isinstance(value, dict):
@@ -990,3 +1020,96 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
990
1020
  counts[dt] += 1
991
1021
  final = max(list(counts.items()))
992
1022
  return final[0]
1023
+
1024
+
1025
+ def closest_factor_pair(n: int):
1026
+ """Tries to find ``a, b`` such as ``n == a * b``."""
1027
+ assert n > 0, f"n={n} must be a positive integer"
1028
+ start = math.isqrt(n)
1029
+ for a in range(start, 0, -1):
1030
+ if n % a == 0:
1031
+ b = n // a
1032
+ return a, b
1033
+ return 1, n
1034
+
1035
+
1036
+ def study_discrepancies(
1037
+ t1: torch.Tensor,
1038
+ t2: torch.Tensor,
1039
+ bins: int = 50,
1040
+ figsize: Optional[Tuple[int, int]] = (15, 15),
1041
+ title: Optional[str] = None,
1042
+ name: Optional[str] = None,
1043
+ ) -> "matplotlib.axes.Axes": # noqa: F821
1044
+ """
1045
+ Computes different metrics for the discrepancies.
1046
+ Returns graphs.
1047
+
1048
+ .. plot::
1049
+ :include-source:
1050
+
1051
+ import torch
1052
+ from onnx_diagnostic.helpers.torch_helper import study_discrepancies
1053
+
1054
+ t1 = torch.randn((512, 1024)) * 10
1055
+ t2 = t1 + torch.randn((512, 1024))
1056
+ study_discrepancies(t1, t2, title="Random noise")
1057
+ """
1058
+ assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}"
1059
+ assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"
1060
+ d1, d2 = (
1061
+ (t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32))
1062
+ )
1063
+
1064
+ d1 = d1.squeeze()
1065
+ d2 = d2.squeeze()
1066
+ if len(d1.shape) == 1:
1067
+ new_shape = closest_factor_pair(d1.shape[0])
1068
+ d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1069
+ elif len(d1.shape) > 2:
1070
+ new_shape = (-1, max(d1.shape))
1071
+ d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1072
+
1073
+ import matplotlib.pyplot as plt
1074
+
1075
+ fig, ax = plt.subplots(3, 2, figsize=figsize)
1076
+ vmin, vmax = d1.min().item(), d1.max().item()
1077
+ ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
1078
+ ax[0, 0].set_title(
1079
+ f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}"
1080
+ )
1081
+
1082
+ diff = d2 - d1
1083
+ vmin, vmax = diff.min().item(), diff.max().item()
1084
+ ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax)
1085
+ ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]")
1086
+
1087
+ ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins)
1088
+ ax[1, 0].set_title("Distribution of the first tensor")
1089
+
1090
+ ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins)
1091
+ ax[1, 1].set_title("Distribution of the differences")
1092
+
1093
+ tf1 = d1.ravel()
1094
+ td1 = diff.ravel()
1095
+ ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".")
1096
+ ax[2, 1].set_title("Graph XY")
1097
+ ax[2, 1].set_xlabel("First tensor values")
1098
+ ax[2, 1].set_ylabel("Difference values")
1099
+
1100
+ if title:
1101
+ fig.suptitle(title)
1102
+ fig.tight_layout()
1103
+ if name:
1104
+ fig.savefig(name)
1105
+ return ax
1106
+
1107
+
1108
+ def int_device_to_torch_device(device_id: int) -> torch.device:
1109
+ """
1110
+ Converts a device defined as an integer (coming from :meth:`torch.Tensor.get_device`)
1111
+ into a ``torch.device``.
1112
+ """
1113
+ if device_id < 0:
1114
+ return torch.device("cpu")
1115
+ return torch.device("cuda", device_id)