onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,531 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
|
+
import onnx
|
|
5
|
+
import torch
|
|
6
|
+
from ..helpers import max_diff, string_type
|
|
7
|
+
from ..helpers.torch_helper import (
|
|
8
|
+
torch_dtype_to_onnx_dtype,
|
|
9
|
+
onnx_dtype_to_torch_dtype,
|
|
10
|
+
int_device_to_torch_device,
|
|
11
|
+
)
|
|
12
|
+
from ..reference import OnnxruntimeEvaluator
|
|
13
|
+
|
|
14
|
+
TUPLE_TENSORS = Tuple[torch.Tensor, ...]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def is_exporting() -> bool:
|
|
18
|
+
"""
|
|
19
|
+
Returns :func:`torch.compiler.is_exporting` or
|
|
20
|
+
:func:`torch.compiler.is_compiling`.
|
|
21
|
+
Changes ``_TEST_EXPORT`` to make it trigger.
|
|
22
|
+
"""
|
|
23
|
+
return torch.compiler.is_exporting() or torch.compiler.is_compiling()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class VerifyResult:
|
|
28
|
+
"""
|
|
29
|
+
Outputs of method :meth:`verify
|
|
30
|
+
<onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx.verify>`.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
eager_outputs: TUPLE_TENSORS
|
|
34
|
+
onnx_outputs: TUPLE_TENSORS
|
|
35
|
+
diffs: Tuple[Dict[str, float], ...]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class EagerDirectReplacementWithOnnx:
|
|
39
|
+
"""
|
|
40
|
+
Replaces a piece of code by another one written in ONNX
|
|
41
|
+
at export time. The function inserts a custom operator
|
|
42
|
+
and links it to the eager_fn
|
|
43
|
+
|
|
44
|
+
:param eager_fn: the code it replaces, it must be given in order to be able
|
|
45
|
+
to execute the torch.fx.Graph the exporter produces
|
|
46
|
+
:param shape_fn: the function produces dummy outputs with the shapes
|
|
47
|
+
the exporter can use for the next operators in the graph
|
|
48
|
+
:param function_proto: instances of ``onnx.FunctionProto``,
|
|
49
|
+
its domain must be ``onnx_plug``
|
|
50
|
+
:param n_inputs: number of inputs of the function, if not given,
|
|
51
|
+
the class will infer it from eager_fn signature,
|
|
52
|
+
only tensors must be counted
|
|
53
|
+
:param n_outputs: same for the number of outputs,
|
|
54
|
+
only tensors must be counted
|
|
55
|
+
:param name: the name of the custom op, the function name if not specified
|
|
56
|
+
:param kwargs: constants parameters with their default values
|
|
57
|
+
:param version_selector: selects the version based on the arguments,
|
|
58
|
+
see below for an example, this allows the user to define different
|
|
59
|
+
onnx version depending on the inputs
|
|
60
|
+
:param default_opset: opset to use by default
|
|
61
|
+
:param verbose: verbose level
|
|
62
|
+
|
|
63
|
+
Here is an example:
|
|
64
|
+
|
|
65
|
+
.. runpython::
|
|
66
|
+
:showcode:
|
|
67
|
+
|
|
68
|
+
import onnx.helper as oh
|
|
69
|
+
import torch
|
|
70
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
71
|
+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
72
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
73
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def demo_customsub(x, y):
|
|
77
|
+
return x - y
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def demo_customsub_shape(x, y):
|
|
81
|
+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def make_function_proto():
|
|
85
|
+
return oh.make_function(
|
|
86
|
+
"onnx_plug",
|
|
87
|
+
"demo_customsub",
|
|
88
|
+
["x", "y"],
|
|
89
|
+
["z"],
|
|
90
|
+
[oh.make_node("Sub", ["x", "y"], ["z"])],
|
|
91
|
+
opset_imports=[oh.make_opsetid("", 22)],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class Model(torch.nn.Module):
|
|
96
|
+
def forward(self, x):
|
|
97
|
+
y = x.sum(axis=1, keepdim=True)
|
|
98
|
+
d = torch.ops.onnx_plug.demo_customsub(x, y)
|
|
99
|
+
return torch.abs(d)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
replacements = [
|
|
103
|
+
EagerDirectReplacementWithOnnx(
|
|
104
|
+
demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
|
|
105
|
+
)
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
x = torch.randn((3, 4), dtype=torch.float32)
|
|
109
|
+
model = Model()
|
|
110
|
+
ds = ({0: "d1", 1: "d2"},)
|
|
111
|
+
|
|
112
|
+
# The exported program shows a custom op.
|
|
113
|
+
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
|
|
114
|
+
print("ep")
|
|
115
|
+
|
|
116
|
+
# As the exporter knows how the replace this custom op.
|
|
117
|
+
# Let's export.
|
|
118
|
+
|
|
119
|
+
onx = to_onnx(
|
|
120
|
+
model,
|
|
121
|
+
(x,),
|
|
122
|
+
dynamic_shapes=ds,
|
|
123
|
+
exporter="custom",
|
|
124
|
+
onnx_plugs=replacements,
|
|
125
|
+
target_opset=22,
|
|
126
|
+
inline=False,
|
|
127
|
+
).model_proto
|
|
128
|
+
|
|
129
|
+
print(pretty_onnx(onx))
|
|
130
|
+
|
|
131
|
+
# And with :func:`torch.onnx.export`:
|
|
132
|
+
|
|
133
|
+
onx = to_onnx(
|
|
134
|
+
model,
|
|
135
|
+
(x,),
|
|
136
|
+
dynamic_shapes=ds,
|
|
137
|
+
exporter="onnx-dynamo",
|
|
138
|
+
onnx_plugs=replacements,
|
|
139
|
+
target_opset=22,
|
|
140
|
+
inline=False,
|
|
141
|
+
).model_proto
|
|
142
|
+
|
|
143
|
+
print(pretty_onnx(onx))
|
|
144
|
+
|
|
145
|
+
This shows how to define multiple versions depending on the device,
|
|
146
|
+
the type or the targeted onnx opset.
|
|
147
|
+
|
|
148
|
+
.. code-block:: python
|
|
149
|
+
|
|
150
|
+
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
151
|
+
first_tensor = next(a for a in args if a is not None)
|
|
152
|
+
dtype = first_tensor.dtype
|
|
153
|
+
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
154
|
+
if dtype == torch.float32:
|
|
155
|
+
if opset >= 24:
|
|
156
|
+
return "LOOPA24", itype
|
|
157
|
+
return "LOOPMHA", itype
|
|
158
|
+
if dtype == torch.float16:
|
|
159
|
+
if first_tensor.is_cuda:
|
|
160
|
+
return "PACKED", itype
|
|
161
|
+
return "LOOPMHA", itype
|
|
162
|
+
raise AssertionError(
|
|
163
|
+
f"Unable to handle type {torch.dtype} (itype={itype}) "
|
|
164
|
+
f"on device {torch.device} with opset={opset}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
|
|
168
|
+
qwen_sdpa_attention,
|
|
169
|
+
lambda qs, *args, **kwargs: torch.empty(
|
|
170
|
+
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
|
|
171
|
+
dtype=qs.dtype,
|
|
172
|
+
device=qs.device,
|
|
173
|
+
),
|
|
174
|
+
{
|
|
175
|
+
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
176
|
+
PackedAttention.to_function_proto()
|
|
177
|
+
),
|
|
178
|
+
("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
|
|
179
|
+
("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
180
|
+
onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
|
|
181
|
+
),
|
|
182
|
+
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
183
|
+
LoopMHAAttention.to_function_proto()
|
|
184
|
+
),
|
|
185
|
+
("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
186
|
+
onnx.TensorProto.FLOAT16,
|
|
187
|
+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
|
|
188
|
+
),
|
|
189
|
+
},
|
|
190
|
+
n_inputs=4,
|
|
191
|
+
n_outputs=1,
|
|
192
|
+
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
|
|
193
|
+
name="qwen_sdpa_attention_versatile",
|
|
194
|
+
version_selector=qwen_version_selector,
|
|
195
|
+
)
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def __init__(
|
|
199
|
+
self,
|
|
200
|
+
eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
201
|
+
shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
202
|
+
function_proto: Union[onnx.FunctionProto, Dict[Any, onnx.FunctionProto]],
|
|
203
|
+
n_inputs: Optional[int] = None,
|
|
204
|
+
n_outputs: Optional[int] = None,
|
|
205
|
+
name: Optional[str] = None,
|
|
206
|
+
kwargs: Optional[Dict[str, Union[int, float]]] = None,
|
|
207
|
+
verbose: int = 0,
|
|
208
|
+
version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None,
|
|
209
|
+
default_opset: int = 22,
|
|
210
|
+
):
|
|
211
|
+
assert isinstance(function_proto, onnx.FunctionProto) or (
|
|
212
|
+
isinstance(function_proto, dict)
|
|
213
|
+
or all(isinstance(v, onnx.FunctionProto) for v in function_proto.values())
|
|
214
|
+
), f"Unexpected type {type(function_proto)} for function_proto"
|
|
215
|
+
assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
|
|
216
|
+
assert isinstance(n_outputs, int), f"not implemented yet when n_outputs={n_outputs}"
|
|
217
|
+
self.eager_fn = eager_fn
|
|
218
|
+
self.shape_fn = shape_fn
|
|
219
|
+
self._function_proto = (
|
|
220
|
+
function_proto if isinstance(function_proto, onnx.FunctionProto) else None
|
|
221
|
+
)
|
|
222
|
+
self._function_proto_versioned = (
|
|
223
|
+
function_proto if isinstance(function_proto, dict) else {}
|
|
224
|
+
)
|
|
225
|
+
self.n_inputs = n_inputs
|
|
226
|
+
self.n_outputs = n_outputs
|
|
227
|
+
self.name = name or (
|
|
228
|
+
eager_fn.__name__
|
|
229
|
+
if "<" not in eager_fn.__name__
|
|
230
|
+
else eager_fn.__qualname__.replace("<locals>", "L")
|
|
231
|
+
.replace("<lambda>", "l")
|
|
232
|
+
.replace(".", "_")
|
|
233
|
+
)
|
|
234
|
+
self.kwargs = kwargs or {}
|
|
235
|
+
assert all(isinstance(v, (int, float)) for v in self.kwargs.values()), (
|
|
236
|
+
f"Only int or floats are allowed for kwargs={kwargs}, one of them "
|
|
237
|
+
f"does not respect that constraint."
|
|
238
|
+
)
|
|
239
|
+
sig = inspect.signature(self.eager_fn)
|
|
240
|
+
params = list(sig.parameters)
|
|
241
|
+
self.args_name = [p for p in params if p not in self.kwargs]
|
|
242
|
+
self.kwargs_name = [p for p in params if p in self.kwargs]
|
|
243
|
+
self.verbose = verbose
|
|
244
|
+
self.custom_op = self._register()
|
|
245
|
+
self.version_selector = version_selector
|
|
246
|
+
self.default_opset = default_opset
|
|
247
|
+
self._check_protos(params)
|
|
248
|
+
|
|
249
|
+
def _check_protos(self, params):
|
|
250
|
+
assert (
|
|
251
|
+
len(params) >= self.n_inputs
|
|
252
|
+
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={self.n_inputs}"
|
|
253
|
+
|
|
254
|
+
# one proto
|
|
255
|
+
assert self._function_proto is None or self.n_inputs == len(
|
|
256
|
+
self._function_proto.input
|
|
257
|
+
), (
|
|
258
|
+
f"Input mismatch n_inputs={self.n_inputs} but "
|
|
259
|
+
f"function_proto.input={self._function_proto.input}"
|
|
260
|
+
)
|
|
261
|
+
assert self._function_proto is None or self.n_outputs == len(
|
|
262
|
+
self._function_proto.output
|
|
263
|
+
), (
|
|
264
|
+
f"Output mismatch n_outputs={self.n_outputs} but "
|
|
265
|
+
f"function_proto.output={self._function_proto.output}"
|
|
266
|
+
)
|
|
267
|
+
assert self._function_proto is None or (
|
|
268
|
+
self._function_proto.domain == self.domain
|
|
269
|
+
), f"Function domain must be {self.domain!r} but it is {self._function_proto.domain!r}"
|
|
270
|
+
|
|
271
|
+
# multiple protos
|
|
272
|
+
assert all(
|
|
273
|
+
self.n_inputs == len(v.input) for v in self._function_proto_versioned.values()
|
|
274
|
+
), f"Output mismatch n_inputs={self.n_inputs} but one version is wrong"
|
|
275
|
+
assert all(
|
|
276
|
+
self.n_outputs == len(v.output) for v in self._function_proto_versioned.values()
|
|
277
|
+
), f"Output mismatch n_outputs={self.n_outputs} but one version is wrong"
|
|
278
|
+
assert all(
|
|
279
|
+
v.domain == self.domain for v in self._function_proto_versioned.values()
|
|
280
|
+
), f"Function domain must be {self.domain!r} but it is different in one version"
|
|
281
|
+
assert (
|
|
282
|
+
not self._function_proto_versioned or self.version_selector
|
|
283
|
+
), "version_selector is needed when multiple protos are given."
|
|
284
|
+
|
|
285
|
+
def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto:
|
|
286
|
+
"""Returns the correct version based on the inputs."""
|
|
287
|
+
if self._function_proto:
|
|
288
|
+
return self._function_proto
|
|
289
|
+
assert isinstance(
|
|
290
|
+
opset, int
|
|
291
|
+
), f"The first argument must be an integer for the onnx opset but it is {type(opset)}"
|
|
292
|
+
assert any(
|
|
293
|
+
a is not None for a in args
|
|
294
|
+
), f"Unexpected args={string_type(args, with_shape=True)}"
|
|
295
|
+
try:
|
|
296
|
+
key = self.version_selector(opset, *args) # type: ignore[misc]
|
|
297
|
+
except (ValueError, AttributeError) as e:
|
|
298
|
+
raise AssertionError(
|
|
299
|
+
f"Unable to select a version, fails to get a key, available="
|
|
300
|
+
f"{set(self._function_proto_versioned)}, "
|
|
301
|
+
f"args={string_type(args,with_shape=True)}"
|
|
302
|
+
) from e
|
|
303
|
+
assert key in self._function_proto_versioned, (
|
|
304
|
+
f"Unable to select a version, key={key}, available="
|
|
305
|
+
f"{set(self._function_proto_versioned)}, args={string_type(args,with_shape=True)}"
|
|
306
|
+
)
|
|
307
|
+
return self._function_proto_versioned[key]
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def domain(self) -> str:
|
|
311
|
+
"Returns the onnx domain."
|
|
312
|
+
return "onnx_plug"
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def target_name(self) -> str:
|
|
316
|
+
"Returns the target name (see in the exported program)."
|
|
317
|
+
return f"{self.domain}::{self.name}"
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def torch_op(self) -> Callable:
|
|
321
|
+
"Returns ``torch.ops.onny_plug.<name>``."
|
|
322
|
+
return getattr(getattr(torch.ops, self.domain), self.name).default
|
|
323
|
+
|
|
324
|
+
def __call__(self, *args, **kwargs):
|
|
325
|
+
"""Calls eager_fn or shape_fn if the model is being exported."""
|
|
326
|
+
if is_exporting():
|
|
327
|
+
return self.torch_op(*args)
|
|
328
|
+
return self.eager_fn(*args, **kwargs)
|
|
329
|
+
|
|
330
|
+
def _register(self):
|
|
331
|
+
"""Registers the custom op."""
|
|
332
|
+
input_args = [f"Tensor {p}" for p in self.args_name]
|
|
333
|
+
for p in self.kwargs_name:
|
|
334
|
+
val = self.kwargs[p]
|
|
335
|
+
if isinstance(val, int):
|
|
336
|
+
input_args.append(f"int {p}={val}")
|
|
337
|
+
elif isinstance(val, float):
|
|
338
|
+
input_args.append(f"float {p}={val}")
|
|
339
|
+
elif isinstance(val, str):
|
|
340
|
+
input_args.append(f"str {p}={val}")
|
|
341
|
+
else:
|
|
342
|
+
raise NotImplementedError(
|
|
343
|
+
f"kwargs {p!r} has a default value of unsupported type {type(val)}"
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
inputs = ", ".join(input_args)
|
|
347
|
+
schema = f"({inputs}) -> Tensor"
|
|
348
|
+
if self.n_outputs > 1:
|
|
349
|
+
schema += "[]"
|
|
350
|
+
if self.verbose:
|
|
351
|
+
print(
|
|
352
|
+
f"[EagerDirectReplacementWithOnnx._register] "
|
|
353
|
+
f"'torch.ops.{self.domain}.{self.name}"
|
|
354
|
+
)
|
|
355
|
+
print(f"[EagerDirectReplacementWithOnnx._register] schema={schema}")
|
|
356
|
+
custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn)
|
|
357
|
+
custom_def.register_kernel(None)(self.eager_fn)
|
|
358
|
+
custom_def._abstract_fn = self.shape_fn
|
|
359
|
+
|
|
360
|
+
def verify(
|
|
361
|
+
self,
|
|
362
|
+
*args,
|
|
363
|
+
engine: Optional[Callable] = None,
|
|
364
|
+
dump_onnx_model: Optional[str] = None,
|
|
365
|
+
opset: int = 22,
|
|
366
|
+
**kwargs,
|
|
367
|
+
) -> VerifyResult:
|
|
368
|
+
"""
|
|
369
|
+
Verifies that the eager mode is equivalent to the onnx function given
|
|
370
|
+
as a replacements. This function evaluates `eager_fn`, checks that the shapes
|
|
371
|
+
are equivalent to the ones given by `shape_fn`, and finally evaluates the
|
|
372
|
+
onnx translation if the previous did not fail.
|
|
373
|
+
|
|
374
|
+
:param args: function inputs
|
|
375
|
+
:param kwargs: arguments for eager_fn
|
|
376
|
+
:param engine: by default an instance of
|
|
377
|
+
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
|
|
378
|
+
:param dump_onnx_model: to dump the onnx model used to verify
|
|
379
|
+
eager and onnx produce the same results
|
|
380
|
+
:param opset: onnx opset to use
|
|
381
|
+
:param kwargs: additional arguments to the function
|
|
382
|
+
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
|
|
383
|
+
"""
|
|
384
|
+
expected = self.eager_fn(*args, **kwargs)
|
|
385
|
+
shapes = self.shape_fn(*args, **kwargs)
|
|
386
|
+
if isinstance(expected, torch.Tensor):
|
|
387
|
+
expected = (expected,)
|
|
388
|
+
assert isinstance(shapes, torch.Tensor), (
|
|
389
|
+
f"eager_fn={self.eager_fn} returns a Tensor but shape_fn={self.shape_fn} "
|
|
390
|
+
f"returns a {type(shapes)}"
|
|
391
|
+
)
|
|
392
|
+
shapes = (shapes,)
|
|
393
|
+
assert isinstance(expected, tuple) and isinstance(shapes, tuple), (
|
|
394
|
+
f"eager_fn={self.eager_fn} returns a {type(expected)} "
|
|
395
|
+
f"and shape_fn={self.shape_fn} returns a {type(shapes)}"
|
|
396
|
+
)
|
|
397
|
+
assert len(expected) and len(shapes), (
|
|
398
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn} "
|
|
399
|
+
f"do not return the same number of tensors."
|
|
400
|
+
)
|
|
401
|
+
for i, (e, s) in enumerate(zip(expected, shapes)):
|
|
402
|
+
assert e.dtype == s.dtype, (
|
|
403
|
+
f"Type mismatch {e.dtype} != {s.dtype} for output {i}, "
|
|
404
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
|
|
405
|
+
)
|
|
406
|
+
assert e.shape == s.shape, (
|
|
407
|
+
f"Type mismatch {e.shape} != {s.shape} for output {i}, "
|
|
408
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Now the ONNX execution.
|
|
412
|
+
assert engine is None, f"Not implemented yet with engine={engine!r}"
|
|
413
|
+
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
414
|
+
sess = OnnxruntimeEvaluator(
|
|
415
|
+
self.get_function_proto(opset, *args),
|
|
416
|
+
whole=True,
|
|
417
|
+
dump_onnx_model=dump_onnx_model,
|
|
418
|
+
function_kwargs=kws,
|
|
419
|
+
)
|
|
420
|
+
feeds = dict(zip(sess.input_names, ags))
|
|
421
|
+
got = sess.run(None, feeds)
|
|
422
|
+
diffs = tuple(max_diff(e, g, hist=[0.1, 0.01]) for e, g in zip(expected, got))
|
|
423
|
+
return VerifyResult(eager_outputs=expected, onnx_outputs=tuple(got), diffs=diffs) # type: ignore[arg-type]
|
|
424
|
+
|
|
425
|
+
def _make_args_kwargs(self, *args, **kwargs):
|
|
426
|
+
ags = args[: len(self.args_name)]
|
|
427
|
+
kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
|
|
428
|
+
kws.update(kwargs)
|
|
429
|
+
return ags, kws
|
|
430
|
+
|
|
431
|
+
def custom_converter(
|
|
432
|
+
self,
|
|
433
|
+
) -> Callable:
|
|
434
|
+
"""
|
|
435
|
+
Returns a function which
|
|
436
|
+
converts a custom ops found in the fx graph into ONNX
|
|
437
|
+
following the API of the custom exporter.
|
|
438
|
+
The converter adds a custom op and registers the local function.
|
|
439
|
+
"""
|
|
440
|
+
|
|
441
|
+
def converter(
|
|
442
|
+
g: Any, # GraphBuilder
|
|
443
|
+
sts: Optional[Dict[str, Any]],
|
|
444
|
+
outputs: List[str],
|
|
445
|
+
*args,
|
|
446
|
+
**kwargs,
|
|
447
|
+
) -> Any:
|
|
448
|
+
has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)]
|
|
449
|
+
assert (
|
|
450
|
+
has_devices
|
|
451
|
+
), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"
|
|
452
|
+
arg_device = has_devices[0]
|
|
453
|
+
fake_tensor = torch.empty(
|
|
454
|
+
tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]),
|
|
455
|
+
dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])),
|
|
456
|
+
device=int_device_to_torch_device(g.get_device(arg_device)),
|
|
457
|
+
)
|
|
458
|
+
function_proto = self.get_function_proto(g.main_opset, fake_tensor)
|
|
459
|
+
if not g.has_local_function(function_proto.name, domain=function_proto.domain):
|
|
460
|
+
g.add_function(function_proto)
|
|
461
|
+
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
462
|
+
res = g.make_node(
|
|
463
|
+
function_proto.name,
|
|
464
|
+
ags,
|
|
465
|
+
outputs,
|
|
466
|
+
domain=function_proto.domain,
|
|
467
|
+
name=self.target_name,
|
|
468
|
+
**kws,
|
|
469
|
+
)
|
|
470
|
+
if not sts:
|
|
471
|
+
new_shapes = self.shape_fn(*args)
|
|
472
|
+
if not isinstance(new_shapes, tuple):
|
|
473
|
+
new_shapes = (new_shapes,)
|
|
474
|
+
for sh, o in zip(new_shapes, outputs):
|
|
475
|
+
g.set_type(o, torch_dtype_to_onnx_dtype(sh.dtype))
|
|
476
|
+
g.set_shape(o, sh.shape)
|
|
477
|
+
return res
|
|
478
|
+
|
|
479
|
+
return converter
|
|
480
|
+
|
|
481
|
+
def onnx_dynamo_converter(self) -> Callable:
|
|
482
|
+
"""
|
|
483
|
+
Returns a function which
|
|
484
|
+
which converts a custom ops found in the fx graph into ONNX
|
|
485
|
+
following the API of :func:`torch.onnx.export`.
|
|
486
|
+
"""
|
|
487
|
+
import onnxscript
|
|
488
|
+
|
|
489
|
+
onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
|
|
490
|
+
|
|
491
|
+
def get_proto(*args):
|
|
492
|
+
function_proto = self.get_function_proto(self.default_opset, *args)
|
|
493
|
+
schema = onnx_plug_op[function_proto.name]
|
|
494
|
+
if schema is None:
|
|
495
|
+
all_types = [
|
|
496
|
+
"tensor(float)",
|
|
497
|
+
"tensor(float16)",
|
|
498
|
+
"tensor(bfloat16)",
|
|
499
|
+
"tensor(double)",
|
|
500
|
+
"tensor(int64)",
|
|
501
|
+
"tensor(int32)",
|
|
502
|
+
]
|
|
503
|
+
type_constraints = []
|
|
504
|
+
for i in range(self.n_inputs):
|
|
505
|
+
type_constraints.append((f"T{i}", all_types, ""))
|
|
506
|
+
for i in range(self.n_outputs):
|
|
507
|
+
type_constraints.append((f"U{i}", all_types, ""))
|
|
508
|
+
schema = onnx.defs.OpSchema(
|
|
509
|
+
function_proto.name,
|
|
510
|
+
function_proto.domain,
|
|
511
|
+
1,
|
|
512
|
+
inputs=[
|
|
513
|
+
onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
|
|
514
|
+
for i in range(self.n_inputs)
|
|
515
|
+
],
|
|
516
|
+
outputs=[
|
|
517
|
+
onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
|
|
518
|
+
for i in range(self.n_outputs)
|
|
519
|
+
],
|
|
520
|
+
type_constraints=type_constraints,
|
|
521
|
+
)
|
|
522
|
+
onnx.defs.register_schema(schema)
|
|
523
|
+
op = onnxscript.values.Op(onnx_plug_op, function_proto.name, schema)
|
|
524
|
+
return op
|
|
525
|
+
|
|
526
|
+
def converter(*cargs, **ckwargs):
|
|
527
|
+
ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
|
|
528
|
+
op = get_proto(*cargs)
|
|
529
|
+
return op(*ags, n_outputs=self.n_outputs, **kws)
|
|
530
|
+
|
|
531
|
+
return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)
|