onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +91 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +3 -3
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +92 -23
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +90 -26
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +103 -1
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +103 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -1,14 +1,87 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
import contextlib
|
|
2
3
|
import json
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
5
6
|
import sys
|
|
6
7
|
import textwrap
|
|
8
|
+
import time
|
|
7
9
|
import onnx
|
|
8
10
|
from typing import Any, Dict, List, Optional, Union
|
|
9
11
|
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
|
|
10
12
|
|
|
11
13
|
|
|
14
|
+
def get_parser_dot() -> ArgumentParser:
|
|
15
|
+
parser = ArgumentParser(
|
|
16
|
+
prog="dot",
|
|
17
|
+
description=textwrap.dedent(
|
|
18
|
+
"""
|
|
19
|
+
Converts a model into a dot file dot can draw into a graph.
|
|
20
|
+
"""
|
|
21
|
+
),
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument("input", type=str, help="onnx model to lighten")
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"-o",
|
|
26
|
+
"--output",
|
|
27
|
+
default="",
|
|
28
|
+
type=str,
|
|
29
|
+
required=False,
|
|
30
|
+
help="dot model to output or empty to print out the result",
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"-v",
|
|
34
|
+
"--verbose",
|
|
35
|
+
type=int,
|
|
36
|
+
default=0,
|
|
37
|
+
required=False,
|
|
38
|
+
help="verbosity",
|
|
39
|
+
)
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
"-r",
|
|
42
|
+
"--run",
|
|
43
|
+
default="",
|
|
44
|
+
required=False,
|
|
45
|
+
help="run dot, in that case, format must be given (svg, png)",
|
|
46
|
+
)
|
|
47
|
+
return parser
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _cmd_dot(argv: List[Any]):
|
|
51
|
+
import subprocess
|
|
52
|
+
from .helpers.dot_helper import to_dot
|
|
53
|
+
|
|
54
|
+
parser = get_parser_dot()
|
|
55
|
+
args = parser.parse_args(argv[1:])
|
|
56
|
+
if args.verbose:
|
|
57
|
+
print(f"-- loads {args.input!r}")
|
|
58
|
+
onx = onnx.load(args.input, load_external_data=False)
|
|
59
|
+
if args.verbose:
|
|
60
|
+
print("-- converts into dot")
|
|
61
|
+
dot = to_dot(onx)
|
|
62
|
+
if args.output:
|
|
63
|
+
if args.verbose:
|
|
64
|
+
print(f"-- saves into {args.output}")
|
|
65
|
+
with open(args.output, "w") as f:
|
|
66
|
+
f.write(dot)
|
|
67
|
+
else:
|
|
68
|
+
print(dot)
|
|
69
|
+
if args.run:
|
|
70
|
+
assert args.output, "Cannot run dot without an output file."
|
|
71
|
+
cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"]
|
|
72
|
+
if args.verbose:
|
|
73
|
+
print(f"-- run {' '.join(cmds)}")
|
|
74
|
+
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
75
|
+
res = p.communicate()
|
|
76
|
+
out, err = res
|
|
77
|
+
if out:
|
|
78
|
+
print("--")
|
|
79
|
+
print(out)
|
|
80
|
+
if err:
|
|
81
|
+
print("--")
|
|
82
|
+
print(err)
|
|
83
|
+
|
|
84
|
+
|
|
12
85
|
def get_parser_lighten() -> ArgumentParser:
|
|
13
86
|
parser = ArgumentParser(
|
|
14
87
|
prog="lighten",
|
|
@@ -624,6 +697,18 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
624
697
|
),
|
|
625
698
|
action=_ParseDict,
|
|
626
699
|
)
|
|
700
|
+
parser.add_argument(
|
|
701
|
+
"--save-ep",
|
|
702
|
+
default="",
|
|
703
|
+
help=textwrap.dedent(
|
|
704
|
+
"""
|
|
705
|
+
saves the exported program with torch.export.save
|
|
706
|
+
and the inputs sets with torch.save,
|
|
707
|
+
then command line sbs can be used to look for discrepancies.
|
|
708
|
+
"""
|
|
709
|
+
),
|
|
710
|
+
)
|
|
711
|
+
|
|
627
712
|
return parser
|
|
628
713
|
|
|
629
714
|
|
|
@@ -690,6 +775,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
690
775
|
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
691
776
|
),
|
|
692
777
|
exporter_options=args.expop,
|
|
778
|
+
save_ep=args.save_ep,
|
|
693
779
|
)
|
|
694
780
|
print("")
|
|
695
781
|
print("-- summary --")
|
|
@@ -1104,6 +1190,287 @@ def _cmd_agg(argv: List[Any]):
|
|
|
1104
1190
|
print(f"Wrote {args.output!r}")
|
|
1105
1191
|
|
|
1106
1192
|
|
|
1193
|
+
def get_parser_sbs() -> ArgumentParser:
|
|
1194
|
+
parser = ArgumentParser(
|
|
1195
|
+
prog="side-by-side (sbs)",
|
|
1196
|
+
description=textwrap.dedent(
|
|
1197
|
+
"""
|
|
1198
|
+
Compares the intermediate outputs between the exported program and
|
|
1199
|
+
the exported onnx model. It assumes some names are common.
|
|
1200
|
+
The execution of the exported program and the onnx model
|
|
1201
|
+
are done in parallel. The device is the one used to store the
|
|
1202
|
+
model and the inputs.
|
|
1203
|
+
Where do discrepancies start? This function tries to answer that question.
|
|
1204
|
+
"""
|
|
1205
|
+
),
|
|
1206
|
+
epilog=textwrap.dedent(
|
|
1207
|
+
"""
|
|
1208
|
+
The command line expects the following files to be saved with
|
|
1209
|
+
the following function. inputs is a dictionary of the input of the model.
|
|
1210
|
+
|
|
1211
|
+
- torch.export.save(ep: torch.export.ExportedProgram)
|
|
1212
|
+
- torch.save(**inputs)
|
|
1213
|
+
- onnx.save(...)
|
|
1214
|
+
|
|
1215
|
+
The Replay functionality is just a way to investigates a part of a model.
|
|
1216
|
+
It saves torch and onnx inputs, the torch outputs, and the minimal onnx model
|
|
1217
|
+
which shares its inputs with the exported program.
|
|
1218
|
+
This is used to investigate the discrepancies between the torch
|
|
1219
|
+
model (through the exported program) and its onnx conversion.
|
|
1220
|
+
This functionality dumps everything it can to disk
|
|
1221
|
+
so that it be replayed in a separate process.
|
|
1222
|
+
"""
|
|
1223
|
+
),
|
|
1224
|
+
)
|
|
1225
|
+
parser.add_argument(
|
|
1226
|
+
"-i",
|
|
1227
|
+
"--inputs",
|
|
1228
|
+
type=str,
|
|
1229
|
+
required=True,
|
|
1230
|
+
help="model inputs saved with torch.save",
|
|
1231
|
+
)
|
|
1232
|
+
parser.add_argument(
|
|
1233
|
+
"-e",
|
|
1234
|
+
"--ep",
|
|
1235
|
+
type=str,
|
|
1236
|
+
required=True,
|
|
1237
|
+
help=textwrap.dedent(
|
|
1238
|
+
"""
|
|
1239
|
+
exported program saved with torch.export.save,
|
|
1240
|
+
input sets saved with torch.save,
|
|
1241
|
+
"""
|
|
1242
|
+
),
|
|
1243
|
+
)
|
|
1244
|
+
parser.add_argument(
|
|
1245
|
+
"-m",
|
|
1246
|
+
"--onnx",
|
|
1247
|
+
type=str,
|
|
1248
|
+
required=True,
|
|
1249
|
+
help="exported model in onnx format",
|
|
1250
|
+
)
|
|
1251
|
+
parser.add_argument(
|
|
1252
|
+
"-o",
|
|
1253
|
+
"--output",
|
|
1254
|
+
type=str,
|
|
1255
|
+
required=True,
|
|
1256
|
+
help="output name to stored what the command line produces, "
|
|
1257
|
+
"it should be an excel file",
|
|
1258
|
+
)
|
|
1259
|
+
parser.add_argument(
|
|
1260
|
+
"--atol",
|
|
1261
|
+
default=1e-5,
|
|
1262
|
+
required=False,
|
|
1263
|
+
help="absolute tolerance",
|
|
1264
|
+
)
|
|
1265
|
+
parser.add_argument(
|
|
1266
|
+
"--rtol",
|
|
1267
|
+
default=1e-5,
|
|
1268
|
+
required=False,
|
|
1269
|
+
help="relative tolerance",
|
|
1270
|
+
)
|
|
1271
|
+
parser.add_argument(
|
|
1272
|
+
"-v",
|
|
1273
|
+
"--verbose",
|
|
1274
|
+
default=0,
|
|
1275
|
+
required=False,
|
|
1276
|
+
help="verbosity",
|
|
1277
|
+
)
|
|
1278
|
+
parser.add_argument(
|
|
1279
|
+
"-r",
|
|
1280
|
+
"--ratio",
|
|
1281
|
+
default=100,
|
|
1282
|
+
required=False,
|
|
1283
|
+
help="Saves the result in an excel file every <ratio> nodes, default is 100.",
|
|
1284
|
+
)
|
|
1285
|
+
parser.add_argument(
|
|
1286
|
+
"--first",
|
|
1287
|
+
action=BooleanOptionalAction,
|
|
1288
|
+
default=False,
|
|
1289
|
+
help="First runs the whole model.",
|
|
1290
|
+
)
|
|
1291
|
+
parser.add_argument(
|
|
1292
|
+
"-2",
|
|
1293
|
+
"--second-run",
|
|
1294
|
+
action=BooleanOptionalAction,
|
|
1295
|
+
default=False,
|
|
1296
|
+
help=textwrap.dedent(
|
|
1297
|
+
"""
|
|
1298
|
+
Tries to run all onnx nodes with torch results produced by the exported
|
|
1299
|
+
program. It then measures the discrepancies again. It can be used
|
|
1300
|
+
to identify kernel introduces discrepancies from other just propagating them.
|
|
1301
|
+
"""
|
|
1302
|
+
),
|
|
1303
|
+
)
|
|
1304
|
+
parser.add_argument(
|
|
1305
|
+
"--reset",
|
|
1306
|
+
required=False,
|
|
1307
|
+
default="",
|
|
1308
|
+
help=textwrap.dedent(
|
|
1309
|
+
"""
|
|
1310
|
+
List of result names separated by a comma. For those results,
|
|
1311
|
+
the side-by-side will take torch results instead of onnx results
|
|
1312
|
+
to compute the rest of the onnx model.
|
|
1313
|
+
"""
|
|
1314
|
+
),
|
|
1315
|
+
)
|
|
1316
|
+
parser.add_argument(
|
|
1317
|
+
"-s",
|
|
1318
|
+
"--replay-threshold",
|
|
1319
|
+
type=float,
|
|
1320
|
+
required=False,
|
|
1321
|
+
default=1e18,
|
|
1322
|
+
help="Triggers the replay if the discrepancies are higher than this value.",
|
|
1323
|
+
)
|
|
1324
|
+
parser.add_argument(
|
|
1325
|
+
"-n",
|
|
1326
|
+
"--replay-names",
|
|
1327
|
+
required=False,
|
|
1328
|
+
default="",
|
|
1329
|
+
help="Triggers the replay if a result name is in this set of values (comma separated)",
|
|
1330
|
+
)
|
|
1331
|
+
parser.add_argument(
|
|
1332
|
+
"-t",
|
|
1333
|
+
"--replay-op-types",
|
|
1334
|
+
required=False,
|
|
1335
|
+
default="",
|
|
1336
|
+
help="Triggers the replay if an onnx type is in this set of values (comma separated)",
|
|
1337
|
+
)
|
|
1338
|
+
parser.add_argument(
|
|
1339
|
+
"-f",
|
|
1340
|
+
"--replay-folder",
|
|
1341
|
+
required=False,
|
|
1342
|
+
default="replay",
|
|
1343
|
+
help="If the replay is triggered, this defines the folder where everything is dumped.",
|
|
1344
|
+
)
|
|
1345
|
+
|
|
1346
|
+
return parser
|
|
1347
|
+
|
|
1348
|
+
|
|
1349
|
+
def _cmd_sbs(argv: List[Any]):
|
|
1350
|
+
import pandas
|
|
1351
|
+
import torch
|
|
1352
|
+
from .helpers import flatten_object, max_diff, string_diff, string_type
|
|
1353
|
+
from .torch_onnx.sbs import run_aligned
|
|
1354
|
+
from .torch_onnx.sbs_dataclasses import ReplayConfiguration
|
|
1355
|
+
from .reference import OnnxruntimeEvaluator
|
|
1356
|
+
|
|
1357
|
+
parser = get_parser_sbs()
|
|
1358
|
+
args = parser.parse_args(argv[1:])
|
|
1359
|
+
|
|
1360
|
+
def _size(name):
|
|
1361
|
+
s = os.stat(name).st_size
|
|
1362
|
+
return f"{s / 2**20:1.3f} Mb"
|
|
1363
|
+
|
|
1364
|
+
print("-- side by side")
|
|
1365
|
+
print(f"-- ep: {_size(args.ep)}: {args.ep}")
|
|
1366
|
+
print(f"-- inputs: {_size(args.inputs)}: {args.inputs}")
|
|
1367
|
+
print(f"-- onnx: {_size(args.onnx)}: {args.onnx}")
|
|
1368
|
+
print(f"-- output: {args.output}")
|
|
1369
|
+
|
|
1370
|
+
print(f"-- load inputs {args.inputs!r}")
|
|
1371
|
+
begin = time.perf_counter()
|
|
1372
|
+
inputs = torch.load(args.inputs)
|
|
1373
|
+
s = string_type(inputs, with_shape=True, with_device=True)
|
|
1374
|
+
print(f"-- done in {time.perf_counter() - begin:1.1f}s - {s}")
|
|
1375
|
+
|
|
1376
|
+
if isinstance(inputs, dict) and len(inputs) == 2 and set(inputs) == {"args", "kwargs"}:
|
|
1377
|
+
margs = inputs["args"]
|
|
1378
|
+
mkwargs = inputs["kwargs"]
|
|
1379
|
+
elif isinstance(inputs, tuple):
|
|
1380
|
+
margs = inputs
|
|
1381
|
+
mkwargs = {}
|
|
1382
|
+
elif isinstance(inputs, dict):
|
|
1383
|
+
margs = tuple()
|
|
1384
|
+
mkwargs = inputs
|
|
1385
|
+
else:
|
|
1386
|
+
raise ValueError(
|
|
1387
|
+
f"Unable to infer args, kwargs from inputs {string_type(inputs, with_shape=True)}"
|
|
1388
|
+
)
|
|
1389
|
+
|
|
1390
|
+
print("-- import transformers.modeling_outputs to register serialization functions")
|
|
1391
|
+
with contextlib.suppress(ImportError):
|
|
1392
|
+
import transformers.modeling_outputs # noqa: F401
|
|
1393
|
+
print(f"-- load ep {args.ep!r}")
|
|
1394
|
+
begin = time.perf_counter()
|
|
1395
|
+
# We need to load the plugs.
|
|
1396
|
+
from .torch_export_patches.patches.patch_transformers import get_transformers_plugs
|
|
1397
|
+
|
|
1398
|
+
plugs = get_transformers_plugs()
|
|
1399
|
+
assert plugs, "Missing PLUGS for Qwen2.5"
|
|
1400
|
+
ep = torch.export.load(args.ep)
|
|
1401
|
+
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
|
|
1402
|
+
|
|
1403
|
+
if args.first:
|
|
1404
|
+
print("-- compare first, run ep")
|
|
1405
|
+
print(f"-- args: {string_type(margs, with_shape=True, with_device=True)}")
|
|
1406
|
+
print(f"-- mkwargs: {string_type(mkwargs, with_shape=True, with_device=True)}")
|
|
1407
|
+
expected = ep.module()(*margs, **mkwargs)
|
|
1408
|
+
print(f"-- expected: {string_type(expected, with_shape=True, with_device=True)}")
|
|
1409
|
+
sess = OnnxruntimeEvaluator(args.onnx, whole=True)
|
|
1410
|
+
onx_inputs = flatten_object([margs, mkwargs], drop_keys=True)
|
|
1411
|
+
feeds = dict(zip(sess.input_names, onx_inputs))
|
|
1412
|
+
print(f"-- feeds: {string_type(feeds, with_shape=True, with_device=True)}")
|
|
1413
|
+
got = sess.run(None, feeds)
|
|
1414
|
+
print(f"-- got: {string_type(got, with_shape=True, with_device=True)}")
|
|
1415
|
+
diff = max_diff(expected, got, hist=[0.1])
|
|
1416
|
+
print(f"-- diff: {string_diff(diff)}")
|
|
1417
|
+
print("-- done")
|
|
1418
|
+
del sess
|
|
1419
|
+
|
|
1420
|
+
print(f"-- load onnx {args.onnx!r}")
|
|
1421
|
+
begin = time.perf_counter()
|
|
1422
|
+
onx = onnx.load(args.onnx)
|
|
1423
|
+
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
|
|
1424
|
+
|
|
1425
|
+
replay_configuration = None
|
|
1426
|
+
if args.replay_threshold < 1e6 or args.replay_names or args.replay_op_types:
|
|
1427
|
+
replay_configuration = ReplayConfiguration(
|
|
1428
|
+
threshold=args.replay_threshold,
|
|
1429
|
+
selected_names=set(args.replay_names.split(",")) if args.replay_names else None,
|
|
1430
|
+
selected_op_types=(
|
|
1431
|
+
set(args.replay_op_types.split(",")) if args.replay_op_types else None
|
|
1432
|
+
),
|
|
1433
|
+
dump_folder=args.replay_folder,
|
|
1434
|
+
)
|
|
1435
|
+
|
|
1436
|
+
print("-- starts side-by-side")
|
|
1437
|
+
ratio = int(args.ratio)
|
|
1438
|
+
data = []
|
|
1439
|
+
for obs in run_aligned(
|
|
1440
|
+
ep,
|
|
1441
|
+
onx,
|
|
1442
|
+
run_cls=OnnxruntimeEvaluator, # type: ignore[arg-type]
|
|
1443
|
+
atol=float(args.atol),
|
|
1444
|
+
rtol=float(args.rtol),
|
|
1445
|
+
verbose=int(args.verbose),
|
|
1446
|
+
args=margs,
|
|
1447
|
+
kwargs=mkwargs,
|
|
1448
|
+
use_tensor=True,
|
|
1449
|
+
reset_names=args.reset.split(","),
|
|
1450
|
+
exc=False,
|
|
1451
|
+
replay_configuration=replay_configuration,
|
|
1452
|
+
run_onnx_with_torch_inputs=args.second_run,
|
|
1453
|
+
):
|
|
1454
|
+
data.append(obs)
|
|
1455
|
+
if (
|
|
1456
|
+
obs.onnx_op_type != "initializer"
|
|
1457
|
+
and obs.ep_target != "placeholder"
|
|
1458
|
+
and len(data) % ratio == 0
|
|
1459
|
+
):
|
|
1460
|
+
df = pandas.DataFrame(data).apply(
|
|
1461
|
+
lambda col: col.fillna("") if col.dtype == "object" else col
|
|
1462
|
+
)
|
|
1463
|
+
df.to_excel(args.output)
|
|
1464
|
+
print(f"-- final saves into {args.output!r}")
|
|
1465
|
+
df = (
|
|
1466
|
+
pandas.DataFrame(data)
|
|
1467
|
+
.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
|
|
1468
|
+
.dropna(axis=1, how="all")
|
|
1469
|
+
)
|
|
1470
|
+
df.to_excel(args.output, index=False)
|
|
1471
|
+
print("-- done")
|
|
1472
|
+
|
|
1473
|
+
|
|
1107
1474
|
def get_main_parser() -> ArgumentParser:
|
|
1108
1475
|
parser = ArgumentParser(
|
|
1109
1476
|
prog="onnx_diagnostic",
|
|
@@ -1116,10 +1483,12 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1116
1483
|
|
|
1117
1484
|
agg - aggregates statistics from multiple files
|
|
1118
1485
|
config - prints a configuration for a model id
|
|
1486
|
+
dot - converts an onnx model into dot format
|
|
1119
1487
|
exportsample - produces a code to export a model
|
|
1120
1488
|
find - find node consuming or producing a result
|
|
1121
1489
|
lighten - makes an onnx model lighter by removing the weights,
|
|
1122
1490
|
print - prints the model on standard output
|
|
1491
|
+
sbs - compares an exported program and a onnx model
|
|
1123
1492
|
stats - produces statistics on a model
|
|
1124
1493
|
unlighten - restores an onnx model produces by the previous experiment
|
|
1125
1494
|
validate - validate a model
|
|
@@ -1131,10 +1500,12 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1131
1500
|
choices=[
|
|
1132
1501
|
"agg",
|
|
1133
1502
|
"config",
|
|
1503
|
+
"dot",
|
|
1134
1504
|
"exportsample",
|
|
1135
1505
|
"find",
|
|
1136
1506
|
"lighten",
|
|
1137
1507
|
"print",
|
|
1508
|
+
"sbs",
|
|
1138
1509
|
"stats",
|
|
1139
1510
|
"unlighten",
|
|
1140
1511
|
"validate",
|
|
@@ -1146,15 +1517,17 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1146
1517
|
|
|
1147
1518
|
def main(argv: Optional[List[Any]] = None):
|
|
1148
1519
|
fcts = dict(
|
|
1520
|
+
agg=_cmd_agg,
|
|
1521
|
+
config=_cmd_config,
|
|
1522
|
+
dot=_cmd_dot,
|
|
1523
|
+
exportsample=_cmd_export_sample,
|
|
1524
|
+
find=_cmd_find,
|
|
1149
1525
|
lighten=_cmd_lighten,
|
|
1150
|
-
unlighten=_cmd_unlighten,
|
|
1151
1526
|
print=_cmd_print,
|
|
1152
|
-
|
|
1153
|
-
config=_cmd_config,
|
|
1154
|
-
validate=_cmd_validate,
|
|
1527
|
+
sbs=_cmd_sbs,
|
|
1155
1528
|
stats=_cmd_stats,
|
|
1156
|
-
|
|
1157
|
-
|
|
1529
|
+
unlighten=_cmd_unlighten,
|
|
1530
|
+
validate=_cmd_validate,
|
|
1158
1531
|
)
|
|
1159
1532
|
|
|
1160
1533
|
if argv is None:
|
|
@@ -1169,15 +1542,17 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
1169
1542
|
parser.parse_args(argv)
|
|
1170
1543
|
else:
|
|
1171
1544
|
parsers = dict(
|
|
1545
|
+
agg=get_parser_agg,
|
|
1546
|
+
config=get_parser_config,
|
|
1547
|
+
dot=get_parser_dot,
|
|
1548
|
+
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
|
|
1549
|
+
find=get_parser_find,
|
|
1172
1550
|
lighten=get_parser_lighten,
|
|
1173
|
-
unlighten=get_parser_unlighten,
|
|
1174
1551
|
print=get_parser_print,
|
|
1175
|
-
|
|
1176
|
-
config=get_parser_config,
|
|
1177
|
-
validate=get_parser_validate,
|
|
1552
|
+
sbs=get_parser_sbs,
|
|
1178
1553
|
stats=get_parser_stats,
|
|
1179
|
-
|
|
1180
|
-
|
|
1554
|
+
unlighten=get_parser_unlighten,
|
|
1555
|
+
validate=get_parser_validate,
|
|
1181
1556
|
)
|
|
1182
1557
|
cmd = argv[0]
|
|
1183
1558
|
if cmd not in parsers:
|
onnx_diagnostic/export/api.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
2
2
|
import torch
|
|
3
|
+
from .onnx_plug import EagerDirectReplacementWithOnnx
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
def to_onnx(
|
|
@@ -18,6 +19,8 @@ def to_onnx(
|
|
|
18
19
|
save_ep: Optional[str] = None,
|
|
19
20
|
optimize: bool = True,
|
|
20
21
|
use_control_flow_dispatcher: bool = False,
|
|
22
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
23
|
+
inline: bool = True,
|
|
21
24
|
) -> Any:
|
|
22
25
|
"""
|
|
23
26
|
Common API for exporters. By default, the models are optimized to use the
|
|
@@ -40,7 +43,9 @@ def to_onnx(
|
|
|
40
43
|
:param save_ep: saves the exported program
|
|
41
44
|
:param optimize: optimizes the model
|
|
42
45
|
:param use_control_flow_dispatcher: use the dispatcher created to supported
|
|
43
|
-
custom loops (see :func:`onnx_diagnostic.export.
|
|
46
|
+
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
|
|
47
|
+
:param onnx_plugs: the code was modified to replace some parts with onnx translation
|
|
48
|
+
:param inline: inline local functions
|
|
44
49
|
:return: the output of the selected exporter, usually a structure including
|
|
45
50
|
an onnx model
|
|
46
51
|
|
|
@@ -55,7 +60,16 @@ def to_onnx(
|
|
|
55
60
|
exporter=exporter,
|
|
56
61
|
filename=filename,
|
|
57
62
|
)
|
|
63
|
+
|
|
64
|
+
Some examples using control flows are available in
|
|
65
|
+
:func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx` or
|
|
66
|
+
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
|
|
58
67
|
"""
|
|
68
|
+
if exporter_kwargs and "inline" in exporter_kwargs:
|
|
69
|
+
assert (
|
|
70
|
+
inline == exporter_kwargs["inline"]
|
|
71
|
+
), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}"
|
|
72
|
+
exporter_kwargs.pop("inline")
|
|
59
73
|
if exporter == "custom":
|
|
60
74
|
from experimental_experiment.torch_interpreter import (
|
|
61
75
|
to_onnx as _to_onnx,
|
|
@@ -63,16 +77,56 @@ def to_onnx(
|
|
|
63
77
|
)
|
|
64
78
|
from experimental_experiment.xbuilder import OptimizationOptions
|
|
65
79
|
|
|
66
|
-
if use_control_flow_dispatcher:
|
|
67
|
-
from .control_flow import create_global_dispatcher
|
|
68
|
-
|
|
69
|
-
dispatcher = create_global_dispatcher()
|
|
70
|
-
|
|
71
80
|
options = None
|
|
72
81
|
if exporter_kwargs is not None:
|
|
73
82
|
options = exporter_kwargs.pop("options", None)
|
|
74
83
|
if options is None:
|
|
75
84
|
options = OptimizationOptions(patterns="default+onnxruntime")
|
|
85
|
+
if onnx_plugs or use_control_flow_dispatcher:
|
|
86
|
+
from experimental_experiment.torch_interpreter import Dispatcher
|
|
87
|
+
|
|
88
|
+
if use_control_flow_dispatcher:
|
|
89
|
+
from .control_flow_onnx import create_global_dispatcher
|
|
90
|
+
|
|
91
|
+
control_flow_dispatcher = create_global_dispatcher()
|
|
92
|
+
else:
|
|
93
|
+
control_flow_dispatcher = None
|
|
94
|
+
|
|
95
|
+
class MainDispatcher(Dispatcher):
|
|
96
|
+
def __init__(self, previous_dispatcher=None):
|
|
97
|
+
super().__init__({})
|
|
98
|
+
self.previous_dispatcher = previous_dispatcher
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def supported(self):
|
|
102
|
+
if self.previous_dispatcher:
|
|
103
|
+
return (
|
|
104
|
+
set(self.registered_functions) | self.previous_dispatcher.supported
|
|
105
|
+
)
|
|
106
|
+
return set(self.registered_functions)
|
|
107
|
+
|
|
108
|
+
def find_function(self, name: Any):
|
|
109
|
+
if self.previous_dispatcher:
|
|
110
|
+
find = self.previous_dispatcher.find_function(name)
|
|
111
|
+
if find:
|
|
112
|
+
return find
|
|
113
|
+
return Dispatcher.find_function(self, name)
|
|
114
|
+
|
|
115
|
+
def find_method(self, name: Any):
|
|
116
|
+
if self.previous_dispatcher:
|
|
117
|
+
find = self.previous_dispatcher.find_method(name)
|
|
118
|
+
if find:
|
|
119
|
+
return find
|
|
120
|
+
return Dispatcher.find_method(self, name)
|
|
121
|
+
|
|
122
|
+
main_dispatcher = MainDispatcher(control_flow_dispatcher)
|
|
123
|
+
if onnx_plugs:
|
|
124
|
+
for plug in onnx_plugs:
|
|
125
|
+
main_dispatcher.registered_functions[plug.target_name] = (
|
|
126
|
+
plug.custom_converter()
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
main_dispatcher = None
|
|
76
130
|
|
|
77
131
|
return _to_onnx(
|
|
78
132
|
mod,
|
|
@@ -88,15 +142,22 @@ def to_onnx(
|
|
|
88
142
|
output_dynamic_shapes=output_dynamic_shapes,
|
|
89
143
|
export_options=ExportOptions(save_ep=save_ep),
|
|
90
144
|
options=options,
|
|
145
|
+
inline=inline,
|
|
146
|
+
dispatcher=main_dispatcher,
|
|
91
147
|
**(exporter_kwargs or {}),
|
|
92
|
-
dispatcher=dispatcher if use_control_flow_dispatcher else None,
|
|
93
148
|
)
|
|
149
|
+
|
|
94
150
|
if exporter in ("dynamo", "onnx-dynamo"):
|
|
151
|
+
import os
|
|
95
152
|
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
96
153
|
|
|
97
154
|
assert (
|
|
98
155
|
not output_dynamic_shapes
|
|
99
156
|
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
|
|
157
|
+
custom_translation_table = {}
|
|
158
|
+
if onnx_plugs:
|
|
159
|
+
for plug in onnx_plugs:
|
|
160
|
+
custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter()
|
|
100
161
|
epo = torch.onnx.export(
|
|
101
162
|
mod,
|
|
102
163
|
args=args or tuple(),
|
|
@@ -106,12 +167,34 @@ def to_onnx(
|
|
|
106
167
|
opset_version=target_opset,
|
|
107
168
|
dynamic_shapes=dynamic_shapes,
|
|
108
169
|
dynamo=True,
|
|
170
|
+
verbose=verbose,
|
|
171
|
+
dump_exported_program=bool(save_ep),
|
|
172
|
+
artifacts_dir=os.path.dirname(filename) if filename else ".",
|
|
173
|
+
custom_translation_table=custom_translation_table,
|
|
109
174
|
**(exporter_kwargs or {}),
|
|
110
175
|
)
|
|
111
|
-
if optimize:
|
|
176
|
+
if not inline and optimize:
|
|
177
|
+
ort_fusions.optimize_for_ort(epo.model)
|
|
178
|
+
|
|
179
|
+
if onnx_plugs:
|
|
180
|
+
import onnx_ir as ir
|
|
181
|
+
import onnx_ir.passes.common as common_passes
|
|
182
|
+
|
|
183
|
+
irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
|
|
184
|
+
for func in irfunctions:
|
|
185
|
+
epo.model.functions[func.identifier()] = func
|
|
186
|
+
if inline:
|
|
187
|
+
common_passes.InlinePass()(epo.model)
|
|
188
|
+
common_passes.RemoveUnusedOpsetsPass()(epo.model)
|
|
189
|
+
|
|
190
|
+
if inline and optimize:
|
|
112
191
|
ort_fusions.optimize_for_ort(epo.model)
|
|
113
192
|
if filename:
|
|
114
193
|
epo.save(filename, external_data=True)
|
|
194
|
+
if save_ep:
|
|
195
|
+
if isinstance(save_ep, tuple):
|
|
196
|
+
save_ep = save_ep[0]
|
|
197
|
+
torch.export.save(epo.exported_program, f"{save_ep}.pt2")
|
|
115
198
|
return epo
|
|
116
199
|
|
|
117
200
|
if exporter == "modelbuilder":
|