onnx-diagnostic 0.7.11__py3-none-any.whl → 0.7.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +5 -2
  3. onnx_diagnostic/export/dynamic_shapes.py +11 -2
  4. onnx_diagnostic/helpers/helper.py +11 -5
  5. onnx_diagnostic/helpers/log_helper.py +65 -12
  6. onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
  7. onnx_diagnostic/helpers/model_builder_helper.py +1 -0
  8. onnx_diagnostic/helpers/rt_helper.py +55 -37
  9. onnx_diagnostic/helpers/torch_helper.py +31 -7
  10. onnx_diagnostic/reference/torch_evaluator.py +2 -2
  11. onnx_diagnostic/tasks/data/__init__.py +13 -0
  12. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  13. onnx_diagnostic/tasks/image_text_to_text.py +256 -141
  14. onnx_diagnostic/tasks/text_generation.py +15 -0
  15. onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
  16. onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +40 -14
  18. onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
  19. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
  21. onnx_diagnostic/torch_models/hghub/hub_api.py +4 -10
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -4
  24. onnx_diagnostic/torch_models/validate.py +337 -113
  25. onnx_diagnostic/torch_onnx/sbs.py +2 -1
  26. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
  27. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
1
+ import gc
1
2
  import datetime
2
3
  import inspect
3
4
  import os
4
5
  import pprint
5
6
  import sys
6
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7
8
  import time
8
9
  import numpy as np
9
10
  import onnx
@@ -11,7 +12,7 @@ import torch
11
12
  from ..export import CoupleInputsDynamicShapes
12
13
  from ..helpers import max_diff, string_type, string_diff
13
14
  from ..helpers.helper import flatten_object
14
- from ..helpers.rt_helper import make_feeds
15
+ from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch
15
16
  from ..helpers.torch_helper import to_any, torch_deepcopy
16
17
  from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
17
18
  from ..tasks import random_input_kwargs
@@ -112,6 +113,9 @@ def _make_folder_name(
112
113
  device: Optional[Union[str, torch.device]] = None,
113
114
  subfolder: Optional[str] = None,
114
115
  opset: Optional[int] = None,
116
+ drop_inputs: Optional[List[str]] = None,
117
+ same_as_pretrained: bool = False,
118
+ use_pretrained: bool = False,
115
119
  ) -> str:
116
120
  "Creates a filename unique based on the given options."
117
121
  els = [model_id.replace("/", "_")]
@@ -137,6 +141,13 @@ def _make_folder_name(
137
141
  els.append(sdev)
138
142
  if opset is not None:
139
143
  els.append(f"op{opset}")
144
+ if drop_inputs:
145
+ ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
146
+ els.append(f"I-{ii.upper()}")
147
+ if use_pretrained:
148
+ els.append("TRAINED")
149
+ elif same_as_pretrained:
150
+ els.append("SAMESIZE")
140
151
  return "-".join(els)
141
152
 
142
153
 
@@ -233,21 +244,35 @@ def _quiet_or_not_quiet(
233
244
  summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
234
245
  summary[f"{suffix}_warmup"] = warmup
235
246
  summary[f"{suffix}_repeat"] = repeat
236
- for _w in range(max(0, warmup - 1)):
247
+ last_ = None
248
+ end_w = max(0, warmup - 1)
249
+ for _w in range(end_w):
237
250
  t = fct()
238
- summary[f"io_{suffix}_{_w+1}"] = string_type(t, with_shape=True, with_min_max=True)
251
+ _ = string_type(t, with_shape=True, with_min_max=True)
252
+ if _ != last_ or _w == end_w - 1:
253
+ summary[f"io_{suffix}_{_w+1}"] = _
254
+ last_ = _
239
255
  summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
240
256
  times = []
241
257
  for _r in range(repeat):
242
258
  begin = time.perf_counter()
243
259
  t = fct()
244
260
  times.append(time.perf_counter() - begin)
245
- a = np.array(times)
261
+ a = np.array(times, dtype=np.float64)
262
+ a.sort()
263
+ i5 = max(1, a.shape[0] * 5 // 100)
264
+ i2 = max(1, a.shape[0] * 2 // 100)
246
265
  summary[f"time_{suffix}_latency"] = a.mean()
247
266
  summary[f"time_{suffix}_latency_std"] = a.std()
248
267
  summary[f"time_{suffix}_latency_min"] = a.min()
249
- summary[f"time_{suffix}_latency_min"] = a.max()
268
+ summary[f"time_{suffix}_latency_max"] = a.max()
269
+ summary[f"time_{suffix}_latency_098"] = a[-i2]
270
+ summary[f"time_{suffix}_latency_095"] = a[-i5]
271
+ summary[f"time_{suffix}_latency_005"] = a[i5]
272
+ summary[f"time_{suffix}_latency_002"] = a[i2]
250
273
  summary[f"time_{suffix}_n"] = len(a)
274
+ summary[f"time_{suffix}_latency_m98"] = a[i2:-i2].mean()
275
+
251
276
  return res
252
277
 
253
278
 
@@ -264,14 +289,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
264
289
  return new_cfg
265
290
 
266
291
 
267
- def _preprocess_model_id(model_id, subfolder):
292
+ def _preprocess_model_id(
293
+ model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
294
+ ) -> Tuple[str, Optional[str], bool, bool]:
268
295
  if subfolder or "//" not in model_id:
269
- return model_id, subfolder
296
+ return model_id, subfolder, same_as_pretrained, use_pretrained
270
297
  spl = model_id.split("//")
298
+ if spl[-1] == "pretrained":
299
+ return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
271
300
  if spl[-1] in {"transformer", "vae"}:
272
301
  # known subfolder
273
- return "//".join(spl[:-1]), spl[-1]
274
- return model_id, subfolder
302
+ return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
303
+ return model_id, subfolder, same_as_pretrained, use_pretrained
275
304
 
276
305
 
277
306
  def validate_model(
@@ -351,9 +380,10 @@ def validate_model(
351
380
  ``orteval10``, ``ref`` only if `do_run` is true
352
381
  :param repeat: number of time to measure the model
353
382
  :param warmup: warmup the model first
354
- :param inputs2: checks that the second set of inputs is reunning as well,
383
+ :param inputs2: checks that other sets of inputs are running as well,
355
384
  this ensures that the model does support dynamism, the value is used
356
- as an increment to the first set of values (added to dimensions)
385
+ as an increment to the first set of values (added to dimensions),
386
+ or an empty cache for example
357
387
  :param output_names: output names the onnx exporter should use
358
388
  :param ort_logs: increases onnxruntime verbosity when creating the session
359
389
  :return: two dictionaries, one with some metrics,
@@ -383,14 +413,23 @@ def validate_model(
383
413
  :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
384
414
  if ``runtime == 'ref'``,
385
415
  ``orteval10`` increases the verbosity.
416
+
417
+ .. versionchanged:: 0.7.13
418
+ *inputs2* not only means a second set of inputs but many
419
+ such as ``input_empty_cache``
420
+ which refers to a set of inputs using an empty cache.
386
421
  """
387
- model_id, subfolder = _preprocess_model_id(model_id, subfolder)
422
+ validation_begin = time.perf_counter()
423
+ model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
424
+ model_id,
425
+ subfolder,
426
+ same_as_pretrained=same_as_pretrained,
427
+ use_pretrained=use_pretrained,
428
+ )
429
+ time_preprocess_model_id = time.perf_counter() - validation_begin
430
+ default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
388
431
  if isinstance(patch, bool):
389
- patch_kwargs = (
390
- dict(patch_transformers=True, patch_diffusers=True, patch=True)
391
- if patch
392
- else dict(patch=False)
393
- )
432
+ patch_kwargs = default_patch if patch else dict(patch=False)
394
433
  elif isinstance(patch, str):
395
434
  patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
396
435
  else:
@@ -399,11 +438,13 @@ def validate_model(
399
438
  if "patch" not in patch_kwargs:
400
439
  if any(patch_kwargs.values()):
401
440
  patch_kwargs["patch"] = True
441
+ elif len(patch) == 1 and patch.get("patch", False):
442
+ patch_kwargs.update(default_patch)
402
443
 
403
444
  assert not rewrite or patch_kwargs.get("patch", False), (
404
445
  f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
405
446
  f"patch must be True to enable rewriting, "
406
- f"if --no-patch was specified on the command line, --no-rewrite must be added."
447
+ f"if --patch=0 was specified on the command line, rewrites are disabled."
407
448
  )
408
449
  summary = version_summary()
409
450
  summary.update(
@@ -426,6 +467,7 @@ def validate_model(
426
467
  version_exporter=exporter or "",
427
468
  version_runtime=runtime,
428
469
  version_inputs2=inputs2,
470
+ time_preprocess_model_id=time_preprocess_model_id,
429
471
  )
430
472
  )
431
473
  if opset:
@@ -441,6 +483,9 @@ def validate_model(
441
483
  device=device,
442
484
  subfolder=subfolder,
443
485
  opset=opset,
486
+ drop_inputs=drop_inputs,
487
+ use_pretrained=use_pretrained,
488
+ same_as_pretrained=same_as_pretrained,
444
489
  )
445
490
  dump_folder = os.path.join(dump_folder, folder_name)
446
491
  if not os.path.exists(dump_folder):
@@ -473,7 +518,7 @@ def validate_model(
473
518
  mop = model_options or {}
474
519
  data = _quiet_or_not_quiet(
475
520
  quiet,
476
- "create",
521
+ "create_torch_model",
477
522
  summary,
478
523
  None,
479
524
  (
@@ -492,10 +537,9 @@ def validate_model(
492
537
  )
493
538
  ),
494
539
  )
495
- assert not inputs2 or "inputs2" in data, (
496
- f"inputs2 is True but second set is missing in data for "
497
- f"model id {model_id!r}: {sorted(data)}"
498
- )
540
+
541
+ second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
542
+
499
543
  if dump_folder:
500
544
  with open(os.path.join(dump_folder, "model_config.txt"), "w") as f:
501
545
  f.write(f"model_id: {model_id}\n------\n")
@@ -536,6 +580,11 @@ def validate_model(
536
580
  if verbose:
537
581
  print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
538
582
 
583
+ # modelbuilder needs different treatments sometimes, so
584
+ # we mark it for later usage.
585
+ # for example, it has different past_kv ordering than
586
+ # flattened CacheObject
587
+ data["exporter"] = exporter
539
588
  data["input_options"] = iop
540
589
  data["model_options"] = mop
541
590
  data["model_dump_folder"] = dump_folder
@@ -583,16 +632,14 @@ def validate_model(
583
632
  if verbose:
584
633
  print(f"[validate_model] new inputs: {string_type(data['inputs'])}")
585
634
  print(f"[validate_model] new dynamic_hapes: {string_type(data['dynamic_shapes'])}")
586
- if inputs2:
587
- assert (
588
- "inputs2" in data
589
- ), "Cannot test a second set of inputs as it was not defined."
590
- data["inputs2"], _ = filter_inputs(
591
- data["inputs2"],
592
- drop_names=drop_inputs,
593
- model=data["model"],
594
- dynamic_shapes=data["dynamic_shapes"],
595
- )
635
+ if second_input_keys:
636
+ for k in second_input_keys:
637
+ data[k], _ = filter_inputs(
638
+ data[k],
639
+ drop_names=drop_inputs,
640
+ model=data["model"],
641
+ dynamic_shapes=data["dynamic_shapes"],
642
+ )
596
643
 
597
644
  if not empty(dtype):
598
645
  if isinstance(dtype, str):
@@ -602,8 +649,9 @@ def validate_model(
602
649
  data["model"] = to_any(data["model"], dtype) # type: ignore
603
650
  data["inputs"] = to_any(data["inputs"], dtype) # type: ignore
604
651
  summary["model_dtype"] = str(dtype)
605
- if "inputs2" in data:
606
- data["inputs2"] = to_any(data["inputs2"], dtype) # type: ignore
652
+ if second_input_keys:
653
+ for k in second_input_keys:
654
+ data[k] = to_any(data[k], dtype) # type: ignore
607
655
 
608
656
  if not empty(device):
609
657
  if verbose:
@@ -611,11 +659,13 @@ def validate_model(
611
659
  data["model"] = to_any(data["model"], device) # type: ignore
612
660
  data["inputs"] = to_any(data["inputs"], device) # type: ignore
613
661
  summary["model_device"] = str(device)
614
- if "inputs2" in data:
615
- data["inputs2"] = to_any(data["inputs2"], device) # type: ignore
662
+ if second_input_keys:
663
+ for k in second_input_keys:
664
+ data[k] = to_any(data[k], device) # type: ignore
616
665
 
617
666
  for k in ["task", "size", "n_weights"]:
618
667
  summary[f"model_{k.replace('_','')}"] = data[k]
668
+ summary["second_input_keys"] = ",".join(second_input_keys)
619
669
  summary["model_inputs_options"] = str(input_options or "")
620
670
  summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
621
671
  summary["model_shapes"] = string_type(data["dynamic_shapes"])
@@ -642,22 +692,37 @@ def validate_model(
642
692
  print(f"[validate_model] +INPUT {k}={string_type(v, with_shape=True)}")
643
693
  for k, v in data["dynamic_shapes"].items():
644
694
  print(f"[validate_model] +SHAPE {k}={string_type(v)}")
695
+ print(f"[validate_model] second_input_keys={second_input_keys}")
645
696
  print("[validate_model] --")
646
697
 
647
698
  if do_run:
699
+ validation_begin = time.perf_counter()
700
+
648
701
  _validate_do_run_model(
649
702
  data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
650
703
  )
651
- if inputs2:
652
- _validate_do_run_model(
653
- data, summary, "inputs2", "run2", "run_expected2", verbose, 1, 0, quiet
654
- )
704
+ if second_input_keys:
705
+ for k in second_input_keys:
706
+ _validate_do_run_model(
707
+ data,
708
+ summary,
709
+ k,
710
+ f"run2{k[6:]}",
711
+ f"run_expected2{k[6:]}",
712
+ verbose,
713
+ 1,
714
+ 0,
715
+ quiet,
716
+ )
717
+
718
+ summary["time_total_validation_torch"] = time.perf_counter() - validation_begin
655
719
 
656
720
  if exporter:
657
721
  print(
658
722
  f"[validate_model] -- export the model with {exporter!r}, "
659
723
  f"optimization={optimization!r}"
660
724
  )
725
+ exporter_begin = time.perf_counter()
661
726
  if patch_kwargs:
662
727
  if verbose:
663
728
  print(
@@ -700,7 +765,9 @@ def validate_model(
700
765
  dump_folder=dump_folder,
701
766
  output_names=output_names,
702
767
  )
768
+
703
769
  summary.update(summary_export)
770
+ summary["time_total_exporter"] = time.perf_counter() - exporter_begin
704
771
 
705
772
  dump_stats = None
706
773
  if dump_folder:
@@ -741,6 +808,8 @@ def validate_model(
741
808
  data["onnx_filename"] = onnx_filename
742
809
  summary["time_onnx_save"] = duration
743
810
  summary.update(compute_statistics(onnx_filename))
811
+ del epo
812
+
744
813
  if verbose:
745
814
  print(f"[validate_model] dumps statistics in {dump_folder!r}...")
746
815
  dump_stats = os.path.join(dump_folder, f"{folder_name}.stats")
@@ -763,6 +832,20 @@ def validate_model(
763
832
  return summary, data
764
833
 
765
834
  if do_run:
835
+ # Let's move the model to CPU to make sure it frees GPU memory.
836
+ if verbose:
837
+ # It does not really work for the time being and the model
838
+ # gets loaded twice, one by torch, one by onnxruntime
839
+ print("[validation_model] -- delete the model")
840
+ for key in ["model", "onnx_program", "config"]:
841
+ if key in data:
842
+ del data[key]
843
+ if device is not None and "cuda" in str(device).lower():
844
+ torch.cuda.empty_cache()
845
+ gc.collect()
846
+ print("[validation_model] -- done")
847
+
848
+ validation_begin = time.perf_counter()
766
849
  summary_valid, data = validate_onnx_model(
767
850
  data=data,
768
851
  quiet=quiet,
@@ -770,10 +853,11 @@ def validate_model(
770
853
  runtime=runtime,
771
854
  repeat=repeat,
772
855
  warmup=warmup,
773
- inputs2=inputs2,
856
+ second_input_keys=second_input_keys,
774
857
  ort_logs=ort_logs,
775
858
  )
776
859
  summary.update(summary_valid)
860
+ summary["time_total_validation_onnx"] = time.perf_counter() - validation_begin
777
861
 
778
862
  if ortfusiontype and "onnx_filename" in data:
779
863
  assert (
@@ -832,13 +916,17 @@ def validate_model(
832
916
  runtime=runtime,
833
917
  repeat=repeat,
834
918
  warmup=warmup,
835
- inputs2=inputs2,
919
+ second_input_keys=second_input_keys,
836
920
  )
837
921
  summary.update(summary_valid)
838
922
 
923
+ _compute_final_statistics(summary)
924
+ summary["time_total"] = time.perf_counter() - validation_begin
925
+
839
926
  if verbose:
840
927
  print("[validate_model] -- done (final)")
841
928
  if dump_stats:
929
+ # Dumps again the statistics.
842
930
  with open(dump_stats, "w") as f:
843
931
  for k, v in sorted(summary.items()):
844
932
  f.write(f":{k}:{v};\n")
@@ -848,15 +936,24 @@ def validate_model(
848
936
  def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
849
937
  """Computes some statistics on the model itself."""
850
938
  onx = onnx.load(onnx_filename, load_external_data=False)
939
+ cache_functions = {(f.domain, f.name): f for f in onx.functions}
940
+ local_domains = set(f.domain for f in onx.functions)
851
941
 
852
942
  def node_iter(proto):
853
943
  if isinstance(proto, onnx.ModelProto):
854
- yield from node_iter(proto.graph)
855
944
  for f in proto.functions:
856
945
  yield from node_iter(f)
946
+ yield from node_iter(proto.graph)
857
947
  elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
858
948
  for node in proto.node:
859
949
  yield node
950
+
951
+ # Let's inline the function
952
+ key = node.domain, node.op_type
953
+ if key in cache_functions:
954
+ yield from node_iter(cache_functions[key])
955
+
956
+ # Let's continue
860
957
  for att in node.attribute:
861
958
  if att.type == onnx.AttributeProto.GRAPH:
862
959
  yield from node_iter(att.g)
@@ -874,6 +971,11 @@ def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
874
971
  n_nodes += 1
875
972
  if proto.op_type != "Constant":
876
973
  n_nodes_nocst += 1
974
+ if proto.domain in local_domains:
975
+ key = "n_node_local_function"
976
+ if key not in counts:
977
+ counts[key] = 0
978
+ counts[key] += 1
877
979
  else:
878
980
  key = f"n_node_initializer_{proto.data_type}"
879
981
 
@@ -960,6 +1062,26 @@ def _validate_do_run_exported_program(data, summary, verbose, quiet):
960
1062
  )
961
1063
 
962
1064
 
1065
+ _cache_export_times = []
1066
+ _main_export_function = torch.export.export
1067
+
1068
+
1069
+ def _torch_export_export(*args, _export=_main_export_function, **kwargs):
1070
+ begin = time.perf_counter()
1071
+ res = _export(*args, **kwargs)
1072
+ duration = time.perf_counter() - begin
1073
+ _cache_export_times.append(duration)
1074
+ return res
1075
+
1076
+
1077
+ def _restore_torch_export_export(summary):
1078
+ torch.export.export = _main_export_function
1079
+ if _cache_export_times:
1080
+ summary["time_torch_export_export"] = sum(_cache_export_times)
1081
+ summary["time_torch_export_export_n"] = len(_cache_export_times)
1082
+ _cache_export_times.clear()
1083
+
1084
+
963
1085
  def call_exporter(
964
1086
  data: Dict[str, Any],
965
1087
  exporter: str,
@@ -985,6 +1107,9 @@ def call_exporter(
985
1107
  :return: two dictionaries, one with some metrics,
986
1108
  another one with whatever the function produces
987
1109
  """
1110
+ _cache_export_times.clear()
1111
+ torch.export.export = _torch_export_export
1112
+
988
1113
  if exporter == "export" or exporter.startswith("export-"):
989
1114
  # torch export
990
1115
  summary, data = call_torch_export_export(
@@ -995,6 +1120,7 @@ def call_exporter(
995
1120
  optimization=optimization,
996
1121
  do_run=do_run,
997
1122
  )
1123
+ _restore_torch_export_export(summary)
998
1124
  return summary, data
999
1125
  if exporter.startswith("onnx-"):
1000
1126
  # torch export
@@ -1006,6 +1132,7 @@ def call_exporter(
1006
1132
  optimization=optimization,
1007
1133
  output_names=output_names,
1008
1134
  )
1135
+ _restore_torch_export_export(summary)
1009
1136
  return summary, data
1010
1137
  if exporter == "custom" or exporter.startswith("custom"):
1011
1138
  # torch export
@@ -1018,6 +1145,7 @@ def call_exporter(
1018
1145
  dump_folder=dump_folder,
1019
1146
  output_names=output_names,
1020
1147
  )
1148
+ _restore_torch_export_export(summary)
1021
1149
  return summary, data
1022
1150
  if exporter == "modelbuilder":
1023
1151
  # torch export
@@ -1029,6 +1157,7 @@ def call_exporter(
1029
1157
  optimization=optimization,
1030
1158
  output_names=output_names,
1031
1159
  )
1160
+ _restore_torch_export_export(summary)
1032
1161
  return summary, data
1033
1162
  raise NotImplementedError(
1034
1163
  f"export with {exporter!r} and optimization={optimization!r} not implemented yet, "
@@ -1171,7 +1300,7 @@ def validate_onnx_model(
1171
1300
  runtime: str = "onnxruntime",
1172
1301
  repeat: int = 1,
1173
1302
  warmup: int = 0,
1174
- inputs2: int = 1,
1303
+ second_input_keys: Optional[List[str]] = None,
1175
1304
  ort_logs: bool = False,
1176
1305
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1177
1306
  """
@@ -1188,7 +1317,7 @@ def validate_onnx_model(
1188
1317
  :param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
1189
1318
  :param repeat: run that number of times the model
1190
1319
  :param warmup: warmup the model
1191
- :param inputs2: to validate the model on the second input set
1320
+ :param second_input_keys: to validate the model on other input sets
1192
1321
  to make sure the exported model supports dynamism, the value is
1193
1322
  used as an increment added to the first set of inputs (added to dimensions)
1194
1323
  :param ort_logs: triggers the logs for onnxruntime
@@ -1313,16 +1442,24 @@ def validate_onnx_model(
1313
1442
  print(f"[validate_onnx_model] done (ort_session) flavour={flavour!r}")
1314
1443
 
1315
1444
  keys = [("inputs", "run_expected", "")]
1316
- if inputs2:
1317
- keys.append(("inputs2", "run_expected2", "2"))
1445
+ if second_input_keys:
1446
+ keys.extend([(k, f"run_expected2{k[6:]}", f"2{k[6:]}") for k in second_input_keys])
1318
1447
  for k_input, k_expected, suffix in keys:
1319
1448
  # make_feeds
1449
+ assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"
1450
+ assert k_expected in data, f"Unable to find {k_expected!r} in {sorted(data)}"
1320
1451
  if verbose:
1321
1452
  print(f"[validate_onnx_model] -- make_feeds for {k_input!r}...")
1322
1453
  print(
1323
1454
  f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}"
1324
1455
  )
1325
- feeds = make_feeds(sess, data[k_input], use_numpy=True, check_flatten=False)
1456
+ feeds = make_feeds(
1457
+ sess,
1458
+ data[k_input],
1459
+ use_numpy=True,
1460
+ check_flatten=False,
1461
+ is_modelbuilder=data["exporter"] == "modelbuilder",
1462
+ )
1326
1463
  if verbose:
1327
1464
  print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
1328
1465
  summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True)
@@ -1342,6 +1479,13 @@ def validate_onnx_model(
1342
1479
  repeat=repeat,
1343
1480
  warmup=warmup,
1344
1481
  )
1482
+ # NOTE: modelbuilder has different order on past_kv outputs
1483
+ if data["exporter"] == "modelbuilder":
1484
+ logits = got[:1]
1485
+ past_key_values = got[1:]
1486
+ reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values)
1487
+ got = logits + reorder_past_key_values
1488
+
1345
1489
  if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
1346
1490
  return summary, data
1347
1491
 
@@ -1382,7 +1526,7 @@ def call_torch_export_onnx(
1382
1526
  :return: two dictionaries, one with some metrics,
1383
1527
  another one with whatever the function produces
1384
1528
  """
1385
- available = {None, "", "ir", "os_ort"}
1529
+ available = {None, "", "ir", "os_ort", "ir+default"}
1386
1530
  assert (
1387
1531
  optimization in available
1388
1532
  ), f"unexpected value for optimization={optimization}, available={available}"
@@ -1472,11 +1616,31 @@ def call_torch_export_onnx(
1472
1616
  print(epo)
1473
1617
  print("[call_torch_export_onnx] -- End of ONNXProgram")
1474
1618
 
1475
- if optimization in {"ir", "os_ort"}:
1619
+ if optimization in {"ir", "os_ort", "ir+default"}:
1476
1620
  if verbose:
1477
1621
  print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
1478
1622
  if optimization == "ir":
1479
1623
  label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1624
+ elif optimization == "ir+default":
1625
+ import onnxscript
1626
+ from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
1627
+
1628
+ def _ir_default_opt(epo):
1629
+ onnxscript.optimizer.optimize_ir(epo.model)
1630
+ onx = epo.model_proto
1631
+ # not very efficient
1632
+ gr = GraphBuilder(
1633
+ onx,
1634
+ infer_shapes_options=True,
1635
+ optimization_options=OptimizationOptions(patterns="default"),
1636
+ )
1637
+ cont = gr.to_onnx(large_model=True)
1638
+ epo.model = cont.to_ir()
1639
+
1640
+ label, f_optim = "export_onnx_opt_ir_default", (
1641
+ lambda epo=epo: _ir_default_opt(epo)
1642
+ )
1643
+
1480
1644
  else:
1481
1645
  import onnxscript
1482
1646
  import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -1567,6 +1731,98 @@ def call_torch_export_model_builder(
1567
1731
  return summary, data
1568
1732
 
1569
1733
 
1734
+ def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, Any]:
1735
+ """
1736
+ Processes statistics coming from the exporters.
1737
+ It takes a sequence of dictionaries (like a data frame)
1738
+ and extracts some metrics.
1739
+ """
1740
+
1741
+ def _simplify(p):
1742
+ for s in [
1743
+ "remove_unused",
1744
+ "constant_folding",
1745
+ "remove_identity",
1746
+ "remove_duplicated_initializer",
1747
+ "dynamic_dimension_naming",
1748
+ "inline",
1749
+ "check",
1750
+ "build_graph_for_pattern",
1751
+ "pattern_optimization",
1752
+ "topological_sort",
1753
+ ]:
1754
+ if s in p or s.replace("_", "-") in p:
1755
+ return s
1756
+ if p.startswith(("apply_", "match_")):
1757
+ return p
1758
+ return "other"
1759
+
1760
+ def _add(d, a, v, use_max=False):
1761
+ if v:
1762
+ if a not in d:
1763
+ d[a] = v
1764
+ elif use_max:
1765
+ d[a] = max(d[a], v)
1766
+ else:
1767
+ d[a] += v
1768
+
1769
+ counts: Dict[str, Any] = {}
1770
+ applied_pattern_time: Dict[str, Any] = {}
1771
+ applied_pattern_n: Dict[str, Any] = {}
1772
+ matching_pattern_time: Dict[str, Any] = {}
1773
+ matching_pattern_n: Dict[str, Any] = {}
1774
+
1775
+ for obs in data:
1776
+ pattern = _simplify(obs["pattern"])
1777
+ _add(counts, "opt_nodes_added", obs.get("added", 0))
1778
+ _add(counts, "opt_nodes_removed", obs.get("removed", 0))
1779
+ _add(counts, "opt_time_steps", obs.get("time_in", 0))
1780
+ _add(counts, "opt_n_steps", 1)
1781
+ _add(
1782
+ counts,
1783
+ "opt_n_iteration",
1784
+ max(counts.get("opt_n_iteration", 0), obs.get("iteration", 0)),
1785
+ use_max=True,
1786
+ )
1787
+
1788
+ if pattern.startswith("apply_"):
1789
+ _add(counts, "opt_n_applied_patterns", 1)
1790
+ _add(counts, "opt_time_applied_patterns", obs.get("time_in", 0))
1791
+ _add(applied_pattern_time, pattern, obs.get("time_in", 0))
1792
+ _add(applied_pattern_n, pattern, 1)
1793
+ elif pattern.startswith("match_"):
1794
+ _add(counts, "opt_n_matching_patterns", 1)
1795
+ _add(counts, "opt_time_matching_patterns", obs.get("time_in", 0))
1796
+ _add(matching_pattern_time, pattern, obs.get("time_in", 0))
1797
+ _add(matching_pattern_n, pattern, 1)
1798
+ else:
1799
+ _add(counts, f"opt_time_{pattern}", obs.get("time_in", 0))
1800
+ _add(counts, f"opt_n_{pattern}", 1)
1801
+ _add(counts, f"opt_nodes_added_{pattern}", obs.get("added", 0))
1802
+ _add(counts, f"opt_nodes_removed_{pattern}", obs.get("removed", 0))
1803
+
1804
+ if applied_pattern_time:
1805
+ longest = max((v, k) for k, v in applied_pattern_time.items())
1806
+ counts["opt_top_time_applied_pattern"], counts["opt_top_time_applied_pattern_arg"] = (
1807
+ longest
1808
+ )
1809
+ longest = max((v, k) for k, v in applied_pattern_n.items())
1810
+ counts["opt_top_n_applied_pattern"], counts["opt_top_n_applied_pattern_arg"] = longest
1811
+
1812
+ if matching_pattern_time:
1813
+ longest = max((v, k) for k, v in matching_pattern_time.items())
1814
+ (
1815
+ counts["opt_top_time_matching_pattern"],
1816
+ counts["opt_top_time_matching_pattern_arg"],
1817
+ ) = longest
1818
+ longest = max((v, k) for k, v in matching_pattern_n.items())
1819
+ counts["opt_top_n_matching_pattern"], counts["opt_top_n_matching_pattern_arg"] = (
1820
+ longest
1821
+ )
1822
+ counts["onnx_opt_optimized"] = 1
1823
+ return counts
1824
+
1825
+
1570
1826
  def call_torch_export_custom(
1571
1827
  data: Dict[str, Any],
1572
1828
  exporter: str,
@@ -1619,6 +1875,8 @@ def call_torch_export_custom(
1619
1875
  "custom-nostrict-noinline",
1620
1876
  "custom-nostrict-default-noinline",
1621
1877
  "custom-nostrict-all-noinline",
1878
+ "custom-dec",
1879
+ "custom-decall",
1622
1880
  }
1623
1881
  assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
1624
1882
  assert "model" in data, f"model is missing from data: {sorted(data)}"
@@ -1655,7 +1913,9 @@ def call_torch_export_custom(
1655
1913
  export_options = ExportOptions(
1656
1914
  strict=strict,
1657
1915
  decomposition_table=(
1658
- "default" if "-default" in exporter else ("all" if "-all" in exporter else None)
1916
+ "default"
1917
+ if ("-default" in exporter or "-dec" in exporter)
1918
+ else ("all" if ("-all" in exporter or "-decall" in exporter) else None)
1659
1919
  ),
1660
1920
  save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
1661
1921
  )
@@ -1696,67 +1956,10 @@ def call_torch_export_custom(
1696
1956
  if "ERR_export_onnx_c" in summary:
1697
1957
  return summary, data
1698
1958
 
1699
- new_stat = {}
1959
+ new_stat: Dict[str, Any] = {k: v for k, v in opt_stats.items() if k.startswith("time_")}
1960
+ new_stat.update({k[5:]: v for k, v in opt_stats.items() if k.startswith("stat_time_")})
1700
1961
  if "optimization" in opt_stats:
1701
- added, removed, time_in = 0, 0, 0.0
1702
- max_iter = 0
1703
- applied = {}
1704
- matched = set()
1705
- n_applied = 0
1706
- by_pattern = {}
1707
- by_pattern_n = {}
1708
- by_iter = {}
1709
- cst_added, cst_removed, cst_time_in = 0, 0, 0.0
1710
-
1711
- for obs in opt_stats["optimization"]:
1712
- pattern = obs["pattern"]
1713
- if pattern == "constant_folding":
1714
- cst_added += obs.get("added", 0)
1715
- cst_removed += obs.get("removed", 0)
1716
- cst_time_in += obs.get("time_in", 0)
1717
- if pattern not in by_pattern:
1718
- by_pattern[pattern] = 0
1719
- by_pattern_n[pattern] = 0
1720
- by_iter[pattern] = 0
1721
- time_in += obs.get("time_in", 0)
1722
- added += obs.get("added", 0)
1723
- removed += obs.get("removed", 0)
1724
- max_iter = max(max_iter, obs.get("iteration", 0))
1725
- by_pattern[pattern] += obs.get("time_in", 0)
1726
- by_pattern_n[pattern] += obs.get("added", 0) - obs.get("removed", 0)
1727
- if not pattern.startswith("match"):
1728
- by_iter[pattern] = max(by_iter[pattern], obs.get("iteration", 0))
1729
- p = obs["pattern"]
1730
- if p.startswith("match_"):
1731
- matched.add(p)
1732
- elif p.startswith("apply_"):
1733
- key = f"op_opt_{p}"
1734
- key2 = f"op_opt_maxiter_{p}"
1735
- if key not in applied:
1736
- applied[key] = 1
1737
- applied[key2] = obs["iteration"]
1738
- else:
1739
- applied[key] += 1
1740
- applied[key2] = max(obs["iteration"], applied[key2])
1741
- n_applied += 1
1742
-
1743
- new_stat.update(
1744
- dict(
1745
- onnx_opt_optimized=1,
1746
- op_opt_all_time_in=time_in,
1747
- op_opt_all_added=added,
1748
- op_opt_all_removed=removed,
1749
- op_opt_max_iter=max_iter,
1750
- op_opt_unique_matched=len(matched),
1751
- op_opt_unique_applied=len(applied),
1752
- op_opt_n_applied=n_applied,
1753
- time_export_optimization=time_in,
1754
- op_opt_export_optimization=time_in,
1755
- op_opt_cst_time_in=cst_time_in,
1756
- op_opt_cst_added=cst_added,
1757
- op_opt_cst_removed=cst_removed,
1758
- )
1759
- )
1962
+ new_stat.update(process_statistics(opt_stats["optimization"]))
1760
1963
 
1761
1964
  summary.update(new_stat)
1762
1965
  assert epo is not None, "no onnx export was found"
@@ -1875,3 +2078,24 @@ def run_ort_fusion(
1875
2078
  f"opt_ort_{model_type}_duration": duration,
1876
2079
  f"opt_ort_{model_type}_duration_save": d,
1877
2080
  }, {f"opt_ort_{model_type}": output_path}
2081
+
2082
+
2083
+ def _compute_final_statistics(summary: Dict[str, Any]):
2084
+ """
2085
+ Updates inline the list of statistics. It adds:
2086
+
2087
+ - speedup
2088
+ """
2089
+ stats = {}
2090
+ if (
2091
+ "time_run_latency" in summary
2092
+ and "time_run_onnx_ort_latency" in summary
2093
+ and summary["time_run_onnx_ort_latency"] > 0
2094
+ ):
2095
+ stats["stat_estimated_speedup_ort"] = (
2096
+ summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
2097
+ )
2098
+ stats["stat_estimated_speedup_ort_m98"] = (
2099
+ summary["time_run_latency_m98"] / summary["time_run_onnx_ort_latency_m98"]
2100
+ )
2101
+ summary.update(stats)