onnx-diagnostic 0.7.10__py3-none-any.whl → 0.7.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,17 +3,15 @@ import inspect
3
3
  import os
4
4
  import pprint
5
5
  import sys
6
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7
7
  import time
8
8
  import numpy as np
9
9
  import onnx
10
- import onnxscript
11
- import onnxscript.rewriter.ort_fusions as ort_fusions
12
10
  import torch
13
11
  from ..export import CoupleInputsDynamicShapes
14
12
  from ..helpers import max_diff, string_type, string_diff
15
13
  from ..helpers.helper import flatten_object
16
- from ..helpers.rt_helper import make_feeds
14
+ from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch
17
15
  from ..helpers.torch_helper import to_any, torch_deepcopy
18
16
  from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
19
17
  from ..tasks import random_input_kwargs
@@ -113,6 +111,8 @@ def _make_folder_name(
113
111
  dtype: Optional[Union[str, torch.dtype]] = None,
114
112
  device: Optional[Union[str, torch.device]] = None,
115
113
  subfolder: Optional[str] = None,
114
+ opset: Optional[int] = None,
115
+ drop_inputs: Optional[List[str]] = None,
116
116
  ) -> str:
117
117
  "Creates a filename unique based on the given options."
118
118
  els = [model_id.replace("/", "_")]
@@ -136,6 +136,11 @@ def _make_folder_name(
136
136
  else:
137
137
  raise AssertionError(f"unexpected value for device={device}, sdev={sdev!r}")
138
138
  els.append(sdev)
139
+ if opset is not None:
140
+ els.append(f"op{opset}")
141
+ if drop_inputs:
142
+ ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
143
+ els.append(f"I-{ii.upper()}")
139
144
  return "-".join(els)
140
145
 
141
146
 
@@ -246,6 +251,7 @@ def _quiet_or_not_quiet(
246
251
  summary[f"time_{suffix}_latency_std"] = a.std()
247
252
  summary[f"time_{suffix}_latency_min"] = a.min()
248
253
  summary[f"time_{suffix}_latency_min"] = a.max()
254
+ summary[f"time_{suffix}_n"] = len(a)
249
255
  return res
250
256
 
251
257
 
@@ -262,6 +268,20 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
262
268
  return new_cfg
263
269
 
264
270
 
271
+ def _preprocess_model_id(
272
+ model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
273
+ ) -> Tuple[str, Optional[str], bool, bool]:
274
+ if subfolder or "//" not in model_id:
275
+ return model_id, subfolder, same_as_pretrained, use_pretrained
276
+ spl = model_id.split("//")
277
+ if spl[-1] == "pretrained":
278
+ return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
279
+ if spl[-1] in {"transformer", "vae"}:
280
+ # known subfolder
281
+ return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
282
+ return model_id, subfolder, same_as_pretrained, use_pretrained
283
+
284
+
265
285
  def validate_model(
266
286
  model_id: str,
267
287
  task: Optional[str] = None,
@@ -290,6 +310,7 @@ def validate_model(
290
310
  warmup: int = 0,
291
311
  inputs2: int = 1,
292
312
  output_names: Optional[List[str]] = None,
313
+ ort_logs: bool = False,
293
314
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
294
315
  """
295
316
  Validates a model.
@@ -334,13 +355,15 @@ def validate_model(
334
355
  :param subfolder: version or subfolders to uses when retrieving a model id
335
356
  :param opset: onnx opset to use for the conversion
336
357
  :param runtime: onnx runtime to use to check about discrepancies,
337
- only if `do_run` is true
358
+ possible values ``onnxruntime``, ``torch``, ``orteval``,
359
+ ``orteval10``, ``ref`` only if `do_run` is true
338
360
  :param repeat: number of time to measure the model
339
361
  :param warmup: warmup the model first
340
362
  :param inputs2: checks that the second set of inputs is reunning as well,
341
363
  this ensures that the model does support dynamism, the value is used
342
364
  as an increment to the first set of values (added to dimensions)
343
365
  :param output_names: output names the onnx exporter should use
366
+ :param ort_logs: increases onnxruntime verbosity when creating the session
344
367
  :return: two dictionaries, one with some metrics,
345
368
  another one with whatever the function produces
346
369
 
@@ -361,14 +384,23 @@ def validate_model(
361
384
 
362
385
  The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
363
386
  exported model returns the same outputs as the original one, otherwise,
364
- :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
387
+ :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
388
+ if ``runtime == 'torch'`` or
389
+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
390
+ if ``runtime == 'orteval'`` or
391
+ :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
392
+ if ``runtime == 'ref'``,
393
+ ``orteval10`` increases the verbosity.
365
394
  """
395
+ model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
396
+ model_id,
397
+ subfolder,
398
+ same_as_pretrained=same_as_pretrained,
399
+ use_pretrained=use_pretrained,
400
+ )
401
+ default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
366
402
  if isinstance(patch, bool):
367
- patch_kwargs = (
368
- dict(patch_transformers=True, patch_diffusers=True, patch=True)
369
- if patch
370
- else dict(patch=False)
371
- )
403
+ patch_kwargs = default_patch if patch else dict(patch=False)
372
404
  elif isinstance(patch, str):
373
405
  patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
374
406
  else:
@@ -377,11 +409,13 @@ def validate_model(
377
409
  if "patch" not in patch_kwargs:
378
410
  if any(patch_kwargs.values()):
379
411
  patch_kwargs["patch"] = True
412
+ elif len(patch) == 1 and patch.get("patch", False):
413
+ patch_kwargs.update(default_patch)
380
414
 
381
415
  assert not rewrite or patch_kwargs.get("patch", False), (
382
416
  f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
383
417
  f"patch must be True to enable rewriting, "
384
- f"if --no-patch was specified on the command line, --no-rewrite must be added."
418
+ f"if --patch=0 was specified on the command line, rewrites are disabled."
385
419
  )
386
420
  summary = version_summary()
387
421
  summary.update(
@@ -412,7 +446,14 @@ def validate_model(
412
446
  folder_name = None
413
447
  if dump_folder:
414
448
  folder_name = _make_folder_name(
415
- model_id, exporter, optimization, dtype=dtype, device=device, subfolder=subfolder
449
+ model_id,
450
+ exporter,
451
+ optimization,
452
+ dtype=dtype,
453
+ device=device,
454
+ subfolder=subfolder,
455
+ opset=opset,
456
+ drop_inputs=drop_inputs,
416
457
  )
417
458
  dump_folder = os.path.join(dump_folder, folder_name)
418
459
  if not os.path.exists(dump_folder):
@@ -508,6 +549,11 @@ def validate_model(
508
549
  if verbose:
509
550
  print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
510
551
 
552
+ # modelbuilder needs different treatments sometimes, so
553
+ # we mark it for later usage.
554
+ # for example, it has different past_kv ordering than
555
+ # flattened CacheObject
556
+ data["exporter"] = exporter
511
557
  data["input_options"] = iop
512
558
  data["model_options"] = mop
513
559
  data["model_dump_folder"] = dump_folder
@@ -743,6 +789,7 @@ def validate_model(
743
789
  repeat=repeat,
744
790
  warmup=warmup,
745
791
  inputs2=inputs2,
792
+ ort_logs=ort_logs,
746
793
  )
747
794
  summary.update(summary_valid)
748
795
 
@@ -807,6 +854,8 @@ def validate_model(
807
854
  )
808
855
  summary.update(summary_valid)
809
856
 
857
+ _compute_final_statistics(summary)
858
+
810
859
  if verbose:
811
860
  print("[validate_model] -- done (final)")
812
861
  if dump_stats:
@@ -819,15 +868,24 @@ def validate_model(
819
868
  def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
820
869
  """Computes some statistics on the model itself."""
821
870
  onx = onnx.load(onnx_filename, load_external_data=False)
871
+ cache_functions = {(f.domain, f.name): f for f in onx.functions}
872
+ local_domains = set(f.domain for f in onx.functions)
822
873
 
823
874
  def node_iter(proto):
824
875
  if isinstance(proto, onnx.ModelProto):
825
- yield from node_iter(proto.graph)
826
876
  for f in proto.functions:
827
877
  yield from node_iter(f)
878
+ yield from node_iter(proto.graph)
828
879
  elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
829
880
  for node in proto.node:
830
881
  yield node
882
+
883
+ # Let's inline the function
884
+ key = node.domain, node.op_type
885
+ if key in cache_functions:
886
+ yield from node_iter(cache_functions[key])
887
+
888
+ # Let's continue
831
889
  for att in node.attribute:
832
890
  if att.type == onnx.AttributeProto.GRAPH:
833
891
  yield from node_iter(att.g)
@@ -837,15 +895,29 @@ def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
837
895
  raise NotImplementedError(f"Unexpected type={type(proto)}")
838
896
 
839
897
  counts: Dict[str, Union[float, int]] = {}
898
+ n_nodes = 0
899
+ n_nodes_nocst = 0
840
900
  for proto in node_iter(onx):
841
901
  if isinstance(proto, onnx.NodeProto):
842
902
  key = f"n_node_{proto.op_type}"
903
+ n_nodes += 1
904
+ if proto.op_type != "Constant":
905
+ n_nodes_nocst += 1
906
+ if proto.domain in local_domains:
907
+ key = "n_node_local_function"
908
+ if key not in counts:
909
+ counts[key] = 0
910
+ counts[key] += 1
843
911
  else:
844
912
  key = f"n_node_initializer_{proto.data_type}"
845
913
 
846
914
  if key not in counts:
847
915
  counts[key] = 0
848
916
  counts[key] += 1
917
+
918
+ counts["n_node_nodes"] = n_nodes
919
+ counts["n_node_nodes_nocst"] = n_nodes_nocst
920
+ counts["n_node_functions"] = len(onx.functions)
849
921
  return counts
850
922
 
851
923
 
@@ -922,6 +994,26 @@ def _validate_do_run_exported_program(data, summary, verbose, quiet):
922
994
  )
923
995
 
924
996
 
997
+ _cache_export_times = []
998
+ _main_export_function = torch.export.export
999
+
1000
+
1001
+ def _torch_export_export(*args, _export=_main_export_function, **kwargs):
1002
+ begin = time.perf_counter()
1003
+ res = _export(*args, **kwargs)
1004
+ duration = time.perf_counter() - begin
1005
+ _cache_export_times.append(duration)
1006
+ return res
1007
+
1008
+
1009
+ def _restore_torch_export_export(summary):
1010
+ torch.export.export = _main_export_function
1011
+ if _cache_export_times:
1012
+ summary["time_torch_export_export"] = sum(_cache_export_times)
1013
+ summary["time_torch_export_export_n"] = len(_cache_export_times)
1014
+ _cache_export_times.clear()
1015
+
1016
+
925
1017
  def call_exporter(
926
1018
  data: Dict[str, Any],
927
1019
  exporter: str,
@@ -947,6 +1039,9 @@ def call_exporter(
947
1039
  :return: two dictionaries, one with some metrics,
948
1040
  another one with whatever the function produces
949
1041
  """
1042
+ _cache_export_times.clear()
1043
+ torch.export.export = _torch_export_export
1044
+
950
1045
  if exporter == "export" or exporter.startswith("export-"):
951
1046
  # torch export
952
1047
  summary, data = call_torch_export_export(
@@ -957,6 +1052,7 @@ def call_exporter(
957
1052
  optimization=optimization,
958
1053
  do_run=do_run,
959
1054
  )
1055
+ _restore_torch_export_export(summary)
960
1056
  return summary, data
961
1057
  if exporter.startswith("onnx-"):
962
1058
  # torch export
@@ -968,6 +1064,7 @@ def call_exporter(
968
1064
  optimization=optimization,
969
1065
  output_names=output_names,
970
1066
  )
1067
+ _restore_torch_export_export(summary)
971
1068
  return summary, data
972
1069
  if exporter == "custom" or exporter.startswith("custom"):
973
1070
  # torch export
@@ -980,6 +1077,7 @@ def call_exporter(
980
1077
  dump_folder=dump_folder,
981
1078
  output_names=output_names,
982
1079
  )
1080
+ _restore_torch_export_export(summary)
983
1081
  return summary, data
984
1082
  if exporter == "modelbuilder":
985
1083
  # torch export
@@ -991,6 +1089,7 @@ def call_exporter(
991
1089
  optimization=optimization,
992
1090
  output_names=output_names,
993
1091
  )
1092
+ _restore_torch_export_export(summary)
994
1093
  return summary, data
995
1094
  raise NotImplementedError(
996
1095
  f"export with {exporter!r} and optimization={optimization!r} not implemented yet, "
@@ -1134,6 +1233,7 @@ def validate_onnx_model(
1134
1233
  repeat: int = 1,
1135
1234
  warmup: int = 0,
1136
1235
  inputs2: int = 1,
1236
+ ort_logs: bool = False,
1137
1237
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1138
1238
  """
1139
1239
  Verifies that an onnx model produces the same
@@ -1146,12 +1246,13 @@ def validate_onnx_model(
1146
1246
  :param quiet: catch exception or not
1147
1247
  :param verbose: verbosity
1148
1248
  :param flavour: use a different version of the inputs
1149
- :param runtime: onnx runtime to use, onnxruntime or torch
1249
+ :param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
1150
1250
  :param repeat: run that number of times the model
1151
1251
  :param warmup: warmup the model
1152
1252
  :param inputs2: to validate the model on the second input set
1153
1253
  to make sure the exported model supports dynamism, the value is
1154
1254
  used as an increment added to the first set of inputs (added to dimensions)
1255
+ :param ort_logs: triggers the logs for onnxruntime
1155
1256
  :return: two dictionaries, one with some metrics,
1156
1257
  another one with whatever the function produces
1157
1258
  """
@@ -1193,23 +1294,71 @@ def validate_onnx_model(
1193
1294
  f"{providers}..., flavour={flavour!r}"
1194
1295
  )
1195
1296
 
1196
- if runtime != "onnxruntime":
1297
+ if runtime == "onnxruntime":
1298
+ if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
1299
+ opts = onnxruntime.SessionOptions()
1300
+ opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
1301
+ if verbose:
1302
+ print(
1303
+ f"[validate_onnx_model] saved optimized onnxruntime "
1304
+ f"in {opts.optimized_model_filepath!r}"
1305
+ )
1306
+ onnxruntime.InferenceSession(data["onnx_filename"], opts, providers=providers)
1307
+ if verbose:
1308
+ print("[validate_onnx_model] -- done")
1309
+
1310
+ if verbose:
1311
+ print("[validate_onnx_model] runtime is onnxruntime")
1312
+ sess_opts = onnxruntime.SessionOptions()
1313
+ if ort_logs:
1314
+ sess_opts.log_severity_level = 0
1315
+ sess_opts.log_verbosity_level = 4
1316
+ cls_runtime = lambda model, providers, _o=sess_opts: onnxruntime.InferenceSession(
1317
+ (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1318
+ _o,
1319
+ providers=providers,
1320
+ )
1321
+ elif runtime == "torch":
1197
1322
  from ..reference import TorchOnnxEvaluator
1198
1323
 
1199
- cls_runtime = (
1200
- (
1201
- lambda model, providers: onnxruntime.InferenceSession(
1202
- (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1203
- providers=providers,
1324
+ if verbose:
1325
+ print("[validate_onnx_model] runtime is TorchOnnxEvaluator")
1326
+ cls_runtime = (
1327
+ lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
1328
+ model, providers=providers, verbose=max(verbose - 1, 0)
1204
1329
  )
1205
1330
  )
1206
- if runtime == "onnxruntime"
1207
- else (
1208
- lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
1331
+ elif runtime == "orteval":
1332
+ from ..reference import OnnxruntimeEvaluator
1333
+
1334
+ if verbose:
1335
+ print("[validate_onnx_model] runtime is OnnxruntimeEvaluator")
1336
+ cls_runtime = (
1337
+ lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
1209
1338
  model, providers=providers, verbose=max(verbose - 1, 0)
1210
1339
  )
1211
1340
  )
1212
- )
1341
+ elif runtime == "orteval10":
1342
+ from ..reference import OnnxruntimeEvaluator
1343
+
1344
+ if verbose:
1345
+ print("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)")
1346
+ cls_runtime = (
1347
+ lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
1348
+ model, providers=providers, verbose=10
1349
+ )
1350
+ )
1351
+ elif runtime == "ref":
1352
+ from ..reference import ExtendedReferenceEvaluator
1353
+
1354
+ if verbose:
1355
+ print("[validate_onnx_model] runtime is ExtendedReferenceEvaluator")
1356
+ cls_runtime = lambda model, providers, _cls_=ExtendedReferenceEvaluator: _cls_( # type: ignore[misc]
1357
+ model, verbose=max(verbose - 1, 0)
1358
+ )
1359
+ else:
1360
+ raise ValueError(f"Unexpecteed runtime={runtime!r}")
1361
+
1213
1362
  sess = _quiet_or_not_quiet(
1214
1363
  quiet,
1215
1364
  _mk("create_onnx_ort"),
@@ -1234,7 +1383,13 @@ def validate_onnx_model(
1234
1383
  print(
1235
1384
  f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}"
1236
1385
  )
1237
- feeds = make_feeds(sess, data[k_input], use_numpy=True, check_flatten=False)
1386
+ feeds = make_feeds(
1387
+ sess,
1388
+ data[k_input],
1389
+ use_numpy=True,
1390
+ check_flatten=False,
1391
+ is_modelbuilder=data["exporter"] == "modelbuilder",
1392
+ )
1238
1393
  if verbose:
1239
1394
  print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
1240
1395
  summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True)
@@ -1254,6 +1409,13 @@ def validate_onnx_model(
1254
1409
  repeat=repeat,
1255
1410
  warmup=warmup,
1256
1411
  )
1412
+ # NOTE: modelbuilder has different order on past_kv outputs
1413
+ if data["exporter"] == "modelbuilder":
1414
+ logits = got[:1]
1415
+ past_key_values = got[1:]
1416
+ reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values)
1417
+ got = logits + reorder_past_key_values
1418
+
1257
1419
  if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
1258
1420
  return summary, data
1259
1421
 
@@ -1294,7 +1456,7 @@ def call_torch_export_onnx(
1294
1456
  :return: two dictionaries, one with some metrics,
1295
1457
  another one with whatever the function produces
1296
1458
  """
1297
- available = {None, "", "ir", "os_ort"}
1459
+ available = {None, "", "ir", "os_ort", "ir+default"}
1298
1460
  assert (
1299
1461
  optimization in available
1300
1462
  ), f"unexpected value for optimization={optimization}, available={available}"
@@ -1384,12 +1546,34 @@ def call_torch_export_onnx(
1384
1546
  print(epo)
1385
1547
  print("[call_torch_export_onnx] -- End of ONNXProgram")
1386
1548
 
1387
- if optimization in {"ir", "os_ort"}:
1549
+ if optimization in {"ir", "os_ort", "ir+default"}:
1388
1550
  if verbose:
1389
1551
  print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
1390
1552
  if optimization == "ir":
1391
1553
  label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1554
+ elif optimization == "ir+default":
1555
+ import onnxscript
1556
+ from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
1557
+
1558
+ def _ir_default_opt(epo):
1559
+ onnxscript.optimizer.optimize_ir(epo.model)
1560
+ onx = epo.model_proto
1561
+ # not very efficient
1562
+ gr = GraphBuilder(
1563
+ onx,
1564
+ infer_shapes_options=True,
1565
+ optimization_options=OptimizationOptions(patterns="default"),
1566
+ )
1567
+ cont = gr.to_onnx(large_model=True)
1568
+ epo.model = cont.to_ir()
1569
+
1570
+ label, f_optim = "export_onnx_opt_ir_default", (
1571
+ lambda epo=epo: _ir_default_opt(epo)
1572
+ )
1573
+
1392
1574
  else:
1575
+ import onnxscript
1576
+ import onnxscript.rewriter.ort_fusions as ort_fusions
1393
1577
 
1394
1578
  def _os_ort_optim(epo):
1395
1579
  onnxscript.optimizer.optimize_ir(epo.model)
@@ -1477,6 +1661,97 @@ def call_torch_export_model_builder(
1477
1661
  return summary, data
1478
1662
 
1479
1663
 
1664
+ def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, Any]:
1665
+ """
1666
+ Processes statistics coming from the exporters.
1667
+ It takes a sequence of dictionaries (like a data frame)
1668
+ and extracts some metrics.
1669
+ """
1670
+
1671
+ def _simplify(p):
1672
+ for s in [
1673
+ "remove_unused",
1674
+ "constant_folding",
1675
+ "remove_identity",
1676
+ "remove_duplicated_initializer",
1677
+ "dynamic_dimension_naming",
1678
+ "inline",
1679
+ "check",
1680
+ "build_graph_for_pattern",
1681
+ "pattern_optimization",
1682
+ ]:
1683
+ if s in p or s.replace("_", "-") in p:
1684
+ return s
1685
+ if p.startswith(("apply_", "match_")):
1686
+ return p
1687
+ return "other"
1688
+
1689
+ def _add(d, a, v, use_max=False):
1690
+ if v:
1691
+ if a not in d:
1692
+ d[a] = v
1693
+ elif use_max:
1694
+ d[a] = max(d[a], v)
1695
+ else:
1696
+ d[a] += v
1697
+
1698
+ counts: Dict[str, Any] = {}
1699
+ applied_pattern_time: Dict[str, Any] = {}
1700
+ applied_pattern_n: Dict[str, Any] = {}
1701
+ matching_pattern_time: Dict[str, Any] = {}
1702
+ matching_pattern_n: Dict[str, Any] = {}
1703
+
1704
+ for obs in data:
1705
+ pattern = _simplify(obs["pattern"])
1706
+ _add(counts, "opt_nodes_added", obs.get("added", 0))
1707
+ _add(counts, "opt_nodes_removed", obs.get("removed", 0))
1708
+ _add(counts, "opt_time_steps", obs.get("time_in", 0))
1709
+ _add(counts, "opt_n_steps", 1)
1710
+ _add(
1711
+ counts,
1712
+ "opt_n_iteration",
1713
+ max(counts.get("opt_n_iteration", 0), obs.get("iteration", 0)),
1714
+ use_max=True,
1715
+ )
1716
+
1717
+ if pattern.startswith("apply_"):
1718
+ _add(counts, "opt_n_applied_patterns", 1)
1719
+ _add(counts, "opt_time_applied_patterns", obs.get("time_in", 0))
1720
+ _add(applied_pattern_time, pattern, obs.get("time_in", 0))
1721
+ _add(applied_pattern_n, pattern, 1)
1722
+ elif pattern.startswith("match_"):
1723
+ _add(counts, "opt_n_matching_patterns", 1)
1724
+ _add(counts, "opt_time_matching_patterns", obs.get("time_in", 0))
1725
+ _add(matching_pattern_time, pattern, obs.get("time_in", 0))
1726
+ _add(matching_pattern_n, pattern, 1)
1727
+ else:
1728
+ _add(counts, f"opt_time_{pattern}", obs.get("time_in", 0))
1729
+ _add(counts, f"opt_n_{pattern}", 1)
1730
+ _add(counts, f"opt_nodes_added_{pattern}", obs.get("added", 0))
1731
+ _add(counts, f"opt_nodes_removed_{pattern}", obs.get("removed", 0))
1732
+
1733
+ if applied_pattern_time:
1734
+ longest = max((v, k) for k, v in applied_pattern_time.items())
1735
+ counts["opt_top_time_applied_pattern"], counts["opt_top_time_applied_pattern_arg"] = (
1736
+ longest
1737
+ )
1738
+ longest = max((v, k) for k, v in applied_pattern_n.items())
1739
+ counts["opt_top_n_applied_pattern"], counts["opt_top_n_applied_pattern_arg"] = longest
1740
+
1741
+ if matching_pattern_time:
1742
+ longest = max((v, k) for k, v in matching_pattern_time.items())
1743
+ (
1744
+ counts["opt_top_time_matching_pattern"],
1745
+ counts["opt_top_time_matching_pattern_arg"],
1746
+ ) = longest
1747
+ longest = max((v, k) for k, v in matching_pattern_n.items())
1748
+ counts["opt_top_n_matching_pattern"], counts["opt_top_n_matching_pattern_arg"] = (
1749
+ longest
1750
+ )
1751
+ counts["onnx_opt_optimized"] = 1
1752
+ return counts
1753
+
1754
+
1480
1755
  def call_torch_export_custom(
1481
1756
  data: Dict[str, Any],
1482
1757
  exporter: str,
@@ -1509,6 +1784,8 @@ def call_torch_export_custom(
1509
1784
  "default+onnxruntime+os_ort",
1510
1785
  None,
1511
1786
  }
1787
+ if optimization == "none":
1788
+ optimization = ""
1512
1789
  assert (
1513
1790
  optimization in available
1514
1791
  ), f"unexpected value for optimization={optimization}, available={available}"
@@ -1604,67 +1881,10 @@ def call_torch_export_custom(
1604
1881
  if "ERR_export_onnx_c" in summary:
1605
1882
  return summary, data
1606
1883
 
1607
- new_stat = {}
1884
+ new_stat: Dict[str, Any] = {k: v for k, v in opt_stats.items() if k.startswith("time_")}
1885
+ new_stat.update({k[5:]: v for k, v in opt_stats.items() if k.startswith("stat_time_")})
1608
1886
  if "optimization" in opt_stats:
1609
- added, removed, time_in = 0, 0, 0.0
1610
- max_iter = 0
1611
- applied = {}
1612
- matched = set()
1613
- n_applied = 0
1614
- by_pattern = {}
1615
- by_pattern_n = {}
1616
- by_iter = {}
1617
- cst_added, cst_removed, cst_time_in = 0, 0, 0.0
1618
-
1619
- for obs in opt_stats["optimization"]:
1620
- pattern = obs["pattern"]
1621
- if pattern == "constant_folding":
1622
- cst_added += obs.get("added", 0)
1623
- cst_removed += obs.get("removed", 0)
1624
- cst_time_in += obs.get("time_in", 0)
1625
- if pattern not in by_pattern:
1626
- by_pattern[pattern] = 0
1627
- by_pattern_n[pattern] = 0
1628
- by_iter[pattern] = 0
1629
- time_in += obs.get("time_in", 0)
1630
- added += obs.get("added", 0)
1631
- removed += obs.get("removed", 0)
1632
- max_iter = max(max_iter, obs.get("iteration", 0))
1633
- by_pattern[pattern] += obs.get("time_in", 0)
1634
- by_pattern_n[pattern] += obs.get("added", 0) - obs.get("removed", 0)
1635
- if not pattern.startswith("match"):
1636
- by_iter[pattern] = max(by_iter[pattern], obs.get("iteration", 0))
1637
- p = obs["pattern"]
1638
- if p.startswith("match_"):
1639
- matched.add(p)
1640
- elif p.startswith("apply_"):
1641
- key = f"op_opt_{p}"
1642
- key2 = f"op_opt_maxiter_{p}"
1643
- if key not in applied:
1644
- applied[key] = 1
1645
- applied[key2] = obs["iteration"]
1646
- else:
1647
- applied[key] += 1
1648
- applied[key2] = max(obs["iteration"], applied[key2])
1649
- n_applied += 1
1650
-
1651
- new_stat.update(
1652
- dict(
1653
- onnx_opt_optimized=1,
1654
- op_opt_all_time_in=time_in,
1655
- op_opt_all_added=added,
1656
- op_opt_all_removed=removed,
1657
- op_opt_max_iter=max_iter,
1658
- op_opt_unique_matched=len(matched),
1659
- op_opt_unique_applied=len(applied),
1660
- op_opt_n_applied=n_applied,
1661
- time_export_optimization=time_in,
1662
- op_opt_export_optimization=time_in,
1663
- op_opt_cst_time_in=cst_time_in,
1664
- op_opt_cst_added=cst_added,
1665
- op_opt_cst_removed=cst_removed,
1666
- )
1667
- )
1887
+ new_stat.update(process_statistics(opt_stats["optimization"]))
1668
1888
 
1669
1889
  summary.update(new_stat)
1670
1890
  assert epo is not None, "no onnx export was found"
@@ -1672,6 +1892,9 @@ def call_torch_export_custom(
1672
1892
  print("[call_torch_export_custom] done (export)")
1673
1893
 
1674
1894
  if os_ort:
1895
+ import onnxscript
1896
+ import onnxscript.rewriter.ort_fusions as ort_fusions
1897
+
1675
1898
  if verbose:
1676
1899
  print("[call_torch_export_custom] conversion to IR...")
1677
1900
  begin = time.perf_counter()
@@ -1780,3 +2003,21 @@ def run_ort_fusion(
1780
2003
  f"opt_ort_{model_type}_duration": duration,
1781
2004
  f"opt_ort_{model_type}_duration_save": d,
1782
2005
  }, {f"opt_ort_{model_type}": output_path}
2006
+
2007
+
2008
+ def _compute_final_statistics(summary: Dict[str, Any]):
2009
+ """
2010
+ Updates inline the list of statistics. It adds:
2011
+
2012
+ - speedup
2013
+ """
2014
+ stats = {}
2015
+ if (
2016
+ "time_run_latency" in summary
2017
+ and "time_run_onnx_ort_latency" in summary
2018
+ and summary["time_run_onnx_ort_latency"] > 0
2019
+ ):
2020
+ stats["stat_estimated_speedup_ort"] = (
2021
+ summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
2022
+ )
2023
+ summary.update(stats)