onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +91 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +3 -3
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +92 -23
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +90 -26
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +103 -1
  15. onnx_diagnostic/helpers/ort_session.py +37 -11
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +103 -6
  18. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/validate.py +50 -1
  37. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  38. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  39. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  40. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
  41. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -108,7 +108,10 @@ class _InferenceSession:
108
108
  session_options,
109
109
  providers=providers,
110
110
  )
111
- except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
111
+ except (
112
+ onnxruntime.capi.onnxruntime_pybind11_state.Fail,
113
+ onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
114
+ ) as e:
112
115
  if isinstance(sess, onnx.ModelProto):
113
116
  debug_path = "_debug_InferenceSession_last_failure.onnx"
114
117
  onnx.save(
@@ -134,7 +137,13 @@ class _InferenceSession:
134
137
 
135
138
  self.sess = sess
136
139
  self.input_names = [i.name for i in sess.get_inputs()]
140
+ assert (
141
+ "" not in self.input_names
142
+ ), f"Input name cannot be empty but input_names={self.input_names}"
137
143
  self.output_names = [i.name for i in sess.get_outputs()]
144
+ assert (
145
+ "" not in self.input_names
146
+ ), f"Output name cannot be empty but output_names={self.output_names}"
138
147
  self.input_shapes = [i.shape for i in sess.get_inputs()]
139
148
  self.output_shapes = [i.shape for i in sess.get_outputs()]
140
149
  self.input_types = [i.type for i in sess.get_inputs()]
@@ -338,6 +347,7 @@ class InferenceSessionForTorch(_InferenceSession):
338
347
  :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
339
348
  :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
340
349
  :param use_training_api: use onnxruntime-traning API
350
+ :param cpu_output: if True, force the outputs to be on CPU
341
351
  """
342
352
 
343
353
  def __init__(
@@ -353,6 +363,7 @@ class InferenceSessionForTorch(_InferenceSession):
353
363
  optimized_model_filepath: Optional[str] = None,
354
364
  disable_aot_function_inlining: Optional[bool] = None,
355
365
  use_training_api: Optional[bool] = None,
366
+ cpu_outputs: bool = False,
356
367
  ):
357
368
  super().__init__(
358
369
  sess,
@@ -367,6 +378,7 @@ class InferenceSessionForTorch(_InferenceSession):
367
378
  disable_aot_function_inlining=disable_aot_function_inlining,
368
379
  use_training_api=use_training_api,
369
380
  )
381
+ self.cpu_outputs = cpu_outputs
370
382
 
371
383
  def _get_ortvalues_from_torch_tensors(
372
384
  self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
@@ -490,23 +502,37 @@ class InferenceSessionForTorch(_InferenceSession):
490
502
  feeds is a dictionary of :class:`torch.Tensor`.
491
503
  The output device is CPU even if the outputs are on CUDA.
492
504
  """
493
- new_feeds = {}
505
+ input_names = []
506
+ values = ORTC.OrtValueVector()
507
+ device = -1
494
508
  for k, v in feeds.items():
509
+ assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
510
+ device = max(device, v.get_device())
495
511
  assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
496
512
  if not v.is_contiguous():
497
513
  v = v.contiguous()
498
514
  if v.dtype == torch.bool:
499
- # It does not work with dlpack
500
- # unless onnxruntime updates the version it is using.
501
- new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
502
- v.detach().numpy(), onnx.TensorProto.BOOL
503
- )
515
+ v = v.to(torch.uint8)
516
+ v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True)
504
517
  else:
505
- new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
518
+ v = ORTC.OrtValue.from_dlpack(v.detach().__dlpack__(), False)
519
+ input_names.append(k)
520
+ values.push_back(v)
506
521
  if self.nvtx:
507
- self.torch.cuda.nvtx.range_push("run_with_ort_values")
508
- ort_outputs = self.sess._sess.run_with_ort_values(
509
- new_feeds, output_names or self.output_names, self.run_options
522
+ self.torch.cuda.nvtx.range_push("run_with_ortvaluevector")
523
+
524
+ # ort_outputs = self.sess._sess.run_with_ort_values(
525
+ # new_feeds, output_names or self.output_names, self.run_options
526
+ # )
527
+ ort_outputs = ORTC.OrtValueVector()
528
+ out_names = output_names or self.output_names
529
+ self.sess._sess.run_with_ortvaluevector(
530
+ self.run_options,
531
+ input_names,
532
+ values,
533
+ out_names,
534
+ ort_outputs,
535
+ [DEVICES[-1 if self.cpu_outputs else device] for o in out_names],
510
536
  )
511
537
  if self.nvtx:
512
538
  self.torch.cuda.nvtx.range_pop()
@@ -0,0 +1,164 @@
1
+ from typing import Any, Dict, Optional, Tuple
2
+ import torch
3
+ from .helper import string_type
4
+
5
+
6
+ def validate_fx_tensor(
7
+ node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
8
+ ) -> None:
9
+ """
10
+ Validates the shape of tensor is expected.
11
+
12
+ :param node: node
13
+ :param tensor: tensor
14
+ :param expected_shape: expected shape
15
+ """
16
+ assert len(tensor.shape) == len(expected_shape), (
17
+ f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
18
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
19
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
20
+ f"node.meta={node.meta}"
21
+ )
22
+ for a, b in zip(tensor.shape, expected_shape):
23
+ assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
24
+ f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
25
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
26
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
27
+ f"node.meta={node.meta}"
28
+ )
29
+
30
+
31
+ def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
32
+ """
33
+ Validates the outputs of a node using metadata stored in the node.
34
+
35
+ :param node: node
36
+ :param outputs: outputs
37
+ """
38
+ if "val" not in node.meta:
39
+ return
40
+ if isinstance(outputs, torch.Tensor):
41
+ validate_fx_tensor(node, outputs, node.meta["val"].shape)
42
+ return
43
+ if isinstance(outputs, (tuple, list)):
44
+ assert isinstance(node.meta["val"], (list, tuple)), (
45
+ f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
46
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
47
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
48
+ f"node.meta={node.meta}"
49
+ )
50
+ assert len(outputs) == len(node.meta["val"]), (
51
+ f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
52
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
53
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
54
+ f"node.meta={node.meta}"
55
+ )
56
+ for a, b in zip(outputs, node.meta["val"]):
57
+ validate_fx_tensor(node, a, b.shape)
58
+ return
59
+ if isinstance(outputs, int):
60
+ assert (
61
+ isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
62
+ or outputs == node.meta["val"]
63
+ ), (
64
+ f"Int mismatch, got {outputs} expected {node.meta['val']}, "
65
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
66
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
67
+ f"node.meta={node.meta}"
68
+ )
69
+ return
70
+ if outputs is None:
71
+ assert node.meta["val"] is None, (
72
+ f"None mismatch, got {outputs} expected {node.meta['val']}, "
73
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
74
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
75
+ f"node.meta={node.meta}"
76
+ )
77
+ return
78
+ raise NotImplementedError(
79
+ f"Validation for output type {type(outputs)} is not implemented, "
80
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
81
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
82
+ f"node.meta={node.meta}"
83
+ )
84
+
85
+
86
+ def run_fx_node(
87
+ node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
88
+ ) -> Tuple[Any, ...]:
89
+ """
90
+ Executes a node
91
+
92
+ :param node: runs a node
93
+ :param args: unnamed inputs to the node
94
+ :param kwargs: named inputs to the node
95
+ :return: results
96
+ """
97
+ if node.op == "output":
98
+ assert len(args) == 1 and not kwargs, (
99
+ f"Unexpected inputs: args={string_type(args, limit=20)} "
100
+ f"kwargs={string_type(kwargs, limit=20)}"
101
+ )
102
+ return args
103
+ if node.op == "call_function":
104
+ assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
105
+ for a, ea in zip(args, node.args):
106
+ if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta:
107
+ ta = ea.meta["val"]
108
+ assert (
109
+ isinstance(ta, torch.Tensor)
110
+ and len(a.shape) == len(ta.shape)
111
+ and a.dtype == ta.dtype
112
+ ), (
113
+ f"Unable to run node {node!r}, target={node.target!r}, "
114
+ f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, "
115
+ f"args={string_type(args, with_shape=True, with_device=True)}, "
116
+ f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
117
+ )
118
+ try:
119
+ outputs = node.target(*args, **(kwargs or {}))
120
+ except RuntimeError as e:
121
+ raise RuntimeError(
122
+ f"Unable to run node {node!r}, target={node.target!r}, "
123
+ f"args={string_type(args, with_shape=True, with_device=True)}, "
124
+ f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}"
125
+ ) from e
126
+ validate_fx_outputs(node, outputs)
127
+ return outputs
128
+ raise NotImplementedError(
129
+ f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
130
+ )
131
+
132
+
133
+ def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
134
+ "See :func:`prepare_args_kwargs`."
135
+ if isinstance(ref, torch.fx.Node):
136
+ return torch_results[ref.name]
137
+ if isinstance(ref, list):
138
+ return [_pick_result(torch_results, n) for n in ref]
139
+ if isinstance(ref, tuple):
140
+ return tuple(_pick_result(torch_results, n) for n in ref)
141
+ if isinstance(ref, dict):
142
+ return {k: _pick_result(torch_results, v) for k, v in ref.items()}
143
+ if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
144
+ return ref
145
+ if ref is None:
146
+ return None
147
+ if isinstance(ref, torch.layout):
148
+ return ref
149
+ raise NotImplementedError(f"Unable to process args type {type(ref)}")
150
+
151
+
152
+ def prepare_args_kwargs(
153
+ torch_results: Dict[str, Any], node: torch.fx.Node
154
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
155
+ """
156
+ Prepares args and kwargs before executing a fx node.
157
+
158
+ :param torch_results: existing results
159
+ :param node: node to execute
160
+ :return: new args and kwargs
161
+ """
162
+ new_args = _pick_result(torch_results, node.args)
163
+ new_kwargs = _pick_result(torch_results, node.kwargs)
164
+ return new_args, new_kwargs
@@ -1,6 +1,7 @@
1
1
  import contextlib
2
2
  import ctypes
3
3
  import inspect
4
+ import math
4
5
  import os
5
6
  import sys
6
7
  import warnings
@@ -30,9 +31,7 @@ from .onnx_helper import (
30
31
 
31
32
 
32
33
  def proto_from_tensor(
33
- arr: "torch.Tensor", # noqa: F821
34
- name: Optional[str] = None,
35
- verbose: int = 0,
34
+ arr: torch.Tensor, name: Optional[str] = None, verbose: int = 0
36
35
  ) -> onnx.TensorProto:
37
36
  """
38
37
  Converts a torch Tensor into a TensorProto.
@@ -98,7 +97,7 @@ def proto_from_tensor(
98
97
  return tensor
99
98
 
100
99
 
101
- def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
100
+ def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype:
102
101
  """
103
102
  Converts an onnx type into a torch dtype.
104
103
 
@@ -140,7 +139,7 @@ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
140
139
  )
141
140
 
142
141
 
143
- def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
142
+ def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
144
143
  """
145
144
  Converts a torch dtype into a onnx element type.
146
145
 
@@ -483,7 +482,7 @@ def is_torchdynamo_exporting() -> bool:
483
482
  return False
484
483
 
485
484
 
486
- def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
485
+ def to_numpy(tensor: torch.Tensor) -> np.ndarray:
487
486
  """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
488
487
  try:
489
488
  return tensor.detach().cpu().numpy()
@@ -498,6 +497,21 @@ def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
498
497
  return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
499
498
 
500
499
 
500
+ def from_numpy(tensor: np.ndarray) -> torch.Tensor:
501
+ """Converts a :class:`numpy.ndarray` to :class:`torch.Tensor`."""
502
+ try:
503
+ return torch.from_numpy(tensor)
504
+ except TypeError:
505
+ # We try with ml_dtypes
506
+ pass
507
+
508
+ import ml_dtypes
509
+
510
+ conv = {ml_dtypes.bfloat16: torch.bfloat16}
511
+ assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
512
+ return torch.from_numpy(tensor.astype(torch.float32)).to(conv[tensor.dtype])
513
+
514
+
501
515
  def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
502
516
  """Replaces strings by ``torch.export.Dim.DYNAMIC``."""
503
517
  import torch
@@ -990,3 +1004,86 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
990
1004
  counts[dt] += 1
991
1005
  final = max(list(counts.items()))
992
1006
  return final[0]
1007
+
1008
+
1009
+ def closest_factor_pair(n: int):
1010
+ """Tries to find ``a, b`` such as ``n == a * b``."""
1011
+ assert n > 0, f"n={n} must be a positive integer"
1012
+ start = math.isqrt(n)
1013
+ for a in range(start, 0, -1):
1014
+ if n % a == 0:
1015
+ b = n // a
1016
+ return a, b
1017
+ return 1, n
1018
+
1019
+
1020
+ def study_discrepancies(
1021
+ t1: torch.Tensor,
1022
+ t2: torch.Tensor,
1023
+ bins: int = 50,
1024
+ figsize: Optional[Tuple[int, int]] = (15, 15),
1025
+ title: Optional[str] = None,
1026
+ name: Optional[str] = None,
1027
+ ) -> "matplotlib.axes.Axes": # noqa: F821
1028
+ """
1029
+ Computes different metrics for the discrepancies.
1030
+ Returns graphs.
1031
+
1032
+ .. plot::
1033
+ :include-source:
1034
+
1035
+ import torch
1036
+ from onnx_diagnostic.helpers.torch_helper import study_discrepancies
1037
+
1038
+ t1 = torch.randn((512, 1024)) * 10
1039
+ t2 = t1 + torch.randn((512, 1024))
1040
+ study_discrepancies(t1, t2, title="Random noise")
1041
+ """
1042
+ assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}"
1043
+ assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}"
1044
+ d1, d2 = (
1045
+ (t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32))
1046
+ )
1047
+
1048
+ d1 = d1.squeeze()
1049
+ d2 = d2.squeeze()
1050
+ if len(d1.shape) == 1:
1051
+ new_shape = closest_factor_pair(d1.shape[0])
1052
+ d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1053
+ elif len(d1.shape) > 2:
1054
+ new_shape = (-1, max(d1.shape))
1055
+ d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape)
1056
+
1057
+ import matplotlib.pyplot as plt
1058
+
1059
+ fig, ax = plt.subplots(3, 2, figsize=figsize)
1060
+ vmin, vmax = d1.min().item(), d1.max().item()
1061
+ ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax)
1062
+ ax[0, 0].set_title(
1063
+ f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}"
1064
+ )
1065
+
1066
+ diff = d2 - d1
1067
+ vmin, vmax = diff.min().item(), diff.max().item()
1068
+ ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax)
1069
+ ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]")
1070
+
1071
+ ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins)
1072
+ ax[1, 0].set_title("Distribution of the first tensor")
1073
+
1074
+ ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins)
1075
+ ax[1, 1].set_title("Distribution of the differences")
1076
+
1077
+ tf1 = d1.ravel()
1078
+ td1 = diff.ravel()
1079
+ ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".")
1080
+ ax[2, 1].set_title("Graph XY")
1081
+ ax[2, 1].set_xlabel("First tensor values")
1082
+ ax[2, 1].set_ylabel("Difference values")
1083
+
1084
+ if title:
1085
+ fig.suptitle(title)
1086
+ fig.tight_layout()
1087
+ if name:
1088
+ fig.savefig(name)
1089
+ return ax