onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ from typing import Any, Dict, Optional, Tuple
2
+ import torch
3
+ from .helper import string_type
4
+
5
+
6
+ def validate_fx_tensor(
7
+ node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
8
+ ) -> None:
9
+ """
10
+ Validates the shape of tensor is expected.
11
+
12
+ :param node: node
13
+ :param tensor: tensor
14
+ :param expected_shape: expected shape
15
+ """
16
+ assert len(tensor.shape) == len(expected_shape), (
17
+ f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
18
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
19
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
20
+ f"node.meta={node.meta}"
21
+ )
22
+ for a, b in zip(tensor.shape, expected_shape):
23
+ assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
24
+ f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
25
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
26
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
27
+ f"node.meta={node.meta}"
28
+ )
29
+
30
+
31
+ def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
32
+ """
33
+ Validates the outputs of a node using metadata stored in the node.
34
+
35
+ :param node: node
36
+ :param outputs: outputs
37
+ """
38
+ if "val" not in node.meta:
39
+ return
40
+ if isinstance(outputs, torch.Tensor):
41
+ validate_fx_tensor(node, outputs, node.meta["val"].shape)
42
+ return
43
+ if isinstance(outputs, (tuple, list)):
44
+ assert isinstance(node.meta["val"], (list, tuple)), (
45
+ f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
46
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
47
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
48
+ f"node.meta={node.meta}"
49
+ )
50
+ assert len(outputs) == len(node.meta["val"]), (
51
+ f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
52
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
53
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
54
+ f"node.meta={node.meta}"
55
+ )
56
+ for a, b in zip(outputs, node.meta["val"]):
57
+ validate_fx_tensor(node, a, b.shape)
58
+ return
59
+ if isinstance(outputs, int):
60
+ assert (
61
+ isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
62
+ or outputs == node.meta["val"]
63
+ ), (
64
+ f"Int mismatch, got {outputs} expected {node.meta['val']}, "
65
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
66
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
67
+ f"node.meta={node.meta}"
68
+ )
69
+ return
70
+ if outputs is None:
71
+ assert node.meta["val"] is None, (
72
+ f"None mismatch, got {outputs} expected {node.meta['val']}, "
73
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
74
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
75
+ f"node.meta={node.meta}"
76
+ )
77
+ return
78
+ raise NotImplementedError(
79
+ f"Validation for output type {type(outputs)} is not implemented, "
80
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
81
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
82
+ f"node.meta={node.meta}"
83
+ )
84
+
85
+
86
+ def run_fx_node(
87
+ node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
88
+ ) -> Tuple[Any, ...]:
89
+ """
90
+ Executes a node
91
+
92
+ :param node: runs a node
93
+ :param args: unnamed inputs to the node
94
+ :param kwargs: named inputs to the node
95
+ :return: results
96
+ """
97
+ if node.op == "output":
98
+ assert len(args) == 1 and not kwargs, (
99
+ f"Unexpected inputs: args={string_type(args, limit=20)} "
100
+ f"kwargs={string_type(kwargs, limit=20)}"
101
+ )
102
+ return args
103
+ if node.op == "call_function":
104
+ assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
105
+ for a, ea in zip(args, node.args):
106
+ if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta:
107
+ ta = ea.meta["val"]
108
+ assert (
109
+ isinstance(ta, torch.Tensor)
110
+ and len(a.shape) == len(ta.shape)
111
+ and a.dtype == ta.dtype
112
+ ), (
113
+ f"Unable to run node {node!r}, target={node.target!r}, "
114
+ f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, "
115
+ f"args={string_type(args, with_shape=True, with_device=True)}, "
116
+ f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
117
+ )
118
+ try:
119
+ outputs = node.target(*args, **(kwargs or {}))
120
+ except RuntimeError as e:
121
+ raise RuntimeError(
122
+ f"Unable to run node {node!r}, target={node.target!r}, "
123
+ f"args={string_type(args, with_shape=True, with_device=True)}, "
124
+ f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
125
+ ) from e
126
+ validate_fx_outputs(node, outputs)
127
+ return outputs
128
+ raise NotImplementedError(
129
+ f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
130
+ )
131
+
132
+
133
+ def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
134
+ "See :func:`prepare_args_kwargs`."
135
+ if isinstance(ref, torch.fx.Node):
136
+ return torch_results[ref.name]
137
+ if isinstance(ref, list):
138
+ return [_pick_result(torch_results, n) for n in ref]
139
+ if isinstance(ref, tuple):
140
+ return tuple(_pick_result(torch_results, n) for n in ref)
141
+ if isinstance(ref, dict):
142
+ return {k: _pick_result(torch_results, v) for k, v in ref.items()}
143
+ if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
144
+ return ref
145
+ if ref is None:
146
+ return None
147
+ if isinstance(ref, torch.layout):
148
+ return ref
149
+ raise NotImplementedError(f"Unable to process args type {type(ref)}")
150
+
151
+
152
+ def prepare_args_kwargs(
153
+ torch_results: Dict[str, Any], node: torch.fx.Node
154
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
155
+ """
156
+ Prepares args and kwargs before executing a fx node.
157
+
158
+ :param torch_results: existing results
159
+ :param node: node to execute
160
+ :return: new args and kwargs
161
+ """
162
+ new_args = _pick_result(torch_results, node.args)
163
+ new_kwargs = _pick_result(torch_results, node.kwargs)
164
+ return new_args, new_kwargs
@@ -1,6 +1,7 @@
1
1
  import contextlib
2
2
  import ctypes
3
3
  import inspect
4
+ import math
4
5
  import os
5
6
  import sys
6
7
  import warnings
@@ -30,9 +31,7 @@ from .onnx_helper import (
30
31
 
31
32
 
32
33
  def proto_from_tensor(
33
- arr: "torch.Tensor", # noqa: F821
34
- name: Optional[str] = None,
35
- verbose: int = 0,
34
+ arr: torch.Tensor, name: Optional[str] = None, verbose: int = 0
36
35
  ) -> onnx.TensorProto:
37
36
  """
38
37
  Converts a torch Tensor into a TensorProto.
@@ -98,7 +97,7 @@ def proto_from_tensor(
98
97
  return tensor
99
98
 
100
99
 
101
- def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
100
+ def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype:
102
101
  """
103
102
  Converts an onnx type into a torch dtype.
104
103
 
@@ -140,7 +139,7 @@ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
140
139
  )
141
140
 
142
141
 
143
- def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
142
+ def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
144
143
  """
145
144
  Converts a torch dtype into a onnx element type.
146
145
 
@@ -450,6 +449,11 @@ def fake_torchdynamo_exporting():
450
449
  """
451
450
  memorize = torch.compiler._is_exporting_flag
452
451
  torch.compiler._is_exporting_flag = True
452
+ assert torch.compiler.is_exporting(), (
453
+ f"Changes not detected "
454
+ f"torch.compiler._is_exporting_flag={torch.compiler._is_exporting_flag} "
455
+ f"and torch.compiler.is_exporting()={torch.compiler.is_exporting()}"
456
+ )
453
457
  try:
454
458
  yield
455
459
  finally:
@@ -478,7 +482,7 @@ def is_torchdynamo_exporting() -> bool:
478
482
  return False
479
483
 
480
484
 
481
- def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
485
+ def to_numpy(tensor: torch.Tensor) -> np.ndarray:
482
486
  """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
483
487
  try:
484
488
  return tensor.detach().cpu().numpy()
@@ -493,6 +497,21 @@ def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
493
497
  return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
494
498
 
495
499
 
500
+ def from_numpy(tensor: np.ndarray) -> torch.Tensor:
501
+ """Converts a :class:`numpy.ndarray` to :class:`torch.Tensor`."""
502
+ try:
503
+ return torch.from_numpy(tensor)
504
+ except TypeError:
505
+ # We try with ml_dtypes
506
+ pass
507
+
508
+ import ml_dtypes
509
+
510
+ conv = {ml_dtypes.bfloat16: torch.bfloat16}
511
+ assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
512
+ return torch.from_numpy(tensor.astype(torch.float32)).to(conv[tensor.dtype])
513
+
514
+
496
515
  def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
497
516
  """Replaces strings by ``torch.export.Dim.DYNAMIC``."""
498
517
  import torch
@@ -985,3 +1004,86 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
985
1004
  counts[dt] += 1
986
1005
  final = max(list(counts.items()))
987
1006
  return final[0]
1007
+
1008
+
1009
+ def closest_factor_pair(n: int):
1010
+ """Tries to find ``a, b`` such as ``n == a * b``."""
1011
+ assert n > 0, f"n={n} must be a positive integer"
1012
+ start = math.isqrt(n)
1013
+ for a in range(start, 0, -1):
1014
+ if n % a == 0:
1015
+ b = n // a
1016
+ return a, b
1017
+ return 1, n
1018
+
1019
+
1020
+ def study_discrepancies(
1021
+ t1: torch.Tensor,
1022
+ t2: torch.Tensor,
1023
+ bins: int = 50,
1024
+ figsize: Optional[Tuple[int, int]] = (15, 15),
1025
+ title: Optional[str] = None,
1026
+ name: Optional[str] = None,
1027
+ ) -> "matplotlib.axes.Axes": # noqa: F821
1028
+ """
1029
+ Computes different metrics for the discrepancies.
1030
+ Returns graphs.
1031
+
1032
+ .. plot::
1033
+ :include-source:
1034
+
1035
+ import torch
1036
+ from onnx_diagnostic.helpers.torch_helper import study_discrepancies
1037
+
1038
+ t1 = torch.randn((512, 1024)) * 10
1039
+ t2 = t1 + torch.randn((512, 1024))
1040
+ study_discrepancies(t1, t2, title="Random noise")
1041
+ """
1042
+ assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}"
1043
+ assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"
1044
+ d1, d2 = (
1045
+ (t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32))
1046
+ )
1047
+
1048
+ d1 = d1.squeeze()
1049
+ d2 = d2.squeeze()
1050
+ if len(d1.shape) == 1:
1051
+ new_shape = closest_factor_pair(d1.shape[0])
1052
+ d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1053
+ elif len(d1.shape) > 2:
1054
+ new_shape = (-1, max(d1.shape))
1055
+ d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1056
+
1057
+ import matplotlib.pyplot as plt
1058
+
1059
+ fig, ax = plt.subplots(3, 2, figsize=figsize)
1060
+ vmin, vmax = d1.min().item(), d1.max().item()
1061
+ ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
1062
+ ax[0, 0].set_title(
1063
+ f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}"
1064
+ )
1065
+
1066
+ diff = d2 - d1
1067
+ vmin, vmax = diff.min().item(), diff.max().item()
1068
+ ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax)
1069
+ ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]")
1070
+
1071
+ ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins)
1072
+ ax[1, 0].set_title("Distribution of the first tensor")
1073
+
1074
+ ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins)
1075
+ ax[1, 1].set_title("Distribution of the differences")
1076
+
1077
+ tf1 = d1.ravel()
1078
+ td1 = diff.ravel()
1079
+ ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".")
1080
+ ax[2, 1].set_title("Graph XY")
1081
+ ax[2, 1].set_xlabel("First tensor values")
1082
+ ax[2, 1].set_ylabel("Difference values")
1083
+
1084
+ if title:
1085
+ fig.suptitle(title)
1086
+ fig.tight_layout()
1087
+ if name:
1088
+ fig.savefig(name)
1089
+ return ax