onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -1,14 +1,87 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
import contextlib
|
|
2
3
|
import json
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
5
6
|
import sys
|
|
6
7
|
import textwrap
|
|
8
|
+
import time
|
|
7
9
|
import onnx
|
|
8
10
|
from typing import Any, Dict, List, Optional, Union
|
|
9
11
|
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
|
|
10
12
|
|
|
11
13
|
|
|
14
|
+
def get_parser_dot() -> ArgumentParser:
|
|
15
|
+
parser = ArgumentParser(
|
|
16
|
+
prog="dot",
|
|
17
|
+
description=textwrap.dedent(
|
|
18
|
+
"""
|
|
19
|
+
Converts a model into a dot file dot can draw into a graph.
|
|
20
|
+
"""
|
|
21
|
+
),
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument("input", type=str, help="onnx model to lighten")
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"-o",
|
|
26
|
+
"--output",
|
|
27
|
+
default="",
|
|
28
|
+
type=str,
|
|
29
|
+
required=False,
|
|
30
|
+
help="dot model to output or empty to print out the result",
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"-v",
|
|
34
|
+
"--verbose",
|
|
35
|
+
type=int,
|
|
36
|
+
default=0,
|
|
37
|
+
required=False,
|
|
38
|
+
help="verbosity",
|
|
39
|
+
)
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
"-r",
|
|
42
|
+
"--run",
|
|
43
|
+
default="",
|
|
44
|
+
required=False,
|
|
45
|
+
help="run dot, in that case, format must be given (svg, png)",
|
|
46
|
+
)
|
|
47
|
+
return parser
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _cmd_dot(argv: List[Any]):
|
|
51
|
+
import subprocess
|
|
52
|
+
from .helpers.dot_helper import to_dot
|
|
53
|
+
|
|
54
|
+
parser = get_parser_dot()
|
|
55
|
+
args = parser.parse_args(argv[1:])
|
|
56
|
+
if args.verbose:
|
|
57
|
+
print(f"-- loads {args.input!r}")
|
|
58
|
+
onx = onnx.load(args.input, load_external_data=False)
|
|
59
|
+
if args.verbose:
|
|
60
|
+
print("-- converts into dot")
|
|
61
|
+
dot = to_dot(onx)
|
|
62
|
+
if args.output:
|
|
63
|
+
if args.verbose:
|
|
64
|
+
print(f"-- saves into {args.output}")
|
|
65
|
+
with open(args.output, "w") as f:
|
|
66
|
+
f.write(dot)
|
|
67
|
+
else:
|
|
68
|
+
print(dot)
|
|
69
|
+
if args.run:
|
|
70
|
+
assert args.output, "Cannot run dot without an output file."
|
|
71
|
+
cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"]
|
|
72
|
+
if args.verbose:
|
|
73
|
+
print(f"-- run {' '.join(cmds)}")
|
|
74
|
+
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
75
|
+
res = p.communicate()
|
|
76
|
+
out, err = res
|
|
77
|
+
if out:
|
|
78
|
+
print("--")
|
|
79
|
+
print(out)
|
|
80
|
+
if err:
|
|
81
|
+
print("--")
|
|
82
|
+
print(err)
|
|
83
|
+
|
|
84
|
+
|
|
12
85
|
def get_parser_lighten() -> ArgumentParser:
|
|
13
86
|
parser = ArgumentParser(
|
|
14
87
|
prog="lighten",
|
|
@@ -624,6 +697,18 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
624
697
|
),
|
|
625
698
|
action=_ParseDict,
|
|
626
699
|
)
|
|
700
|
+
parser.add_argument(
|
|
701
|
+
"--save-ep",
|
|
702
|
+
default="",
|
|
703
|
+
help=textwrap.dedent(
|
|
704
|
+
"""
|
|
705
|
+
saves the exported program with torch.export.save
|
|
706
|
+
and the inputs sets with torch.save,
|
|
707
|
+
then command line sbs can be used to look for discrepancies.
|
|
708
|
+
"""
|
|
709
|
+
),
|
|
710
|
+
)
|
|
711
|
+
|
|
627
712
|
return parser
|
|
628
713
|
|
|
629
714
|
|
|
@@ -690,6 +775,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
690
775
|
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
691
776
|
),
|
|
692
777
|
exporter_options=args.expop,
|
|
778
|
+
save_ep=args.save_ep,
|
|
693
779
|
)
|
|
694
780
|
print("")
|
|
695
781
|
print("-- summary --")
|
|
@@ -1104,6 +1190,312 @@ def _cmd_agg(argv: List[Any]):
|
|
|
1104
1190
|
print(f"Wrote {args.output!r}")
|
|
1105
1191
|
|
|
1106
1192
|
|
|
1193
|
+
def get_parser_sbs() -> ArgumentParser:
|
|
1194
|
+
parser = ArgumentParser(
|
|
1195
|
+
prog="side-by-side (sbs)",
|
|
1196
|
+
description=textwrap.dedent(
|
|
1197
|
+
"""
|
|
1198
|
+
Compares the intermediate outputs between the exported program and
|
|
1199
|
+
the exported onnx model. It assumes some names are common.
|
|
1200
|
+
The execution of the exported program and the onnx model
|
|
1201
|
+
are done in parallel. The device is the one used to store the
|
|
1202
|
+
model and the inputs.
|
|
1203
|
+
Where do discrepancies start? This function tries to answer that question.
|
|
1204
|
+
"""
|
|
1205
|
+
),
|
|
1206
|
+
epilog=textwrap.dedent(
|
|
1207
|
+
"""
|
|
1208
|
+
The command line expects the following files to be saved with
|
|
1209
|
+
the following function. inputs is a dictionary of the input of the model.
|
|
1210
|
+
|
|
1211
|
+
- torch.export.save(ep: torch.export.ExportedProgram)
|
|
1212
|
+
- torch.save(**inputs)
|
|
1213
|
+
- onnx.save(...)
|
|
1214
|
+
|
|
1215
|
+
The Replay functionality is just a way to investigates a part of a model.
|
|
1216
|
+
It saves torch and onnx inputs, the torch outputs, and the minimal onnx model
|
|
1217
|
+
which shares its inputs with the exported program.
|
|
1218
|
+
This is used to investigate the discrepancies between the torch
|
|
1219
|
+
model (through the exported program) and its onnx conversion.
|
|
1220
|
+
This functionality dumps everything it can to disk
|
|
1221
|
+
so that it be replayed in a separate process.
|
|
1222
|
+
"""
|
|
1223
|
+
),
|
|
1224
|
+
)
|
|
1225
|
+
parser.add_argument(
|
|
1226
|
+
"-i",
|
|
1227
|
+
"--inputs",
|
|
1228
|
+
type=str,
|
|
1229
|
+
required=True,
|
|
1230
|
+
help="model inputs saved with torch.save",
|
|
1231
|
+
)
|
|
1232
|
+
parser.add_argument(
|
|
1233
|
+
"-e",
|
|
1234
|
+
"--ep",
|
|
1235
|
+
type=str,
|
|
1236
|
+
required=True,
|
|
1237
|
+
help=textwrap.dedent(
|
|
1238
|
+
"""
|
|
1239
|
+
exported program saved with torch.export.save,
|
|
1240
|
+
input sets saved with torch.save,
|
|
1241
|
+
"""
|
|
1242
|
+
),
|
|
1243
|
+
)
|
|
1244
|
+
parser.add_argument(
|
|
1245
|
+
"-m",
|
|
1246
|
+
"--onnx",
|
|
1247
|
+
type=str,
|
|
1248
|
+
required=True,
|
|
1249
|
+
help="exported model in onnx format",
|
|
1250
|
+
)
|
|
1251
|
+
parser.add_argument(
|
|
1252
|
+
"-o",
|
|
1253
|
+
"--output",
|
|
1254
|
+
type=str,
|
|
1255
|
+
required=True,
|
|
1256
|
+
help="output name to stored what the command line produces, "
|
|
1257
|
+
"it should be an excel file",
|
|
1258
|
+
)
|
|
1259
|
+
parser.add_argument(
|
|
1260
|
+
"--atol",
|
|
1261
|
+
default=1e-5,
|
|
1262
|
+
required=False,
|
|
1263
|
+
help="absolute tolerance",
|
|
1264
|
+
)
|
|
1265
|
+
parser.add_argument(
|
|
1266
|
+
"--rtol",
|
|
1267
|
+
default=1e-5,
|
|
1268
|
+
required=False,
|
|
1269
|
+
help="relative tolerance",
|
|
1270
|
+
)
|
|
1271
|
+
parser.add_argument(
|
|
1272
|
+
"-v",
|
|
1273
|
+
"--verbose",
|
|
1274
|
+
default=0,
|
|
1275
|
+
required=False,
|
|
1276
|
+
help="verbosity",
|
|
1277
|
+
)
|
|
1278
|
+
parser.add_argument(
|
|
1279
|
+
"-r",
|
|
1280
|
+
"--ratio",
|
|
1281
|
+
default=100,
|
|
1282
|
+
required=False,
|
|
1283
|
+
help="Saves the result in an excel file every <ratio> nodes, default is 100.",
|
|
1284
|
+
)
|
|
1285
|
+
parser.add_argument(
|
|
1286
|
+
"--first",
|
|
1287
|
+
action=BooleanOptionalAction,
|
|
1288
|
+
default=False,
|
|
1289
|
+
help="First runs the whole model (default is False).",
|
|
1290
|
+
)
|
|
1291
|
+
parser.add_argument(
|
|
1292
|
+
"--sbs",
|
|
1293
|
+
action=BooleanOptionalAction,
|
|
1294
|
+
default=True,
|
|
1295
|
+
help="Runs the side-by-side (default is True).",
|
|
1296
|
+
)
|
|
1297
|
+
parser.add_argument(
|
|
1298
|
+
"-2",
|
|
1299
|
+
"--second-run",
|
|
1300
|
+
action=BooleanOptionalAction,
|
|
1301
|
+
default=False,
|
|
1302
|
+
help=textwrap.dedent(
|
|
1303
|
+
"""
|
|
1304
|
+
Tries to run all onnx nodes with torch results produced by the exported
|
|
1305
|
+
program. It then measures the discrepancies again. It can be used
|
|
1306
|
+
to identify kernel introduces discrepancies from other just propagating them.
|
|
1307
|
+
"""
|
|
1308
|
+
),
|
|
1309
|
+
)
|
|
1310
|
+
parser.add_argument(
|
|
1311
|
+
"--reset",
|
|
1312
|
+
required=False,
|
|
1313
|
+
default="",
|
|
1314
|
+
help=textwrap.dedent(
|
|
1315
|
+
"""
|
|
1316
|
+
List of result names separated by a comma. For those results,
|
|
1317
|
+
the side-by-side will take torch results instead of onnx results
|
|
1318
|
+
to compute the rest of the onnx model.
|
|
1319
|
+
"""
|
|
1320
|
+
),
|
|
1321
|
+
)
|
|
1322
|
+
parser.add_argument(
|
|
1323
|
+
"-s",
|
|
1324
|
+
"--replay-threshold",
|
|
1325
|
+
type=float,
|
|
1326
|
+
required=False,
|
|
1327
|
+
default=1e18,
|
|
1328
|
+
help="Triggers the replay if the discrepancies are higher than this value.",
|
|
1329
|
+
)
|
|
1330
|
+
parser.add_argument(
|
|
1331
|
+
"-n",
|
|
1332
|
+
"--replay-names",
|
|
1333
|
+
required=False,
|
|
1334
|
+
default="",
|
|
1335
|
+
help="Triggers the replay if a result name is in this set of values (comma separated)",
|
|
1336
|
+
)
|
|
1337
|
+
parser.add_argument(
|
|
1338
|
+
"-t",
|
|
1339
|
+
"--replay-op-types",
|
|
1340
|
+
required=False,
|
|
1341
|
+
default="",
|
|
1342
|
+
help="Triggers the replay if an onnx type is in this set of values (comma separated)",
|
|
1343
|
+
)
|
|
1344
|
+
parser.add_argument(
|
|
1345
|
+
"-f",
|
|
1346
|
+
"--replay-folder",
|
|
1347
|
+
required=False,
|
|
1348
|
+
default="replay",
|
|
1349
|
+
help="If the replay is triggered, this defines the folder where everything is dumped.",
|
|
1350
|
+
)
|
|
1351
|
+
parser.add_argument(
|
|
1352
|
+
"-p",
|
|
1353
|
+
"--replay-prefix-model",
|
|
1354
|
+
action=BooleanOptionalAction,
|
|
1355
|
+
default=False,
|
|
1356
|
+
help=textwrap.dedent(
|
|
1357
|
+
"""
|
|
1358
|
+
There are two ways to recompute an intermediate output, the first one is to "
|
|
1359
|
+
produce the minimal model between torch and onnx.
|
|
1360
|
+
The second one is to dump onnx models from the inputs
|
|
1361
|
+
to the considered intermediate results. This enables the second one.
|
|
1362
|
+
"""
|
|
1363
|
+
),
|
|
1364
|
+
)
|
|
1365
|
+
|
|
1366
|
+
return parser
|
|
1367
|
+
|
|
1368
|
+
|
|
1369
|
+
def _cmd_sbs(argv: List[Any]):
|
|
1370
|
+
import pandas
|
|
1371
|
+
import torch
|
|
1372
|
+
from .helpers import flatten_object, max_diff, string_diff, string_type
|
|
1373
|
+
from .torch_onnx.sbs import run_aligned
|
|
1374
|
+
from .torch_onnx.sbs_dataclasses import ReplayConfiguration
|
|
1375
|
+
from .reference import OnnxruntimeEvaluator
|
|
1376
|
+
|
|
1377
|
+
parser = get_parser_sbs()
|
|
1378
|
+
args = parser.parse_args(argv[1:])
|
|
1379
|
+
|
|
1380
|
+
def _size(name):
|
|
1381
|
+
s = os.stat(name).st_size
|
|
1382
|
+
return f"{s / 2**20:1.3f} Mb"
|
|
1383
|
+
|
|
1384
|
+
print("-- side by side")
|
|
1385
|
+
print(f"-- ep: {_size(args.ep)}: {args.ep}")
|
|
1386
|
+
print(f"-- inputs: {_size(args.inputs)}: {args.inputs}")
|
|
1387
|
+
print(f"-- onnx: {_size(args.onnx)}: {args.onnx}")
|
|
1388
|
+
print(f"-- output: {args.output}")
|
|
1389
|
+
|
|
1390
|
+
print(f"-- load inputs {args.inputs!r}")
|
|
1391
|
+
begin = time.perf_counter()
|
|
1392
|
+
inputs = torch.load(args.inputs)
|
|
1393
|
+
s = string_type(inputs, with_shape=True, with_device=True)
|
|
1394
|
+
print(f"-- done in {time.perf_counter() - begin:1.1f}s - {s}")
|
|
1395
|
+
|
|
1396
|
+
if isinstance(inputs, dict) and len(inputs) == 2 and set(inputs) == {"args", "kwargs"}:
|
|
1397
|
+
margs = inputs["args"]
|
|
1398
|
+
mkwargs = inputs["kwargs"]
|
|
1399
|
+
elif isinstance(inputs, tuple):
|
|
1400
|
+
margs = inputs
|
|
1401
|
+
mkwargs = {}
|
|
1402
|
+
elif isinstance(inputs, dict):
|
|
1403
|
+
margs = tuple()
|
|
1404
|
+
mkwargs = inputs
|
|
1405
|
+
else:
|
|
1406
|
+
raise ValueError(
|
|
1407
|
+
f"Unable to infer args, kwargs from inputs {string_type(inputs, with_shape=True)}"
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
print("-- import transformers.modeling_outputs to register serialization functions")
|
|
1411
|
+
with contextlib.suppress(ImportError):
|
|
1412
|
+
import transformers.modeling_outputs # noqa: F401
|
|
1413
|
+
print(f"-- load ep {args.ep!r}")
|
|
1414
|
+
begin = time.perf_counter()
|
|
1415
|
+
# We need to load the plugs.
|
|
1416
|
+
from .torch_export_patches.patches.patch_transformers import get_transformers_plugs
|
|
1417
|
+
|
|
1418
|
+
plugs = get_transformers_plugs()
|
|
1419
|
+
assert plugs, "Missing PLUGS for Qwen2.5"
|
|
1420
|
+
ep = torch.export.load(args.ep)
|
|
1421
|
+
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
|
|
1422
|
+
|
|
1423
|
+
if args.first:
|
|
1424
|
+
print("-- compare first, run ep")
|
|
1425
|
+
print(f"-- args: {string_type(margs, with_shape=True, with_device=True)}")
|
|
1426
|
+
print(f"-- mkwargs: {string_type(mkwargs, with_shape=True, with_device=True)}")
|
|
1427
|
+
expected = ep.module()(*margs, **mkwargs)
|
|
1428
|
+
print(f"-- expected: {string_type(expected, with_shape=True, with_device=True)}")
|
|
1429
|
+
sess = OnnxruntimeEvaluator(args.onnx, whole=True)
|
|
1430
|
+
onx_inputs = flatten_object([margs, mkwargs], drop_keys=True)
|
|
1431
|
+
feeds = dict(zip(sess.input_names, onx_inputs))
|
|
1432
|
+
print(f"-- feeds: {string_type(feeds, with_shape=True, with_device=True)}")
|
|
1433
|
+
got = sess.run(None, feeds)
|
|
1434
|
+
print(f"-- got: {string_type(got, with_shape=True, with_device=True)}")
|
|
1435
|
+
diff = max_diff(expected, got, hist=[0.1])
|
|
1436
|
+
print(f"-- diff: {string_diff(diff)}")
|
|
1437
|
+
print("-- done")
|
|
1438
|
+
del sess
|
|
1439
|
+
|
|
1440
|
+
if not args.sbs:
|
|
1441
|
+
print("-- done")
|
|
1442
|
+
return
|
|
1443
|
+
|
|
1444
|
+
print(f"-- load onnx {args.onnx!r}")
|
|
1445
|
+
begin = time.perf_counter()
|
|
1446
|
+
onx = onnx.load(args.onnx)
|
|
1447
|
+
print(f"-- done in {time.perf_counter() - begin:1.1f}s")
|
|
1448
|
+
|
|
1449
|
+
replay_configuration = None
|
|
1450
|
+
if args.replay_threshold < 1e6 or args.replay_names or args.replay_op_types:
|
|
1451
|
+
replay_configuration = ReplayConfiguration(
|
|
1452
|
+
threshold=args.replay_threshold,
|
|
1453
|
+
selected_names=set(args.replay_names.split(",")) if args.replay_names else None,
|
|
1454
|
+
selected_op_types=(
|
|
1455
|
+
set(args.replay_op_types.split(",")) if args.replay_op_types else None
|
|
1456
|
+
),
|
|
1457
|
+
dump_folder=args.replay_folder,
|
|
1458
|
+
dump_prefix_model=args.replay_prefix_model,
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
print("-- starts side-by-side")
|
|
1462
|
+
ratio = int(args.ratio)
|
|
1463
|
+
data = []
|
|
1464
|
+
for obs in run_aligned(
|
|
1465
|
+
ep,
|
|
1466
|
+
onx,
|
|
1467
|
+
run_cls=OnnxruntimeEvaluator, # type: ignore[arg-type]
|
|
1468
|
+
atol=float(args.atol),
|
|
1469
|
+
rtol=float(args.rtol),
|
|
1470
|
+
verbose=int(args.verbose),
|
|
1471
|
+
args=margs,
|
|
1472
|
+
kwargs=mkwargs,
|
|
1473
|
+
use_tensor=True,
|
|
1474
|
+
reset_names=args.reset.split(","),
|
|
1475
|
+
exc=False,
|
|
1476
|
+
replay_configuration=replay_configuration,
|
|
1477
|
+
run_onnx_with_torch_inputs=args.second_run,
|
|
1478
|
+
):
|
|
1479
|
+
data.append(obs)
|
|
1480
|
+
if (
|
|
1481
|
+
obs.onnx_op_type != "initializer"
|
|
1482
|
+
and obs.ep_target != "placeholder"
|
|
1483
|
+
and len(data) % ratio == 0
|
|
1484
|
+
):
|
|
1485
|
+
df = pandas.DataFrame(data).apply(
|
|
1486
|
+
lambda col: col.fillna("") if col.dtype == "object" else col
|
|
1487
|
+
)
|
|
1488
|
+
df.to_excel(args.output)
|
|
1489
|
+
print(f"-- final saves into {args.output!r}")
|
|
1490
|
+
df = (
|
|
1491
|
+
pandas.DataFrame(data)
|
|
1492
|
+
.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
|
|
1493
|
+
.dropna(axis=1, how="all")
|
|
1494
|
+
)
|
|
1495
|
+
df.to_excel(args.output, index=False)
|
|
1496
|
+
print("-- done")
|
|
1497
|
+
|
|
1498
|
+
|
|
1107
1499
|
def get_main_parser() -> ArgumentParser:
|
|
1108
1500
|
parser = ArgumentParser(
|
|
1109
1501
|
prog="onnx_diagnostic",
|
|
@@ -1116,10 +1508,12 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1116
1508
|
|
|
1117
1509
|
agg - aggregates statistics from multiple files
|
|
1118
1510
|
config - prints a configuration for a model id
|
|
1511
|
+
dot - converts an onnx model into dot format
|
|
1119
1512
|
exportsample - produces a code to export a model
|
|
1120
1513
|
find - find node consuming or producing a result
|
|
1121
1514
|
lighten - makes an onnx model lighter by removing the weights,
|
|
1122
1515
|
print - prints the model on standard output
|
|
1516
|
+
sbs - compares an exported program and a onnx model
|
|
1123
1517
|
stats - produces statistics on a model
|
|
1124
1518
|
unlighten - restores an onnx model produces by the previous experiment
|
|
1125
1519
|
validate - validate a model
|
|
@@ -1131,10 +1525,12 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1131
1525
|
choices=[
|
|
1132
1526
|
"agg",
|
|
1133
1527
|
"config",
|
|
1528
|
+
"dot",
|
|
1134
1529
|
"exportsample",
|
|
1135
1530
|
"find",
|
|
1136
1531
|
"lighten",
|
|
1137
1532
|
"print",
|
|
1533
|
+
"sbs",
|
|
1138
1534
|
"stats",
|
|
1139
1535
|
"unlighten",
|
|
1140
1536
|
"validate",
|
|
@@ -1146,15 +1542,17 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1146
1542
|
|
|
1147
1543
|
def main(argv: Optional[List[Any]] = None):
|
|
1148
1544
|
fcts = dict(
|
|
1545
|
+
agg=_cmd_agg,
|
|
1546
|
+
config=_cmd_config,
|
|
1547
|
+
dot=_cmd_dot,
|
|
1548
|
+
exportsample=_cmd_export_sample,
|
|
1549
|
+
find=_cmd_find,
|
|
1149
1550
|
lighten=_cmd_lighten,
|
|
1150
|
-
unlighten=_cmd_unlighten,
|
|
1151
1551
|
print=_cmd_print,
|
|
1152
|
-
|
|
1153
|
-
config=_cmd_config,
|
|
1154
|
-
validate=_cmd_validate,
|
|
1552
|
+
sbs=_cmd_sbs,
|
|
1155
1553
|
stats=_cmd_stats,
|
|
1156
|
-
|
|
1157
|
-
|
|
1554
|
+
unlighten=_cmd_unlighten,
|
|
1555
|
+
validate=_cmd_validate,
|
|
1158
1556
|
)
|
|
1159
1557
|
|
|
1160
1558
|
if argv is None:
|
|
@@ -1169,15 +1567,17 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
1169
1567
|
parser.parse_args(argv)
|
|
1170
1568
|
else:
|
|
1171
1569
|
parsers = dict(
|
|
1570
|
+
agg=get_parser_agg,
|
|
1571
|
+
config=get_parser_config,
|
|
1572
|
+
dot=get_parser_dot,
|
|
1573
|
+
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
|
|
1574
|
+
find=get_parser_find,
|
|
1172
1575
|
lighten=get_parser_lighten,
|
|
1173
|
-
unlighten=get_parser_unlighten,
|
|
1174
1576
|
print=get_parser_print,
|
|
1175
|
-
|
|
1176
|
-
config=get_parser_config,
|
|
1177
|
-
validate=get_parser_validate,
|
|
1577
|
+
sbs=get_parser_sbs,
|
|
1178
1578
|
stats=get_parser_stats,
|
|
1179
|
-
|
|
1180
|
-
|
|
1579
|
+
unlighten=get_parser_unlighten,
|
|
1580
|
+
validate=get_parser_validate,
|
|
1181
1581
|
)
|
|
1182
1582
|
cmd = argv[0]
|
|
1183
1583
|
if cmd not in parsers:
|
onnx_diagnostic/export/api.py
CHANGED
|
@@ -1,5 +1,52 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
2
2
|
import torch
|
|
3
|
+
from .onnx_plug import EagerDirectReplacementWithOnnx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_main_dispatcher(
|
|
7
|
+
use_control_flow_dispatcher: bool = False,
|
|
8
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
9
|
+
) -> Any: # Dispatcher
|
|
10
|
+
"""Creates a custom dispatcher for the custom exporter."""
|
|
11
|
+
from experimental_experiment.torch_interpreter import Dispatcher
|
|
12
|
+
|
|
13
|
+
if use_control_flow_dispatcher:
|
|
14
|
+
from .control_flow_onnx import create_global_dispatcher
|
|
15
|
+
|
|
16
|
+
control_flow_dispatcher = create_global_dispatcher()
|
|
17
|
+
else:
|
|
18
|
+
control_flow_dispatcher = None
|
|
19
|
+
|
|
20
|
+
class MainDispatcher(Dispatcher):
|
|
21
|
+
def __init__(self, previous_dispatcher=None):
|
|
22
|
+
super().__init__({})
|
|
23
|
+
self.previous_dispatcher = previous_dispatcher
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def supported(self):
|
|
27
|
+
if self.previous_dispatcher:
|
|
28
|
+
return set(self.registered_functions) | self.previous_dispatcher.supported
|
|
29
|
+
return set(self.registered_functions)
|
|
30
|
+
|
|
31
|
+
def find_function(self, name: Any):
|
|
32
|
+
if self.previous_dispatcher:
|
|
33
|
+
find = self.previous_dispatcher.find_function(name)
|
|
34
|
+
if find:
|
|
35
|
+
return find
|
|
36
|
+
return Dispatcher.find_function(self, name)
|
|
37
|
+
|
|
38
|
+
def find_method(self, name: Any):
|
|
39
|
+
if self.previous_dispatcher:
|
|
40
|
+
find = self.previous_dispatcher.find_method(name)
|
|
41
|
+
if find:
|
|
42
|
+
return find
|
|
43
|
+
return Dispatcher.find_method(self, name)
|
|
44
|
+
|
|
45
|
+
main_dispatcher = MainDispatcher(control_flow_dispatcher)
|
|
46
|
+
if onnx_plugs:
|
|
47
|
+
for plug in onnx_plugs:
|
|
48
|
+
main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
|
|
49
|
+
return main_dispatcher
|
|
3
50
|
|
|
4
51
|
|
|
5
52
|
def to_onnx(
|
|
@@ -18,6 +65,8 @@ def to_onnx(
|
|
|
18
65
|
save_ep: Optional[str] = None,
|
|
19
66
|
optimize: bool = True,
|
|
20
67
|
use_control_flow_dispatcher: bool = False,
|
|
68
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
69
|
+
inline: bool = True,
|
|
21
70
|
) -> Any:
|
|
22
71
|
"""
|
|
23
72
|
Common API for exporters. By default, the models are optimized to use the
|
|
@@ -40,7 +89,9 @@ def to_onnx(
|
|
|
40
89
|
:param save_ep: saves the exported program
|
|
41
90
|
:param optimize: optimizes the model
|
|
42
91
|
:param use_control_flow_dispatcher: use the dispatcher created to supported
|
|
43
|
-
custom loops (see :func:`onnx_diagnostic.export.
|
|
92
|
+
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
|
|
93
|
+
:param onnx_plugs: the code was modified to replace some parts with onnx translation
|
|
94
|
+
:param inline: inline local functions
|
|
44
95
|
:return: the output of the selected exporter, usually a structure including
|
|
45
96
|
an onnx model
|
|
46
97
|
|
|
@@ -55,7 +106,16 @@ def to_onnx(
|
|
|
55
106
|
exporter=exporter,
|
|
56
107
|
filename=filename,
|
|
57
108
|
)
|
|
109
|
+
|
|
110
|
+
Some examples using control flows are available in
|
|
111
|
+
:func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx` or
|
|
112
|
+
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
|
|
58
113
|
"""
|
|
114
|
+
if exporter_kwargs and "inline" in exporter_kwargs:
|
|
115
|
+
assert (
|
|
116
|
+
inline == exporter_kwargs["inline"]
|
|
117
|
+
), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}"
|
|
118
|
+
exporter_kwargs.pop("inline")
|
|
59
119
|
if exporter == "custom":
|
|
60
120
|
from experimental_experiment.torch_interpreter import (
|
|
61
121
|
to_onnx as _to_onnx,
|
|
@@ -63,16 +123,16 @@ def to_onnx(
|
|
|
63
123
|
)
|
|
64
124
|
from experimental_experiment.xbuilder import OptimizationOptions
|
|
65
125
|
|
|
66
|
-
if use_control_flow_dispatcher:
|
|
67
|
-
from .control_flow import create_global_dispatcher
|
|
68
|
-
|
|
69
|
-
dispatcher = create_global_dispatcher()
|
|
70
|
-
|
|
71
126
|
options = None
|
|
72
127
|
if exporter_kwargs is not None:
|
|
73
128
|
options = exporter_kwargs.pop("options", None)
|
|
74
129
|
if options is None:
|
|
75
130
|
options = OptimizationOptions(patterns="default+onnxruntime")
|
|
131
|
+
main_dispatcher = (
|
|
132
|
+
get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
|
|
133
|
+
if onnx_plugs or use_control_flow_dispatcher
|
|
134
|
+
else None
|
|
135
|
+
)
|
|
76
136
|
|
|
77
137
|
return _to_onnx(
|
|
78
138
|
mod,
|
|
@@ -88,15 +148,23 @@ def to_onnx(
|
|
|
88
148
|
output_dynamic_shapes=output_dynamic_shapes,
|
|
89
149
|
export_options=ExportOptions(save_ep=save_ep),
|
|
90
150
|
options=options,
|
|
151
|
+
inline=inline,
|
|
152
|
+
dispatcher=main_dispatcher,
|
|
91
153
|
**(exporter_kwargs or {}),
|
|
92
|
-
dispatcher=dispatcher if use_control_flow_dispatcher else None,
|
|
93
154
|
)
|
|
155
|
+
|
|
94
156
|
if exporter in ("dynamo", "onnx-dynamo"):
|
|
157
|
+
import os
|
|
158
|
+
from ..helpers import flatten_object
|
|
95
159
|
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
96
160
|
|
|
97
161
|
assert (
|
|
98
162
|
not output_dynamic_shapes
|
|
99
163
|
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
|
|
164
|
+
custom_translation_table = {}
|
|
165
|
+
if onnx_plugs:
|
|
166
|
+
for plug in onnx_plugs:
|
|
167
|
+
custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter()
|
|
100
168
|
epo = torch.onnx.export(
|
|
101
169
|
mod,
|
|
102
170
|
args=args or tuple(),
|
|
@@ -106,12 +174,47 @@ def to_onnx(
|
|
|
106
174
|
opset_version=target_opset,
|
|
107
175
|
dynamic_shapes=dynamic_shapes,
|
|
108
176
|
dynamo=True,
|
|
177
|
+
verbose=verbose,
|
|
178
|
+
dump_exported_program=bool(save_ep),
|
|
179
|
+
artifacts_dir=os.path.dirname(filename) if filename else ".",
|
|
180
|
+
custom_translation_table=custom_translation_table,
|
|
109
181
|
**(exporter_kwargs or {}),
|
|
110
182
|
)
|
|
111
|
-
if optimize:
|
|
183
|
+
if not inline and optimize:
|
|
184
|
+
ort_fusions.optimize_for_ort(epo.model)
|
|
185
|
+
|
|
186
|
+
if onnx_plugs:
|
|
187
|
+
import onnx_ir as ir
|
|
188
|
+
import onnx_ir.passes.common as common_passes
|
|
189
|
+
|
|
190
|
+
opset = (
|
|
191
|
+
18
|
|
192
|
+
if target_opset is None
|
|
193
|
+
else (target_opset if isinstance(target_opset, int) else target_opset[""])
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
irfunctions = [
|
|
197
|
+
ir.from_proto(
|
|
198
|
+
plug.get_function_proto(
|
|
199
|
+
opset, *flatten_object((args, kwargs), drop_keys=True)
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
for plug in onnx_plugs
|
|
203
|
+
]
|
|
204
|
+
for func in irfunctions:
|
|
205
|
+
epo.model.functions[func.identifier()] = func
|
|
206
|
+
if inline:
|
|
207
|
+
common_passes.InlinePass()(epo.model)
|
|
208
|
+
common_passes.RemoveUnusedOpsetsPass()(epo.model)
|
|
209
|
+
|
|
210
|
+
if inline and optimize:
|
|
112
211
|
ort_fusions.optimize_for_ort(epo.model)
|
|
113
212
|
if filename:
|
|
114
213
|
epo.save(filename, external_data=True)
|
|
214
|
+
if save_ep:
|
|
215
|
+
if isinstance(save_ep, tuple):
|
|
216
|
+
save_ep = save_ep[0]
|
|
217
|
+
torch.export.save(epo.exported_program, f"{save_ep}.pt2")
|
|
115
218
|
return epo
|
|
116
219
|
|
|
117
220
|
if exporter == "modelbuilder":
|