onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ import itertools
9
9
  import logging
10
10
  import os
11
11
  import re
12
+ import shutil
12
13
  import sys
13
14
  import unittest
14
15
  import warnings
@@ -63,7 +64,7 @@ def skipif_ci_apple(msg) -> Callable:
63
64
  return lambda x: x
64
65
 
65
66
 
66
- def unit_test_going():
67
+ def unit_test_going() -> bool:
67
68
  """
68
69
  Enables a flag telling the script is running while testing it.
69
70
  Avois unit tests to be very long.
@@ -147,7 +148,7 @@ def hide_stdout(f: Optional[Callable] = None) -> Callable:
147
148
 
148
149
  def wrapper(fct):
149
150
  def call_f(self):
150
- if os.environ.get("UNHIDE", ""):
151
+ if os.environ.get("UNHIDE", "") in (1, "1", "True", "true"):
151
152
  fct(self)
152
153
  return
153
154
  st = StringIO()
@@ -609,6 +610,21 @@ def requires_onnxruntime(version: str, msg: str = "") -> Callable:
609
610
  return lambda x: x
610
611
 
611
612
 
613
+ def has_onnxruntime(version: str, msg: str = "") -> Callable:
614
+ """Skips a unit test if :epkg:`onnxruntime` is not recent enough."""
615
+ import packaging.version as pv
616
+ import onnxruntime
617
+
618
+ if not hasattr(onnxruntime, "__version__"):
619
+ # development version
620
+ return True
621
+
622
+ if pv.Version(onnxruntime.__version__) < pv.Version(version):
623
+ msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}"
624
+ return False
625
+ return True
626
+
627
+
612
628
  def has_onnxruntime_training(push_back_batch: bool = False):
613
629
  """Tells if onnxruntime_training is installed."""
614
630
  try:
@@ -742,8 +758,15 @@ class ExtTestCase(unittest.TestCase):
742
758
  _warns: List[Tuple[str, int, Warning]] = []
743
759
  _todos: List[Tuple[Callable, str]] = []
744
760
 
761
+ def unit_test_going(self) -> bool:
762
+ """
763
+ Enables a flag telling the script is running while testing it.
764
+ Avois unit tests to be very long.
765
+ """
766
+ return unit_test_going()
767
+
745
768
  @property
746
- def verbose(self):
769
+ def verbose(self) -> int:
747
770
  "Returns the the value of environment variable ``VERBOSE``."
748
771
  return int(os.environ.get("VERBOSE", "0"))
749
772
 
@@ -768,13 +791,13 @@ class ExtTestCase(unittest.TestCase):
768
791
  cls._todos.append((f, msg))
769
792
 
770
793
  @classmethod
771
- def ort(cls):
794
+ def ort(cls) -> unittest.__class__:
772
795
  import onnxruntime
773
796
 
774
797
  return onnxruntime
775
798
 
776
799
  @classmethod
777
- def to_onnx(self, *args, **kwargs):
800
+ def to_onnx(self, *args, **kwargs) -> "ModelProto": # noqa: F821
778
801
  from experimental_experiment.torch_interpreter import to_onnx
779
802
 
780
803
  return to_onnx(*args, **kwargs)
@@ -806,18 +829,29 @@ class ExtTestCase(unittest.TestCase):
806
829
  os.makedirs(folder)
807
830
  return folder
808
831
 
809
- def dump_onnx(
810
- self,
811
- name: str,
812
- proto: Any,
813
- folder: Optional[str] = None,
814
- ) -> str:
832
+ def clean_dump(self, folder: str = "dump_test"):
833
+ """Cleans this folder."""
834
+ for item in os.listdir(folder):
835
+ item_path = os.path.join(folder, item)
836
+ if os.path.isfile(item_path) or os.path.islink(item_path):
837
+ os.remove(item_path)
838
+ elif os.path.isdir(item_path):
839
+ shutil.rmtree(item_path)
840
+
841
+ def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
815
842
  """Dumps an onnx file."""
816
843
  fullname = self.get_dump_file(name, folder=folder)
817
844
  with open(fullname, "wb") as f:
818
845
  f.write(proto.SerializeToString())
819
846
  return fullname
820
847
 
848
+ def dump_text(self, name: str, text: str, folder: Optional[str] = None) -> str:
849
+ """Dumps text in a file."""
850
+ fullname = self.get_dump_file(name, folder=folder)
851
+ with open(fullname, "w") as f:
852
+ f.write(text)
853
+ return fullname
854
+
821
855
  def assertExists(self, name):
822
856
  """Checks the existing of a file."""
823
857
  if not os.path.exists(name):
@@ -1094,10 +1128,15 @@ class ExtTestCase(unittest.TestCase):
1094
1128
  value = numpy.array(value).astype(expected.dtype)
1095
1129
  self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
1096
1130
 
1097
- def check_ort(self, onx: "onnx.ModelProto") -> bool: # noqa: F821
1131
+ def check_ort(
1132
+ self, onx: "onnx.ModelProto" # noqa: F821
1133
+ ) -> "onnxruntime.InferenceSession": # noqa: F821
1098
1134
  from onnxruntime import InferenceSession
1099
1135
 
1100
- return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
1136
+ return InferenceSession(
1137
+ onx if isinstance(onx, str) else onx.SerializeToString(),
1138
+ providers=["CPUExecutionProvider"],
1139
+ )
1101
1140
 
1102
1141
  def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
1103
1142
  """In the name"""
@@ -1137,7 +1176,7 @@ class ExtTestCase(unittest.TestCase):
1137
1176
  if not full.endswith(suffix):
1138
1177
  raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
1139
1178
 
1140
- def capture(self, fct: Callable):
1179
+ def capture(self, fct: Callable) -> Tuple[Any, str, str]:
1141
1180
  """
1142
1181
  Runs a function and capture standard output and error.
1143
1182
 
@@ -1179,9 +1218,9 @@ class ExtTestCase(unittest.TestCase):
1179
1218
  def assert_onnx_disc(
1180
1219
  self,
1181
1220
  test_name: str,
1182
- proto: "onnx.ModelProto", # noqa: F821
1221
+ proto: Union[str, "onnx.ModelProto"], # noqa: F821
1183
1222
  model: "torch.nn.Module", # noqa: F821
1184
- inputs: Union[Tuple[Any], Dict[str, Any]],
1223
+ inputs: Union[Tuple[Any], Dict[str, Any], List[Any]],
1185
1224
  verbose: int = 0,
1186
1225
  atol: float = 1e-5,
1187
1226
  rtol: float = 1e-3,
@@ -1189,6 +1228,7 @@ class ExtTestCase(unittest.TestCase):
1189
1228
  expected: Optional[Any] = None,
1190
1229
  use_ort: bool = False,
1191
1230
  ort_optimized_graph: bool = False,
1231
+ ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
1192
1232
  **kwargs,
1193
1233
  ):
1194
1234
  """
@@ -1208,6 +1248,7 @@ class ExtTestCase(unittest.TestCase):
1208
1248
  :param copy_inputs: to copy the inputs
1209
1249
  :param use_ort: use :class:`onnxruntime.InferenceSession`
1210
1250
  :param ort_optimized_graph: dumps the optimized onnxruntime graph
1251
+ :param ep: exported program (or saved exported program)
1211
1252
  :param kwargs: arguments sent to
1212
1253
  :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
1213
1254
  """
@@ -1223,71 +1264,135 @@ class ExtTestCase(unittest.TestCase):
1223
1264
  name = f"{test_name}.onnx"
1224
1265
  if verbose:
1225
1266
  print(f"[{vname}] save the onnx model into {name!r}")
1267
+ model_file = None
1226
1268
  if isinstance(proto, str):
1269
+ model_file = proto
1227
1270
  name = proto
1228
1271
  proto = onnx.load(name)
1229
- else:
1272
+ elif not self.unit_test_going():
1230
1273
  assert isinstance(
1231
1274
  proto, onnx.ModelProto
1232
1275
  ), f"Unexpected type {type(proto)} for proto"
1233
1276
  name = self.dump_onnx(name, proto)
1234
- if verbose:
1277
+ if verbose and not self.unit_test_going():
1235
1278
  print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
1236
1279
  if verbose:
1237
1280
  print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
1281
+
1282
+ if not isinstance(inputs, list):
1283
+ inputs = [inputs]
1284
+ if expected is not None:
1285
+ expected = [expected]
1286
+
1287
+ gots = []
1238
1288
  if use_ort:
1239
1289
  assert isinstance(
1240
1290
  proto, onnx.ModelProto
1241
1291
  ), f"Unexpected type {type(proto)} for proto"
1242
- feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
1243
1292
  import onnxruntime
1244
1293
 
1245
- if verbose:
1246
- print(f"[{vname}] create onnxruntime.InferenceSession")
1247
1294
  options = onnxruntime.SessionOptions()
1248
1295
  if ort_optimized_graph:
1249
1296
  options.optimized_model_filepath = f"{name}.optort.onnx"
1297
+ if "log_severity_level" in kwargs:
1298
+ options.log_severity_level = kwargs["log_severity_level"]
1299
+ if "log_verbosity_level" in kwargs:
1300
+ options.log_verbosity_level = kwargs["log_verbosity_level"]
1301
+ providers = kwargs.get("providers", ["CPUExecutionProvider"])
1302
+ if verbose:
1303
+ print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
1250
1304
  sess = onnxruntime.InferenceSession(
1251
- proto.SerializeToString(),
1252
- options,
1253
- providers=kwargs.get("providers", ["CPUExecutionProvider"]),
1305
+ model_file or proto.SerializeToString(), options, providers=providers
1254
1306
  )
1255
- if verbose:
1256
- print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
1257
- got = sess.run(None, feeds)
1307
+ for inp in inputs:
1308
+ feeds = make_feeds(proto, inp, use_numpy=True, copy=True)
1309
+ if verbose:
1310
+ print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
1311
+ got = sess.run(None, feeds)
1312
+ gots.append(got)
1258
1313
  else:
1259
- feeds = make_feeds(proto, inputs, copy=True)
1260
1314
  if verbose:
1261
1315
  print(f"[{vname}] create InferenceSessionForTorch")
1262
1316
  sess = InferenceSessionForTorch(proto, **kwargs)
1263
- if verbose:
1264
- print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
1265
- got = sess.run(None, feeds)
1317
+ for inp in inputs:
1318
+ feeds = make_feeds(proto, inp, copy=True)
1319
+ if verbose:
1320
+ print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
1321
+ got = sess.run(None, feeds)
1322
+ gots.append(got)
1266
1323
  if verbose:
1267
1324
  print(f"[{vname}] compute expected values")
1325
+
1268
1326
  if expected is None:
1269
1327
  if copy_inputs:
1270
- expected = (
1271
- model(*copy.deepcopy(inputs))
1272
- if isinstance(inputs, tuple)
1273
- else model(**copy.deepcopy(inputs))
1274
- )
1328
+ expected = [
1329
+ (
1330
+ model(*copy.deepcopy(inp))
1331
+ if isinstance(inp, tuple)
1332
+ else model(**copy.deepcopy(inp))
1333
+ )
1334
+ for inp in inputs
1335
+ ]
1275
1336
  else:
1276
- expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
1337
+ expected = [
1338
+ model(*inp) if isinstance(inp, tuple) else model(**inp) for inp in inputs
1339
+ ]
1340
+
1277
1341
  if verbose:
1278
1342
  print(f"[{vname}] expected {string_type(expected, **kws)}")
1279
1343
  print(f"[{vname}] obtained {string_type(got, **kws)}")
1280
- diff = max_diff(expected, got, flatten=True)
1281
- if verbose:
1282
- print(f"[{vname}] diff {string_diff(diff)}")
1283
- assert (
1284
- isinstance(diff["abs"], float)
1285
- and isinstance(diff["rel"], float)
1286
- and not numpy.isnan(diff["abs"])
1287
- and diff["abs"] <= atol
1288
- and not numpy.isnan(diff["rel"])
1289
- and diff["rel"] <= rtol
1290
- ), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
1344
+
1345
+ if ep:
1346
+ if isinstance(ep, str):
1347
+ if verbose:
1348
+ print(f"[{vname}] load exported program {ep!r}")
1349
+ import torch
1350
+
1351
+ ep = torch.export.load(ep)
1352
+
1353
+ ep_model = ep.module() # type: ignore[union-attr]
1354
+ for expe, inp, got in zip(expected, inputs, gots):
1355
+ ep_inputs = copy.deepcopy(inp) if copy_inputs else inp
1356
+ ep_expected = (
1357
+ ep_model(*copy.deepcopy(ep_inputs))
1358
+ if isinstance(ep_inputs, tuple)
1359
+ else ep_model(**copy.deepcopy(ep_inputs))
1360
+ )
1361
+ if verbose:
1362
+ print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1363
+ ep_diff = max_diff(expe, ep_expected, hist=[0.1, 0.01])
1364
+ if verbose:
1365
+ print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1366
+ assert (
1367
+ isinstance(ep_diff["abs"], float)
1368
+ and isinstance(ep_diff["rel"], float)
1369
+ and not numpy.isnan(ep_diff["abs"])
1370
+ and ep_diff["abs"] <= atol
1371
+ and not numpy.isnan(ep_diff["rel"])
1372
+ and ep_diff["rel"] <= rtol
1373
+ ), (
1374
+ f"discrepancies in {test_name!r} between the exported program "
1375
+ f"and the exported model diff={string_diff(ep_diff)}"
1376
+ )
1377
+ ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
1378
+ if verbose:
1379
+ print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1380
+
1381
+ for expe, got in zip(expected, gots):
1382
+ diff = max_diff(expe, got, flatten=True, hist=[0.1, 0.01])
1383
+ if verbose:
1384
+ print(f"[{vname}] diff {string_diff(diff)}")
1385
+ assert (
1386
+ isinstance(diff["abs"], float)
1387
+ and isinstance(diff["rel"], float)
1388
+ and not numpy.isnan(diff["abs"])
1389
+ and diff["abs"] <= atol
1390
+ and not numpy.isnan(diff["rel"])
1391
+ and diff["rel"] <= rtol
1392
+ ), (
1393
+ f"discrepancies in {test_name!r} between the model and "
1394
+ f"the onnx model diff={string_diff(diff)}"
1395
+ )
1291
1396
 
1292
1397
  def _debug(self):
1293
1398
  "Tells if DEBUG=1 is set up."
@@ -1298,6 +1403,16 @@ class ExtTestCase(unittest.TestCase):
1298
1403
 
1299
1404
  return string_type(*args, **kwargs)
1300
1405
 
1406
+ def max_diff(self, *args, **kwargs):
1407
+ from .helpers import max_diff
1408
+
1409
+ return max_diff(*args, **kwargs)
1410
+
1411
+ def use_dyn_not_str(self, *args, **kwargs):
1412
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1413
+
1414
+ return use_dyn_not_str(*args, *kwargs)
1415
+
1301
1416
  def subloop(self, *args, verbose: int = 0):
1302
1417
  "Loops over elements and calls :meth:`unittests.TestCase.subTest`."
1303
1418
  if len(args) == 1:
@@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
80
80
  start = 0
81
81
  end = 0
82
82
  subtrees = []
83
- for subspec in spec.children_specs:
83
+ for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
84
84
  end += subspec.num_leaves
85
85
  value = subspec.unflatten(flat[start:end])
86
86
  value = flatten_unflatten_for_dynamic_shapes(
@@ -0,0 +1,222 @@
1
+ from typing import Dict
2
+ import numpy as np
3
+ import onnx
4
+ import onnx.numpy_helper as onh
5
+ from ..reference import ExtendedReferenceEvaluator as Inference
6
+ from .onnx_helper import onnx_dtype_name, pretty_onnx, get_hidden_inputs
7
+
8
+
9
+ def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
10
+ els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "\\n("]
11
+ ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
12
+ for att in node.attribute:
13
+ if att.name == "to":
14
+ ee.append(f"{att.name}={onnx_dtype_name(att.i)}")
15
+ elif att.name in {"to", "axis", "value_int", "stash_type", "start", "end"}:
16
+ ee.append(f"{att.name}={att.i}")
17
+ elif att.name in {"value_float"}:
18
+ ee.append(f"{att.name}={att.f}")
19
+ elif att.name in {"value_floats"}:
20
+ ee.append(f"{att.name}={att.floats}")
21
+ elif att.name in {"value_ints", "perm"}:
22
+ ee.append(f"{att.name}={att.ints}")
23
+ els.append(", ".join(ee))
24
+ els.append(")")
25
+ if node.op_type == "Constant":
26
+ els.extend([" -> ", node.output[0]])
27
+ res = "".join(els)
28
+ if len(res) < 40:
29
+ return res.replace("\\n(", "(")
30
+ return res
31
+
32
+
33
+ def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
34
+ itype = value_info.type.tensor_type.elem_type
35
+ if itype == onnx.TensorProto.UNDEFINED:
36
+ return ""
37
+ shape = tuple(
38
+ d.dim_param if d.dim_param else d.dim_value
39
+ for d in value_info.type.tensor_type.shape.dim
40
+ )
41
+ res = [
42
+ str(a)
43
+ for a in [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape]
44
+ ]
45
+ sshape = ",".join(res)
46
+ if multi_line and len(sshape) > 30:
47
+ sshape = ",\\n".join(res)
48
+ return f"{onnx_dtype_name(itype)}({sshape})"
49
+
50
+
51
+ def to_dot(model: onnx.ModelProto) -> str:
52
+ """
53
+ Converts a model into a dot graph.
54
+ Here is an example:
55
+
56
+ .. gdot::
57
+ :script: DOT-SECTION
58
+ :process:
59
+
60
+ from onnx_diagnostic.helpers.dot_helper import to_dot
61
+ from onnx_diagnostic.export.api import to_onnx
62
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
63
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
64
+
65
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
66
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
67
+ with torch_export_patches(patch_transformers=True):
68
+ em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom")
69
+ dot = to_dot(em.model_proto)
70
+ print("DOT-SECTION", dot)
71
+
72
+ Or this one obtained with :func:`torch.onnx.export`.
73
+
74
+ .. gdot::
75
+ :script: DOT-SECTION
76
+ :process:
77
+
78
+ from onnx_diagnostic.helpers.dot_helper import to_dot
79
+ from onnx_diagnostic.export.api import to_onnx
80
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
81
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
82
+
83
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
84
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
85
+ with torch_export_patches(patch_transformers=True):
86
+ em = to_onnx(model, kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo")
87
+ dot = to_dot(em.model_proto)
88
+ print("DOT-SECTION", dot)
89
+ """
90
+ _unique: Dict[int, int] = {}
91
+
92
+ def _mkn(obj: object) -> int:
93
+ id_obj = id(obj)
94
+ if id_obj in _unique:
95
+ return _unique[id_obj]
96
+ i = len(_unique)
97
+ _unique[id_obj] = i
98
+ return i
99
+
100
+ model = onnx.shape_inference.infer_shapes(model)
101
+
102
+ op_type_colors = {
103
+ "Shape": "#d2a81f",
104
+ "MatMul": "#ee9999",
105
+ "Transpose": "#ee99ee",
106
+ "Reshape": "#eeeeee",
107
+ "Squeeze": "#eeeeee",
108
+ "Unsqueeze": "#eeeeee",
109
+ }
110
+
111
+ edge_label = {}
112
+ for val in model.graph.value_info:
113
+ edge_label[val.name] = _make_edge_label(val, multi_line=True)
114
+
115
+ rows = [
116
+ "digraph {",
117
+ (
118
+ " graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, "
119
+ "ranksep=0.2, fontsize=8];"
120
+ ),
121
+ ' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];',
122
+ " edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];",
123
+ ]
124
+ inputs = list(model.graph.input)
125
+ outputs = list(model.graph.output)
126
+ nodes = list(model.graph.node)
127
+ inits = list(model.graph.initializer)
128
+ tiny_inits = {}
129
+ name_to_ids = {}
130
+
131
+ for inp in inputs:
132
+ if not inp.name:
133
+ continue
134
+ lab = _make_edge_label(inp)
135
+ rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
136
+ name_to_ids[inp.name] = f"I_{_mkn(inp)}"
137
+ edge_label[inp.name] = _make_edge_label(inp, multi_line=True)
138
+
139
+ # Small constant --> initializer
140
+ output_names = {n.name for n in outputs}
141
+ for node in nodes:
142
+ if node.op_type != "Constant" or node.output[0] in output_names:
143
+ continue
144
+ skip = False
145
+ for att in node.attribute:
146
+ if att.name == "value" and (
147
+ len(att.t.dims) > 1 or np.prod(tuple(att.t.dims)) > 10
148
+ ):
149
+ skip = True
150
+ break
151
+ if skip:
152
+ continue
153
+
154
+ sess = Inference(node)
155
+ value = sess.run(None, {})[0]
156
+ inits.append(onh.from_array(value, name=node.output[0]))
157
+
158
+ for init in inits:
159
+ if init.name in name_to_ids:
160
+ # hide optional inputs
161
+ continue
162
+ shape = tuple(init.dims)
163
+ if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
164
+ a = onh.to_array(init)
165
+ tiny_inits[init.name] = (
166
+ str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]"
167
+ )
168
+ else:
169
+ ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})"
170
+ rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
171
+ name_to_ids[init.name] = f"i_{_mkn(init)}"
172
+ edge_label[init.name] = ls
173
+
174
+ for node in nodes:
175
+ if node.op_type == "Constant" and node.output[0] in tiny_inits:
176
+ continue
177
+ color = op_type_colors.get(node.op_type, "#cccccc")
178
+ label = _make_node_label(node, tiny_inits)
179
+ rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')
180
+ name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o})
181
+
182
+ # nodes
183
+ done = set()
184
+ for node in nodes:
185
+ names = list(node.input)
186
+ for i in names:
187
+ if not i or i in tiny_inits:
188
+ continue
189
+ if i not in name_to_ids:
190
+ raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}")
191
+ edge = name_to_ids[i], f"{node.op_type}_{_mkn(node)}"
192
+ if edge in done:
193
+ continue
194
+ done.add(edge)
195
+ lab = edge_label.get(i, "")
196
+ if lab:
197
+ ls = ",".join([f'label="{lab}"'])
198
+ lab = f" [{ls}]"
199
+ rows.append(f" {edge[0]} -> {edge[1]}{lab};")
200
+ if node.op_type in {"Scan", "Loop", "If"}:
201
+ unique = set()
202
+ for att in node.attribute:
203
+ if att.type == onnx.AttributeProto.GRAPH:
204
+ unique |= get_hidden_inputs(att.g)
205
+ for i in unique:
206
+ edge = name_to_ids[i], _mkn(node) # type: ignore[assignment]
207
+ if edge in done:
208
+ continue
209
+ done.add(edge)
210
+ rows.append(f" {edge[0]} -> {edge[1]} [style=dotted];")
211
+
212
+ # outputs
213
+ for out in outputs:
214
+ if not out.name:
215
+ continue
216
+ lab = _make_edge_label(out)
217
+ rows.append(f' O_{_mkn(out)} [label="{out.name}\\n{lab}", fillcolor="#aaaaee"];')
218
+ edge = name_to_ids[out.name], f"O_{_mkn(out)}"
219
+ rows.append(f" {edge[0]} -> {edge[1]};")
220
+
221
+ rows.append("}")
222
+ return "\n".join(rows)