onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +47 -10
  3. onnx_diagnostic/export/api.py +81 -50
  4. onnx_diagnostic/export/control_flow_research.py +10 -5
  5. onnx_diagnostic/export/onnx_plug.py +250 -61
  6. onnx_diagnostic/ext_test_case.py +99 -53
  7. onnx_diagnostic/helpers/dot_helper.py +37 -25
  8. onnx_diagnostic/helpers/helper.py +44 -38
  9. onnx_diagnostic/helpers/onnx_helper.py +441 -18
  10. onnx_diagnostic/helpers/ort_session.py +8 -8
  11. onnx_diagnostic/helpers/torch_helper.py +28 -2
  12. onnx_diagnostic/reference/ort_evaluator.py +6 -29
  13. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
  14. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
  15. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
  16. onnx_diagnostic/torch_models/code_sample.py +2 -1
  17. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  18. onnx_diagnostic/torch_models/validate.py +14 -1
  19. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  20. onnx_diagnostic/torch_onnx/sbs.py +11 -5
  21. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
  22. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
  23. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
  24. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
  25. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  26. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
@@ -1,31 +1,13 @@
1
- from typing import Dict, Set
1
+ from typing import Dict
2
+ import numpy as np
2
3
  import onnx
3
4
  import onnx.numpy_helper as onh
4
- from .onnx_helper import onnx_dtype_name, pretty_onnx
5
-
6
-
7
- def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
8
- hidden = set()
9
- memo = (
10
- {i.name for i in graph.initializer}
11
- | {i.values.name for i in graph.sparse_initializer}
12
- | {i.name for i in graph.input}
13
- )
14
- for node in graph.node:
15
- for i in node.input:
16
- if i not in memo:
17
- hidden.add(i)
18
- for att in node.attribute:
19
- if att.type == onnx.AttributeProto.GRAPH and att.g:
20
- hid = _get_hidden_inputs(att.g)
21
- less = set(h for h in hid if h not in memo)
22
- hidden |= less
23
- memo |= set(node.output)
24
- return hidden
5
+ from ..reference import ExtendedReferenceEvaluator as Inference
6
+ from .onnx_helper import onnx_dtype_name, pretty_onnx, get_hidden_inputs
25
7
 
26
8
 
27
9
  def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
28
- els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("]
10
+ els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "\\n("]
29
11
  ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
30
12
  for att in node.attribute:
31
13
  if att.name == "to":
@@ -42,7 +24,10 @@ def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
42
24
  els.append(")")
43
25
  if node.op_type == "Constant":
44
26
  els.extend([" -> ", node.output[0]])
45
- return "".join(els)
27
+ res = "".join(els)
28
+ if len(res) < 40:
29
+ return res.replace("\\n(", "(")
30
+ return res
46
31
 
47
32
 
48
33
  def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
@@ -142,6 +127,7 @@ def to_dot(model: onnx.ModelProto) -> str:
142
127
  inits = list(model.graph.initializer)
143
128
  tiny_inits = {}
144
129
  name_to_ids = {}
130
+
145
131
  for inp in inputs:
146
132
  if not inp.name:
147
133
  continue
@@ -149,7 +135,30 @@ def to_dot(model: onnx.ModelProto) -> str:
149
135
  rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
150
136
  name_to_ids[inp.name] = f"I_{_mkn(inp)}"
151
137
  edge_label[inp.name] = _make_edge_label(inp, multi_line=True)
138
+
139
+ # Small constant --> initializer
140
+ output_names = {n.name for n in outputs}
141
+ for node in nodes:
142
+ if node.op_type != "Constant" or node.output[0] in output_names:
143
+ continue
144
+ skip = False
145
+ for att in node.attribute:
146
+ if att.name == "value" and (
147
+ len(att.t.dims) > 1 or np.prod(tuple(att.t.dims)) > 10
148
+ ):
149
+ skip = True
150
+ break
151
+ if skip:
152
+ continue
153
+
154
+ sess = Inference(node)
155
+ value = sess.run(None, {})[0]
156
+ inits.append(onh.from_array(value, name=node.output[0]))
157
+
152
158
  for init in inits:
159
+ if init.name in name_to_ids:
160
+ # hide optional inputs
161
+ continue
153
162
  shape = tuple(init.dims)
154
163
  if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
155
164
  a = onh.to_array(init)
@@ -161,7 +170,10 @@ def to_dot(model: onnx.ModelProto) -> str:
161
170
  rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
162
171
  name_to_ids[init.name] = f"i_{_mkn(init)}"
163
172
  edge_label[init.name] = ls
173
+
164
174
  for node in nodes:
175
+ if node.op_type == "Constant" and node.output[0] in tiny_inits:
176
+ continue
165
177
  color = op_type_colors.get(node.op_type, "#cccccc")
166
178
  label = _make_node_label(node, tiny_inits)
167
179
  rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')
@@ -189,7 +201,7 @@ def to_dot(model: onnx.ModelProto) -> str:
189
201
  unique = set()
190
202
  for att in node.attribute:
191
203
  if att.type == onnx.AttributeProto.GRAPH:
192
- unique |= _get_hidden_inputs(att.g)
204
+ unique |= get_hidden_inputs(att.g)
193
205
  for i in unique:
194
206
  edge = name_to_ids[i], _mkn(node) # type: ignore[assignment]
195
207
  if edge in done:
@@ -2,6 +2,7 @@ import ast
2
2
  import enum
3
3
  import inspect
4
4
  import itertools
5
+ import json
5
6
  from dataclasses import is_dataclass, fields
6
7
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
7
8
  import numpy as np
@@ -94,6 +95,20 @@ def size_type(dtype: Any) -> int:
94
95
  raise AssertionError(f"Unexpected dtype={dtype}")
95
96
 
96
97
 
98
+ def _string_tensor(obj, cls: str, with_shape: bool, with_device: bool, verbose: int) -> str:
99
+ from .torch_helper import torch_dtype_to_onnx_dtype
100
+
101
+ i = torch_dtype_to_onnx_dtype(obj.dtype)
102
+ prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
103
+ if not with_shape:
104
+ if verbose:
105
+ print(f"[string_type] {cls}1:{type(obj)}")
106
+ return f"{prefix}{cls}{i}r{len(obj.shape)}"
107
+ if verbose:
108
+ print(f"[string_type] {cls}2:{type(obj)}")
109
+ return f"{prefix}{cls}{i}s{'x'.join(map(str, obj.shape))}"
110
+
111
+
97
112
  def string_type(
98
113
  obj: Any,
99
114
  with_shape: bool = False,
@@ -453,17 +468,7 @@ def string_type(
453
468
 
454
469
  # Tensors
455
470
  if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
456
- from .torch_helper import torch_dtype_to_onnx_dtype
457
-
458
- i = torch_dtype_to_onnx_dtype(obj.dtype)
459
- prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
460
- if not with_shape:
461
- if verbose:
462
- print(f"[string_type] F1:{type(obj)}")
463
- return f"{prefix}F{i}r{len(obj.shape)}"
464
- if verbose:
465
- print(f"[string_type] F2:{type(obj)}")
466
- return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
471
+ return _string_tensor(obj, "F", with_shape, with_device, verbose)
467
472
 
468
473
  if isinstance(obj, torch.Tensor):
469
474
  from .torch_helper import torch_dtype_to_onnx_dtype
@@ -544,6 +549,9 @@ def string_type(
544
549
  print(f"[string_type] V6:{type(obj)}")
545
550
  return f"{dev}OV{dt}r{len(shape)}"
546
551
 
552
+ if obj.__class__.__name__ == "SymbolicTensor":
553
+ return _string_tensor(obj, "ST", with_shape, with_device, verbose)
554
+
547
555
  # others classes
548
556
 
549
557
  if obj.__class__.__name__ == "MambaCache":
@@ -1366,11 +1374,7 @@ def max_diff(
1366
1374
  if hist:
1367
1375
  if isinstance(hist, bool):
1368
1376
  hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
1369
- ind = np.digitize(diff.reshape((-1,)), hist, right=True)
1370
- cou = np.bincount(ind, minlength=ind.shape[0] + 1)
1371
- res["rep"] = dict(
1372
- zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))])
1373
- )
1377
+ res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist}
1374
1378
  return res # type: ignore
1375
1379
 
1376
1380
  if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
@@ -1486,27 +1490,11 @@ def max_diff(
1486
1490
  dev=dev,
1487
1491
  )
1488
1492
  if hist:
1489
- if isinstance(hist, list) and len(hist) == 1:
1490
- res["rep"] = {f">{hist[0]}": (diff > hist[0]).sum().item()}
1491
- elif isinstance(hist, list) and len(hist) == 2:
1492
- res["rep"] = {
1493
- f">{hist[0]}": (diff > hist[0]).sum().item(),
1494
- f">{hist[1]}": (diff > hist[1]).sum().item(),
1495
- }
1496
- else:
1497
- if isinstance(hist, bool):
1498
- hist = torch.tensor(
1499
- [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1500
- )
1501
- hist = torch.tensor(hist).to(diff.device)
1502
- ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1503
- cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1504
- res["rep"] = dict(
1505
- zip(
1506
- [f">{x}" for x in hist],
1507
- [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1508
- )
1493
+ if isinstance(hist, bool):
1494
+ hist = torch.tensor(
1495
+ [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1509
1496
  )
1497
+ res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist}
1510
1498
  return res # type: ignore
1511
1499
 
1512
1500
  if isinstance(expected, int) and isinstance(got, torch.Tensor):
@@ -1743,8 +1731,26 @@ def max_diff(
1743
1731
  )
1744
1732
 
1745
1733
 
1746
- def string_diff(diff: Dict[str, Any]) -> str:
1747
- """Renders discrepancies return by :func:`max_diff` into one string."""
1734
+ def string_diff(diff: Dict[str, Any], js: bool = False, ratio: bool = False, **kwargs) -> str:
1735
+ """
1736
+ Renders discrepancies return by :func:`max_diff` into one string.
1737
+
1738
+ :param diff: differences
1739
+ :param js: json format
1740
+ :param ratio: display mismatch ratio
1741
+ :param kwargs: addition values to add in the json format
1742
+ """
1743
+ if js:
1744
+ if "rep" in diff:
1745
+ rep = diff["rep"]
1746
+ diff = {**{k: v for k, v in diff.items() if k != "rep"}, **rep}
1747
+ if ratio:
1748
+ for k, v in rep.items():
1749
+ diff[f"%{k}"] = v / diff["n"]
1750
+ diff["mean"] = diff["sum"] / diff["n"]
1751
+ diff.update(kwargs)
1752
+ return json.dumps(diff)
1753
+
1748
1754
  # dict(abs=, rel=, sum=, n=n_diff, dnan=)
1749
1755
  if "dev" in diff:
1750
1756
  ddiff = {k: v for k, v in diff.items() if k != "dev"}