onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.2"
6
+ __version__ = "0.8.4"
7
7
  __author__ = "Xavier Dupré"
@@ -1,14 +1,87 @@
1
1
  import argparse
2
+ import contextlib
2
3
  import json
3
4
  import os
4
5
  import re
5
6
  import sys
6
7
  import textwrap
8
+ import time
7
9
  import onnx
8
10
  from typing import Any, Dict, List, Optional, Union
9
11
  from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
10
12
 
11
13
 
14
+ def get_parser_dot() -> ArgumentParser:
15
+ parser = ArgumentParser(
16
+ prog="dot",
17
+ description=textwrap.dedent(
18
+ """
19
+ Converts a model into a dot file dot can draw into a graph.
20
+ """
21
+ ),
22
+ )
23
+ parser.add_argument("input", type=str, help="onnx model to lighten")
24
+ parser.add_argument(
25
+ "-o",
26
+ "--output",
27
+ default="",
28
+ type=str,
29
+ required=False,
30
+ help="dot model to output or empty to print out the result",
31
+ )
32
+ parser.add_argument(
33
+ "-v",
34
+ "--verbose",
35
+ type=int,
36
+ default=0,
37
+ required=False,
38
+ help="verbosity",
39
+ )
40
+ parser.add_argument(
41
+ "-r",
42
+ "--run",
43
+ default="",
44
+ required=False,
45
+ help="run dot, in that case, format must be given (svg, png)",
46
+ )
47
+ return parser
48
+
49
+
50
+ def _cmd_dot(argv: List[Any]):
51
+ import subprocess
52
+ from .helpers.dot_helper import to_dot
53
+
54
+ parser = get_parser_dot()
55
+ args = parser.parse_args(argv[1:])
56
+ if args.verbose:
57
+ print(f"-- loads {args.input!r}")
58
+ onx = onnx.load(args.input, load_external_data=False)
59
+ if args.verbose:
60
+ print("-- converts into dot")
61
+ dot = to_dot(onx)
62
+ if args.output:
63
+ if args.verbose:
64
+ print(f"-- saves into {args.output}")
65
+ with open(args.output, "w") as f:
66
+ f.write(dot)
67
+ else:
68
+ print(dot)
69
+ if args.run:
70
+ assert args.output, "Cannot run dot without an output file."
71
+ cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"]
72
+ if args.verbose:
73
+ print(f"-- run {' '.join(cmds)}")
74
+ p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
75
+ res = p.communicate()
76
+ out, err = res
77
+ if out:
78
+ print("--")
79
+ print(out)
80
+ if err:
81
+ print("--")
82
+ print(err)
83
+
84
+
12
85
  def get_parser_lighten() -> ArgumentParser:
13
86
  parser = ArgumentParser(
14
87
  prog="lighten",
@@ -624,6 +697,18 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
624
697
  ),
625
698
  action=_ParseDict,
626
699
  )
700
+ parser.add_argument(
701
+ "--save-ep",
702
+ default="",
703
+ help=textwrap.dedent(
704
+ """
705
+ saves the exported program with torch.export.save
706
+ and the inputs sets with torch.save,
707
+ then command line sbs can be used to look for discrepancies.
708
+ """
709
+ ),
710
+ )
711
+
627
712
  return parser
628
713
 
629
714
 
@@ -690,6 +775,7 @@ def _cmd_validate(argv: List[Any]):
690
775
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
691
776
  ),
692
777
  exporter_options=args.expop,
778
+ save_ep=args.save_ep,
693
779
  )
694
780
  print("")
695
781
  print("-- summary --")
@@ -1104,6 +1190,312 @@ def _cmd_agg(argv: List[Any]):
1104
1190
  print(f"Wrote {args.output!r}")
1105
1191
 
1106
1192
 
1193
+ def get_parser_sbs() -> ArgumentParser:
1194
+ parser = ArgumentParser(
1195
+ prog="side-by-side (sbs)",
1196
+ description=textwrap.dedent(
1197
+ """
1198
+ Compares the intermediate outputs between the exported program and
1199
+ the exported onnx model. It assumes some names are common.
1200
+ The execution of the exported program and the onnx model
1201
+ are done in parallel. The device is the one used to store the
1202
+ model and the inputs.
1203
+ Where do discrepancies start? This function tries to answer that question.
1204
+ """
1205
+ ),
1206
+ epilog=textwrap.dedent(
1207
+ """
1208
+ The command line expects the following files to be saved with
1209
+ the following function. inputs is a dictionary of the input of the model.
1210
+
1211
+ - torch.export.save(ep: torch.export.ExportedProgram)
1212
+ - torch.save(**inputs)
1213
+ - onnx.save(...)
1214
+
1215
+ The Replay functionality is just a way to investigates a part of a model.
1216
+ It saves torch and onnx inputs, the torch outputs, and the minimal onnx model
1217
+ which shares its inputs with the exported program.
1218
+ This is used to investigate the discrepancies between the torch
1219
+ model (through the exported program) and its onnx conversion.
1220
+ This functionality dumps everything it can to disk
1221
+ so that it be replayed in a separate process.
1222
+ """
1223
+ ),
1224
+ )
1225
+ parser.add_argument(
1226
+ "-i",
1227
+ "--inputs",
1228
+ type=str,
1229
+ required=True,
1230
+ help="model inputs saved with torch.save",
1231
+ )
1232
+ parser.add_argument(
1233
+ "-e",
1234
+ "--ep",
1235
+ type=str,
1236
+ required=True,
1237
+ help=textwrap.dedent(
1238
+ """
1239
+ exported program saved with torch.export.save,
1240
+ input sets saved with torch.save,
1241
+ """
1242
+ ),
1243
+ )
1244
+ parser.add_argument(
1245
+ "-m",
1246
+ "--onnx",
1247
+ type=str,
1248
+ required=True,
1249
+ help="exported model in onnx format",
1250
+ )
1251
+ parser.add_argument(
1252
+ "-o",
1253
+ "--output",
1254
+ type=str,
1255
+ required=True,
1256
+ help="output name to stored what the command line produces, "
1257
+ "it should be an excel file",
1258
+ )
1259
+ parser.add_argument(
1260
+ "--atol",
1261
+ default=1e-5,
1262
+ required=False,
1263
+ help="absolute tolerance",
1264
+ )
1265
+ parser.add_argument(
1266
+ "--rtol",
1267
+ default=1e-5,
1268
+ required=False,
1269
+ help="relative tolerance",
1270
+ )
1271
+ parser.add_argument(
1272
+ "-v",
1273
+ "--verbose",
1274
+ default=0,
1275
+ required=False,
1276
+ help="verbosity",
1277
+ )
1278
+ parser.add_argument(
1279
+ "-r",
1280
+ "--ratio",
1281
+ default=100,
1282
+ required=False,
1283
+ help="Saves the result in an excel file every <ratio> nodes, default is 100.",
1284
+ )
1285
+ parser.add_argument(
1286
+ "--first",
1287
+ action=BooleanOptionalAction,
1288
+ default=False,
1289
+ help="First runs the whole model (default is False).",
1290
+ )
1291
+ parser.add_argument(
1292
+ "--sbs",
1293
+ action=BooleanOptionalAction,
1294
+ default=True,
1295
+ help="Runs the side-by-side (default is True).",
1296
+ )
1297
+ parser.add_argument(
1298
+ "-2",
1299
+ "--second-run",
1300
+ action=BooleanOptionalAction,
1301
+ default=False,
1302
+ help=textwrap.dedent(
1303
+ """
1304
+ Tries to run all onnx nodes with torch results produced by the exported
1305
+ program. It then measures the discrepancies again. It can be used
1306
+ to identify kernel introduces discrepancies from other just propagating them.
1307
+ """
1308
+ ),
1309
+ )
1310
+ parser.add_argument(
1311
+ "--reset",
1312
+ required=False,
1313
+ default="",
1314
+ help=textwrap.dedent(
1315
+ """
1316
+ List of result names separated by a comma. For those results,
1317
+ the side-by-side will take torch results instead of onnx results
1318
+ to compute the rest of the onnx model.
1319
+ """
1320
+ ),
1321
+ )
1322
+ parser.add_argument(
1323
+ "-s",
1324
+ "--replay-threshold",
1325
+ type=float,
1326
+ required=False,
1327
+ default=1e18,
1328
+ help="Triggers the replay if the discrepancies are higher than this value.",
1329
+ )
1330
+ parser.add_argument(
1331
+ "-n",
1332
+ "--replay-names",
1333
+ required=False,
1334
+ default="",
1335
+ help="Triggers the replay if a result name is in this set of values (comma separated)",
1336
+ )
1337
+ parser.add_argument(
1338
+ "-t",
1339
+ "--replay-op-types",
1340
+ required=False,
1341
+ default="",
1342
+ help="Triggers the replay if an onnx type is in this set of values (comma separated)",
1343
+ )
1344
+ parser.add_argument(
1345
+ "-f",
1346
+ "--replay-folder",
1347
+ required=False,
1348
+ default="replay",
1349
+ help="If the replay is triggered, this defines the folder where everything is dumped.",
1350
+ )
1351
+ parser.add_argument(
1352
+ "-p",
1353
+ "--replay-prefix-model",
1354
+ action=BooleanOptionalAction,
1355
+ default=False,
1356
+ help=textwrap.dedent(
1357
+ """
1358
+ There are two ways to recompute an intermediate output, the first one is to "
1359
+ produce the minimal model between torch and onnx.
1360
+ The second one is to dump onnx models from the inputs
1361
+ to the considered intermediate results. This enables the second one.
1362
+ """
1363
+ ),
1364
+ )
1365
+
1366
+ return parser
1367
+
1368
+
1369
+ def _cmd_sbs(argv: List[Any]):
1370
+ import pandas
1371
+ import torch
1372
+ from .helpers import flatten_object, max_diff, string_diff, string_type
1373
+ from .torch_onnx.sbs import run_aligned
1374
+ from .torch_onnx.sbs_dataclasses import ReplayConfiguration
1375
+ from .reference import OnnxruntimeEvaluator
1376
+
1377
+ parser = get_parser_sbs()
1378
+ args = parser.parse_args(argv[1:])
1379
+
1380
+ def _size(name):
1381
+ s = os.stat(name).st_size
1382
+ return f"{s / 2**20:1.3f} Mb"
1383
+
1384
+ print("-- side by side")
1385
+ print(f"-- ep: {_size(args.ep)}: {args.ep}")
1386
+ print(f"-- inputs: {_size(args.inputs)}: {args.inputs}")
1387
+ print(f"-- onnx: {_size(args.onnx)}: {args.onnx}")
1388
+ print(f"-- output: {args.output}")
1389
+
1390
+ print(f"-- load inputs {args.inputs!r}")
1391
+ begin = time.perf_counter()
1392
+ inputs = torch.load(args.inputs)
1393
+ s = string_type(inputs, with_shape=True, with_device=True)
1394
+ print(f"-- done in {time.perf_counter() - begin:1.1f}s - {s}")
1395
+
1396
+ if isinstance(inputs, dict) and len(inputs) == 2 and set(inputs) == {"args", "kwargs"}:
1397
+ margs = inputs["args"]
1398
+ mkwargs = inputs["kwargs"]
1399
+ elif isinstance(inputs, tuple):
1400
+ margs = inputs
1401
+ mkwargs = {}
1402
+ elif isinstance(inputs, dict):
1403
+ margs = tuple()
1404
+ mkwargs = inputs
1405
+ else:
1406
+ raise ValueError(
1407
+ f"Unable to infer args, kwargs from inputs {string_type(inputs, with_shape=True)}"
1408
+ )
1409
+
1410
+ print("-- import transformers.modeling_outputs to register serialization functions")
1411
+ with contextlib.suppress(ImportError):
1412
+ import transformers.modeling_outputs # noqa: F401
1413
+ print(f"-- load ep {args.ep!r}")
1414
+ begin = time.perf_counter()
1415
+ # We need to load the plugs.
1416
+ from .torch_export_patches.patches.patch_transformers import get_transformers_plugs
1417
+
1418
+ plugs = get_transformers_plugs()
1419
+ assert plugs, "Missing PLUGS for Qwen2.5"
1420
+ ep = torch.export.load(args.ep)
1421
+ print(f"-- done in {time.perf_counter() - begin:1.1f}s")
1422
+
1423
+ if args.first:
1424
+ print("-- compare first, run ep")
1425
+ print(f"-- args: {string_type(margs, with_shape=True, with_device=True)}")
1426
+ print(f"-- mkwargs: {string_type(mkwargs, with_shape=True, with_device=True)}")
1427
+ expected = ep.module()(*margs, **mkwargs)
1428
+ print(f"-- expected: {string_type(expected, with_shape=True, with_device=True)}")
1429
+ sess = OnnxruntimeEvaluator(args.onnx, whole=True)
1430
+ onx_inputs = flatten_object([margs, mkwargs], drop_keys=True)
1431
+ feeds = dict(zip(sess.input_names, onx_inputs))
1432
+ print(f"-- feeds: {string_type(feeds, with_shape=True, with_device=True)}")
1433
+ got = sess.run(None, feeds)
1434
+ print(f"-- got: {string_type(got, with_shape=True, with_device=True)}")
1435
+ diff = max_diff(expected, got, hist=[0.1])
1436
+ print(f"-- diff: {string_diff(diff)}")
1437
+ print("-- done")
1438
+ del sess
1439
+
1440
+ if not args.sbs:
1441
+ print("-- done")
1442
+ return
1443
+
1444
+ print(f"-- load onnx {args.onnx!r}")
1445
+ begin = time.perf_counter()
1446
+ onx = onnx.load(args.onnx)
1447
+ print(f"-- done in {time.perf_counter() - begin:1.1f}s")
1448
+
1449
+ replay_configuration = None
1450
+ if args.replay_threshold < 1e6 or args.replay_names or args.replay_op_types:
1451
+ replay_configuration = ReplayConfiguration(
1452
+ threshold=args.replay_threshold,
1453
+ selected_names=set(args.replay_names.split(",")) if args.replay_names else None,
1454
+ selected_op_types=(
1455
+ set(args.replay_op_types.split(",")) if args.replay_op_types else None
1456
+ ),
1457
+ dump_folder=args.replay_folder,
1458
+ dump_prefix_model=args.replay_prefix_model,
1459
+ )
1460
+
1461
+ print("-- starts side-by-side")
1462
+ ratio = int(args.ratio)
1463
+ data = []
1464
+ for obs in run_aligned(
1465
+ ep,
1466
+ onx,
1467
+ run_cls=OnnxruntimeEvaluator, # type: ignore[arg-type]
1468
+ atol=float(args.atol),
1469
+ rtol=float(args.rtol),
1470
+ verbose=int(args.verbose),
1471
+ args=margs,
1472
+ kwargs=mkwargs,
1473
+ use_tensor=True,
1474
+ reset_names=args.reset.split(","),
1475
+ exc=False,
1476
+ replay_configuration=replay_configuration,
1477
+ run_onnx_with_torch_inputs=args.second_run,
1478
+ ):
1479
+ data.append(obs)
1480
+ if (
1481
+ obs.onnx_op_type != "initializer"
1482
+ and obs.ep_target != "placeholder"
1483
+ and len(data) % ratio == 0
1484
+ ):
1485
+ df = pandas.DataFrame(data).apply(
1486
+ lambda col: col.fillna("") if col.dtype == "object" else col
1487
+ )
1488
+ df.to_excel(args.output)
1489
+ print(f"-- final saves into {args.output!r}")
1490
+ df = (
1491
+ pandas.DataFrame(data)
1492
+ .apply(lambda col: col.fillna("") if col.dtype == "object" else col)
1493
+ .dropna(axis=1, how="all")
1494
+ )
1495
+ df.to_excel(args.output, index=False)
1496
+ print("-- done")
1497
+
1498
+
1107
1499
  def get_main_parser() -> ArgumentParser:
1108
1500
  parser = ArgumentParser(
1109
1501
  prog="onnx_diagnostic",
@@ -1116,10 +1508,12 @@ def get_main_parser() -> ArgumentParser:
1116
1508
 
1117
1509
  agg - aggregates statistics from multiple files
1118
1510
  config - prints a configuration for a model id
1511
+ dot - converts an onnx model into dot format
1119
1512
  exportsample - produces a code to export a model
1120
1513
  find - find node consuming or producing a result
1121
1514
  lighten - makes an onnx model lighter by removing the weights,
1122
1515
  print - prints the model on standard output
1516
+ sbs - compares an exported program and a onnx model
1123
1517
  stats - produces statistics on a model
1124
1518
  unlighten - restores an onnx model produces by the previous experiment
1125
1519
  validate - validate a model
@@ -1131,10 +1525,12 @@ def get_main_parser() -> ArgumentParser:
1131
1525
  choices=[
1132
1526
  "agg",
1133
1527
  "config",
1528
+ "dot",
1134
1529
  "exportsample",
1135
1530
  "find",
1136
1531
  "lighten",
1137
1532
  "print",
1533
+ "sbs",
1138
1534
  "stats",
1139
1535
  "unlighten",
1140
1536
  "validate",
@@ -1146,15 +1542,17 @@ def get_main_parser() -> ArgumentParser:
1146
1542
 
1147
1543
  def main(argv: Optional[List[Any]] = None):
1148
1544
  fcts = dict(
1545
+ agg=_cmd_agg,
1546
+ config=_cmd_config,
1547
+ dot=_cmd_dot,
1548
+ exportsample=_cmd_export_sample,
1549
+ find=_cmd_find,
1149
1550
  lighten=_cmd_lighten,
1150
- unlighten=_cmd_unlighten,
1151
1551
  print=_cmd_print,
1152
- find=_cmd_find,
1153
- config=_cmd_config,
1154
- validate=_cmd_validate,
1552
+ sbs=_cmd_sbs,
1155
1553
  stats=_cmd_stats,
1156
- agg=_cmd_agg,
1157
- exportsample=_cmd_export_sample,
1554
+ unlighten=_cmd_unlighten,
1555
+ validate=_cmd_validate,
1158
1556
  )
1159
1557
 
1160
1558
  if argv is None:
@@ -1169,15 +1567,17 @@ def main(argv: Optional[List[Any]] = None):
1169
1567
  parser.parse_args(argv)
1170
1568
  else:
1171
1569
  parsers = dict(
1570
+ agg=get_parser_agg,
1571
+ config=get_parser_config,
1572
+ dot=get_parser_dot,
1573
+ exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
1574
+ find=get_parser_find,
1172
1575
  lighten=get_parser_lighten,
1173
- unlighten=get_parser_unlighten,
1174
1576
  print=get_parser_print,
1175
- find=get_parser_find,
1176
- config=get_parser_config,
1177
- validate=get_parser_validate,
1577
+ sbs=get_parser_sbs,
1178
1578
  stats=get_parser_stats,
1179
- agg=get_parser_agg,
1180
- exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
1579
+ unlighten=get_parser_unlighten,
1580
+ validate=get_parser_validate,
1181
1581
  )
1182
1582
  cmd = argv[0]
1183
1583
  if cmd not in parsers:
@@ -1,5 +1,52 @@
1
1
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
2
2
  import torch
3
+ from .onnx_plug import EagerDirectReplacementWithOnnx
4
+
5
+
6
+ def get_main_dispatcher(
7
+ use_control_flow_dispatcher: bool = False,
8
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
9
+ ) -> Any: # Dispatcher
10
+ """Creates a custom dispatcher for the custom exporter."""
11
+ from experimental_experiment.torch_interpreter import Dispatcher
12
+
13
+ if use_control_flow_dispatcher:
14
+ from .control_flow_onnx import create_global_dispatcher
15
+
16
+ control_flow_dispatcher = create_global_dispatcher()
17
+ else:
18
+ control_flow_dispatcher = None
19
+
20
+ class MainDispatcher(Dispatcher):
21
+ def __init__(self, previous_dispatcher=None):
22
+ super().__init__({})
23
+ self.previous_dispatcher = previous_dispatcher
24
+
25
+ @property
26
+ def supported(self):
27
+ if self.previous_dispatcher:
28
+ return set(self.registered_functions) | self.previous_dispatcher.supported
29
+ return set(self.registered_functions)
30
+
31
+ def find_function(self, name: Any):
32
+ if self.previous_dispatcher:
33
+ find = self.previous_dispatcher.find_function(name)
34
+ if find:
35
+ return find
36
+ return Dispatcher.find_function(self, name)
37
+
38
+ def find_method(self, name: Any):
39
+ if self.previous_dispatcher:
40
+ find = self.previous_dispatcher.find_method(name)
41
+ if find:
42
+ return find
43
+ return Dispatcher.find_method(self, name)
44
+
45
+ main_dispatcher = MainDispatcher(control_flow_dispatcher)
46
+ if onnx_plugs:
47
+ for plug in onnx_plugs:
48
+ main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
49
+ return main_dispatcher
3
50
 
4
51
 
5
52
  def to_onnx(
@@ -18,6 +65,8 @@ def to_onnx(
18
65
  save_ep: Optional[str] = None,
19
66
  optimize: bool = True,
20
67
  use_control_flow_dispatcher: bool = False,
68
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
69
+ inline: bool = True,
21
70
  ) -> Any:
22
71
  """
23
72
  Common API for exporters. By default, the models are optimized to use the
@@ -40,7 +89,9 @@ def to_onnx(
40
89
  :param save_ep: saves the exported program
41
90
  :param optimize: optimizes the model
42
91
  :param use_control_flow_dispatcher: use the dispatcher created to supported
43
- custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
92
+ custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
93
+ :param onnx_plugs: the code was modified to replace some parts with onnx translation
94
+ :param inline: inline local functions
44
95
  :return: the output of the selected exporter, usually a structure including
45
96
  an onnx model
46
97
 
@@ -55,7 +106,16 @@ def to_onnx(
55
106
  exporter=exporter,
56
107
  filename=filename,
57
108
  )
109
+
110
+ Some examples using control flows are available in
111
+ :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx` or
112
+ :class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
58
113
  """
114
+ if exporter_kwargs and "inline" in exporter_kwargs:
115
+ assert (
116
+ inline == exporter_kwargs["inline"]
117
+ ), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}"
118
+ exporter_kwargs.pop("inline")
59
119
  if exporter == "custom":
60
120
  from experimental_experiment.torch_interpreter import (
61
121
  to_onnx as _to_onnx,
@@ -63,16 +123,16 @@ def to_onnx(
63
123
  )
64
124
  from experimental_experiment.xbuilder import OptimizationOptions
65
125
 
66
- if use_control_flow_dispatcher:
67
- from .control_flow import create_global_dispatcher
68
-
69
- dispatcher = create_global_dispatcher()
70
-
71
126
  options = None
72
127
  if exporter_kwargs is not None:
73
128
  options = exporter_kwargs.pop("options", None)
74
129
  if options is None:
75
130
  options = OptimizationOptions(patterns="default+onnxruntime")
131
+ main_dispatcher = (
132
+ get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
133
+ if onnx_plugs or use_control_flow_dispatcher
134
+ else None
135
+ )
76
136
 
77
137
  return _to_onnx(
78
138
  mod,
@@ -88,15 +148,23 @@ def to_onnx(
88
148
  output_dynamic_shapes=output_dynamic_shapes,
89
149
  export_options=ExportOptions(save_ep=save_ep),
90
150
  options=options,
151
+ inline=inline,
152
+ dispatcher=main_dispatcher,
91
153
  **(exporter_kwargs or {}),
92
- dispatcher=dispatcher if use_control_flow_dispatcher else None,
93
154
  )
155
+
94
156
  if exporter in ("dynamo", "onnx-dynamo"):
157
+ import os
158
+ from ..helpers import flatten_object
95
159
  import onnxscript.rewriter.ort_fusions as ort_fusions
96
160
 
97
161
  assert (
98
162
  not output_dynamic_shapes
99
163
  ), f"output_dynamic_shapes not supported for exporter={exporter!r}"
164
+ custom_translation_table = {}
165
+ if onnx_plugs:
166
+ for plug in onnx_plugs:
167
+ custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter()
100
168
  epo = torch.onnx.export(
101
169
  mod,
102
170
  args=args or tuple(),
@@ -106,12 +174,47 @@ def to_onnx(
106
174
  opset_version=target_opset,
107
175
  dynamic_shapes=dynamic_shapes,
108
176
  dynamo=True,
177
+ verbose=verbose,
178
+ dump_exported_program=bool(save_ep),
179
+ artifacts_dir=os.path.dirname(filename) if filename else ".",
180
+ custom_translation_table=custom_translation_table,
109
181
  **(exporter_kwargs or {}),
110
182
  )
111
- if optimize:
183
+ if not inline and optimize:
184
+ ort_fusions.optimize_for_ort(epo.model)
185
+
186
+ if onnx_plugs:
187
+ import onnx_ir as ir
188
+ import onnx_ir.passes.common as common_passes
189
+
190
+ opset = (
191
+ 18
192
+ if target_opset is None
193
+ else (target_opset if isinstance(target_opset, int) else target_opset[""])
194
+ )
195
+
196
+ irfunctions = [
197
+ ir.from_proto(
198
+ plug.get_function_proto(
199
+ opset, *flatten_object((args, kwargs), drop_keys=True)
200
+ )
201
+ )
202
+ for plug in onnx_plugs
203
+ ]
204
+ for func in irfunctions:
205
+ epo.model.functions[func.identifier()] = func
206
+ if inline:
207
+ common_passes.InlinePass()(epo.model)
208
+ common_passes.RemoveUnusedOpsetsPass()(epo.model)
209
+
210
+ if inline and optimize:
112
211
  ort_fusions.optimize_for_ort(epo.model)
113
212
  if filename:
114
213
  epo.save(filename, external_data=True)
214
+ if save_ep:
215
+ if isinstance(save_ep, tuple):
216
+ save_ep = save_ep[0]
217
+ torch.export.save(epo.exported_program, f"{save_ep}.pt2")
115
218
  return epo
116
219
 
117
220
  if exporter == "modelbuilder":