onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +47 -10
- onnx_diagnostic/export/api.py +81 -50
- onnx_diagnostic/export/control_flow_research.py +10 -5
- onnx_diagnostic/export/onnx_plug.py +250 -61
- onnx_diagnostic/ext_test_case.py +99 -53
- onnx_diagnostic/helpers/dot_helper.py +37 -25
- onnx_diagnostic/helpers/helper.py +44 -38
- onnx_diagnostic/helpers/onnx_helper.py +441 -18
- onnx_diagnostic/helpers/ort_session.py +8 -8
- onnx_diagnostic/helpers/torch_helper.py +28 -2
- onnx_diagnostic/reference/ort_evaluator.py +6 -29
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +14 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +11 -5
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -198,15 +198,19 @@ def get_parser_print() -> ArgumentParser:
|
|
|
198
198
|
)
|
|
199
199
|
parser.add_argument(
|
|
200
200
|
"fmt",
|
|
201
|
-
choices=["pretty", "raw", "
|
|
201
|
+
choices=["dot", "pretty", "printer", "raw", "shape", "text"],
|
|
202
202
|
default="pretty",
|
|
203
203
|
help=textwrap.dedent(
|
|
204
204
|
"""
|
|
205
205
|
Prints out a model on the standard output.
|
|
206
|
-
|
|
207
|
-
|
|
206
|
+
|
|
207
|
+
dot - converts the graph into dot
|
|
208
208
|
pretty - an improved rendering
|
|
209
|
+
printer - onnx.printer.to_text(...)
|
|
210
|
+
raw - just prints the model with print(...)
|
|
211
|
+
shape - prints every node node with input and output shapes
|
|
209
212
|
text - uses GraphRendering
|
|
213
|
+
|
|
210
214
|
""".strip(
|
|
211
215
|
"\n"
|
|
212
216
|
)
|
|
@@ -232,6 +236,14 @@ def _cmd_print(argv: List[Any]):
|
|
|
232
236
|
from .helpers.graph_helper import GraphRendering
|
|
233
237
|
|
|
234
238
|
print(GraphRendering(onx).text_rendering())
|
|
239
|
+
elif args.fmt == "shape":
|
|
240
|
+
from experimental_experiment.xbuilder import GraphBuilder
|
|
241
|
+
|
|
242
|
+
print(GraphBuilder(onx).pretty_text())
|
|
243
|
+
elif args.fmt == "dot":
|
|
244
|
+
from .helpers.dot_helper import to_dot
|
|
245
|
+
|
|
246
|
+
print(to_dot(onx))
|
|
235
247
|
else:
|
|
236
248
|
raise ValueError(f"Unexpected value fmt={args.fmt!r}")
|
|
237
249
|
|
|
@@ -517,12 +529,12 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
517
529
|
nargs="*",
|
|
518
530
|
help=textwrap.dedent(
|
|
519
531
|
"""
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
532
|
+
Applies patches before exporting, it can be a boolean
|
|
533
|
+
to enable to disable the patches or be more finetuned
|
|
534
|
+
(default is True). It is possible to disable patch for torch
|
|
535
|
+
by adding:
|
|
536
|
+
--patch "patch_sympy=False" --patch "patch_torch=False"
|
|
537
|
+
""".strip(
|
|
526
538
|
"\n"
|
|
527
539
|
)
|
|
528
540
|
),
|
|
@@ -1286,7 +1298,13 @@ def get_parser_sbs() -> ArgumentParser:
|
|
|
1286
1298
|
"--first",
|
|
1287
1299
|
action=BooleanOptionalAction,
|
|
1288
1300
|
default=False,
|
|
1289
|
-
help="First runs the whole model.",
|
|
1301
|
+
help="First runs the whole model (default is False).",
|
|
1302
|
+
)
|
|
1303
|
+
parser.add_argument(
|
|
1304
|
+
"--sbs",
|
|
1305
|
+
action=BooleanOptionalAction,
|
|
1306
|
+
default=True,
|
|
1307
|
+
help="Runs the side-by-side (default is True).",
|
|
1290
1308
|
)
|
|
1291
1309
|
parser.add_argument(
|
|
1292
1310
|
"-2",
|
|
@@ -1342,6 +1360,20 @@ def get_parser_sbs() -> ArgumentParser:
|
|
|
1342
1360
|
default="replay",
|
|
1343
1361
|
help="If the replay is triggered, this defines the folder where everything is dumped.",
|
|
1344
1362
|
)
|
|
1363
|
+
parser.add_argument(
|
|
1364
|
+
"-p",
|
|
1365
|
+
"--replay-prefix-model",
|
|
1366
|
+
action=BooleanOptionalAction,
|
|
1367
|
+
default=False,
|
|
1368
|
+
help=textwrap.dedent(
|
|
1369
|
+
"""
|
|
1370
|
+
There are two ways to recompute an intermediate output, the first one is to "
|
|
1371
|
+
produce the minimal model between torch and onnx.
|
|
1372
|
+
The second one is to dump onnx models from the inputs
|
|
1373
|
+
to the considered intermediate results. This enables the second one.
|
|
1374
|
+
"""
|
|
1375
|
+
),
|
|
1376
|
+
)
|
|
1345
1377
|
|
|
1346
1378
|
return parser
|
|
1347
1379
|
|
|
@@ -1417,6 +1449,10 @@ def _cmd_sbs(argv: List[Any]):
|
|
|
1417
1449
|
print("-- done")
|
|
1418
1450
|
del sess
|
|
1419
1451
|
|
|
1452
|
+
if not args.sbs:
|
|
1453
|
+
print("-- done")
|
|
1454
|
+
return
|
|
1455
|
+
|
|
1420
1456
|
print(f"-- load onnx {args.onnx!r}")
|
|
1421
1457
|
begin = time.perf_counter()
|
|
1422
1458
|
onx = onnx.load(args.onnx)
|
|
@@ -1431,6 +1467,7 @@ def _cmd_sbs(argv: List[Any]):
|
|
|
1431
1467
|
set(args.replay_op_types.split(",")) if args.replay_op_types else None
|
|
1432
1468
|
),
|
|
1433
1469
|
dump_folder=args.replay_folder,
|
|
1470
|
+
dump_prefix_model=args.replay_prefix_model,
|
|
1434
1471
|
)
|
|
1435
1472
|
|
|
1436
1473
|
print("-- starts side-by-side")
|
onnx_diagnostic/export/api.py
CHANGED
|
@@ -3,6 +3,52 @@ import torch
|
|
|
3
3
|
from .onnx_plug import EagerDirectReplacementWithOnnx
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
def get_main_dispatcher(
|
|
7
|
+
use_control_flow_dispatcher: bool = False,
|
|
8
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
9
|
+
) -> Any: # Dispatcher
|
|
10
|
+
"""Creates a custom dispatcher for the custom exporter."""
|
|
11
|
+
from experimental_experiment.torch_interpreter import Dispatcher
|
|
12
|
+
|
|
13
|
+
if use_control_flow_dispatcher:
|
|
14
|
+
from .control_flow_onnx import create_global_dispatcher
|
|
15
|
+
|
|
16
|
+
control_flow_dispatcher = create_global_dispatcher()
|
|
17
|
+
else:
|
|
18
|
+
control_flow_dispatcher = None
|
|
19
|
+
|
|
20
|
+
class MainDispatcher(Dispatcher):
|
|
21
|
+
def __init__(self, previous_dispatcher=None):
|
|
22
|
+
super().__init__({})
|
|
23
|
+
self.previous_dispatcher = previous_dispatcher
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def supported(self):
|
|
27
|
+
if self.previous_dispatcher:
|
|
28
|
+
return set(self.registered_functions) | self.previous_dispatcher.supported
|
|
29
|
+
return set(self.registered_functions)
|
|
30
|
+
|
|
31
|
+
def find_function(self, name: Any):
|
|
32
|
+
if self.previous_dispatcher:
|
|
33
|
+
find = self.previous_dispatcher.find_function(name)
|
|
34
|
+
if find:
|
|
35
|
+
return find
|
|
36
|
+
return Dispatcher.find_function(self, name)
|
|
37
|
+
|
|
38
|
+
def find_method(self, name: Any):
|
|
39
|
+
if self.previous_dispatcher:
|
|
40
|
+
find = self.previous_dispatcher.find_method(name)
|
|
41
|
+
if find:
|
|
42
|
+
return find
|
|
43
|
+
return Dispatcher.find_method(self, name)
|
|
44
|
+
|
|
45
|
+
main_dispatcher = MainDispatcher(control_flow_dispatcher)
|
|
46
|
+
if onnx_plugs:
|
|
47
|
+
for plug in onnx_plugs:
|
|
48
|
+
main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
|
|
49
|
+
return main_dispatcher
|
|
50
|
+
|
|
51
|
+
|
|
6
52
|
def to_onnx(
|
|
7
53
|
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
|
|
8
54
|
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
|
|
@@ -18,6 +64,7 @@ def to_onnx(
|
|
|
18
64
|
exporter_kwargs: Optional[Dict[str, Any]] = None,
|
|
19
65
|
save_ep: Optional[str] = None,
|
|
20
66
|
optimize: bool = True,
|
|
67
|
+
optimizer_for_ort: bool = True,
|
|
21
68
|
use_control_flow_dispatcher: bool = False,
|
|
22
69
|
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
23
70
|
inline: bool = True,
|
|
@@ -42,6 +89,7 @@ def to_onnx(
|
|
|
42
89
|
:param exporter_kwargs: additional parameters sent to the exporter
|
|
43
90
|
:param save_ep: saves the exported program
|
|
44
91
|
:param optimize: optimizes the model
|
|
92
|
+
:param optimizer_for_ort: optimizes the model for onnxruntime
|
|
45
93
|
:param use_control_flow_dispatcher: use the dispatcher created to supported
|
|
46
94
|
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
|
|
47
95
|
:param onnx_plugs: the code was modified to replace some parts with onnx translation
|
|
@@ -80,53 +128,15 @@ def to_onnx(
|
|
|
80
128
|
options = None
|
|
81
129
|
if exporter_kwargs is not None:
|
|
82
130
|
options = exporter_kwargs.pop("options", None)
|
|
83
|
-
if options is None:
|
|
84
|
-
options = OptimizationOptions(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
else:
|
|
93
|
-
control_flow_dispatcher = None
|
|
94
|
-
|
|
95
|
-
class MainDispatcher(Dispatcher):
|
|
96
|
-
def __init__(self, previous_dispatcher=None):
|
|
97
|
-
super().__init__({})
|
|
98
|
-
self.previous_dispatcher = previous_dispatcher
|
|
99
|
-
|
|
100
|
-
@property
|
|
101
|
-
def supported(self):
|
|
102
|
-
if self.previous_dispatcher:
|
|
103
|
-
return (
|
|
104
|
-
set(self.registered_functions) | self.previous_dispatcher.supported
|
|
105
|
-
)
|
|
106
|
-
return set(self.registered_functions)
|
|
107
|
-
|
|
108
|
-
def find_function(self, name: Any):
|
|
109
|
-
if self.previous_dispatcher:
|
|
110
|
-
find = self.previous_dispatcher.find_function(name)
|
|
111
|
-
if find:
|
|
112
|
-
return find
|
|
113
|
-
return Dispatcher.find_function(self, name)
|
|
114
|
-
|
|
115
|
-
def find_method(self, name: Any):
|
|
116
|
-
if self.previous_dispatcher:
|
|
117
|
-
find = self.previous_dispatcher.find_method(name)
|
|
118
|
-
if find:
|
|
119
|
-
return find
|
|
120
|
-
return Dispatcher.find_method(self, name)
|
|
121
|
-
|
|
122
|
-
main_dispatcher = MainDispatcher(control_flow_dispatcher)
|
|
123
|
-
if onnx_plugs:
|
|
124
|
-
for plug in onnx_plugs:
|
|
125
|
-
main_dispatcher.registered_functions[plug.target_name] = (
|
|
126
|
-
plug.custom_converter()
|
|
127
|
-
)
|
|
128
|
-
else:
|
|
129
|
-
main_dispatcher = None
|
|
131
|
+
if options is None and optimize:
|
|
132
|
+
options = OptimizationOptions(
|
|
133
|
+
patterns="default+onnxruntime" if optimizer_for_ort else "default"
|
|
134
|
+
)
|
|
135
|
+
main_dispatcher = (
|
|
136
|
+
get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
|
|
137
|
+
if onnx_plugs or use_control_flow_dispatcher
|
|
138
|
+
else None
|
|
139
|
+
)
|
|
130
140
|
|
|
131
141
|
return _to_onnx(
|
|
132
142
|
mod,
|
|
@@ -149,11 +159,15 @@ def to_onnx(
|
|
|
149
159
|
|
|
150
160
|
if exporter in ("dynamo", "onnx-dynamo"):
|
|
151
161
|
import os
|
|
162
|
+
from ..helpers import flatten_object
|
|
152
163
|
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
153
164
|
|
|
154
165
|
assert (
|
|
155
166
|
not output_dynamic_shapes
|
|
156
167
|
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
|
|
168
|
+
assert (
|
|
169
|
+
optimize
|
|
170
|
+
), f"torch.onnx.export always optimizes the model but optimize={optimize}"
|
|
157
171
|
custom_translation_table = {}
|
|
158
172
|
if onnx_plugs:
|
|
159
173
|
for plug in onnx_plugs:
|
|
@@ -173,21 +187,34 @@ def to_onnx(
|
|
|
173
187
|
custom_translation_table=custom_translation_table,
|
|
174
188
|
**(exporter_kwargs or {}),
|
|
175
189
|
)
|
|
176
|
-
if not inline and optimize:
|
|
190
|
+
if not inline and optimize and optimizer_for_ort:
|
|
177
191
|
ort_fusions.optimize_for_ort(epo.model)
|
|
178
192
|
|
|
179
193
|
if onnx_plugs:
|
|
180
194
|
import onnx_ir as ir
|
|
181
195
|
import onnx_ir.passes.common as common_passes
|
|
182
196
|
|
|
183
|
-
|
|
197
|
+
opset = (
|
|
198
|
+
18
|
|
199
|
+
if target_opset is None
|
|
200
|
+
else (target_opset if isinstance(target_opset, int) else target_opset[""])
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
irfunctions = [
|
|
204
|
+
ir.from_proto(
|
|
205
|
+
plug.get_function_proto(
|
|
206
|
+
opset, *flatten_object((args, kwargs), drop_keys=True)
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
for plug in onnx_plugs
|
|
210
|
+
]
|
|
184
211
|
for func in irfunctions:
|
|
185
212
|
epo.model.functions[func.identifier()] = func
|
|
186
213
|
if inline:
|
|
187
214
|
common_passes.InlinePass()(epo.model)
|
|
188
215
|
common_passes.RemoveUnusedOpsetsPass()(epo.model)
|
|
189
216
|
|
|
190
|
-
if inline and optimize:
|
|
217
|
+
if inline and optimize and optimizer_for_ort:
|
|
191
218
|
ort_fusions.optimize_for_ort(epo.model)
|
|
192
219
|
if filename:
|
|
193
220
|
epo.save(filename, external_data=True)
|
|
@@ -212,6 +239,10 @@ def to_onnx(
|
|
|
212
239
|
f"Only a specified set of inputs is supported for exporter={exporter!r}, "
|
|
213
240
|
f"but it is {list(kwargs)}" # type: ignore[arg-type]
|
|
214
241
|
)
|
|
242
|
+
assert optimizer_for_ort and optimize, (
|
|
243
|
+
f"ModelBuilder only produces model optimized for onnxruntime but "
|
|
244
|
+
f"optimizer_for_ort={optimizer_for_ort} and optimize={optimize}"
|
|
245
|
+
)
|
|
215
246
|
flat_inputs = flatten_object(kwargs, drop_keys=True)
|
|
216
247
|
first = flat_inputs[0]
|
|
217
248
|
first_float = [
|
|
@@ -92,10 +92,11 @@ def simple_loop_for(
|
|
|
92
92
|
|
|
93
93
|
from torch._higher_order_ops.utils import setup_compilation_env
|
|
94
94
|
|
|
95
|
-
with setup_compilation_env() as
|
|
96
|
-
return
|
|
97
|
-
|
|
98
|
-
|
|
95
|
+
with setup_compilation_env() as _backend:
|
|
96
|
+
return _loop_for_op_wrapper(n_iter, body_fn, *operands)
|
|
97
|
+
# return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
|
|
98
|
+
# n_iter, body_fn, operands
|
|
99
|
+
# )
|
|
99
100
|
|
|
100
101
|
|
|
101
102
|
def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands):
|
|
@@ -127,9 +128,13 @@ def loop_for_op_dense(n_iter, body_fn, operands):
|
|
|
127
128
|
), f"Dense implementation operands must be a list of tensors and ints {operands}"
|
|
128
129
|
mode = _get_current_dispatch_mode()
|
|
129
130
|
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
|
130
|
-
return _loop_for_onnx_fn(body_fn, n_iter, None,
|
|
131
|
+
return _loop_for_onnx_fn(body_fn, n_iter, None, operands)
|
|
131
132
|
|
|
132
133
|
|
|
133
134
|
@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
|
|
134
135
|
def inner(mode, n_iter, body_fn, operands):
|
|
135
136
|
return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
|
|
140
|
+
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
|