onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +91 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +3 -3
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +92 -23
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +90 -26
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +103 -1
  15. onnx_diagnostic/helpers/ort_session.py +37 -11
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +103 -6
  18. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/validate.py +50 -1
  37. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  38. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  39. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  40. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
  41. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,210 @@
1
+ from typing import Dict, Set
2
+ import onnx
3
+ import onnx.numpy_helper as onh
4
+ from .onnx_helper import onnx_dtype_name, pretty_onnx
5
+
6
+
7
+ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
8
+ hidden = set()
9
+ memo = (
10
+ {i.name for i in graph.initializer}
11
+ | {i.values.name for i in graph.sparse_initializer}
12
+ | {i.name for i in graph.input}
13
+ )
14
+ for node in graph.node:
15
+ for i in node.input:
16
+ if i not in memo:
17
+ hidden.add(i)
18
+ for att in node.attribute:
19
+ if att.type == onnx.AttributeProto.GRAPH and att.g:
20
+ hid = _get_hidden_inputs(att.g)
21
+ less = set(h for h in hid if h not in memo)
22
+ hidden |= less
23
+ memo |= set(node.output)
24
+ return hidden
25
+
26
+
27
+ def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
28
+ els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("]
29
+ ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
30
+ for att in node.attribute:
31
+ if att.name == "to":
32
+ ee.append(f"{att.name}={onnx_dtype_name(att.i)}")
33
+ elif att.name in {"to", "axis", "value_int", "stash_type", "start", "end"}:
34
+ ee.append(f"{att.name}={att.i}")
35
+ elif att.name in {"value_float"}:
36
+ ee.append(f"{att.name}={att.f}")
37
+ elif att.name in {"value_floats"}:
38
+ ee.append(f"{att.name}={att.floats}")
39
+ elif att.name in {"value_ints", "perm"}:
40
+ ee.append(f"{att.name}={att.ints}")
41
+ els.append(", ".join(ee))
42
+ els.append(")")
43
+ if node.op_type == "Constant":
44
+ els.extend([" -> ", node.output[0]])
45
+ return "".join(els)
46
+
47
+
48
+ def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
49
+ itype = value_info.type.tensor_type.elem_type
50
+ if itype == onnx.TensorProto.UNDEFINED:
51
+ return ""
52
+ shape = tuple(
53
+ d.dim_param if d.dim_param else d.dim_value
54
+ for d in value_info.type.tensor_type.shape.dim
55
+ )
56
+ res = [
57
+ str(a)
58
+ for a in [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape]
59
+ ]
60
+ sshape = ",".join(res)
61
+ if multi_line and len(sshape) > 30:
62
+ sshape = ",\\n".join(res)
63
+ return f"{onnx_dtype_name(itype)}({sshape})"
64
+
65
+
66
+ def to_dot(model: onnx.ModelProto) -> str:
67
+ """
68
+ Converts a model into a dot graph.
69
+ Here is an example:
70
+
71
+ .. gdot::
72
+ :script: DOT-SECTION
73
+ :process:
74
+
75
+ from onnx_diagnostic.helpers.dot_helper import to_dot
76
+ from onnx_diagnostic.export.api import to_onnx
77
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
78
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
79
+
80
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
81
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
82
+ with torch_export_patches(patch_transformers=True):
83
+ em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom")
84
+ dot = to_dot(em.model_proto)
85
+ print("DOT-SECTION", dot)
86
+
87
+ Or this one obtained with :func:`torch.onnx.export`.
88
+
89
+ .. gdot::
90
+ :script: DOT-SECTION
91
+ :process:
92
+
93
+ from onnx_diagnostic.helpers.dot_helper import to_dot
94
+ from onnx_diagnostic.export.api import to_onnx
95
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
96
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
97
+
98
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
99
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
100
+ with torch_export_patches(patch_transformers=True):
101
+ em = to_onnx(model, kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo")
102
+ dot = to_dot(em.model_proto)
103
+ print("DOT-SECTION", dot)
104
+ """
105
+ _unique: Dict[int, int] = {}
106
+
107
+ def _mkn(obj: object) -> int:
108
+ id_obj = id(obj)
109
+ if id_obj in _unique:
110
+ return _unique[id_obj]
111
+ i = len(_unique)
112
+ _unique[id_obj] = i
113
+ return i
114
+
115
+ model = onnx.shape_inference.infer_shapes(model)
116
+
117
+ op_type_colors = {
118
+ "Shape": "#d2a81f",
119
+ "MatMul": "#ee9999",
120
+ "Transpose": "#ee99ee",
121
+ "Reshape": "#eeeeee",
122
+ "Squeeze": "#eeeeee",
123
+ "Unsqueeze": "#eeeeee",
124
+ }
125
+
126
+ edge_label = {}
127
+ for val in model.graph.value_info:
128
+ edge_label[val.name] = _make_edge_label(val, multi_line=True)
129
+
130
+ rows = [
131
+ "digraph {",
132
+ (
133
+ " graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, "
134
+ "ranksep=0.2, fontsize=8];"
135
+ ),
136
+ ' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];',
137
+ " edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];",
138
+ ]
139
+ inputs = list(model.graph.input)
140
+ outputs = list(model.graph.output)
141
+ nodes = list(model.graph.node)
142
+ inits = list(model.graph.initializer)
143
+ tiny_inits = {}
144
+ name_to_ids = {}
145
+ for inp in inputs:
146
+ if not inp.name:
147
+ continue
148
+ lab = _make_edge_label(inp)
149
+ rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
150
+ name_to_ids[inp.name] = f"I_{_mkn(inp)}"
151
+ edge_label[inp.name] = _make_edge_label(inp, multi_line=True)
152
+ for init in inits:
153
+ shape = tuple(init.dims)
154
+ if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
155
+ a = onh.to_array(init)
156
+ tiny_inits[init.name] = (
157
+ str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]"
158
+ )
159
+ else:
160
+ ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})"
161
+ rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
162
+ name_to_ids[init.name] = f"i_{_mkn(init)}"
163
+ edge_label[init.name] = ls
164
+ for node in nodes:
165
+ color = op_type_colors.get(node.op_type, "#cccccc")
166
+ label = _make_node_label(node, tiny_inits)
167
+ rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')
168
+ name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o})
169
+
170
+ # nodes
171
+ done = set()
172
+ for node in nodes:
173
+ names = list(node.input)
174
+ for i in names:
175
+ if not i or i in tiny_inits:
176
+ continue
177
+ if i not in name_to_ids:
178
+ raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}")
179
+ edge = name_to_ids[i], f"{node.op_type}_{_mkn(node)}"
180
+ if edge in done:
181
+ continue
182
+ done.add(edge)
183
+ lab = edge_label.get(i, "")
184
+ if lab:
185
+ ls = ",".join([f'label="{lab}"'])
186
+ lab = f" [{ls}]"
187
+ rows.append(f" {edge[0]} -> {edge[1]}{lab};")
188
+ if node.op_type in {"Scan", "Loop", "If"}:
189
+ unique = set()
190
+ for att in node.attribute:
191
+ if att.type == onnx.AttributeProto.GRAPH:
192
+ unique |= _get_hidden_inputs(att.g)
193
+ for i in unique:
194
+ edge = name_to_ids[i], _mkn(node) # type: ignore[assignment]
195
+ if edge in done:
196
+ continue
197
+ done.add(edge)
198
+ rows.append(f" {edge[0]} -> {edge[1]} [style=dotted];")
199
+
200
+ # outputs
201
+ for out in outputs:
202
+ if not out.name:
203
+ continue
204
+ lab = _make_edge_label(out)
205
+ rows.append(f' O_{_mkn(out)} [label="{out.name}\\n{lab}", fillcolor="#aaaaee"];')
206
+ edge = name_to_ids[out.name], f"O_{_mkn(out)}"
207
+ rows.append(f" {edge[0]} -> {edge[1]};")
208
+
209
+ rows.append("}")
210
+ return "\n".join(rows)
@@ -529,16 +529,20 @@ def string_type(
529
529
  return "OV(NO-NUMPY:FIXIT)"
530
530
  if verbose:
531
531
  print(f"[string_type] V4:{type(obj)}")
532
- return f"OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
532
+ dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else ""
533
+ return (
534
+ f"{dev}OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
535
+ )
533
536
  dt = obj.element_type()
534
537
  shape = obj.shape()
538
+ dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else ""
535
539
  if with_shape:
536
540
  if verbose:
537
541
  print(f"[string_type] V5:{type(obj)}")
538
- return f"OV{dt}s{'x'.join(map(str, shape))}"
542
+ return f"{dev}OV{dt}s{'x'.join(map(str, shape))}"
539
543
  if verbose:
540
544
  print(f"[string_type] V6:{type(obj)}")
541
- return f"OV{dt}r{len(shape)}"
545
+ return f"{dev}OV{dt}r{len(shape)}"
542
546
 
543
547
  # others classes
544
548
 
@@ -990,7 +994,7 @@ def max_diff(
990
994
  _index: int = 0,
991
995
  allow_unique_tensor_with_list_of_one_element: bool = True,
992
996
  hist: Optional[Union[bool, List[float]]] = None,
993
- ) -> Dict[str, Union[float, int, Tuple[int, ...]]]:
997
+ ) -> Dict[str, Union[float, int, Tuple[Any, ...]]]:
994
998
  """
995
999
  Returns the maximum discrepancy.
996
1000
 
@@ -1015,6 +1019,7 @@ def max_diff(
1015
1019
  output, this number will be the number of elements
1016
1020
  of this output
1017
1021
  * dnan: difference in the number of nan
1022
+ * dev: tensor on the same device, if applicable
1018
1023
 
1019
1024
  You may use :func:`string_diff` to display the discrepancies in one string.
1020
1025
  """
@@ -1167,7 +1172,7 @@ def max_diff(
1167
1172
 
1168
1173
  if verbose >= 6:
1169
1174
  print(f"[max_diff] list,tuple,6: {string_type(expected)} ? {string_type(got)}")
1170
- am, rm, sm, n, dn, drep = 0, 0, 0.0, 0.0, 0, None
1175
+ am, rm, sm, n, dn, drep, dd = 0, 0, 0.0, 0.0, 0, None, None
1171
1176
  for ip, (e, g) in enumerate(zip(expected, got)):
1172
1177
  d = max_diff(
1173
1178
  e,
@@ -1199,7 +1204,15 @@ def max_diff(
1199
1204
  else:
1200
1205
  for k, v in d["rep"].items():
1201
1206
  drep[k] += v
1207
+ if "dev" in d and d["dev"] is not None:
1208
+ if dd is None:
1209
+ dd = d["dev"]
1210
+ else:
1211
+ dd += d["dev"] # type: ignore[operator]
1212
+
1202
1213
  res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn)
1214
+ if dd is not None:
1215
+ res["dev"] = dd
1203
1216
  if drep:
1204
1217
  res["rep"] = drep
1205
1218
  return res # type: ignore
@@ -1233,33 +1246,42 @@ def max_diff(
1233
1246
  import torch
1234
1247
 
1235
1248
  if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
1249
+ dev = None
1236
1250
  if isinstance(expected, torch.Tensor):
1237
1251
  from .torch_helper import to_numpy
1238
1252
 
1253
+ dev = 0 if expected.device.type == "cpu" else 1
1239
1254
  expected = to_numpy(expected)
1240
1255
  if isinstance(got, torch.Tensor):
1241
1256
  from .torch_helper import to_numpy
1242
1257
 
1258
+ dev = 0 if got.device.type == "cpu" else 1
1243
1259
  got = to_numpy(got)
1244
1260
  if verbose >= 6:
1245
1261
  print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1246
1262
 
1247
1263
  if _index < begin or (end != -1 and _index >= end):
1248
1264
  # out of boundary
1249
- return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1265
+ res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1266
+ if dev is not None:
1267
+ res["dev"] = dev # type: ignore[operator]
1268
+ return res # type: ignore[return-value]
1250
1269
  if isinstance(expected, (int, float)):
1251
1270
  if isinstance(got, np.ndarray) and len(got.shape) == 0:
1252
1271
  got = float(got)
1253
1272
  if isinstance(got, (int, float)):
1254
1273
  if expected == got:
1255
1274
  return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1256
- return dict(
1275
+ res = dict(
1257
1276
  abs=abs(expected - got),
1258
1277
  rel=abs(expected - got) / (abs(expected) + 1e-5),
1259
1278
  sum=abs(expected - got),
1260
1279
  n=1,
1261
1280
  dnan=0,
1262
1281
  )
1282
+ if dev is not None:
1283
+ res["dev"] = dev
1284
+ return res # type: ignore[return-value]
1263
1285
  return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1264
1286
  if expected.dtype in (np.complex64, np.complex128):
1265
1287
  if got.dtype == expected.dtype:
@@ -1339,6 +1361,8 @@ def max_diff(
1339
1361
  res: Dict[str, float] = dict( # type: ignore
1340
1362
  abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1341
1363
  )
1364
+ if dev is not None:
1365
+ res["dev"] = dev
1342
1366
  if hist:
1343
1367
  if isinstance(hist, bool):
1344
1368
  hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
@@ -1352,9 +1376,14 @@ def max_diff(
1352
1376
  if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
1353
1377
  if verbose >= 6:
1354
1378
  print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1379
+ dev = 0 if expected.device == got.device else 1
1355
1380
  if _index < begin or (end != -1 and _index >= end):
1356
1381
  # out of boundary
1357
- return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1382
+ if verbose >= 10:
1383
+ if debug_info:
1384
+ print("\n".join(debug_info))
1385
+ print("[max_diff] out of boundary")
1386
+ return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0, dev=dev)
1358
1387
  if expected.dtype in (torch.complex64, torch.complex128):
1359
1388
  if got.dtype == expected.dtype:
1360
1389
  got = torch.view_as_real(got)
@@ -1448,31 +1477,63 @@ def max_diff(
1448
1477
  )
1449
1478
 
1450
1479
  res: Dict[str, float] = dict( # type: ignore
1451
- abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1480
+ abs=abs_diff,
1481
+ rel=rel_diff,
1482
+ sum=sum_diff,
1483
+ n=n_diff,
1484
+ dnan=nan_diff,
1485
+ argm=argm,
1486
+ dev=dev,
1452
1487
  )
1453
1488
  if hist:
1454
- if isinstance(hist, bool):
1455
- hist = torch.tensor(
1456
- [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1457
- )
1458
- hist = hist.to(diff.device)
1459
- ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1460
- cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1461
- res["rep"] = dict(
1462
- zip(
1463
- [f">{x}" for x in hist],
1464
- [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1489
+ if isinstance(hist, list) and len(hist) == 1:
1490
+ res["rep"] = {f">{hist[0]}": (diff > hist[0]).sum().item()}
1491
+ elif isinstance(hist, list) and len(hist) == 2:
1492
+ res["rep"] = {
1493
+ f">{hist[0]}": (diff > hist[0]).sum().item(),
1494
+ f">{hist[1]}": (diff > hist[1]).sum().item(),
1495
+ }
1496
+ else:
1497
+ if isinstance(hist, bool):
1498
+ hist = torch.tensor(
1499
+ [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1500
+ )
1501
+ hist = torch.tensor(hist).to(diff.device)
1502
+ ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1503
+ cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1504
+ res["rep"] = dict(
1505
+ zip(
1506
+ [f">{x}" for x in hist],
1507
+ [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1508
+ )
1465
1509
  )
1466
- )
1467
1510
  return res # type: ignore
1468
1511
 
1512
+ if isinstance(expected, int) and isinstance(got, torch.Tensor):
1513
+ # a size
1514
+ if verbose >= 6:
1515
+ print(f"[max_diff] int: {string_type(expected)} ? {string_type(got)}")
1516
+ if got.shape != tuple():
1517
+ return dict( # type: ignore
1518
+ abs=np.inf,
1519
+ rel=np.inf,
1520
+ sum=np.inf,
1521
+ n=np.inf,
1522
+ dnan=np.inf,
1523
+ argm=np.inf,
1524
+ )
1525
+ return dict( # type: ignore
1526
+ abs=abs(expected - got.item()),
1527
+ rel=abs((expected - got.item()) / max(1, expected)),
1528
+ sum=abs(expected - got.item()),
1529
+ n=1,
1530
+ dnan=0,
1531
+ )
1532
+
1469
1533
  if "SquashedNormal" in expected.__class__.__name__:
1470
1534
  if verbose >= 6:
1471
1535
  print(f"[max_diff] SquashedNormal: {string_type(expected)} ? {string_type(got)}")
1472
- values = (
1473
- expected.mean.detach().to("cpu"),
1474
- expected.scale.detach().to("cpu"),
1475
- )
1536
+ values = (expected.mean, expected.scale)
1476
1537
  return max_diff(values, got, debug_info=_debug("SquashedNormal"), **_dkws)
1477
1538
 
1478
1539
  if expected.__class__ in torch.utils._pytree.SUPPORTED_NODES:
@@ -1677,7 +1738,7 @@ def max_diff(
1677
1738
 
1678
1739
  raise AssertionError(
1679
1740
  f"Not implemented with implemented with expected="
1680
- f"{string_type(expected)}, got={string_type(got)},\n"
1741
+ f"{string_type(expected)} ({type(expected)}), got={string_type(got)},\n"
1681
1742
  f"level={level}"
1682
1743
  )
1683
1744
 
@@ -1685,6 +1746,9 @@ def max_diff(
1685
1746
  def string_diff(diff: Dict[str, Any]) -> str:
1686
1747
  """Renders discrepancies return by :func:`max_diff` into one string."""
1687
1748
  # dict(abs=, rel=, sum=, n=n_diff, dnan=)
1749
+ if "dev" in diff:
1750
+ ddiff = {k: v for k, v in diff.items() if k != "dev"}
1751
+ return f"{string_diff(ddiff)}, dev={diff['dev']}"
1688
1752
  suffix = ""
1689
1753
  if "rep" in diff:
1690
1754
  rows = []
@@ -159,7 +159,9 @@ class MiniOnnxBuilder:
159
159
  """
160
160
  if not tensors:
161
161
  # empty list
162
- self.nodes.append(oh.make_node("SequenceEmpty", [], [name]))
162
+ self.nodes.append(
163
+ oh.make_node("SequenceEmpty", [], [name], dtype=TensorProto.FLOAT)
164
+ )
163
165
  tensor_type_proto = oh.make_tensor_type_proto(
164
166
  elem_type=TensorProto.FLOAT, shape=None
165
167
  )
@@ -28,10 +28,37 @@ def download_model_builder_to_cache(
28
28
  if file_path.exists():
29
29
  return file_path
30
30
 
31
+ builders = cache_dir / "builders"
32
+ if not builders.exists():
33
+ builders.mkdir(parents=True, exist_ok=True)
34
+
35
+ for subfile in [
36
+ "__init__.py",
37
+ "base.py",
38
+ "chatglm.py",
39
+ "ernie.py",
40
+ "gemma.py",
41
+ "gptoss.py",
42
+ "granite.py",
43
+ "llama.py",
44
+ "mistral.py",
45
+ "nemotron.py",
46
+ "olmo.py",
47
+ "phi.py",
48
+ "qwen.py",
49
+ "smollm.py",
50
+ ]:
51
+ u = f"{'/'.join(url.split('/')[:-1])}/builders/{subfile}"
52
+ response = requests.get(u)
53
+ response.raise_for_status()
54
+ with open(builders / subfile, "wb") as f:
55
+ f.write(response.content)
56
+
31
57
  response = requests.get(url)
32
58
  response.raise_for_status()
33
59
  with open(file_path, "wb") as f:
34
60
  f.write(response.content)
61
+
35
62
  return file_path
36
63
 
37
64
 
@@ -3,7 +3,7 @@ import json
3
3
  import os
4
4
  import sys
5
5
  import warnings
6
- from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
6
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
  import onnx
@@ -15,6 +15,7 @@ from onnx import (
15
15
  GraphProto,
16
16
  ModelProto,
17
17
  NodeProto,
18
+ OperatorSetIdProto,
18
19
  TensorProto,
19
20
  ValueInfoProto,
20
21
  load as onnx_load,
@@ -1195,3 +1196,104 @@ def shadowing_names(
1195
1196
  existing |= not_empty
1196
1197
  created |= not_empty
1197
1198
  return shadow, post_shadow, created
1199
+
1200
+
1201
+ def extract_subset_of_nodes(
1202
+ model: ModelProto,
1203
+ name: str,
1204
+ node_index: Optional[int] = None,
1205
+ cut_points: Optional[Set[str]] = None,
1206
+ ) -> List[NodeProto]:
1207
+ """
1208
+ Extracts the minimal subgraphs which can produce the output ``name``
1209
+ knowing ``cut_points``.
1210
+
1211
+ :param model: original model
1212
+ :param name: result name
1213
+ :param node_index: if the node index is known, otherwise searches for it
1214
+ :param cut_points: the known results or input name otherwise
1215
+ :return: minimal list of nodes
1216
+ """
1217
+ if node_index is None:
1218
+ for i, node in enumerate(model.graph.node):
1219
+ if name in node.output:
1220
+ node_index = i
1221
+ break
1222
+ assert (
1223
+ node_index is not None
1224
+ and node_index < len(model.graph.node)
1225
+ and name in model.graph.node[node_index].output
1226
+ ), f"node_index is still empty or wrong for result {name!r}"
1227
+ if cut_points is None:
1228
+ cut_points = {n.name for n in model.graph.input} | {
1229
+ n.name for n in model.graph.initializer
1230
+ }
1231
+ elif model.graph.initializer:
1232
+ cut_points = cut_points | {n.name for n in model.graph.initializer}
1233
+
1234
+ node = model.graph.node[node_index]
1235
+ selected = {node_index}
1236
+ current_node_index = node_index
1237
+ current_input_index = 0
1238
+ intermediate = {name}
1239
+ inputs = set(k for k in node.input if k)
1240
+ while not (inputs <= cut_points) and current_node_index >= 0:
1241
+ node = model.graph.node[current_node_index]
1242
+ if current_input_index == 0:
1243
+ needs = [o for o in node.output if o in intermediate and o not in cut_points]
1244
+ if needs:
1245
+ selected.add(current_node_index)
1246
+ else:
1247
+ current_node_index -= 1
1248
+ continue
1249
+ res = node.input[current_input_index]
1250
+ if res not in cut_points:
1251
+ intermediate.add(res)
1252
+ current_input_index += 1
1253
+ if current_input_index >= len(node.input):
1254
+ current_node_index -= 1
1255
+ current_input_index = 0
1256
+
1257
+ return [model.graph.node[i] for i in sorted(selected)]
1258
+
1259
+
1260
+ def make_submodel(
1261
+ nodes: List[NodeProto],
1262
+ ir_version: int,
1263
+ opset_imports: List[OperatorSetIdProto],
1264
+ output_names: List[str],
1265
+ type_rank_fn: Callable[[str], Tuple[int, int]],
1266
+ ) -> ModelProto:
1267
+ """
1268
+ Creates a model with the given list of nodes.
1269
+ It computes the minimum list of inputs needed for this model.
1270
+ The function assumes the nodes are sorted.
1271
+ It does not handle yet subgraphs.
1272
+
1273
+ :param nodes: list of nodes
1274
+ :param ir_version: ir version
1275
+ :param opset_imports: opset import
1276
+ :param output_names: desired outputs
1277
+ :param function: function returning the type and the rank of a result
1278
+ :return: model proto
1279
+ """
1280
+
1281
+ def _mkv_(name, itype, irank):
1282
+ return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)])
1283
+
1284
+ not_known: Set[str] = set()
1285
+ for node in nodes[::-1]:
1286
+ not_known -= set(node.output)
1287
+ not_known |= set(node.input)
1288
+
1289
+ model = oh.make_model(
1290
+ oh.make_graph(
1291
+ nodes,
1292
+ "submodel",
1293
+ [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
1294
+ [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
1295
+ ),
1296
+ ir_version=ir_version,
1297
+ opset_imports=opset_imports,
1298
+ )
1299
+ return model