onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +26 -1
- onnx_diagnostic/export/api.py +66 -46
- onnx_diagnostic/export/control_flow_research.py +10 -5
- onnx_diagnostic/export/onnx_plug.py +195 -60
- onnx_diagnostic/ext_test_case.py +99 -53
- onnx_diagnostic/helpers/dot_helper.py +37 -25
- onnx_diagnostic/helpers/helper.py +18 -11
- onnx_diagnostic/helpers/onnx_helper.py +441 -18
- onnx_diagnostic/helpers/ort_session.py +8 -8
- onnx_diagnostic/helpers/torch_helper.py +28 -2
- onnx_diagnostic/reference/ort_evaluator.py +6 -29
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +14 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +11 -5
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +25 -25
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -1286,7 +1286,13 @@ def get_parser_sbs() -> ArgumentParser:
|
|
|
1286
1286
|
"--first",
|
|
1287
1287
|
action=BooleanOptionalAction,
|
|
1288
1288
|
default=False,
|
|
1289
|
-
help="First runs the whole model.",
|
|
1289
|
+
help="First runs the whole model (default is False).",
|
|
1290
|
+
)
|
|
1291
|
+
parser.add_argument(
|
|
1292
|
+
"--sbs",
|
|
1293
|
+
action=BooleanOptionalAction,
|
|
1294
|
+
default=True,
|
|
1295
|
+
help="Runs the side-by-side (default is True).",
|
|
1290
1296
|
)
|
|
1291
1297
|
parser.add_argument(
|
|
1292
1298
|
"-2",
|
|
@@ -1342,6 +1348,20 @@ def get_parser_sbs() -> ArgumentParser:
|
|
|
1342
1348
|
default="replay",
|
|
1343
1349
|
help="If the replay is triggered, this defines the folder where everything is dumped.",
|
|
1344
1350
|
)
|
|
1351
|
+
parser.add_argument(
|
|
1352
|
+
"-p",
|
|
1353
|
+
"--replay-prefix-model",
|
|
1354
|
+
action=BooleanOptionalAction,
|
|
1355
|
+
default=False,
|
|
1356
|
+
help=textwrap.dedent(
|
|
1357
|
+
"""
|
|
1358
|
+
There are two ways to recompute an intermediate output, the first one is to "
|
|
1359
|
+
produce the minimal model between torch and onnx.
|
|
1360
|
+
The second one is to dump onnx models from the inputs
|
|
1361
|
+
to the considered intermediate results. This enables the second one.
|
|
1362
|
+
"""
|
|
1363
|
+
),
|
|
1364
|
+
)
|
|
1345
1365
|
|
|
1346
1366
|
return parser
|
|
1347
1367
|
|
|
@@ -1417,6 +1437,10 @@ def _cmd_sbs(argv: List[Any]):
|
|
|
1417
1437
|
print("-- done")
|
|
1418
1438
|
del sess
|
|
1419
1439
|
|
|
1440
|
+
if not args.sbs:
|
|
1441
|
+
print("-- done")
|
|
1442
|
+
return
|
|
1443
|
+
|
|
1420
1444
|
print(f"-- load onnx {args.onnx!r}")
|
|
1421
1445
|
begin = time.perf_counter()
|
|
1422
1446
|
onx = onnx.load(args.onnx)
|
|
@@ -1431,6 +1455,7 @@ def _cmd_sbs(argv: List[Any]):
|
|
|
1431
1455
|
set(args.replay_op_types.split(",")) if args.replay_op_types else None
|
|
1432
1456
|
),
|
|
1433
1457
|
dump_folder=args.replay_folder,
|
|
1458
|
+
dump_prefix_model=args.replay_prefix_model,
|
|
1434
1459
|
)
|
|
1435
1460
|
|
|
1436
1461
|
print("-- starts side-by-side")
|
onnx_diagnostic/export/api.py
CHANGED
|
@@ -3,6 +3,52 @@ import torch
|
|
|
3
3
|
from .onnx_plug import EagerDirectReplacementWithOnnx
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
def get_main_dispatcher(
|
|
7
|
+
use_control_flow_dispatcher: bool = False,
|
|
8
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
9
|
+
) -> Any: # Dispatcher
|
|
10
|
+
"""Creates a custom dispatcher for the custom exporter."""
|
|
11
|
+
from experimental_experiment.torch_interpreter import Dispatcher
|
|
12
|
+
|
|
13
|
+
if use_control_flow_dispatcher:
|
|
14
|
+
from .control_flow_onnx import create_global_dispatcher
|
|
15
|
+
|
|
16
|
+
control_flow_dispatcher = create_global_dispatcher()
|
|
17
|
+
else:
|
|
18
|
+
control_flow_dispatcher = None
|
|
19
|
+
|
|
20
|
+
class MainDispatcher(Dispatcher):
|
|
21
|
+
def __init__(self, previous_dispatcher=None):
|
|
22
|
+
super().__init__({})
|
|
23
|
+
self.previous_dispatcher = previous_dispatcher
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def supported(self):
|
|
27
|
+
if self.previous_dispatcher:
|
|
28
|
+
return set(self.registered_functions) | self.previous_dispatcher.supported
|
|
29
|
+
return set(self.registered_functions)
|
|
30
|
+
|
|
31
|
+
def find_function(self, name: Any):
|
|
32
|
+
if self.previous_dispatcher:
|
|
33
|
+
find = self.previous_dispatcher.find_function(name)
|
|
34
|
+
if find:
|
|
35
|
+
return find
|
|
36
|
+
return Dispatcher.find_function(self, name)
|
|
37
|
+
|
|
38
|
+
def find_method(self, name: Any):
|
|
39
|
+
if self.previous_dispatcher:
|
|
40
|
+
find = self.previous_dispatcher.find_method(name)
|
|
41
|
+
if find:
|
|
42
|
+
return find
|
|
43
|
+
return Dispatcher.find_method(self, name)
|
|
44
|
+
|
|
45
|
+
main_dispatcher = MainDispatcher(control_flow_dispatcher)
|
|
46
|
+
if onnx_plugs:
|
|
47
|
+
for plug in onnx_plugs:
|
|
48
|
+
main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
|
|
49
|
+
return main_dispatcher
|
|
50
|
+
|
|
51
|
+
|
|
6
52
|
def to_onnx(
|
|
7
53
|
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
|
|
8
54
|
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
|
|
@@ -82,51 +128,11 @@ def to_onnx(
|
|
|
82
128
|
options = exporter_kwargs.pop("options", None)
|
|
83
129
|
if options is None:
|
|
84
130
|
options = OptimizationOptions(patterns="default+onnxruntime")
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
control_flow_dispatcher = create_global_dispatcher()
|
|
92
|
-
else:
|
|
93
|
-
control_flow_dispatcher = None
|
|
94
|
-
|
|
95
|
-
class MainDispatcher(Dispatcher):
|
|
96
|
-
def __init__(self, previous_dispatcher=None):
|
|
97
|
-
super().__init__({})
|
|
98
|
-
self.previous_dispatcher = previous_dispatcher
|
|
99
|
-
|
|
100
|
-
@property
|
|
101
|
-
def supported(self):
|
|
102
|
-
if self.previous_dispatcher:
|
|
103
|
-
return (
|
|
104
|
-
set(self.registered_functions) | self.previous_dispatcher.supported
|
|
105
|
-
)
|
|
106
|
-
return set(self.registered_functions)
|
|
107
|
-
|
|
108
|
-
def find_function(self, name: Any):
|
|
109
|
-
if self.previous_dispatcher:
|
|
110
|
-
find = self.previous_dispatcher.find_function(name)
|
|
111
|
-
if find:
|
|
112
|
-
return find
|
|
113
|
-
return Dispatcher.find_function(self, name)
|
|
114
|
-
|
|
115
|
-
def find_method(self, name: Any):
|
|
116
|
-
if self.previous_dispatcher:
|
|
117
|
-
find = self.previous_dispatcher.find_method(name)
|
|
118
|
-
if find:
|
|
119
|
-
return find
|
|
120
|
-
return Dispatcher.find_method(self, name)
|
|
121
|
-
|
|
122
|
-
main_dispatcher = MainDispatcher(control_flow_dispatcher)
|
|
123
|
-
if onnx_plugs:
|
|
124
|
-
for plug in onnx_plugs:
|
|
125
|
-
main_dispatcher.registered_functions[plug.target_name] = (
|
|
126
|
-
plug.custom_converter()
|
|
127
|
-
)
|
|
128
|
-
else:
|
|
129
|
-
main_dispatcher = None
|
|
131
|
+
main_dispatcher = (
|
|
132
|
+
get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
|
|
133
|
+
if onnx_plugs or use_control_flow_dispatcher
|
|
134
|
+
else None
|
|
135
|
+
)
|
|
130
136
|
|
|
131
137
|
return _to_onnx(
|
|
132
138
|
mod,
|
|
@@ -149,6 +155,7 @@ def to_onnx(
|
|
|
149
155
|
|
|
150
156
|
if exporter in ("dynamo", "onnx-dynamo"):
|
|
151
157
|
import os
|
|
158
|
+
from ..helpers import flatten_object
|
|
152
159
|
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
153
160
|
|
|
154
161
|
assert (
|
|
@@ -180,7 +187,20 @@ def to_onnx(
|
|
|
180
187
|
import onnx_ir as ir
|
|
181
188
|
import onnx_ir.passes.common as common_passes
|
|
182
189
|
|
|
183
|
-
|
|
190
|
+
opset = (
|
|
191
|
+
18
|
|
192
|
+
if target_opset is None
|
|
193
|
+
else (target_opset if isinstance(target_opset, int) else target_opset[""])
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
irfunctions = [
|
|
197
|
+
ir.from_proto(
|
|
198
|
+
plug.get_function_proto(
|
|
199
|
+
opset, *flatten_object((args, kwargs), drop_keys=True)
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
for plug in onnx_plugs
|
|
203
|
+
]
|
|
184
204
|
for func in irfunctions:
|
|
185
205
|
epo.model.functions[func.identifier()] = func
|
|
186
206
|
if inline:
|
|
@@ -92,10 +92,11 @@ def simple_loop_for(
|
|
|
92
92
|
|
|
93
93
|
from torch._higher_order_ops.utils import setup_compilation_env
|
|
94
94
|
|
|
95
|
-
with setup_compilation_env() as
|
|
96
|
-
return
|
|
97
|
-
|
|
98
|
-
|
|
95
|
+
with setup_compilation_env() as _backend:
|
|
96
|
+
return _loop_for_op_wrapper(n_iter, body_fn, *operands)
|
|
97
|
+
# return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
|
|
98
|
+
# n_iter, body_fn, operands
|
|
99
|
+
# )
|
|
99
100
|
|
|
100
101
|
|
|
101
102
|
def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands):
|
|
@@ -127,9 +128,13 @@ def loop_for_op_dense(n_iter, body_fn, operands):
|
|
|
127
128
|
), f"Dense implementation operands must be a list of tensors and ints {operands}"
|
|
128
129
|
mode = _get_current_dispatch_mode()
|
|
129
130
|
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
|
130
|
-
return _loop_for_onnx_fn(body_fn, n_iter, None,
|
|
131
|
+
return _loop_for_onnx_fn(body_fn, n_iter, None, operands)
|
|
131
132
|
|
|
132
133
|
|
|
133
134
|
@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
|
|
134
135
|
def inner(mode, n_iter, body_fn, operands):
|
|
135
136
|
return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
|
|
140
|
+
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
|
|
@@ -3,8 +3,12 @@ from dataclasses import dataclass
|
|
|
3
3
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
4
|
import onnx
|
|
5
5
|
import torch
|
|
6
|
-
from ..helpers import max_diff
|
|
7
|
-
from ..helpers.torch_helper import
|
|
6
|
+
from ..helpers import max_diff, string_type
|
|
7
|
+
from ..helpers.torch_helper import (
|
|
8
|
+
torch_dtype_to_onnx_dtype,
|
|
9
|
+
onnx_dtype_to_torch_dtype,
|
|
10
|
+
int_device_to_torch_device,
|
|
11
|
+
)
|
|
8
12
|
from ..reference import OnnxruntimeEvaluator
|
|
9
13
|
|
|
10
14
|
TUPLE_TENSORS = Tuple[torch.Tensor, ...]
|
|
@@ -50,6 +54,10 @@ class EagerDirectReplacementWithOnnx:
|
|
|
50
54
|
only tensors must be counted
|
|
51
55
|
:param name: the name of the custom op, the function name if not specified
|
|
52
56
|
:param kwargs: constants parameters with their default values
|
|
57
|
+
:param version_selector: selects the version based on the arguments,
|
|
58
|
+
see below for an example, this allows the user to define different
|
|
59
|
+
onnx version depending on the inputs
|
|
60
|
+
:param default_opset: opset to use by default
|
|
53
61
|
:param verbose: verbose level
|
|
54
62
|
|
|
55
63
|
Here is an example:
|
|
@@ -133,27 +141,87 @@ class EagerDirectReplacementWithOnnx:
|
|
|
133
141
|
).model_proto
|
|
134
142
|
|
|
135
143
|
print(pretty_onnx(onx))
|
|
144
|
+
|
|
145
|
+
This shows how to define multiple versions depending on the device,
|
|
146
|
+
the type or the targeted onnx opset.
|
|
147
|
+
|
|
148
|
+
.. code-block:: python
|
|
149
|
+
|
|
150
|
+
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
151
|
+
first_tensor = next(a for a in args if a is not None)
|
|
152
|
+
dtype = first_tensor.dtype
|
|
153
|
+
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
154
|
+
if dtype == torch.float32:
|
|
155
|
+
if opset >= 24:
|
|
156
|
+
return "LOOPA24", itype
|
|
157
|
+
return "LOOPMHA", itype
|
|
158
|
+
if dtype == torch.float16:
|
|
159
|
+
if first_tensor.is_cuda:
|
|
160
|
+
return "PACKED", itype
|
|
161
|
+
return "LOOPMHA", itype
|
|
162
|
+
raise AssertionError(
|
|
163
|
+
f"Unable to handle type {torch.dtype} (itype={itype}) "
|
|
164
|
+
f"on device {torch.device} with opset={opset}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
|
|
168
|
+
qwen_sdpa_attention,
|
|
169
|
+
lambda qs, *args, **kwargs: torch.empty(
|
|
170
|
+
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
|
|
171
|
+
dtype=qs.dtype,
|
|
172
|
+
device=qs.device,
|
|
173
|
+
),
|
|
174
|
+
{
|
|
175
|
+
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
176
|
+
PackedAttention.to_function_proto()
|
|
177
|
+
),
|
|
178
|
+
("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
|
|
179
|
+
("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
180
|
+
onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
|
|
181
|
+
),
|
|
182
|
+
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
183
|
+
LoopMHAAttention.to_function_proto()
|
|
184
|
+
),
|
|
185
|
+
("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
186
|
+
onnx.TensorProto.FLOAT16,
|
|
187
|
+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
|
|
188
|
+
),
|
|
189
|
+
},
|
|
190
|
+
n_inputs=4,
|
|
191
|
+
n_outputs=1,
|
|
192
|
+
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
|
|
193
|
+
name="qwen_sdpa_attention_versatile",
|
|
194
|
+
version_selector=qwen_version_selector,
|
|
195
|
+
)
|
|
136
196
|
"""
|
|
137
197
|
|
|
138
198
|
def __init__(
|
|
139
199
|
self,
|
|
140
200
|
eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
141
201
|
shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
142
|
-
function_proto: onnx.FunctionProto,
|
|
202
|
+
function_proto: Union[onnx.FunctionProto, Dict[Any, onnx.FunctionProto]],
|
|
143
203
|
n_inputs: Optional[int] = None,
|
|
144
204
|
n_outputs: Optional[int] = None,
|
|
145
205
|
name: Optional[str] = None,
|
|
146
206
|
kwargs: Optional[Dict[str, Union[int, float]]] = None,
|
|
147
207
|
verbose: int = 0,
|
|
208
|
+
version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None,
|
|
209
|
+
default_opset: int = 22,
|
|
148
210
|
):
|
|
149
|
-
assert isinstance(
|
|
150
|
-
function_proto,
|
|
211
|
+
assert isinstance(function_proto, onnx.FunctionProto) or (
|
|
212
|
+
isinstance(function_proto, dict)
|
|
213
|
+
or all(isinstance(v, onnx.FunctionProto) for v in function_proto.values())
|
|
151
214
|
), f"Unexpected type {type(function_proto)} for function_proto"
|
|
152
215
|
assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
|
|
153
|
-
assert isinstance(n_outputs, int), f"not implemented yet when
|
|
216
|
+
assert isinstance(n_outputs, int), f"not implemented yet when n_outputs={n_outputs}"
|
|
154
217
|
self.eager_fn = eager_fn
|
|
155
218
|
self.shape_fn = shape_fn
|
|
156
|
-
self.
|
|
219
|
+
self._function_proto = (
|
|
220
|
+
function_proto if isinstance(function_proto, onnx.FunctionProto) else None
|
|
221
|
+
)
|
|
222
|
+
self._function_proto_versioned = (
|
|
223
|
+
function_proto if isinstance(function_proto, dict) else {}
|
|
224
|
+
)
|
|
157
225
|
self.n_inputs = n_inputs
|
|
158
226
|
self.n_outputs = n_outputs
|
|
159
227
|
self.name = name or (
|
|
@@ -170,24 +238,73 @@ class EagerDirectReplacementWithOnnx:
|
|
|
170
238
|
)
|
|
171
239
|
sig = inspect.signature(self.eager_fn)
|
|
172
240
|
params = list(sig.parameters)
|
|
173
|
-
assert (
|
|
174
|
-
len(params) >= n_inputs
|
|
175
|
-
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
|
|
176
|
-
assert n_inputs == len(function_proto.input), (
|
|
177
|
-
f"Input mismatch n_inputs={n_inputs} but "
|
|
178
|
-
f"function_proto.input={function_proto.input}"
|
|
179
|
-
)
|
|
180
|
-
assert n_outputs == len(function_proto.output), (
|
|
181
|
-
f"Output mismatch n_outputs={n_outputs} but "
|
|
182
|
-
f"function_proto.output={function_proto.output}"
|
|
183
|
-
)
|
|
184
|
-
assert (
|
|
185
|
-
function_proto.domain == self.domain
|
|
186
|
-
), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
|
|
187
241
|
self.args_name = [p for p in params if p not in self.kwargs]
|
|
188
242
|
self.kwargs_name = [p for p in params if p in self.kwargs]
|
|
189
243
|
self.verbose = verbose
|
|
190
244
|
self.custom_op = self._register()
|
|
245
|
+
self.version_selector = version_selector
|
|
246
|
+
self.default_opset = default_opset
|
|
247
|
+
self._check_protos(params)
|
|
248
|
+
|
|
249
|
+
def _check_protos(self, params):
|
|
250
|
+
assert (
|
|
251
|
+
len(params) >= self.n_inputs
|
|
252
|
+
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={self.n_inputs}"
|
|
253
|
+
|
|
254
|
+
# one proto
|
|
255
|
+
assert self._function_proto is None or self.n_inputs == len(
|
|
256
|
+
self._function_proto.input
|
|
257
|
+
), (
|
|
258
|
+
f"Input mismatch n_inputs={self.n_inputs} but "
|
|
259
|
+
f"function_proto.input={self._function_proto.input}"
|
|
260
|
+
)
|
|
261
|
+
assert self._function_proto is None or self.n_outputs == len(
|
|
262
|
+
self._function_proto.output
|
|
263
|
+
), (
|
|
264
|
+
f"Output mismatch n_outputs={self.n_outputs} but "
|
|
265
|
+
f"function_proto.output={self._function_proto.output}"
|
|
266
|
+
)
|
|
267
|
+
assert self._function_proto is None or (
|
|
268
|
+
self._function_proto.domain == self.domain
|
|
269
|
+
), f"Function domain must be {self.domain!r} but it is {self._function_proto.domain!r}"
|
|
270
|
+
|
|
271
|
+
# multiple protos
|
|
272
|
+
assert all(
|
|
273
|
+
self.n_inputs == len(v.input) for v in self._function_proto_versioned.values()
|
|
274
|
+
), f"Output mismatch n_inputs={self.n_inputs} but one version is wrong"
|
|
275
|
+
assert all(
|
|
276
|
+
self.n_outputs == len(v.output) for v in self._function_proto_versioned.values()
|
|
277
|
+
), f"Output mismatch n_outputs={self.n_outputs} but one version is wrong"
|
|
278
|
+
assert all(
|
|
279
|
+
v.domain == self.domain for v in self._function_proto_versioned.values()
|
|
280
|
+
), f"Function domain must be {self.domain!r} but it is different in one version"
|
|
281
|
+
assert (
|
|
282
|
+
not self._function_proto_versioned or self.version_selector
|
|
283
|
+
), "version_selector is needed when multiple protos are given."
|
|
284
|
+
|
|
285
|
+
def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto:
|
|
286
|
+
"""Returns the correct version based on the inputs."""
|
|
287
|
+
if self._function_proto:
|
|
288
|
+
return self._function_proto
|
|
289
|
+
assert isinstance(
|
|
290
|
+
opset, int
|
|
291
|
+
), f"The first argument must be an integer for the onnx opset but it is {type(opset)}"
|
|
292
|
+
assert any(
|
|
293
|
+
a is not None for a in args
|
|
294
|
+
), f"Unexpected args={string_type(args, with_shape=True)}"
|
|
295
|
+
try:
|
|
296
|
+
key = self.version_selector(opset, *args) # type: ignore[misc]
|
|
297
|
+
except (ValueError, AttributeError) as e:
|
|
298
|
+
raise AssertionError(
|
|
299
|
+
f"Unable to select a version, fails to get a key, available="
|
|
300
|
+
f"{set(self._function_proto_versioned)}, "
|
|
301
|
+
f"args={string_type(args,with_shape=True)}"
|
|
302
|
+
) from e
|
|
303
|
+
assert key in self._function_proto_versioned, (
|
|
304
|
+
f"Unable to select a version, key={key}, available="
|
|
305
|
+
f"{set(self._function_proto_versioned)}, args={string_type(args,with_shape=True)}"
|
|
306
|
+
)
|
|
307
|
+
return self._function_proto_versioned[key]
|
|
191
308
|
|
|
192
309
|
@property
|
|
193
310
|
def domain(self) -> str:
|
|
@@ -219,6 +336,8 @@ class EagerDirectReplacementWithOnnx:
|
|
|
219
336
|
input_args.append(f"int {p}={val}")
|
|
220
337
|
elif isinstance(val, float):
|
|
221
338
|
input_args.append(f"float {p}={val}")
|
|
339
|
+
elif isinstance(val, str):
|
|
340
|
+
input_args.append(f"str {p}={val}")
|
|
222
341
|
else:
|
|
223
342
|
raise NotImplementedError(
|
|
224
343
|
f"kwargs {p!r} has a default value of unsupported type {type(val)}"
|
|
@@ -243,6 +362,7 @@ class EagerDirectReplacementWithOnnx:
|
|
|
243
362
|
*args,
|
|
244
363
|
engine: Optional[Callable] = None,
|
|
245
364
|
dump_onnx_model: Optional[str] = None,
|
|
365
|
+
opset: int = 22,
|
|
246
366
|
**kwargs,
|
|
247
367
|
) -> VerifyResult:
|
|
248
368
|
"""
|
|
@@ -257,6 +377,7 @@ class EagerDirectReplacementWithOnnx:
|
|
|
257
377
|
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
|
|
258
378
|
:param dump_onnx_model: to dump the onnx model used to verify
|
|
259
379
|
eager and onnx produce the same results
|
|
380
|
+
:param opset: onnx opset to use
|
|
260
381
|
:param kwargs: additional arguments to the function
|
|
261
382
|
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
|
|
262
383
|
"""
|
|
@@ -291,7 +412,7 @@ class EagerDirectReplacementWithOnnx:
|
|
|
291
412
|
assert engine is None, f"Not implemented yet with engine={engine!r}"
|
|
292
413
|
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
293
414
|
sess = OnnxruntimeEvaluator(
|
|
294
|
-
self.
|
|
415
|
+
self.get_function_proto(opset, *args),
|
|
295
416
|
whole=True,
|
|
296
417
|
dump_onnx_model=dump_onnx_model,
|
|
297
418
|
function_kwargs=kws,
|
|
@@ -324,16 +445,25 @@ class EagerDirectReplacementWithOnnx:
|
|
|
324
445
|
*args,
|
|
325
446
|
**kwargs,
|
|
326
447
|
) -> Any:
|
|
327
|
-
if
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
448
|
+
has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)]
|
|
449
|
+
assert (
|
|
450
|
+
has_devices
|
|
451
|
+
), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"
|
|
452
|
+
arg_device = has_devices[0]
|
|
453
|
+
fake_tensor = torch.empty(
|
|
454
|
+
tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]),
|
|
455
|
+
dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])),
|
|
456
|
+
device=int_device_to_torch_device(g.get_device(arg_device)),
|
|
457
|
+
)
|
|
458
|
+
function_proto = self.get_function_proto(g.main_opset, fake_tensor)
|
|
459
|
+
if not g.has_local_function(function_proto.name, domain=function_proto.domain):
|
|
460
|
+
g.add_function(function_proto)
|
|
331
461
|
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
332
462
|
res = g.make_node(
|
|
333
|
-
|
|
463
|
+
function_proto.name,
|
|
334
464
|
ags,
|
|
335
465
|
outputs,
|
|
336
|
-
domain=
|
|
466
|
+
domain=function_proto.domain,
|
|
337
467
|
name=self.target_name,
|
|
338
468
|
**kws,
|
|
339
469
|
)
|
|
@@ -356,41 +486,46 @@ class EagerDirectReplacementWithOnnx:
|
|
|
356
486
|
"""
|
|
357
487
|
import onnxscript
|
|
358
488
|
|
|
359
|
-
onnx_plug_op = onnxscript.values.Opset(domain=self.
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
self.
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
489
|
+
onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
|
|
490
|
+
|
|
491
|
+
def get_proto(*args):
|
|
492
|
+
function_proto = self.get_function_proto(self.default_opset, *args)
|
|
493
|
+
schema = onnx_plug_op[function_proto.name]
|
|
494
|
+
if schema is None:
|
|
495
|
+
all_types = [
|
|
496
|
+
"tensor(float)",
|
|
497
|
+
"tensor(float16)",
|
|
498
|
+
"tensor(bfloat16)",
|
|
499
|
+
"tensor(double)",
|
|
500
|
+
"tensor(int64)",
|
|
501
|
+
"tensor(int32)",
|
|
502
|
+
]
|
|
503
|
+
type_constraints = []
|
|
504
|
+
for i in range(self.n_inputs):
|
|
505
|
+
type_constraints.append((f"T{i}", all_types, ""))
|
|
506
|
+
for i in range(self.n_outputs):
|
|
507
|
+
type_constraints.append((f"U{i}", all_types, ""))
|
|
508
|
+
schema = onnx.defs.OpSchema(
|
|
509
|
+
function_proto.name,
|
|
510
|
+
function_proto.domain,
|
|
511
|
+
1,
|
|
512
|
+
inputs=[
|
|
513
|
+
onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
|
|
514
|
+
for i in range(self.n_inputs)
|
|
515
|
+
],
|
|
516
|
+
outputs=[
|
|
517
|
+
onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
|
|
518
|
+
for i in range(self.n_outputs)
|
|
519
|
+
],
|
|
520
|
+
type_constraints=type_constraints,
|
|
521
|
+
)
|
|
522
|
+
onnx.defs.register_schema(schema)
|
|
523
|
+
op = onnxscript.values.Op(onnx_plug_op, function_proto.name, schema)
|
|
524
|
+
return op
|
|
391
525
|
|
|
392
526
|
def converter(*cargs, **ckwargs):
|
|
393
527
|
ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
|
|
528
|
+
op = get_proto(*cargs)
|
|
394
529
|
return op(*ags, n_outputs=self.n_outputs, **kws)
|
|
395
530
|
|
|
396
531
|
return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)
|