onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -236,7 +236,7 @@ def code_sample(
236
236
  )
237
237
  )
238
238
  """
239
- model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
239
+ model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
240
240
  model_id,
241
241
  subfolder,
242
242
  same_as_pretrained=same_as_pretrained,
@@ -256,6 +256,7 @@ def code_sample(
256
256
  model_kwargs=mop,
257
257
  subfolder=subfolder,
258
258
  add_second_input=False,
259
+ submodule=submodule,
259
260
  )
260
261
  if drop_inputs:
261
262
  update = {}
@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
26
26
 
27
27
 
28
28
  def _preprocess_model_id(
29
- model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
30
- ) -> Tuple[str, Optional[str], bool, bool]:
29
+ model_id: str,
30
+ subfolder: Optional[str],
31
+ same_as_pretrained: bool,
32
+ use_pretrained: bool,
33
+ submodule: Optional[str] = None,
34
+ ) -> Tuple[str, Optional[str], bool, bool, Optional[str]]:
35
+ if "::" in model_id:
36
+ assert (
37
+ not submodule
38
+ ), f"submodule={submodule!r} cannot be defined in model_id={model_id!r} as well"
39
+ model_id, submodule = model_id.split("::", maxsplit=1)
31
40
  if subfolder or "//" not in model_id:
32
- return model_id, subfolder, same_as_pretrained, use_pretrained
41
+ return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
33
42
  spl = model_id.split("//")
34
43
  if spl[-1] == "pretrained":
35
- return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
44
+ return _preprocess_model_id("//".join(spl[:-1]), "", True, True, submodule)
36
45
  if spl[-1] in {"transformer", "vae"}:
37
46
  # known subfolder
38
- return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
39
- return model_id, subfolder, same_as_pretrained, use_pretrained
47
+ return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained, submodule
48
+ return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
40
49
 
41
50
 
42
51
  def get_untrained_model_with_inputs(
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
54
63
  subfolder: Optional[str] = None,
55
64
  use_only_preinstalled: bool = False,
56
65
  config_reduction: Optional[Callable[[Any, str], Dict]] = None,
66
+ submodule: Optional[str] = None,
57
67
  ) -> Dict[str, Any]:
58
68
  """
59
69
  Gets a non initialized model similar to the original model
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
82
92
  <onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
83
93
  this function takes a configuration and a task (string)
84
94
  as arguments
95
+ :param submodule: use a submodule instead of the main model
85
96
  :return: dictionary with a model, inputs, dynamic shapes, and the configuration,
86
97
  some necessary rewriting as well
87
98
 
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
108
119
  f"model_id={model_id!r}, preinstalled model is only available "
109
120
  f"if use_only_preinstalled is False."
110
121
  )
111
- model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
122
+ model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
112
123
  model_id,
113
124
  subfolder,
114
125
  same_as_pretrained=same_as_pretrained,
115
126
  use_pretrained=use_pretrained,
127
+ submodule=submodule,
116
128
  )
117
129
  if verbose:
118
130
  print(
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
147
159
  if verbose:
148
160
  print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
149
161
  print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
162
+ if submodule:
163
+ print(f"[get_untrained_model_with_inputs] submodule={submodule!r}")
150
164
  if task is None:
151
165
  task = task_from_arch(arch, model_id=model_id, subfolder=subfolder)
152
166
  if verbose:
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
357
371
  if diff_config is not None:
358
372
  res["dump_info"] = dict(config_diff=diff_config)
359
373
 
374
+ if submodule:
375
+ path = submodule.split("::") if "::" in submodule else [submodule]
376
+ for p in path:
377
+ assert hasattr(model, p), (
378
+ f"Unable to find submodule {p!r} in in class {type(model)}, "
379
+ f"submodule={submodule!r}, possible candidates: "
380
+ f"{[k for k in dir(model) if isinstance(getattr(model, k), torch.nn.Module)]}"
381
+ )
382
+ model = getattr(model, p)
383
+
384
+ if verbose:
385
+ print(f"[get_untrained_model_with_inputs] model class={model.__class__.__name__!r}")
386
+
360
387
  sizes = compute_model_size(model)
361
388
  res["model"] = model
362
389
  res["configuration"] = config
@@ -349,13 +349,15 @@ def _prepare_validation(
349
349
  verbose,
350
350
  output_names,
351
351
  dump_folder,
352
+ submodule,
352
353
  ):
353
354
  main_validation_begin = time.perf_counter()
354
- model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
355
+ model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
355
356
  model_id,
356
357
  subfolder,
357
358
  same_as_pretrained=same_as_pretrained,
358
359
  use_pretrained=use_pretrained,
360
+ submodule=submodule,
359
361
  )
360
362
  time_preprocess_model_id = time.perf_counter() - main_validation_begin
361
363
  patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
@@ -364,6 +366,7 @@ def _prepare_validation(
364
366
  summary.update(
365
367
  dict(
366
368
  version_model_id=model_id,
369
+ version_submodule=submodule,
367
370
  version_do_run=str(do_run),
368
371
  version_dtype=str(dtype or ""),
369
372
  version_device=str(device or ""),
@@ -444,6 +447,7 @@ def _prepare_validation(
444
447
  dump_folder,
445
448
  folder_name,
446
449
  patch_kwargs,
450
+ submodule,
447
451
  )
448
452
 
449
453
 
@@ -460,6 +464,7 @@ def _get_untrained_model_with_inputs(
460
464
  inputs2,
461
465
  quiet,
462
466
  dump_folder,
467
+ submodule,
463
468
  ):
464
469
  iop = input_options or {}
465
470
  mop = model_options or {}
@@ -480,6 +485,7 @@ def _get_untrained_model_with_inputs(
480
485
  model_kwargs=mop,
481
486
  subfolder=sub,
482
487
  add_second_input=i2,
488
+ submodule=submodule,
483
489
  )
484
490
  )
485
491
  ),
@@ -671,7 +677,16 @@ def _call_exporter(
671
677
  do_run,
672
678
  output_names,
673
679
  exporter_options,
680
+ save_ep,
674
681
  ):
682
+ if save_ep and dump_folder:
683
+ for name in data:
684
+ if name.startswith("inputs"):
685
+ if verbose:
686
+ print(f"[validate_model] -- dump {name!r}")
687
+ filename = os.path.join(dump_folder, f"{save_ep}.{name}.pt")
688
+ torch.save(data[name], filename)
689
+
675
690
  if exporter:
676
691
  expop = exporter_options or {}
677
692
  if verbose:
@@ -711,6 +726,7 @@ def _call_exporter(
711
726
  dump_folder=dump_folder,
712
727
  output_names=output_names,
713
728
  exporter_options=expop,
729
+ save_ep=save_ep,
714
730
  )
715
731
  else:
716
732
  data["inputs_export"] = data["inputs"]
@@ -831,6 +847,8 @@ def validate_model(
831
847
  output_names: Optional[List[str]] = None,
832
848
  ort_logs: bool = False,
833
849
  quiet_input_sets: Optional[Set[str]] = None,
850
+ save_ep: Optional[str] = None,
851
+ submodule: Optional[str] = None,
834
852
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
835
853
  """
836
854
  Validates a model.
@@ -889,6 +907,9 @@ def validate_model(
889
907
  :param ort_logs: increases onnxruntime verbosity when creating the session
890
908
  :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
891
909
  even if quiet is False
910
+ :param save_ep: if not empty, this can be used to save the input sets and
911
+ the exported program
912
+ :param submodule: to test not the model but a submodule of this model
892
913
  :return: two dictionaries, one with some metrics,
893
914
  another one with whatever the function produces
894
915
 
@@ -952,6 +973,8 @@ def validate_model(
952
973
  subfolder=subfolder,
953
974
  use_pretrained=use_pretrained,
954
975
  same_as_pretrained=same_as_pretrained,
976
+ save_ep=save_ep,
977
+ submodule=submodule,
955
978
  )
956
979
  if dump_folder:
957
980
  with open(dump_stats, "w") as f:
@@ -1038,6 +1061,8 @@ def _validate_model_step1(
1038
1061
  subfolder,
1039
1062
  use_pretrained,
1040
1063
  same_as_pretrained,
1064
+ save_ep,
1065
+ submodule,
1041
1066
  ):
1042
1067
  assert not do_same or do_run, (
1043
1068
  f"Discrepancies cannot be measured if the model is not run, "
@@ -1052,6 +1077,7 @@ def _validate_model_step1(
1052
1077
  dump_folder,
1053
1078
  folder_name,
1054
1079
  patch_kwargs,
1080
+ submodule,
1055
1081
  ) = _prepare_validation(
1056
1082
  model_id=model_id,
1057
1083
  subfolder=subfolder,
@@ -1078,6 +1104,7 @@ def _validate_model_step1(
1078
1104
  verbose=verbose,
1079
1105
  output_names=output_names,
1080
1106
  dump_folder=dump_folder,
1107
+ submodule=submodule,
1081
1108
  )
1082
1109
 
1083
1110
  data, iop, mop = _get_untrained_model_with_inputs(
@@ -1093,6 +1120,7 @@ def _validate_model_step1(
1093
1120
  inputs2=inputs2,
1094
1121
  quiet=quiet,
1095
1122
  dump_folder=dump_folder,
1123
+ submodule=submodule,
1096
1124
  )
1097
1125
 
1098
1126
  second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
@@ -1153,6 +1181,7 @@ def _validate_model_step1(
1153
1181
  do_run=do_run,
1154
1182
  output_names=output_names,
1155
1183
  exporter_options=exporter_options,
1184
+ save_ep=save_ep,
1156
1185
  )
1157
1186
 
1158
1187
  cont, dump_stats = _dump_onnx_model(
@@ -1426,6 +1455,7 @@ def call_exporter(
1426
1455
  dump_folder: Optional[str] = None,
1427
1456
  output_names: Optional[List[str]] = None,
1428
1457
  exporter_options: Optional[Dict[str, Any]] = None,
1458
+ save_ep: Optional[str] = None,
1429
1459
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
1430
1460
  """
1431
1461
  Calls an exporter on a model;
@@ -1440,6 +1470,7 @@ def call_exporter(
1440
1470
  :param dump_folder: to dump additional information
1441
1471
  :param output_names: list of output names to use with the onnx exporter
1442
1472
  :param exporter_options: exporter options
1473
+ :param save_ep: saves the exported program
1443
1474
  :return: two dictionaries, one with some metrics,
1444
1475
  another one with whatever the function produces
1445
1476
  """
@@ -1456,6 +1487,8 @@ def call_exporter(
1456
1487
  optimization=optimization,
1457
1488
  do_run=do_run,
1458
1489
  exporter_options=exporter_options,
1490
+ save_ep=save_ep,
1491
+ dump_folder=dump_folder,
1459
1492
  )
1460
1493
  _restore_torch_export_export(summary)
1461
1494
  return summary, data
@@ -1469,6 +1502,8 @@ def call_exporter(
1469
1502
  optimization=optimization,
1470
1503
  output_names=output_names,
1471
1504
  exporter_options=exporter_options,
1505
+ dump_folder=dump_folder,
1506
+ save_ep=save_ep,
1472
1507
  )
1473
1508
  _restore_torch_export_export(summary)
1474
1509
  return summary, data
@@ -1483,6 +1518,7 @@ def call_exporter(
1483
1518
  dump_folder=dump_folder,
1484
1519
  output_names=output_names,
1485
1520
  exporter_options=exporter_options,
1521
+ save_ep=save_ep,
1486
1522
  )
1487
1523
  _restore_torch_export_export(summary)
1488
1524
  return summary, data
@@ -1516,6 +1552,8 @@ def call_torch_export_export(
1516
1552
  optimization: Optional[str] = None,
1517
1553
  do_run: bool = False,
1518
1554
  exporter_options: Optional[Dict[str, Any]] = None,
1555
+ dump_folder: Optional[str] = None,
1556
+ save_ep: Optional[str] = None,
1519
1557
  ):
1520
1558
  """
1521
1559
  Exports a model with :func:`torch.export.export`.
@@ -1529,6 +1567,8 @@ def call_torch_export_export(
1529
1567
  :param optimization: optimization to do
1530
1568
  :param do_run: runs and compute discrepancies
1531
1569
  :param exporter_options: additional options given to the exporter
1570
+ :param dump_folder: folder where to dump the exported program
1571
+ :param save_ep: to save the exported program
1532
1572
  :return: two dictionaries, one with some metrics,
1533
1573
  another one with whatever the function produces
1534
1574
  """
@@ -1604,6 +1644,12 @@ def call_torch_export_export(
1604
1644
  print(ep)
1605
1645
  print("[call_torch_export_export] -- End of ExportedProgram")
1606
1646
 
1647
+ if dump_folder and save_ep:
1648
+ fname = f"{save_ep}.pt2"
1649
+ if verbose:
1650
+ print(f"[call_torch_export_export] -- save the exported program in {fname!r}")
1651
+ torch.export.save(ep, os.path.join(dump_folder, fname))
1652
+
1607
1653
  if do_run:
1608
1654
  # We check for discrepancies.
1609
1655
  if verbose:
@@ -1880,6 +1926,8 @@ def call_torch_export_onnx(
1880
1926
  optimization: Optional[str] = None,
1881
1927
  output_names: Optional[List[str]] = None,
1882
1928
  exporter_options: Optional[Dict[str, Any]] = None,
1929
+ dump_folder: Optional[str] = None,
1930
+ save_ep: Optional[str] = None,
1883
1931
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1884
1932
  """
1885
1933
  Exports a model into onnx.
@@ -1893,6 +1941,8 @@ def call_torch_export_onnx(
1893
1941
  :param optimization: optimization to do
1894
1942
  :param output_names: output names to use
1895
1943
  :param exporter_options: additional options to give the exporter
1944
+ :param dump_folder: to know where to dump the exported program
1945
+ :param save_ep: to save the exported program
1896
1946
  :return: two dictionaries, one with some metrics,
1897
1947
  another one with whatever the function produces
1898
1948
  """
@@ -1986,6 +2036,12 @@ def call_torch_export_onnx(
1986
2036
  return summary, data
1987
2037
 
1988
2038
  assert epo is not None, "no onnx export was found"
2039
+ if dump_folder and save_ep:
2040
+ fname = f"{save_ep}.pt2"
2041
+ if verbose:
2042
+ print(f"[call_torch_export_export] -- save the exported program in {fname!r}")
2043
+ torch.export.save(epo.exported_program, os.path.join(dump_folder, fname))
2044
+
1989
2045
  if verbose:
1990
2046
  print("[call_torch_export_onnx] done (export)")
1991
2047
  data["onnx_program"] = epo
@@ -2219,6 +2275,7 @@ def call_torch_export_custom(
2219
2275
  dump_folder: Optional[str] = None,
2220
2276
  output_names: Optional[List[str]] = None,
2221
2277
  exporter_options: Optional[Dict[str, Any]] = None,
2278
+ save_ep: Optional[str] = None,
2222
2279
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
2223
2280
  """
2224
2281
  Exports a model into onnx.
@@ -2233,6 +2290,7 @@ def call_torch_export_custom(
2233
2290
  :param dump_folder: to store additional information
2234
2291
  :param output_names: list of output names to use
2235
2292
  :param exporter_options: additional exporter options
2293
+ :param save_ep: to save the exported program
2236
2294
  :return: two dictionaries, one with some metrics,
2237
2295
  another one with whatever the function produces
2238
2296
  """
@@ -2345,7 +2403,11 @@ def call_torch_export_custom(
2345
2403
  export_options = ExportOptions(
2346
2404
  strict=strict,
2347
2405
  decomposition_table=decomposition_table,
2348
- save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
2406
+ save_ep=(
2407
+ (os.path.join(dump_folder, f"{exporter}.ep"), 2**35 if save_ep else 2**18)
2408
+ if dump_folder
2409
+ else None
2410
+ ),
2349
2411
  **exporter_options,
2350
2412
  )
2351
2413
  options = OptimizationOptions(patterns=optimization) if optimization else None
@@ -4,6 +4,7 @@ import onnx
4
4
  import torch
5
5
  from ..api import TensorLike
6
6
  from ..helpers import string_type
7
+ from ..helpers.onnx_helper import get_hidden_inputs
7
8
 
8
9
 
9
10
  class RuntimeValueKind(enum.IntEnum):
@@ -151,30 +152,6 @@ class RuntimeValue:
151
152
  return self.kind == RuntimeValueKind.INITIALIZER
152
153
 
153
154
 
154
- def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
155
- """
156
- Returns the hidden inputs (inputs coming from an upper context)
157
- used by a subgraph.
158
- """
159
- hidden = set()
160
- memo = (
161
- set(i.name for i in graph.initializer)
162
- | set(i.name for i in graph.sparse_initializer)
163
- | set(i.name for i in graph.input)
164
- )
165
- for node in graph.node:
166
- for i in node.input:
167
- if i not in memo:
168
- hidden.add(i)
169
- for att in node.attribute:
170
- if att.type == onnx.AttributeProto.GRAPH and att.g:
171
- hid = get_hidden_inputs(att.g)
172
- less = set(h for h in hid if h not in memo)
173
- hidden |= less
174
- memo |= set(node.output)
175
- return hidden
176
-
177
-
178
155
  def set_is_shape(
179
156
  node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None
180
157
  ) -> List[str]: