onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.1"
6
+ __version__ = "0.8.3"
7
7
  __author__ = "Xavier Dupré"
@@ -1,14 +1,87 @@
1
1
  import argparse
2
+ import contextlib
2
3
  import json
3
4
  import os
4
5
  import re
5
6
  import sys
6
7
  import textwrap
8
+ import time
7
9
  import onnx
8
10
  from typing import Any, Dict, List, Optional, Union
9
11
  from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
10
12
 
11
13
 
14
+ def get_parser_dot() -> ArgumentParser:
15
+ parser = ArgumentParser(
16
+ prog="dot",
17
+ description=textwrap.dedent(
18
+ """
19
+ Converts a model into a dot file dot can draw into a graph.
20
+ """
21
+ ),
22
+ )
23
+ parser.add_argument("input", type=str, help="onnx model to lighten")
24
+ parser.add_argument(
25
+ "-o",
26
+ "--output",
27
+ default="",
28
+ type=str,
29
+ required=False,
30
+ help="dot model to output or empty to print out the result",
31
+ )
32
+ parser.add_argument(
33
+ "-v",
34
+ "--verbose",
35
+ type=int,
36
+ default=0,
37
+ required=False,
38
+ help="verbosity",
39
+ )
40
+ parser.add_argument(
41
+ "-r",
42
+ "--run",
43
+ default="",
44
+ required=False,
45
+ help="run dot, in that case, format must be given (svg, png)",
46
+ )
47
+ return parser
48
+
49
+
50
+ def _cmd_dot(argv: List[Any]):
51
+ import subprocess
52
+ from .helpers.dot_helper import to_dot
53
+
54
+ parser = get_parser_dot()
55
+ args = parser.parse_args(argv[1:])
56
+ if args.verbose:
57
+ print(f"-- loads {args.input!r}")
58
+ onx = onnx.load(args.input, load_external_data=False)
59
+ if args.verbose:
60
+ print("-- converts into dot")
61
+ dot = to_dot(onx)
62
+ if args.output:
63
+ if args.verbose:
64
+ print(f"-- saves into {args.output}")
65
+ with open(args.output, "w") as f:
66
+ f.write(dot)
67
+ else:
68
+ print(dot)
69
+ if args.run:
70
+ assert args.output, "Cannot run dot without an output file."
71
+ cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"]
72
+ if args.verbose:
73
+ print(f"-- run {' '.join(cmds)}")
74
+ p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
75
+ res = p.communicate()
76
+ out, err = res
77
+ if out:
78
+ print("--")
79
+ print(out)
80
+ if err:
81
+ print("--")
82
+ print(err)
83
+
84
+
12
85
  def get_parser_lighten() -> ArgumentParser:
13
86
  parser = ArgumentParser(
14
87
  prog="lighten",
@@ -624,6 +697,18 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
624
697
  ),
625
698
  action=_ParseDict,
626
699
  )
700
+ parser.add_argument(
701
+ "--save-ep",
702
+ default="",
703
+ help=textwrap.dedent(
704
+ """
705
+ saves the exported program with torch.export.save
706
+ and the inputs sets with torch.save,
707
+ then command line sbs can be used to look for discrepancies.
708
+ """
709
+ ),
710
+ )
711
+
627
712
  return parser
628
713
 
629
714
 
@@ -690,6 +775,7 @@ def _cmd_validate(argv: List[Any]):
690
775
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
691
776
  ),
692
777
  exporter_options=args.expop,
778
+ save_ep=args.save_ep,
693
779
  )
694
780
  print("")
695
781
  print("-- summary --")
@@ -1104,6 +1190,287 @@ def _cmd_agg(argv: List[Any]):
1104
1190
  print(f"Wrote {args.output!r}")
1105
1191
 
1106
1192
 
1193
+ def get_parser_sbs() -> ArgumentParser:
1194
+ parser = ArgumentParser(
1195
+ prog="side-by-side (sbs)",
1196
+ description=textwrap.dedent(
1197
+ """
1198
+ Compares the intermediate outputs between the exported program and
1199
+ the exported onnx model. It assumes some names are common.
1200
+ The execution of the exported program and the onnx model
1201
+ are done in parallel. The device is the one used to store the
1202
+ model and the inputs.
1203
+ Where do discrepancies start? This function tries to answer that question.
1204
+ """
1205
+ ),
1206
+ epilog=textwrap.dedent(
1207
+ """
1208
+ The command line expects the following files to be saved with
1209
+ the following function. inputs is a dictionary of the input of the model.
1210
+
1211
+ - torch.export.save(ep: torch.export.ExportedProgram)
1212
+ - torch.save(**inputs)
1213
+ - onnx.save(...)
1214
+
1215
+ The Replay functionality is just a way to investigates a part of a model.
1216
+ It saves torch and onnx inputs, the torch outputs, and the minimal onnx model
1217
+ which shares its inputs with the exported program.
1218
+ This is used to investigate the discrepancies between the torch
1219
+ model (through the exported program) and its onnx conversion.
1220
+ This functionality dumps everything it can to disk
1221
+ so that it be replayed in a separate process.
1222
+ """
1223
+ ),
1224
+ )
1225
+ parser.add_argument(
1226
+ "-i",
1227
+ "--inputs",
1228
+ type=str,
1229
+ required=True,
1230
+ help="model inputs saved with torch.save",
1231
+ )
1232
+ parser.add_argument(
1233
+ "-e",
1234
+ "--ep",
1235
+ type=str,
1236
+ required=True,
1237
+ help=textwrap.dedent(
1238
+ """
1239
+ exported program saved with torch.export.save,
1240
+ input sets saved with torch.save,
1241
+ """
1242
+ ),
1243
+ )
1244
+ parser.add_argument(
1245
+ "-m",
1246
+ "--onnx",
1247
+ type=str,
1248
+ required=True,
1249
+ help="exported model in onnx format",
1250
+ )
1251
+ parser.add_argument(
1252
+ "-o",
1253
+ "--output",
1254
+ type=str,
1255
+ required=True,
1256
+ help="output name to stored what the command line produces, "
1257
+ "it should be an excel file",
1258
+ )
1259
+ parser.add_argument(
1260
+ "--atol",
1261
+ default=1e-5,
1262
+ required=False,
1263
+ help="absolute tolerance",
1264
+ )
1265
+ parser.add_argument(
1266
+ "--rtol",
1267
+ default=1e-5,
1268
+ required=False,
1269
+ help="relative tolerance",
1270
+ )
1271
+ parser.add_argument(
1272
+ "-v",
1273
+ "--verbose",
1274
+ default=0,
1275
+ required=False,
1276
+ help="verbosity",
1277
+ )
1278
+ parser.add_argument(
1279
+ "-r",
1280
+ "--ratio",
1281
+ default=100,
1282
+ required=False,
1283
+ help="Saves the result in an excel file every <ratio> nodes, default is 100.",
1284
+ )
1285
+ parser.add_argument(
1286
+ "--first",
1287
+ action=BooleanOptionalAction,
1288
+ default=False,
1289
+ help="First runs the whole model.",
1290
+ )
1291
+ parser.add_argument(
1292
+ "-2",
1293
+ "--second-run",
1294
+ action=BooleanOptionalAction,
1295
+ default=False,
1296
+ help=textwrap.dedent(
1297
+ """
1298
+ Tries to run all onnx nodes with torch results produced by the exported
1299
+ program. It then measures the discrepancies again. It can be used
1300
+ to identify kernel introduces discrepancies from other just propagating them.
1301
+ """
1302
+ ),
1303
+ )
1304
+ parser.add_argument(
1305
+ "--reset",
1306
+ required=False,
1307
+ default="",
1308
+ help=textwrap.dedent(
1309
+ """
1310
+ List of result names separated by a comma. For those results,
1311
+ the side-by-side will take torch results instead of onnx results
1312
+ to compute the rest of the onnx model.
1313
+ """
1314
+ ),
1315
+ )
1316
+ parser.add_argument(
1317
+ "-s",
1318
+ "--replay-threshold",
1319
+ type=float,
1320
+ required=False,
1321
+ default=1e18,
1322
+ help="Triggers the replay if the discrepancies are higher than this value.",
1323
+ )
1324
+ parser.add_argument(
1325
+ "-n",
1326
+ "--replay-names",
1327
+ required=False,
1328
+ default="",
1329
+ help="Triggers the replay if a result name is in this set of values (comma separated)",
1330
+ )
1331
+ parser.add_argument(
1332
+ "-t",
1333
+ "--replay-op-types",
1334
+ required=False,
1335
+ default="",
1336
+ help="Triggers the replay if an onnx type is in this set of values (comma separated)",
1337
+ )
1338
+ parser.add_argument(
1339
+ "-f",
1340
+ "--replay-folder",
1341
+ required=False,
1342
+ default="replay",
1343
+ help="If the replay is triggered, this defines the folder where everything is dumped.",
1344
+ )
1345
+
1346
+ return parser
1347
+
1348
+
1349
+ def _cmd_sbs(argv: List[Any]):
1350
+ import pandas
1351
+ import torch
1352
+ from .helpers import flatten_object, max_diff, string_diff, string_type
1353
+ from .torch_onnx.sbs import run_aligned
1354
+ from .torch_onnx.sbs_dataclasses import ReplayConfiguration
1355
+ from .reference import OnnxruntimeEvaluator
1356
+
1357
+ parser = get_parser_sbs()
1358
+ args = parser.parse_args(argv[1:])
1359
+
1360
+ def _size(name):
1361
+ s = os.stat(name).st_size
1362
+ return f"{s / 2**20:1.3f} Mb"
1363
+
1364
+ print("-- side by side")
1365
+ print(f"-- ep: {_size(args.ep)}: {args.ep}")
1366
+ print(f"-- inputs: {_size(args.inputs)}: {args.inputs}")
1367
+ print(f"-- onnx: {_size(args.onnx)}: {args.onnx}")
1368
+ print(f"-- output: {args.output}")
1369
+
1370
+ print(f"-- load inputs {args.inputs!r}")
1371
+ begin = time.perf_counter()
1372
+ inputs = torch.load(args.inputs)
1373
+ s = string_type(inputs, with_shape=True, with_device=True)
1374
+ print(f"-- done in {time.perf_counter() - begin:1.1f}s - {s}")
1375
+
1376
+ if isinstance(inputs, dict) and len(inputs) == 2 and set(inputs) == {"args", "kwargs"}:
1377
+ margs = inputs["args"]
1378
+ mkwargs = inputs["kwargs"]
1379
+ elif isinstance(inputs, tuple):
1380
+ margs = inputs
1381
+ mkwargs = {}
1382
+ elif isinstance(inputs, dict):
1383
+ margs = tuple()
1384
+ mkwargs = inputs
1385
+ else:
1386
+ raise ValueError(
1387
+ f"Unable to infer args, kwargs from inputs {string_type(inputs, with_shape=True)}"
1388
+ )
1389
+
1390
+ print("-- import transformers.modeling_outputs to register serialization functions")
1391
+ with contextlib.suppress(ImportError):
1392
+ import transformers.modeling_outputs # noqa: F401
1393
+ print(f"-- load ep {args.ep!r}")
1394
+ begin = time.perf_counter()
1395
+ # We need to load the plugs.
1396
+ from .torch_export_patches.patches.patch_transformers import get_transformers_plugs
1397
+
1398
+ plugs = get_transformers_plugs()
1399
+ assert plugs, "Missing PLUGS for Qwen2.5"
1400
+ ep = torch.export.load(args.ep)
1401
+ print(f"-- done in {time.perf_counter() - begin:1.1f}s")
1402
+
1403
+ if args.first:
1404
+ print("-- compare first, run ep")
1405
+ print(f"-- args: {string_type(margs, with_shape=True, with_device=True)}")
1406
+ print(f"-- mkwargs: {string_type(mkwargs, with_shape=True, with_device=True)}")
1407
+ expected = ep.module()(*margs, **mkwargs)
1408
+ print(f"-- expected: {string_type(expected, with_shape=True, with_device=True)}")
1409
+ sess = OnnxruntimeEvaluator(args.onnx, whole=True)
1410
+ onx_inputs = flatten_object([margs, mkwargs], drop_keys=True)
1411
+ feeds = dict(zip(sess.input_names, onx_inputs))
1412
+ print(f"-- feeds: {string_type(feeds, with_shape=True, with_device=True)}")
1413
+ got = sess.run(None, feeds)
1414
+ print(f"-- got: {string_type(got, with_shape=True, with_device=True)}")
1415
+ diff = max_diff(expected, got, hist=[0.1])
1416
+ print(f"-- diff: {string_diff(diff)}")
1417
+ print("-- done")
1418
+ del sess
1419
+
1420
+ print(f"-- load onnx {args.onnx!r}")
1421
+ begin = time.perf_counter()
1422
+ onx = onnx.load(args.onnx)
1423
+ print(f"-- done in {time.perf_counter() - begin:1.1f}s")
1424
+
1425
+ replay_configuration = None
1426
+ if args.replay_threshold < 1e6 or args.replay_names or args.replay_op_types:
1427
+ replay_configuration = ReplayConfiguration(
1428
+ threshold=args.replay_threshold,
1429
+ selected_names=set(args.replay_names.split(",")) if args.replay_names else None,
1430
+ selected_op_types=(
1431
+ set(args.replay_op_types.split(",")) if args.replay_op_types else None
1432
+ ),
1433
+ dump_folder=args.replay_folder,
1434
+ )
1435
+
1436
+ print("-- starts side-by-side")
1437
+ ratio = int(args.ratio)
1438
+ data = []
1439
+ for obs in run_aligned(
1440
+ ep,
1441
+ onx,
1442
+ run_cls=OnnxruntimeEvaluator, # type: ignore[arg-type]
1443
+ atol=float(args.atol),
1444
+ rtol=float(args.rtol),
1445
+ verbose=int(args.verbose),
1446
+ args=margs,
1447
+ kwargs=mkwargs,
1448
+ use_tensor=True,
1449
+ reset_names=args.reset.split(","),
1450
+ exc=False,
1451
+ replay_configuration=replay_configuration,
1452
+ run_onnx_with_torch_inputs=args.second_run,
1453
+ ):
1454
+ data.append(obs)
1455
+ if (
1456
+ obs.onnx_op_type != "initializer"
1457
+ and obs.ep_target != "placeholder"
1458
+ and len(data) % ratio == 0
1459
+ ):
1460
+ df = pandas.DataFrame(data).apply(
1461
+ lambda col: col.fillna("") if col.dtype == "object" else col
1462
+ )
1463
+ df.to_excel(args.output)
1464
+ print(f"-- final saves into {args.output!r}")
1465
+ df = (
1466
+ pandas.DataFrame(data)
1467
+ .apply(lambda col: col.fillna("") if col.dtype == "object" else col)
1468
+ .dropna(axis=1, how="all")
1469
+ )
1470
+ df.to_excel(args.output, index=False)
1471
+ print("-- done")
1472
+
1473
+
1107
1474
  def get_main_parser() -> ArgumentParser:
1108
1475
  parser = ArgumentParser(
1109
1476
  prog="onnx_diagnostic",
@@ -1116,10 +1483,12 @@ def get_main_parser() -> ArgumentParser:
1116
1483
 
1117
1484
  agg - aggregates statistics from multiple files
1118
1485
  config - prints a configuration for a model id
1486
+ dot - converts an onnx model into dot format
1119
1487
  exportsample - produces a code to export a model
1120
1488
  find - find node consuming or producing a result
1121
1489
  lighten - makes an onnx model lighter by removing the weights,
1122
1490
  print - prints the model on standard output
1491
+ sbs - compares an exported program and a onnx model
1123
1492
  stats - produces statistics on a model
1124
1493
  unlighten - restores an onnx model produces by the previous experiment
1125
1494
  validate - validate a model
@@ -1131,10 +1500,12 @@ def get_main_parser() -> ArgumentParser:
1131
1500
  choices=[
1132
1501
  "agg",
1133
1502
  "config",
1503
+ "dot",
1134
1504
  "exportsample",
1135
1505
  "find",
1136
1506
  "lighten",
1137
1507
  "print",
1508
+ "sbs",
1138
1509
  "stats",
1139
1510
  "unlighten",
1140
1511
  "validate",
@@ -1146,15 +1517,17 @@ def get_main_parser() -> ArgumentParser:
1146
1517
 
1147
1518
  def main(argv: Optional[List[Any]] = None):
1148
1519
  fcts = dict(
1520
+ agg=_cmd_agg,
1521
+ config=_cmd_config,
1522
+ dot=_cmd_dot,
1523
+ exportsample=_cmd_export_sample,
1524
+ find=_cmd_find,
1149
1525
  lighten=_cmd_lighten,
1150
- unlighten=_cmd_unlighten,
1151
1526
  print=_cmd_print,
1152
- find=_cmd_find,
1153
- config=_cmd_config,
1154
- validate=_cmd_validate,
1527
+ sbs=_cmd_sbs,
1155
1528
  stats=_cmd_stats,
1156
- agg=_cmd_agg,
1157
- exportsample=_cmd_export_sample,
1529
+ unlighten=_cmd_unlighten,
1530
+ validate=_cmd_validate,
1158
1531
  )
1159
1532
 
1160
1533
  if argv is None:
@@ -1169,15 +1542,17 @@ def main(argv: Optional[List[Any]] = None):
1169
1542
  parser.parse_args(argv)
1170
1543
  else:
1171
1544
  parsers = dict(
1545
+ agg=get_parser_agg,
1546
+ config=get_parser_config,
1547
+ dot=get_parser_dot,
1548
+ exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
1549
+ find=get_parser_find,
1172
1550
  lighten=get_parser_lighten,
1173
- unlighten=get_parser_unlighten,
1174
1551
  print=get_parser_print,
1175
- find=get_parser_find,
1176
- config=get_parser_config,
1177
- validate=get_parser_validate,
1552
+ sbs=get_parser_sbs,
1178
1553
  stats=get_parser_stats,
1179
- agg=get_parser_agg,
1180
- exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
1554
+ unlighten=get_parser_unlighten,
1555
+ validate=get_parser_validate,
1181
1556
  )
1182
1557
  cmd = argv[0]
1183
1558
  if cmd not in parsers:
@@ -1,5 +1,6 @@
1
- from typing import Any, Dict, List, Sequence, Optional, Tuple, Union
1
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
2
2
  import torch
3
+ from .onnx_plug import EagerDirectReplacementWithOnnx
3
4
 
4
5
 
5
6
  def to_onnx(
@@ -14,6 +15,12 @@ def to_onnx(
14
15
  output_names: Optional[List[str]] = None,
15
16
  output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
16
17
  exporter: str = "onnx-dynamo",
18
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
19
+ save_ep: Optional[str] = None,
20
+ optimize: bool = True,
21
+ use_control_flow_dispatcher: bool = False,
22
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
23
+ inline: bool = True,
17
24
  ) -> Any:
18
25
  """
19
26
  Common API for exporters. By default, the models are optimized to use the
@@ -32,6 +39,13 @@ def to_onnx(
32
39
  :param output_names: to change the output of the onnx model
33
40
  :param output_dynamic_shapes: to overwrite the dynamic shapes names
34
41
  :param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
42
+ :param exporter_kwargs: additional parameters sent to the exporter
43
+ :param save_ep: saves the exported program
44
+ :param optimize: optimizes the model
45
+ :param use_control_flow_dispatcher: use the dispatcher created to supported
46
+ custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
47
+ :param onnx_plugs: the code was modified to replace some parts with onnx translation
48
+ :param inline: inline local functions
35
49
  :return: the output of the selected exporter, usually a structure including
36
50
  an onnx model
37
51
 
@@ -46,11 +60,74 @@ def to_onnx(
46
60
  exporter=exporter,
47
61
  filename=filename,
48
62
  )
63
+
64
+ Some examples using control flows are available in
65
+ :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx` or
66
+ :class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
49
67
  """
68
+ if exporter_kwargs and "inline" in exporter_kwargs:
69
+ assert (
70
+ inline == exporter_kwargs["inline"]
71
+ ), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}"
72
+ exporter_kwargs.pop("inline")
50
73
  if exporter == "custom":
51
- from experimental_experiment.torch_interpreter import to_onnx as _to_onnx
74
+ from experimental_experiment.torch_interpreter import (
75
+ to_onnx as _to_onnx,
76
+ ExportOptions,
77
+ )
52
78
  from experimental_experiment.xbuilder import OptimizationOptions
53
79
 
80
+ options = None
81
+ if exporter_kwargs is not None:
82
+ options = exporter_kwargs.pop("options", None)
83
+ if options is None:
84
+ options = OptimizationOptions(patterns="default+onnxruntime")
85
+ if onnx_plugs or use_control_flow_dispatcher:
86
+ from experimental_experiment.torch_interpreter import Dispatcher
87
+
88
+ if use_control_flow_dispatcher:
89
+ from .control_flow_onnx import create_global_dispatcher
90
+
91
+ control_flow_dispatcher = create_global_dispatcher()
92
+ else:
93
+ control_flow_dispatcher = None
94
+
95
+ class MainDispatcher(Dispatcher):
96
+ def __init__(self, previous_dispatcher=None):
97
+ super().__init__({})
98
+ self.previous_dispatcher = previous_dispatcher
99
+
100
+ @property
101
+ def supported(self):
102
+ if self.previous_dispatcher:
103
+ return (
104
+ set(self.registered_functions) | self.previous_dispatcher.supported
105
+ )
106
+ return set(self.registered_functions)
107
+
108
+ def find_function(self, name: Any):
109
+ if self.previous_dispatcher:
110
+ find = self.previous_dispatcher.find_function(name)
111
+ if find:
112
+ return find
113
+ return Dispatcher.find_function(self, name)
114
+
115
+ def find_method(self, name: Any):
116
+ if self.previous_dispatcher:
117
+ find = self.previous_dispatcher.find_method(name)
118
+ if find:
119
+ return find
120
+ return Dispatcher.find_method(self, name)
121
+
122
+ main_dispatcher = MainDispatcher(control_flow_dispatcher)
123
+ if onnx_plugs:
124
+ for plug in onnx_plugs:
125
+ main_dispatcher.registered_functions[plug.target_name] = (
126
+ plug.custom_converter()
127
+ )
128
+ else:
129
+ main_dispatcher = None
130
+
54
131
  return _to_onnx(
55
132
  mod,
56
133
  args=args,
@@ -63,14 +140,24 @@ def to_onnx(
63
140
  dynamic_shapes=dynamic_shapes,
64
141
  large_model=True,
65
142
  output_dynamic_shapes=output_dynamic_shapes,
66
- options=OptimizationOptions(patterns="default+onnxruntime"),
143
+ export_options=ExportOptions(save_ep=save_ep),
144
+ options=options,
145
+ inline=inline,
146
+ dispatcher=main_dispatcher,
147
+ **(exporter_kwargs or {}),
67
148
  )
149
+
68
150
  if exporter in ("dynamo", "onnx-dynamo"):
151
+ import os
69
152
  import onnxscript.rewriter.ort_fusions as ort_fusions
70
153
 
71
154
  assert (
72
155
  not output_dynamic_shapes
73
156
  ), f"output_dynamic_shapes not supported for exporter={exporter!r}"
157
+ custom_translation_table = {}
158
+ if onnx_plugs:
159
+ for plug in onnx_plugs:
160
+ custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter()
74
161
  epo = torch.onnx.export(
75
162
  mod,
76
163
  args=args or tuple(),
@@ -80,9 +167,34 @@ def to_onnx(
80
167
  opset_version=target_opset,
81
168
  dynamic_shapes=dynamic_shapes,
82
169
  dynamo=True,
170
+ verbose=verbose,
171
+ dump_exported_program=bool(save_ep),
172
+ artifacts_dir=os.path.dirname(filename) if filename else ".",
173
+ custom_translation_table=custom_translation_table,
174
+ **(exporter_kwargs or {}),
83
175
  )
84
- ort_fusions.optimize_for_ort(epo.model)
85
- epo.save(filename)
176
+ if not inline and optimize:
177
+ ort_fusions.optimize_for_ort(epo.model)
178
+
179
+ if onnx_plugs:
180
+ import onnx_ir as ir
181
+ import onnx_ir.passes.common as common_passes
182
+
183
+ irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
184
+ for func in irfunctions:
185
+ epo.model.functions[func.identifier()] = func
186
+ if inline:
187
+ common_passes.InlinePass()(epo.model)
188
+ common_passes.RemoveUnusedOpsetsPass()(epo.model)
189
+
190
+ if inline and optimize:
191
+ ort_fusions.optimize_for_ort(epo.model)
192
+ if filename:
193
+ epo.save(filename, external_data=True)
194
+ if save_ep:
195
+ if isinstance(save_ep, tuple):
196
+ save_ep = save_ep[0]
197
+ torch.export.save(epo.exported_program, f"{save_ep}.pt2")
86
198
  return epo
87
199
 
88
200
  if exporter == "modelbuilder":
@@ -117,6 +229,7 @@ def to_onnx(
117
229
  precision=str(first_float[0].dtype).split(".")[-1],
118
230
  execution_provider="cuda" if first.is_cuda else "cpu",
119
231
  cache_dir=os.path.dirname(filename),
232
+ **(exporter_kwargs or {}),
120
233
  )
121
234
  save_model_builder(onx, os.path.dirname(filename))
122
235
  return onx