onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -94,6 +94,20 @@ def size_type(dtype: Any) -> int:
94
94
  raise AssertionError(f"Unexpected dtype={dtype}")
95
95
 
96
96
 
97
+ def _string_tensor(obj, cls: str, with_shape: bool, with_device: bool, verbose: int) -> str:
98
+ from .torch_helper import torch_dtype_to_onnx_dtype
99
+
100
+ i = torch_dtype_to_onnx_dtype(obj.dtype)
101
+ prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
102
+ if not with_shape:
103
+ if verbose:
104
+ print(f"[string_type] {cls}1:{type(obj)}")
105
+ return f"{prefix}{cls}{i}r{len(obj.shape)}"
106
+ if verbose:
107
+ print(f"[string_type] {cls}2:{type(obj)}")
108
+ return f"{prefix}{cls}{i}s{'x'.join(map(str, obj.shape))}"
109
+
110
+
97
111
  def string_type(
98
112
  obj: Any,
99
113
  with_shape: bool = False,
@@ -453,17 +467,7 @@ def string_type(
453
467
 
454
468
  # Tensors
455
469
  if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
456
- from .torch_helper import torch_dtype_to_onnx_dtype
457
-
458
- i = torch_dtype_to_onnx_dtype(obj.dtype)
459
- prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
460
- if not with_shape:
461
- if verbose:
462
- print(f"[string_type] F1:{type(obj)}")
463
- return f"{prefix}F{i}r{len(obj.shape)}"
464
- if verbose:
465
- print(f"[string_type] F2:{type(obj)}")
466
- return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
470
+ return _string_tensor(obj, "F", with_shape, with_device, verbose)
467
471
 
468
472
  if isinstance(obj, torch.Tensor):
469
473
  from .torch_helper import torch_dtype_to_onnx_dtype
@@ -529,16 +533,23 @@ def string_type(
529
533
  return "OV(NO-NUMPY:FIXIT)"
530
534
  if verbose:
531
535
  print(f"[string_type] V4:{type(obj)}")
532
- return f"OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
536
+ dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else ""
537
+ return (
538
+ f"{dev}OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
539
+ )
533
540
  dt = obj.element_type()
534
541
  shape = obj.shape()
542
+ dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else ""
535
543
  if with_shape:
536
544
  if verbose:
537
545
  print(f"[string_type] V5:{type(obj)}")
538
- return f"OV{dt}s{'x'.join(map(str, shape))}"
546
+ return f"{dev}OV{dt}s{'x'.join(map(str, shape))}"
539
547
  if verbose:
540
548
  print(f"[string_type] V6:{type(obj)}")
541
- return f"OV{dt}r{len(shape)}"
549
+ return f"{dev}OV{dt}r{len(shape)}"
550
+
551
+ if obj.__class__.__name__ == "SymbolicTensor":
552
+ return _string_tensor(obj, "ST", with_shape, with_device, verbose)
542
553
 
543
554
  # others classes
544
555
 
@@ -990,7 +1001,7 @@ def max_diff(
990
1001
  _index: int = 0,
991
1002
  allow_unique_tensor_with_list_of_one_element: bool = True,
992
1003
  hist: Optional[Union[bool, List[float]]] = None,
993
- ) -> Dict[str, Union[float, int, Tuple[int, ...]]]:
1004
+ ) -> Dict[str, Union[float, int, Tuple[Any, ...]]]:
994
1005
  """
995
1006
  Returns the maximum discrepancy.
996
1007
 
@@ -1015,6 +1026,7 @@ def max_diff(
1015
1026
  output, this number will be the number of elements
1016
1027
  of this output
1017
1028
  * dnan: difference in the number of nan
1029
+ * dev: tensor on the same device, if applicable
1018
1030
 
1019
1031
  You may use :func:`string_diff` to display the discrepancies in one string.
1020
1032
  """
@@ -1167,7 +1179,7 @@ def max_diff(
1167
1179
 
1168
1180
  if verbose >= 6:
1169
1181
  print(f"[max_diff] list,tuple,6: {string_type(expected)} ? {string_type(got)}")
1170
- am, rm, sm, n, dn, drep = 0, 0, 0.0, 0.0, 0, None
1182
+ am, rm, sm, n, dn, drep, dd = 0, 0, 0.0, 0.0, 0, None, None
1171
1183
  for ip, (e, g) in enumerate(zip(expected, got)):
1172
1184
  d = max_diff(
1173
1185
  e,
@@ -1199,7 +1211,15 @@ def max_diff(
1199
1211
  else:
1200
1212
  for k, v in d["rep"].items():
1201
1213
  drep[k] += v
1214
+ if "dev" in d and d["dev"] is not None:
1215
+ if dd is None:
1216
+ dd = d["dev"]
1217
+ else:
1218
+ dd += d["dev"] # type: ignore[operator]
1219
+
1202
1220
  res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn)
1221
+ if dd is not None:
1222
+ res["dev"] = dd
1203
1223
  if drep:
1204
1224
  res["rep"] = drep
1205
1225
  return res # type: ignore
@@ -1233,33 +1253,42 @@ def max_diff(
1233
1253
  import torch
1234
1254
 
1235
1255
  if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
1256
+ dev = None
1236
1257
  if isinstance(expected, torch.Tensor):
1237
1258
  from .torch_helper import to_numpy
1238
1259
 
1260
+ dev = 0 if expected.device.type == "cpu" else 1
1239
1261
  expected = to_numpy(expected)
1240
1262
  if isinstance(got, torch.Tensor):
1241
1263
  from .torch_helper import to_numpy
1242
1264
 
1265
+ dev = 0 if got.device.type == "cpu" else 1
1243
1266
  got = to_numpy(got)
1244
1267
  if verbose >= 6:
1245
1268
  print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1246
1269
 
1247
1270
  if _index < begin or (end != -1 and _index >= end):
1248
1271
  # out of boundary
1249
- return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1272
+ res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1273
+ if dev is not None:
1274
+ res["dev"] = dev # type: ignore[operator]
1275
+ return res # type: ignore[return-value]
1250
1276
  if isinstance(expected, (int, float)):
1251
1277
  if isinstance(got, np.ndarray) and len(got.shape) == 0:
1252
1278
  got = float(got)
1253
1279
  if isinstance(got, (int, float)):
1254
1280
  if expected == got:
1255
1281
  return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1256
- return dict(
1282
+ res = dict(
1257
1283
  abs=abs(expected - got),
1258
1284
  rel=abs(expected - got) / (abs(expected) + 1e-5),
1259
1285
  sum=abs(expected - got),
1260
1286
  n=1,
1261
1287
  dnan=0,
1262
1288
  )
1289
+ if dev is not None:
1290
+ res["dev"] = dev
1291
+ return res # type: ignore[return-value]
1263
1292
  return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1264
1293
  if expected.dtype in (np.complex64, np.complex128):
1265
1294
  if got.dtype == expected.dtype:
@@ -1339,6 +1368,8 @@ def max_diff(
1339
1368
  res: Dict[str, float] = dict( # type: ignore
1340
1369
  abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1341
1370
  )
1371
+ if dev is not None:
1372
+ res["dev"] = dev
1342
1373
  if hist:
1343
1374
  if isinstance(hist, bool):
1344
1375
  hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
@@ -1352,9 +1383,14 @@ def max_diff(
1352
1383
  if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
1353
1384
  if verbose >= 6:
1354
1385
  print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1386
+ dev = 0 if expected.device == got.device else 1
1355
1387
  if _index < begin or (end != -1 and _index >= end):
1356
1388
  # out of boundary
1357
- return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1389
+ if verbose >= 10:
1390
+ if debug_info:
1391
+ print("\n".join(debug_info))
1392
+ print("[max_diff] out of boundary")
1393
+ return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0, dev=dev)
1358
1394
  if expected.dtype in (torch.complex64, torch.complex128):
1359
1395
  if got.dtype == expected.dtype:
1360
1396
  got = torch.view_as_real(got)
@@ -1448,31 +1484,63 @@ def max_diff(
1448
1484
  )
1449
1485
 
1450
1486
  res: Dict[str, float] = dict( # type: ignore
1451
- abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1487
+ abs=abs_diff,
1488
+ rel=rel_diff,
1489
+ sum=sum_diff,
1490
+ n=n_diff,
1491
+ dnan=nan_diff,
1492
+ argm=argm,
1493
+ dev=dev,
1452
1494
  )
1453
1495
  if hist:
1454
- if isinstance(hist, bool):
1455
- hist = torch.tensor(
1456
- [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1457
- )
1458
- hist = hist.to(diff.device)
1459
- ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1460
- cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1461
- res["rep"] = dict(
1462
- zip(
1463
- [f">{x}" for x in hist],
1464
- [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1496
+ if isinstance(hist, list) and len(hist) == 1:
1497
+ res["rep"] = {f">{hist[0]}": (diff > hist[0]).sum().item()}
1498
+ elif isinstance(hist, list) and len(hist) == 2:
1499
+ res["rep"] = {
1500
+ f">{hist[0]}": (diff > hist[0]).sum().item(),
1501
+ f">{hist[1]}": (diff > hist[1]).sum().item(),
1502
+ }
1503
+ else:
1504
+ if isinstance(hist, bool):
1505
+ hist = torch.tensor(
1506
+ [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1507
+ )
1508
+ hist = torch.tensor(hist).to(diff.device)
1509
+ ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1510
+ cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1511
+ res["rep"] = dict(
1512
+ zip(
1513
+ [f">{x}" for x in hist],
1514
+ [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1515
+ )
1465
1516
  )
1466
- )
1467
1517
  return res # type: ignore
1468
1518
 
1519
+ if isinstance(expected, int) and isinstance(got, torch.Tensor):
1520
+ # a size
1521
+ if verbose >= 6:
1522
+ print(f"[max_diff] int: {string_type(expected)} ? {string_type(got)}")
1523
+ if got.shape != tuple():
1524
+ return dict( # type: ignore
1525
+ abs=np.inf,
1526
+ rel=np.inf,
1527
+ sum=np.inf,
1528
+ n=np.inf,
1529
+ dnan=np.inf,
1530
+ argm=np.inf,
1531
+ )
1532
+ return dict( # type: ignore
1533
+ abs=abs(expected - got.item()),
1534
+ rel=abs((expected - got.item()) / max(1, expected)),
1535
+ sum=abs(expected - got.item()),
1536
+ n=1,
1537
+ dnan=0,
1538
+ )
1539
+
1469
1540
  if "SquashedNormal" in expected.__class__.__name__:
1470
1541
  if verbose >= 6:
1471
1542
  print(f"[max_diff] SquashedNormal: {string_type(expected)} ? {string_type(got)}")
1472
- values = (
1473
- expected.mean.detach().to("cpu"),
1474
- expected.scale.detach().to("cpu"),
1475
- )
1543
+ values = (expected.mean, expected.scale)
1476
1544
  return max_diff(values, got, debug_info=_debug("SquashedNormal"), **_dkws)
1477
1545
 
1478
1546
  if expected.__class__ in torch.utils._pytree.SUPPORTED_NODES:
@@ -1677,7 +1745,7 @@ def max_diff(
1677
1745
 
1678
1746
  raise AssertionError(
1679
1747
  f"Not implemented with implemented with expected="
1680
- f"{string_type(expected)}, got={string_type(got)},\n"
1748
+ f"{string_type(expected)} ({type(expected)}), got={string_type(got)},\n"
1681
1749
  f"level={level}"
1682
1750
  )
1683
1751
 
@@ -1685,6 +1753,9 @@ def max_diff(
1685
1753
  def string_diff(diff: Dict[str, Any]) -> str:
1686
1754
  """Renders discrepancies return by :func:`max_diff` into one string."""
1687
1755
  # dict(abs=, rel=, sum=, n=n_diff, dnan=)
1756
+ if "dev" in diff:
1757
+ ddiff = {k: v for k, v in diff.items() if k != "dev"}
1758
+ return f"{string_diff(ddiff)}, dev={diff['dev']}"
1688
1759
  suffix = ""
1689
1760
  if "rep" in diff:
1690
1761
  rows = []
@@ -159,7 +159,9 @@ class MiniOnnxBuilder:
159
159
  """
160
160
  if not tensors:
161
161
  # empty list
162
- self.nodes.append(oh.make_node("SequenceEmpty", [], [name]))
162
+ self.nodes.append(
163
+ oh.make_node("SequenceEmpty", [], [name], dtype=TensorProto.FLOAT)
164
+ )
163
165
  tensor_type_proto = oh.make_tensor_type_proto(
164
166
  elem_type=TensorProto.FLOAT, shape=None
165
167
  )
@@ -28,10 +28,37 @@ def download_model_builder_to_cache(
28
28
  if file_path.exists():
29
29
  return file_path
30
30
 
31
+ builders = cache_dir / "builders"
32
+ if not builders.exists():
33
+ builders.mkdir(parents=True, exist_ok=True)
34
+
35
+ for subfile in [
36
+ "__init__.py",
37
+ "base.py",
38
+ "chatglm.py",
39
+ "ernie.py",
40
+ "gemma.py",
41
+ "gptoss.py",
42
+ "granite.py",
43
+ "llama.py",
44
+ "mistral.py",
45
+ "nemotron.py",
46
+ "olmo.py",
47
+ "phi.py",
48
+ "qwen.py",
49
+ "smollm.py",
50
+ ]:
51
+ u = f"{'/'.join(url.split('/')[:-1])}/builders/{subfile}"
52
+ response = requests.get(u)
53
+ response.raise_for_status()
54
+ with open(builders / subfile, "wb") as f:
55
+ f.write(response.content)
56
+
31
57
  response = requests.get(url)
32
58
  response.raise_for_status()
33
59
  with open(file_path, "wb") as f:
34
60
  f.write(response.content)
61
+
35
62
  return file_path
36
63
 
37
64