onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,210 @@
1
+ from typing import Dict, Set
2
+ import onnx
3
+ import onnx.numpy_helper as onh
4
+ from .onnx_helper import onnx_dtype_name, pretty_onnx
5
+
6
+
7
+ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
8
+ hidden = set()
9
+ memo = (
10
+ {i.name for i in graph.initializer}
11
+ | {i.values.name for i in graph.sparse_initializer}
12
+ | {i.name for i in graph.input}
13
+ )
14
+ for node in graph.node:
15
+ for i in node.input:
16
+ if i not in memo:
17
+ hidden.add(i)
18
+ for att in node.attribute:
19
+ if att.type == onnx.AttributeProto.GRAPH and att.g:
20
+ hid = _get_hidden_inputs(att.g)
21
+ less = set(h for h in hid if h not in memo)
22
+ hidden |= less
23
+ memo |= set(node.output)
24
+ return hidden
25
+
26
+
27
+ def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
28
+ els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("]
29
+ ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
30
+ for att in node.attribute:
31
+ if att.name == "to":
32
+ ee.append(f"{att.name}={onnx_dtype_name(att.i)}")
33
+ elif att.name in {"to", "axis", "value_int", "stash_type", "start", "end"}:
34
+ ee.append(f"{att.name}={att.i}")
35
+ elif att.name in {"value_float"}:
36
+ ee.append(f"{att.name}={att.f}")
37
+ elif att.name in {"value_floats"}:
38
+ ee.append(f"{att.name}={att.floats}")
39
+ elif att.name in {"value_ints", "perm"}:
40
+ ee.append(f"{att.name}={att.ints}")
41
+ els.append(", ".join(ee))
42
+ els.append(")")
43
+ if node.op_type == "Constant":
44
+ els.extend([" -> ", node.output[0]])
45
+ return "".join(els)
46
+
47
+
48
+ def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
49
+ itype = value_info.type.tensor_type.elem_type
50
+ if itype == onnx.TensorProto.UNDEFINED:
51
+ return ""
52
+ shape = tuple(
53
+ d.dim_param if d.dim_param else d.dim_value
54
+ for d in value_info.type.tensor_type.shape.dim
55
+ )
56
+ res = [
57
+ str(a)
58
+ for a in [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape]
59
+ ]
60
+ sshape = ",".join(res)
61
+ if multi_line and len(sshape) > 30:
62
+ sshape = ",\\n".join(res)
63
+ return f"{onnx_dtype_name(itype)}({sshape})"
64
+
65
+
66
+ def to_dot(model: onnx.ModelProto) -> str:
67
+ """
68
+ Converts a model into a dot graph.
69
+ Here is an example:
70
+
71
+ .. gdot::
72
+ :script: DOT-SECTION
73
+ :process:
74
+
75
+ from onnx_diagnostic.helpers.dot_helper import to_dot
76
+ from onnx_diagnostic.export.api import to_onnx
77
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
78
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
79
+
80
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
81
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
82
+ with torch_export_patches(patch_transformers=True):
83
+ em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom")
84
+ dot = to_dot(em.model_proto)
85
+ print("DOT-SECTION", dot)
86
+
87
+ Or this one obtained with :func:`torch.onnx.export`.
88
+
89
+ .. gdot::
90
+ :script: DOT-SECTION
91
+ :process:
92
+
93
+ from onnx_diagnostic.helpers.dot_helper import to_dot
94
+ from onnx_diagnostic.export.api import to_onnx
95
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
96
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
97
+
98
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
99
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
100
+ with torch_export_patches(patch_transformers=True):
101
+ em = to_onnx(model, kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo")
102
+ dot = to_dot(em.model_proto)
103
+ print("DOT-SECTION", dot)
104
+ """
105
+ _unique: Dict[int, int] = {}
106
+
107
+ def _mkn(obj: object) -> int:
108
+ id_obj = id(obj)
109
+ if id_obj in _unique:
110
+ return _unique[id_obj]
111
+ i = len(_unique)
112
+ _unique[id_obj] = i
113
+ return i
114
+
115
+ model = onnx.shape_inference.infer_shapes(model)
116
+
117
+ op_type_colors = {
118
+ "Shape": "#d2a81f",
119
+ "MatMul": "#ee9999",
120
+ "Transpose": "#ee99ee",
121
+ "Reshape": "#eeeeee",
122
+ "Squeeze": "#eeeeee",
123
+ "Unsqueeze": "#eeeeee",
124
+ }
125
+
126
+ edge_label = {}
127
+ for val in model.graph.value_info:
128
+ edge_label[val.name] = _make_edge_label(val, multi_line=True)
129
+
130
+ rows = [
131
+ "digraph {",
132
+ (
133
+ " graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, "
134
+ "ranksep=0.2, fontsize=8];"
135
+ ),
136
+ ' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];',
137
+ " edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];",
138
+ ]
139
+ inputs = list(model.graph.input)
140
+ outputs = list(model.graph.output)
141
+ nodes = list(model.graph.node)
142
+ inits = list(model.graph.initializer)
143
+ tiny_inits = {}
144
+ name_to_ids = {}
145
+ for inp in inputs:
146
+ if not inp.name:
147
+ continue
148
+ lab = _make_edge_label(inp)
149
+ rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
150
+ name_to_ids[inp.name] = f"I_{_mkn(inp)}"
151
+ edge_label[inp.name] = _make_edge_label(inp, multi_line=True)
152
+ for init in inits:
153
+ shape = tuple(init.dims)
154
+ if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
155
+ a = onh.to_array(init)
156
+ tiny_inits[init.name] = (
157
+ str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]"
158
+ )
159
+ else:
160
+ ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})"
161
+ rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
162
+ name_to_ids[init.name] = f"i_{_mkn(init)}"
163
+ edge_label[init.name] = ls
164
+ for node in nodes:
165
+ color = op_type_colors.get(node.op_type, "#cccccc")
166
+ label = _make_node_label(node, tiny_inits)
167
+ rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')
168
+ name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o})
169
+
170
+ # nodes
171
+ done = set()
172
+ for node in nodes:
173
+ names = list(node.input)
174
+ for i in names:
175
+ if not i or i in tiny_inits:
176
+ continue
177
+ if i not in name_to_ids:
178
+ raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}")
179
+ edge = name_to_ids[i], f"{node.op_type}_{_mkn(node)}"
180
+ if edge in done:
181
+ continue
182
+ done.add(edge)
183
+ lab = edge_label.get(i, "")
184
+ if lab:
185
+ ls = ",".join([f'label="{lab}"'])
186
+ lab = f" [{ls}]"
187
+ rows.append(f" {edge[0]} -> {edge[1]}{lab};")
188
+ if node.op_type in {"Scan", "Loop", "If"}:
189
+ unique = set()
190
+ for att in node.attribute:
191
+ if att.type == onnx.AttributeProto.GRAPH:
192
+ unique |= _get_hidden_inputs(att.g)
193
+ for i in unique:
194
+ edge = name_to_ids[i], _mkn(node) # type: ignore[assignment]
195
+ if edge in done:
196
+ continue
197
+ done.add(edge)
198
+ rows.append(f" {edge[0]} -> {edge[1]} [style=dotted];")
199
+
200
+ # outputs
201
+ for out in outputs:
202
+ if not out.name:
203
+ continue
204
+ lab = _make_edge_label(out)
205
+ rows.append(f' O_{_mkn(out)} [label="{out.name}\\n{lab}", fillcolor="#aaaaee"];')
206
+ edge = name_to_ids[out.name], f"O_{_mkn(out)}"
207
+ rows.append(f" {edge[0]} -> {edge[1]};")
208
+
209
+ rows.append("}")
210
+ return "\n".join(rows)
@@ -529,16 +529,20 @@ def string_type(
529
529
  return "OV(NO-NUMPY:FIXIT)"
530
530
  if verbose:
531
531
  print(f"[string_type] V4:{type(obj)}")
532
- return f"OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
532
+ dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else ""
533
+ return (
534
+ f"{dev}OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
535
+ )
533
536
  dt = obj.element_type()
534
537
  shape = obj.shape()
538
+ dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else ""
535
539
  if with_shape:
536
540
  if verbose:
537
541
  print(f"[string_type] V5:{type(obj)}")
538
- return f"OV{dt}s{'x'.join(map(str, shape))}"
542
+ return f"{dev}OV{dt}s{'x'.join(map(str, shape))}"
539
543
  if verbose:
540
544
  print(f"[string_type] V6:{type(obj)}")
541
- return f"OV{dt}r{len(shape)}"
545
+ return f"{dev}OV{dt}r{len(shape)}"
542
546
 
543
547
  # others classes
544
548
 
@@ -787,6 +791,8 @@ def string_type(
787
791
  return f"ultralytics.{obj.__class__.__name__}(...)"
788
792
  if obj.__class__.__name__ == "FakeTensorMode":
789
793
  return f"{obj}"
794
+ if obj.__class__.__name__ == "FakeTensorContext":
795
+ return "FakeTensorContext(...)"
790
796
 
791
797
  if verbose:
792
798
  print(f"[string_type] END:{type(obj)}")
@@ -988,7 +994,7 @@ def max_diff(
988
994
  _index: int = 0,
989
995
  allow_unique_tensor_with_list_of_one_element: bool = True,
990
996
  hist: Optional[Union[bool, List[float]]] = None,
991
- ) -> Dict[str, Union[float, int, Tuple[int, ...]]]:
997
+ ) -> Dict[str, Union[float, int, Tuple[Any, ...]]]:
992
998
  """
993
999
  Returns the maximum discrepancy.
994
1000
 
@@ -1013,6 +1019,7 @@ def max_diff(
1013
1019
  output, this number will be the number of elements
1014
1020
  of this output
1015
1021
  * dnan: difference in the number of nan
1022
+ * dev: tensor on the same device, if applicable
1016
1023
 
1017
1024
  You may use :func:`string_diff` to display the discrepancies in one string.
1018
1025
  """
@@ -1165,7 +1172,7 @@ def max_diff(
1165
1172
 
1166
1173
  if verbose >= 6:
1167
1174
  print(f"[max_diff] list,tuple,6: {string_type(expected)} ? {string_type(got)}")
1168
- am, rm, sm, n, dn, drep = 0, 0, 0.0, 0.0, 0, None
1175
+ am, rm, sm, n, dn, drep, dd = 0, 0, 0.0, 0.0, 0, None, None
1169
1176
  for ip, (e, g) in enumerate(zip(expected, got)):
1170
1177
  d = max_diff(
1171
1178
  e,
@@ -1197,7 +1204,15 @@ def max_diff(
1197
1204
  else:
1198
1205
  for k, v in d["rep"].items():
1199
1206
  drep[k] += v
1207
+ if "dev" in d and d["dev"] is not None:
1208
+ if dd is None:
1209
+ dd = d["dev"]
1210
+ else:
1211
+ dd += d["dev"] # type: ignore[operator]
1212
+
1200
1213
  res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn)
1214
+ if dd is not None:
1215
+ res["dev"] = dd
1201
1216
  if drep:
1202
1217
  res["rep"] = drep
1203
1218
  return res # type: ignore
@@ -1231,33 +1246,42 @@ def max_diff(
1231
1246
  import torch
1232
1247
 
1233
1248
  if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
1249
+ dev = None
1234
1250
  if isinstance(expected, torch.Tensor):
1235
1251
  from .torch_helper import to_numpy
1236
1252
 
1253
+ dev = 0 if expected.device.type == "cpu" else 1
1237
1254
  expected = to_numpy(expected)
1238
1255
  if isinstance(got, torch.Tensor):
1239
1256
  from .torch_helper import to_numpy
1240
1257
 
1258
+ dev = 0 if got.device.type == "cpu" else 1
1241
1259
  got = to_numpy(got)
1242
1260
  if verbose >= 6:
1243
1261
  print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1244
1262
 
1245
1263
  if _index < begin or (end != -1 and _index >= end):
1246
1264
  # out of boundary
1247
- return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1265
+ res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1266
+ if dev is not None:
1267
+ res["dev"] = dev # type: ignore[operator]
1268
+ return res # type: ignore[return-value]
1248
1269
  if isinstance(expected, (int, float)):
1249
1270
  if isinstance(got, np.ndarray) and len(got.shape) == 0:
1250
1271
  got = float(got)
1251
1272
  if isinstance(got, (int, float)):
1252
1273
  if expected == got:
1253
1274
  return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1254
- return dict(
1275
+ res = dict(
1255
1276
  abs=abs(expected - got),
1256
1277
  rel=abs(expected - got) / (abs(expected) + 1e-5),
1257
1278
  sum=abs(expected - got),
1258
1279
  n=1,
1259
1280
  dnan=0,
1260
1281
  )
1282
+ if dev is not None:
1283
+ res["dev"] = dev
1284
+ return res # type: ignore[return-value]
1261
1285
  return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1262
1286
  if expected.dtype in (np.complex64, np.complex128):
1263
1287
  if got.dtype == expected.dtype:
@@ -1337,6 +1361,8 @@ def max_diff(
1337
1361
  res: Dict[str, float] = dict( # type: ignore
1338
1362
  abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1339
1363
  )
1364
+ if dev is not None:
1365
+ res["dev"] = dev
1340
1366
  if hist:
1341
1367
  if isinstance(hist, bool):
1342
1368
  hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
@@ -1350,9 +1376,14 @@ def max_diff(
1350
1376
  if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
1351
1377
  if verbose >= 6:
1352
1378
  print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1379
+ dev = 0 if expected.device == got.device else 1
1353
1380
  if _index < begin or (end != -1 and _index >= end):
1354
1381
  # out of boundary
1355
- return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1382
+ if verbose >= 10:
1383
+ if debug_info:
1384
+ print("\n".join(debug_info))
1385
+ print("[max_diff] out of boundary")
1386
+ return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0, dev=dev)
1356
1387
  if expected.dtype in (torch.complex64, torch.complex128):
1357
1388
  if got.dtype == expected.dtype:
1358
1389
  got = torch.view_as_real(got)
@@ -1446,31 +1477,63 @@ def max_diff(
1446
1477
  )
1447
1478
 
1448
1479
  res: Dict[str, float] = dict( # type: ignore
1449
- abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1480
+ abs=abs_diff,
1481
+ rel=rel_diff,
1482
+ sum=sum_diff,
1483
+ n=n_diff,
1484
+ dnan=nan_diff,
1485
+ argm=argm,
1486
+ dev=dev,
1450
1487
  )
1451
1488
  if hist:
1452
- if isinstance(hist, bool):
1453
- hist = torch.tensor(
1454
- [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1455
- )
1456
- hist = hist.to(diff.device)
1457
- ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1458
- cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1459
- res["rep"] = dict(
1460
- zip(
1461
- [f">{x}" for x in hist],
1462
- [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1489
+ if isinstance(hist, list) and len(hist) == 1:
1490
+ res["rep"] = {f">{hist[0]}": (diff > hist[0]).sum().item()}
1491
+ elif isinstance(hist, list) and len(hist) == 2:
1492
+ res["rep"] = {
1493
+ f">{hist[0]}": (diff > hist[0]).sum().item(),
1494
+ f">{hist[1]}": (diff > hist[1]).sum().item(),
1495
+ }
1496
+ else:
1497
+ if isinstance(hist, bool):
1498
+ hist = torch.tensor(
1499
+ [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1500
+ )
1501
+ hist = torch.tensor(hist).to(diff.device)
1502
+ ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1503
+ cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1504
+ res["rep"] = dict(
1505
+ zip(
1506
+ [f">{x}" for x in hist],
1507
+ [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1508
+ )
1463
1509
  )
1464
- )
1465
1510
  return res # type: ignore
1466
1511
 
1512
+ if isinstance(expected, int) and isinstance(got, torch.Tensor):
1513
+ # a size
1514
+ if verbose >= 6:
1515
+ print(f"[max_diff] int: {string_type(expected)} ? {string_type(got)}")
1516
+ if got.shape != tuple():
1517
+ return dict( # type: ignore
1518
+ abs=np.inf,
1519
+ rel=np.inf,
1520
+ sum=np.inf,
1521
+ n=np.inf,
1522
+ dnan=np.inf,
1523
+ argm=np.inf,
1524
+ )
1525
+ return dict( # type: ignore
1526
+ abs=abs(expected - got.item()),
1527
+ rel=abs((expected - got.item()) / max(1, expected)),
1528
+ sum=abs(expected - got.item()),
1529
+ n=1,
1530
+ dnan=0,
1531
+ )
1532
+
1467
1533
  if "SquashedNormal" in expected.__class__.__name__:
1468
1534
  if verbose >= 6:
1469
1535
  print(f"[max_diff] SquashedNormal: {string_type(expected)} ? {string_type(got)}")
1470
- values = (
1471
- expected.mean.detach().to("cpu"),
1472
- expected.scale.detach().to("cpu"),
1473
- )
1536
+ values = (expected.mean, expected.scale)
1474
1537
  return max_diff(values, got, debug_info=_debug("SquashedNormal"), **_dkws)
1475
1538
 
1476
1539
  if expected.__class__ in torch.utils._pytree.SUPPORTED_NODES:
@@ -1675,7 +1738,7 @@ def max_diff(
1675
1738
 
1676
1739
  raise AssertionError(
1677
1740
  f"Not implemented with implemented with expected="
1678
- f"{string_type(expected)}, got={string_type(got)},\n"
1741
+ f"{string_type(expected)} ({type(expected)}), got={string_type(got)},\n"
1679
1742
  f"level={level}"
1680
1743
  )
1681
1744
 
@@ -1683,6 +1746,9 @@ def max_diff(
1683
1746
  def string_diff(diff: Dict[str, Any]) -> str:
1684
1747
  """Renders discrepancies return by :func:`max_diff` into one string."""
1685
1748
  # dict(abs=, rel=, sum=, n=n_diff, dnan=)
1749
+ if "dev" in diff:
1750
+ ddiff = {k: v for k, v in diff.items() if k != "dev"}
1751
+ return f"{string_diff(ddiff)}, dev={diff['dev']}"
1686
1752
  suffix = ""
1687
1753
  if "rep" in diff:
1688
1754
  rows = []
@@ -901,13 +901,19 @@ class CubeLogs:
901
901
  else g.groupby([*key_index, *key_columns], dropna=False).sum()
902
902
  )
903
903
  not_unique = r[r["count"] > 1]
904
+ if not_unique.shape[0] > 0 and os.environ.get("DUPLICATE", ""):
905
+ filename = os.environ.get("DUPLICATE")
906
+ subset = data.set_index([*key_index, *key_columns]).merge(
907
+ not_unique.head(), left_index=True, right_index=True
908
+ )
909
+ subset.to_excel(filename)
904
910
  assert not_unique.shape[0] == 0, (
905
911
  f"view_def.name={view_def.name!r}, "
906
912
  f"unable to run the pivot with index={sorted(key_index)}, "
907
913
  f"key={sorted(key_columns)}, key_agg={key_agg}, values={sorted(values)}, "
908
914
  f"columns={sorted(data.columns)}, ignored={view_def.ignore_columns}, "
909
- f"not unique={set(data.columns) - unique}"
910
- f"\n--\n{not_unique.head(10)}"
915
+ f"not unique={set(data.columns) - unique}, set DUPLICATE=<filename> "
916
+ f"to store the duplicates in a excel file\n--\n{not_unique.head(10)}"
911
917
  )
912
918
 
913
919
  # pivot
@@ -1000,8 +1006,12 @@ class CubeLogs:
1000
1006
  keys = set(self.keys_time) - {columns_to_fix}
1001
1007
  select = data[self.keys_time]
1002
1008
  select_agg = select.groupby(list(keys)).count()
1009
+ if select_agg.shape[0] == 0:
1010
+ # nothing to fix
1011
+ return data
1003
1012
  assert select_agg[columns_to_fix].max() <= 1, (
1004
- f"Column {columns_to_fix!r} has two distinct values at least for one date\n"
1013
+ f"Column {columns_to_fix!r} has two distinct values at least for one date, "
1014
+ f"max={select_agg[columns_to_fix].max()}\n"
1005
1015
  f"{select_agg[select_agg[columns_to_fix] > 1]}"
1006
1016
  )
1007
1017
 
@@ -1038,6 +1048,16 @@ class CubeLogs:
1038
1048
  f"data.columns.equals(res.columns)={data.columns.equals(res.columns)}, "
1039
1049
  f"data.index.equals(res.columns)={data.index.equals(res.columns)}, "
1040
1050
  )
1051
+ select = res[self.keys_time]
1052
+ select_agg = select.groupby(list(keys)).count()
1053
+ if select_agg.shape[0] == 0:
1054
+ # nothing to fix
1055
+ return data
1056
+ # assert select_agg[columns_to_fix].max() <= 1, (
1057
+ # f"Column {columns_to_fix!r} has two distinct values at least for one date, "
1058
+ # f"max={select_agg[columns_to_fix].max()}\n"
1059
+ # f"{select_agg[select_agg[columns_to_fix] > 1]}"
1060
+ # )
1041
1061
  return res
1042
1062
 
1043
1063
  def _dropna(
@@ -1977,7 +1997,9 @@ class CubeLogsPerformance(CubeLogs):
1977
1997
  * **cmd:** command lines
1978
1998
  * **raw-short:** raw data without all the unused columns
1979
1999
  """
1980
- fix_aggregation_change = ["model_speedup_input_set", "model_test_with"]
2000
+ # This does not work.
2001
+ # used to be ["model_speedup_input_set", "model_test_with"]
2002
+ fix_aggregation_change = [] # type: ignore[var-annotated]
1981
2003
  fs = ["suite", "model_suite", "task", "model_name", "model_task"]
1982
2004
  index_cols = self._filter_column(fs, self.keys_time)
1983
2005
  assert index_cols, (
@@ -159,7 +159,9 @@ class MiniOnnxBuilder:
159
159
  """
160
160
  if not tensors:
161
161
  # empty list
162
- self.nodes.append(oh.make_node("SequenceEmpty", [], [name]))
162
+ self.nodes.append(
163
+ oh.make_node("SequenceEmpty", [], [name], dtype=TensorProto.FLOAT)
164
+ )
163
165
  tensor_type_proto = oh.make_tensor_type_proto(
164
166
  elem_type=TensorProto.FLOAT, shape=None
165
167
  )
@@ -422,6 +424,27 @@ def create_onnx_model_from_input_tensors(
422
424
  :return: ModelProto
423
425
 
424
426
  The function raises an error if not supported.
427
+ An example:
428
+
429
+ .. code-block:: python
430
+
431
+ from onnx_diagnostic.helpers.mini_onnx_builder import (
432
+ create_onnx_model_from_input_tensors,
433
+ )
434
+ import onnx
435
+
436
+ proto = create_onnx_model_from_input_tensors(
437
+ dict(
438
+ query_states=query_states,
439
+ key_states=key_states,
440
+ value_states=value_states,
441
+ cu_seqlens=cu_seqlens,
442
+ max_seqlen=(cu_seqlens[1:] - cu_seqlens[:-1]).max(),
443
+ scaling=self.scaling,
444
+ attn_output=attn_output,
445
+ )
446
+ )
447
+ onnx.save(proto, "attention_inputs.onnx")
425
448
  """
426
449
  if switch_low_high is None:
427
450
  switch_low_high = sys.byteorder != "big"
@@ -461,7 +484,17 @@ def _unflatten(
461
484
  if spl[-1] == "array":
462
485
  return pos + 1, outputs[pos]
463
486
  if spl[-1] == "tensor":
464
- return pos + 1, torch.from_numpy(outputs[pos]).to(device)
487
+ try:
488
+ return pos + 1, torch.from_numpy(outputs[pos]).to(device)
489
+ except TypeError:
490
+ # it should be more robust
491
+ import ml_dtypes
492
+
493
+ if outputs[pos].dtype == ml_dtypes.bfloat16:
494
+ return pos + 1, torch.from_numpy(outputs[pos].astype(float)).to(device).to(
495
+ torch.bfloat16
496
+ )
497
+ raise
465
498
  raise AssertionError(f"Unexpected name {name!r} in {names}")
466
499
 
467
500
  res: List[Any] = []
@@ -532,6 +565,12 @@ def _unflatten(
532
565
  return d
533
566
  return ty(res)
534
567
 
568
+ if end and len(res) == 1:
569
+ if res[0] is None:
570
+ return next_pos, ty()
571
+ if isinstance(res[0], tuple) and len(res[0]) == 2 and res[0] == ("dict.", None):
572
+ return next_pos, ty()
573
+ return next_pos, _make(ty, res)
535
574
  return next_pos, (
536
575
  ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res)
537
576
  )
@@ -557,6 +596,19 @@ def create_input_tensors_from_onnx_model(
557
596
  :return: restored data
558
597
 
559
598
  See example :ref:`l-plot-intermediate-results` for an example.
599
+
600
+ .. code-block:: python
601
+
602
+ import os
603
+ from onnx_diagnostic.helpers.mini_onnx_builder import (
604
+ create_input_tensors_from_onnx_model,
605
+ )
606
+ from onnx_diagnostic.helpers import string_type
607
+
608
+ restored = create_input_tensors_from_onnx_model("attention_inputs.onnx")
609
+ for k, v in restored.items():
610
+ print(f"{k}: {string_type(v, with_shape=True, with_min_max=True)}")
611
+
560
612
  """
561
613
  if engine == "ExtendedReferenceEvaluator":
562
614
  from ..reference import ExtendedReferenceEvaluator
@@ -595,6 +647,8 @@ def create_input_tensors_from_onnx_model(
595
647
  return float(output[0])
596
648
  if name == "tensor":
597
649
  return torch.from_numpy(output).to(device)
598
- raise AssertionError(f"Unexpected name {name!r} in {names}")
650
+ assert name.startswith(
651
+ ("list_", "list.", "dict.", "tuple_", "tuple.")
652
+ ), f"Unexpected name {name!r} in {names}"
599
653
 
600
654
  return _unflatten(sep, names, got, device=device)[1]
@@ -28,10 +28,37 @@ def download_model_builder_to_cache(
28
28
  if file_path.exists():
29
29
  return file_path
30
30
 
31
+ builders = cache_dir / "builders"
32
+ if not builders.exists():
33
+ builders.mkdir(parents=True, exist_ok=True)
34
+
35
+ for subfile in [
36
+ "__init__.py",
37
+ "base.py",
38
+ "chatglm.py",
39
+ "ernie.py",
40
+ "gemma.py",
41
+ "gptoss.py",
42
+ "granite.py",
43
+ "llama.py",
44
+ "mistral.py",
45
+ "nemotron.py",
46
+ "olmo.py",
47
+ "phi.py",
48
+ "qwen.py",
49
+ "smollm.py",
50
+ ]:
51
+ u = f"{'/'.join(url.split('/')[:-1])}/builders/{subfile}"
52
+ response = requests.get(u)
53
+ response.raise_for_status()
54
+ with open(builders / subfile, "wb") as f:
55
+ f.write(response.content)
56
+
31
57
  response = requests.get(url)
32
58
  response.raise_for_status()
33
59
  with open(file_path, "wb") as f:
34
60
  f.write(response.content)
61
+
35
62
  return file_path
36
63
 
37
64