onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +47 -10
- onnx_diagnostic/export/api.py +81 -50
- onnx_diagnostic/export/control_flow_research.py +10 -5
- onnx_diagnostic/export/onnx_plug.py +250 -61
- onnx_diagnostic/ext_test_case.py +99 -53
- onnx_diagnostic/helpers/dot_helper.py +37 -25
- onnx_diagnostic/helpers/helper.py +44 -38
- onnx_diagnostic/helpers/onnx_helper.py +441 -18
- onnx_diagnostic/helpers/ort_session.py +8 -8
- onnx_diagnostic/helpers/torch_helper.py +28 -2
- onnx_diagnostic/reference/ort_evaluator.py +6 -29
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +14 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +11 -5
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
|
@@ -3,8 +3,12 @@ from dataclasses import dataclass
|
|
|
3
3
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
4
|
import onnx
|
|
5
5
|
import torch
|
|
6
|
-
from ..helpers import max_diff
|
|
7
|
-
from ..helpers.torch_helper import
|
|
6
|
+
from ..helpers import max_diff, string_type
|
|
7
|
+
from ..helpers.torch_helper import (
|
|
8
|
+
torch_dtype_to_onnx_dtype,
|
|
9
|
+
onnx_dtype_to_torch_dtype,
|
|
10
|
+
int_device_to_torch_device,
|
|
11
|
+
)
|
|
8
12
|
from ..reference import OnnxruntimeEvaluator
|
|
9
13
|
|
|
10
14
|
TUPLE_TENSORS = Tuple[torch.Tensor, ...]
|
|
@@ -50,6 +54,10 @@ class EagerDirectReplacementWithOnnx:
|
|
|
50
54
|
only tensors must be counted
|
|
51
55
|
:param name: the name of the custom op, the function name if not specified
|
|
52
56
|
:param kwargs: constants parameters with their default values
|
|
57
|
+
:param version_selector: selects the version based on the arguments,
|
|
58
|
+
see below for an example, this allows the user to define different
|
|
59
|
+
onnx version depending on the inputs
|
|
60
|
+
:param default_opset: opset to use by default
|
|
53
61
|
:param verbose: verbose level
|
|
54
62
|
|
|
55
63
|
Here is an example:
|
|
@@ -120,7 +128,61 @@ class EagerDirectReplacementWithOnnx:
|
|
|
120
128
|
|
|
121
129
|
print(pretty_onnx(onx))
|
|
122
130
|
|
|
123
|
-
|
|
131
|
+
We do the same with :func:`torch.onnx.export`:
|
|
132
|
+
|
|
133
|
+
.. runpython::
|
|
134
|
+
:showcode:
|
|
135
|
+
|
|
136
|
+
import onnx.helper as oh
|
|
137
|
+
import torch
|
|
138
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
139
|
+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
140
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
141
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def demo_customsub(x, y):
|
|
145
|
+
return x - y
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def demo_customsub_shape(x, y):
|
|
149
|
+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def make_function_proto():
|
|
153
|
+
return oh.make_function(
|
|
154
|
+
"onnx_plug",
|
|
155
|
+
"demo_customsub",
|
|
156
|
+
["x", "y"],
|
|
157
|
+
["z"],
|
|
158
|
+
[oh.make_node("Sub", ["x", "y"], ["z"])],
|
|
159
|
+
opset_imports=[oh.make_opsetid("", 22)],
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class Model(torch.nn.Module):
|
|
164
|
+
def forward(self, x):
|
|
165
|
+
y = x.sum(axis=1, keepdim=True)
|
|
166
|
+
d = torch.ops.onnx_plug.demo_customsub(x, y)
|
|
167
|
+
return torch.abs(d)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
replacements = [
|
|
171
|
+
EagerDirectReplacementWithOnnx(
|
|
172
|
+
demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
|
|
173
|
+
)
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
x = torch.randn((3, 4), dtype=torch.float32)
|
|
177
|
+
model = Model()
|
|
178
|
+
ds = ({0: "d1", 1: "d2"},)
|
|
179
|
+
|
|
180
|
+
# The exported program shows a custom op.
|
|
181
|
+
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
|
|
182
|
+
print("ep")
|
|
183
|
+
|
|
184
|
+
# As the exporter knows how the replace this custom op.
|
|
185
|
+
# Let's export.
|
|
124
186
|
|
|
125
187
|
onx = to_onnx(
|
|
126
188
|
model,
|
|
@@ -133,27 +195,87 @@ class EagerDirectReplacementWithOnnx:
|
|
|
133
195
|
).model_proto
|
|
134
196
|
|
|
135
197
|
print(pretty_onnx(onx))
|
|
198
|
+
|
|
199
|
+
This shows how to define multiple versions depending on the device,
|
|
200
|
+
the type or the targeted onnx opset.
|
|
201
|
+
|
|
202
|
+
.. code-block:: python
|
|
203
|
+
|
|
204
|
+
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
205
|
+
first_tensor = next(a for a in args if a is not None)
|
|
206
|
+
dtype = first_tensor.dtype
|
|
207
|
+
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
208
|
+
if dtype == torch.float32:
|
|
209
|
+
if opset >= 23:
|
|
210
|
+
return "LOOPA23", itype
|
|
211
|
+
return "LOOPMHA", itype
|
|
212
|
+
if dtype == torch.float16:
|
|
213
|
+
if first_tensor.is_cuda:
|
|
214
|
+
return "PACKED", itype
|
|
215
|
+
return "LOOPMHA", itype
|
|
216
|
+
raise AssertionError(
|
|
217
|
+
f"Unable to handle type {torch.dtype} (itype={itype}) "
|
|
218
|
+
f"on device {torch.device} with opset={opset}"
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
|
|
222
|
+
qwen_sdpa_attention,
|
|
223
|
+
lambda qs, *args, **kwargs: torch.empty(
|
|
224
|
+
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
|
|
225
|
+
dtype=qs.dtype,
|
|
226
|
+
device=qs.device,
|
|
227
|
+
),
|
|
228
|
+
{
|
|
229
|
+
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
230
|
+
PackedAttention.to_function_proto()
|
|
231
|
+
),
|
|
232
|
+
("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
|
|
233
|
+
("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
234
|
+
onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
|
|
235
|
+
),
|
|
236
|
+
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
237
|
+
LoopMHAAttention.to_function_proto()
|
|
238
|
+
),
|
|
239
|
+
("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
240
|
+
onnx.TensorProto.FLOAT16,
|
|
241
|
+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
|
|
242
|
+
),
|
|
243
|
+
},
|
|
244
|
+
n_inputs=4,
|
|
245
|
+
n_outputs=1,
|
|
246
|
+
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
|
|
247
|
+
name="qwen_sdpa_attention_versatile",
|
|
248
|
+
version_selector=qwen_version_selector,
|
|
249
|
+
)
|
|
136
250
|
"""
|
|
137
251
|
|
|
138
252
|
def __init__(
|
|
139
253
|
self,
|
|
140
254
|
eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
141
255
|
shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
142
|
-
function_proto: onnx.FunctionProto,
|
|
256
|
+
function_proto: Union[onnx.FunctionProto, Dict[Any, onnx.FunctionProto]],
|
|
143
257
|
n_inputs: Optional[int] = None,
|
|
144
258
|
n_outputs: Optional[int] = None,
|
|
145
259
|
name: Optional[str] = None,
|
|
146
260
|
kwargs: Optional[Dict[str, Union[int, float]]] = None,
|
|
147
261
|
verbose: int = 0,
|
|
262
|
+
version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None,
|
|
263
|
+
default_opset: int = 22,
|
|
148
264
|
):
|
|
149
|
-
assert isinstance(
|
|
150
|
-
function_proto,
|
|
265
|
+
assert isinstance(function_proto, onnx.FunctionProto) or (
|
|
266
|
+
isinstance(function_proto, dict)
|
|
267
|
+
or all(isinstance(v, onnx.FunctionProto) for v in function_proto.values())
|
|
151
268
|
), f"Unexpected type {type(function_proto)} for function_proto"
|
|
152
269
|
assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
|
|
153
|
-
assert isinstance(n_outputs, int), f"not implemented yet when
|
|
270
|
+
assert isinstance(n_outputs, int), f"not implemented yet when n_outputs={n_outputs}"
|
|
154
271
|
self.eager_fn = eager_fn
|
|
155
272
|
self.shape_fn = shape_fn
|
|
156
|
-
self.
|
|
273
|
+
self._function_proto = (
|
|
274
|
+
function_proto if isinstance(function_proto, onnx.FunctionProto) else None
|
|
275
|
+
)
|
|
276
|
+
self._function_proto_versioned = (
|
|
277
|
+
function_proto if isinstance(function_proto, dict) else {}
|
|
278
|
+
)
|
|
157
279
|
self.n_inputs = n_inputs
|
|
158
280
|
self.n_outputs = n_outputs
|
|
159
281
|
self.name = name or (
|
|
@@ -170,24 +292,73 @@ class EagerDirectReplacementWithOnnx:
|
|
|
170
292
|
)
|
|
171
293
|
sig = inspect.signature(self.eager_fn)
|
|
172
294
|
params = list(sig.parameters)
|
|
173
|
-
assert (
|
|
174
|
-
len(params) >= n_inputs
|
|
175
|
-
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
|
|
176
|
-
assert n_inputs == len(function_proto.input), (
|
|
177
|
-
f"Input mismatch n_inputs={n_inputs} but "
|
|
178
|
-
f"function_proto.input={function_proto.input}"
|
|
179
|
-
)
|
|
180
|
-
assert n_outputs == len(function_proto.output), (
|
|
181
|
-
f"Output mismatch n_outputs={n_outputs} but "
|
|
182
|
-
f"function_proto.output={function_proto.output}"
|
|
183
|
-
)
|
|
184
|
-
assert (
|
|
185
|
-
function_proto.domain == self.domain
|
|
186
|
-
), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
|
|
187
295
|
self.args_name = [p for p in params if p not in self.kwargs]
|
|
188
296
|
self.kwargs_name = [p for p in params if p in self.kwargs]
|
|
189
297
|
self.verbose = verbose
|
|
190
298
|
self.custom_op = self._register()
|
|
299
|
+
self.version_selector = version_selector
|
|
300
|
+
self.default_opset = default_opset
|
|
301
|
+
self._check_protos(params)
|
|
302
|
+
|
|
303
|
+
def _check_protos(self, params):
|
|
304
|
+
assert (
|
|
305
|
+
len(params) >= self.n_inputs
|
|
306
|
+
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={self.n_inputs}"
|
|
307
|
+
|
|
308
|
+
# one proto
|
|
309
|
+
assert self._function_proto is None or self.n_inputs == len(
|
|
310
|
+
self._function_proto.input
|
|
311
|
+
), (
|
|
312
|
+
f"Input mismatch n_inputs={self.n_inputs} but "
|
|
313
|
+
f"function_proto.input={self._function_proto.input}"
|
|
314
|
+
)
|
|
315
|
+
assert self._function_proto is None or self.n_outputs == len(
|
|
316
|
+
self._function_proto.output
|
|
317
|
+
), (
|
|
318
|
+
f"Output mismatch n_outputs={self.n_outputs} but "
|
|
319
|
+
f"function_proto.output={self._function_proto.output}"
|
|
320
|
+
)
|
|
321
|
+
assert self._function_proto is None or (
|
|
322
|
+
self._function_proto.domain == self.domain
|
|
323
|
+
), f"Function domain must be {self.domain!r} but it is {self._function_proto.domain!r}"
|
|
324
|
+
|
|
325
|
+
# multiple protos
|
|
326
|
+
assert all(
|
|
327
|
+
self.n_inputs == len(v.input) for v in self._function_proto_versioned.values()
|
|
328
|
+
), f"Output mismatch n_inputs={self.n_inputs} but one version is wrong"
|
|
329
|
+
assert all(
|
|
330
|
+
self.n_outputs == len(v.output) for v in self._function_proto_versioned.values()
|
|
331
|
+
), f"Output mismatch n_outputs={self.n_outputs} but one version is wrong"
|
|
332
|
+
assert all(
|
|
333
|
+
v.domain == self.domain for v in self._function_proto_versioned.values()
|
|
334
|
+
), f"Function domain must be {self.domain!r} but it is different in one version"
|
|
335
|
+
assert (
|
|
336
|
+
not self._function_proto_versioned or self.version_selector
|
|
337
|
+
), "version_selector is needed when multiple protos are given."
|
|
338
|
+
|
|
339
|
+
def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto:
|
|
340
|
+
"""Returns the correct version based on the inputs."""
|
|
341
|
+
if self._function_proto:
|
|
342
|
+
return self._function_proto
|
|
343
|
+
assert isinstance(
|
|
344
|
+
opset, int
|
|
345
|
+
), f"The first argument must be an integer for the onnx opset but it is {type(opset)}"
|
|
346
|
+
assert any(
|
|
347
|
+
a is not None for a in args
|
|
348
|
+
), f"Unexpected args={string_type(args, with_shape=True)}"
|
|
349
|
+
try:
|
|
350
|
+
key = self.version_selector(opset, *args) # type: ignore[misc]
|
|
351
|
+
except (ValueError, AttributeError) as e:
|
|
352
|
+
raise AssertionError(
|
|
353
|
+
f"Unable to select a version, fails to get a key, available="
|
|
354
|
+
f"{set(self._function_proto_versioned)}, "
|
|
355
|
+
f"args={string_type(args,with_shape=True)}"
|
|
356
|
+
) from e
|
|
357
|
+
assert key in self._function_proto_versioned, (
|
|
358
|
+
f"Unable to select a version, key={key}, available="
|
|
359
|
+
f"{set(self._function_proto_versioned)}, args={string_type(args,with_shape=True)}"
|
|
360
|
+
)
|
|
361
|
+
return self._function_proto_versioned[key]
|
|
191
362
|
|
|
192
363
|
@property
|
|
193
364
|
def domain(self) -> str:
|
|
@@ -219,6 +390,8 @@ class EagerDirectReplacementWithOnnx:
|
|
|
219
390
|
input_args.append(f"int {p}={val}")
|
|
220
391
|
elif isinstance(val, float):
|
|
221
392
|
input_args.append(f"float {p}={val}")
|
|
393
|
+
elif isinstance(val, str):
|
|
394
|
+
input_args.append(f"str {p}={val}")
|
|
222
395
|
else:
|
|
223
396
|
raise NotImplementedError(
|
|
224
397
|
f"kwargs {p!r} has a default value of unsupported type {type(val)}"
|
|
@@ -243,6 +416,7 @@ class EagerDirectReplacementWithOnnx:
|
|
|
243
416
|
*args,
|
|
244
417
|
engine: Optional[Callable] = None,
|
|
245
418
|
dump_onnx_model: Optional[str] = None,
|
|
419
|
+
opset: int = 22,
|
|
246
420
|
**kwargs,
|
|
247
421
|
) -> VerifyResult:
|
|
248
422
|
"""
|
|
@@ -257,6 +431,7 @@ class EagerDirectReplacementWithOnnx:
|
|
|
257
431
|
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
|
|
258
432
|
:param dump_onnx_model: to dump the onnx model used to verify
|
|
259
433
|
eager and onnx produce the same results
|
|
434
|
+
:param opset: onnx opset to use
|
|
260
435
|
:param kwargs: additional arguments to the function
|
|
261
436
|
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
|
|
262
437
|
"""
|
|
@@ -291,7 +466,7 @@ class EagerDirectReplacementWithOnnx:
|
|
|
291
466
|
assert engine is None, f"Not implemented yet with engine={engine!r}"
|
|
292
467
|
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
293
468
|
sess = OnnxruntimeEvaluator(
|
|
294
|
-
self.
|
|
469
|
+
self.get_function_proto(opset, *args),
|
|
295
470
|
whole=True,
|
|
296
471
|
dump_onnx_model=dump_onnx_model,
|
|
297
472
|
function_kwargs=kws,
|
|
@@ -324,16 +499,25 @@ class EagerDirectReplacementWithOnnx:
|
|
|
324
499
|
*args,
|
|
325
500
|
**kwargs,
|
|
326
501
|
) -> Any:
|
|
327
|
-
if
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
502
|
+
has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)]
|
|
503
|
+
assert (
|
|
504
|
+
has_devices
|
|
505
|
+
), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"
|
|
506
|
+
arg_device = has_devices[0]
|
|
507
|
+
fake_tensor = torch.empty(
|
|
508
|
+
tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]),
|
|
509
|
+
dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])),
|
|
510
|
+
device=int_device_to_torch_device(g.get_device(arg_device)),
|
|
511
|
+
)
|
|
512
|
+
function_proto = self.get_function_proto(g.main_opset, fake_tensor)
|
|
513
|
+
if not g.has_local_function(function_proto.name, domain=function_proto.domain):
|
|
514
|
+
g.add_function(function_proto)
|
|
331
515
|
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
332
516
|
res = g.make_node(
|
|
333
|
-
|
|
517
|
+
function_proto.name,
|
|
334
518
|
ags,
|
|
335
519
|
outputs,
|
|
336
|
-
domain=
|
|
520
|
+
domain=function_proto.domain,
|
|
337
521
|
name=self.target_name,
|
|
338
522
|
**kws,
|
|
339
523
|
)
|
|
@@ -356,41 +540,46 @@ class EagerDirectReplacementWithOnnx:
|
|
|
356
540
|
"""
|
|
357
541
|
import onnxscript
|
|
358
542
|
|
|
359
|
-
onnx_plug_op = onnxscript.values.Opset(domain=self.
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
self.
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
543
|
+
onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
|
|
544
|
+
|
|
545
|
+
def get_proto(*args):
|
|
546
|
+
function_proto = self.get_function_proto(self.default_opset, *args)
|
|
547
|
+
schema = onnx_plug_op[function_proto.name]
|
|
548
|
+
if schema is None:
|
|
549
|
+
all_types = [
|
|
550
|
+
"tensor(float)",
|
|
551
|
+
"tensor(float16)",
|
|
552
|
+
"tensor(bfloat16)",
|
|
553
|
+
"tensor(double)",
|
|
554
|
+
"tensor(int64)",
|
|
555
|
+
"tensor(int32)",
|
|
556
|
+
]
|
|
557
|
+
type_constraints = []
|
|
558
|
+
for i in range(self.n_inputs):
|
|
559
|
+
type_constraints.append((f"T{i}", all_types, ""))
|
|
560
|
+
for i in range(self.n_outputs):
|
|
561
|
+
type_constraints.append((f"U{i}", all_types, ""))
|
|
562
|
+
schema = onnx.defs.OpSchema(
|
|
563
|
+
function_proto.name,
|
|
564
|
+
function_proto.domain,
|
|
565
|
+
1,
|
|
566
|
+
inputs=[
|
|
567
|
+
onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
|
|
568
|
+
for i in range(self.n_inputs)
|
|
569
|
+
],
|
|
570
|
+
outputs=[
|
|
571
|
+
onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
|
|
572
|
+
for i in range(self.n_outputs)
|
|
573
|
+
],
|
|
574
|
+
type_constraints=type_constraints,
|
|
575
|
+
)
|
|
576
|
+
onnx.defs.register_schema(schema)
|
|
577
|
+
op = onnxscript.values.Op(onnx_plug_op, function_proto.name, schema)
|
|
578
|
+
return op
|
|
391
579
|
|
|
392
580
|
def converter(*cargs, **ckwargs):
|
|
393
581
|
ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
|
|
582
|
+
op = get_proto(*cargs)
|
|
394
583
|
return op(*ags, n_outputs=self.n_outputs, **kws)
|
|
395
584
|
|
|
396
585
|
return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -610,6 +610,21 @@ def requires_onnxruntime(version: str, msg: str = "") -> Callable:
|
|
|
610
610
|
return lambda x: x
|
|
611
611
|
|
|
612
612
|
|
|
613
|
+
def has_onnxruntime(version: str, msg: str = "") -> Callable:
|
|
614
|
+
"""Skips a unit test if :epkg:`onnxruntime` is not recent enough."""
|
|
615
|
+
import packaging.version as pv
|
|
616
|
+
import onnxruntime
|
|
617
|
+
|
|
618
|
+
if not hasattr(onnxruntime, "__version__"):
|
|
619
|
+
# development version
|
|
620
|
+
return True
|
|
621
|
+
|
|
622
|
+
if pv.Version(onnxruntime.__version__) < pv.Version(version):
|
|
623
|
+
msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}"
|
|
624
|
+
return False
|
|
625
|
+
return True
|
|
626
|
+
|
|
627
|
+
|
|
613
628
|
def has_onnxruntime_training(push_back_batch: bool = False):
|
|
614
629
|
"""Tells if onnxruntime_training is installed."""
|
|
615
630
|
try:
|
|
@@ -830,6 +845,13 @@ class ExtTestCase(unittest.TestCase):
|
|
|
830
845
|
f.write(proto.SerializeToString())
|
|
831
846
|
return fullname
|
|
832
847
|
|
|
848
|
+
def dump_text(self, name: str, text: str, folder: Optional[str] = None) -> str:
|
|
849
|
+
"""Dumps text in a file."""
|
|
850
|
+
fullname = self.get_dump_file(name, folder=folder)
|
|
851
|
+
with open(fullname, "w") as f:
|
|
852
|
+
f.write(text)
|
|
853
|
+
return fullname
|
|
854
|
+
|
|
833
855
|
def assertExists(self, name):
|
|
834
856
|
"""Checks the existing of a file."""
|
|
835
857
|
if not os.path.exists(name):
|
|
@@ -1196,9 +1218,9 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1196
1218
|
def assert_onnx_disc(
|
|
1197
1219
|
self,
|
|
1198
1220
|
test_name: str,
|
|
1199
|
-
proto: "onnx.ModelProto", # noqa: F821
|
|
1221
|
+
proto: Union[str, "onnx.ModelProto"], # noqa: F821
|
|
1200
1222
|
model: "torch.nn.Module", # noqa: F821
|
|
1201
|
-
inputs: Union[Tuple[Any], Dict[str, Any]],
|
|
1223
|
+
inputs: Union[Tuple[Any], Dict[str, Any], List[Any]],
|
|
1202
1224
|
verbose: int = 0,
|
|
1203
1225
|
atol: float = 1e-5,
|
|
1204
1226
|
rtol: float = 1e-3,
|
|
@@ -1242,7 +1264,9 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1242
1264
|
name = f"{test_name}.onnx"
|
|
1243
1265
|
if verbose:
|
|
1244
1266
|
print(f"[{vname}] save the onnx model into {name!r}")
|
|
1267
|
+
model_file = None
|
|
1245
1268
|
if isinstance(proto, str):
|
|
1269
|
+
model_file = proto
|
|
1246
1270
|
name = proto
|
|
1247
1271
|
proto = onnx.load(name)
|
|
1248
1272
|
elif not self.unit_test_going():
|
|
@@ -1255,45 +1279,64 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1255
1279
|
if verbose:
|
|
1256
1280
|
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
|
|
1257
1281
|
|
|
1282
|
+
if not isinstance(inputs, list):
|
|
1283
|
+
inputs = [inputs]
|
|
1284
|
+
if expected is not None:
|
|
1285
|
+
expected = [expected]
|
|
1286
|
+
|
|
1287
|
+
gots = []
|
|
1258
1288
|
if use_ort:
|
|
1259
1289
|
assert isinstance(
|
|
1260
1290
|
proto, onnx.ModelProto
|
|
1261
1291
|
), f"Unexpected type {type(proto)} for proto"
|
|
1262
|
-
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
|
|
1263
1292
|
import onnxruntime
|
|
1264
1293
|
|
|
1265
1294
|
options = onnxruntime.SessionOptions()
|
|
1266
1295
|
if ort_optimized_graph:
|
|
1267
1296
|
options.optimized_model_filepath = f"{name}.optort.onnx"
|
|
1297
|
+
if "log_severity_level" in kwargs:
|
|
1298
|
+
options.log_severity_level = kwargs["log_severity_level"]
|
|
1299
|
+
if "log_verbosity_level" in kwargs:
|
|
1300
|
+
options.log_verbosity_level = kwargs["log_verbosity_level"]
|
|
1268
1301
|
providers = kwargs.get("providers", ["CPUExecutionProvider"])
|
|
1269
1302
|
if verbose:
|
|
1270
1303
|
print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
|
|
1271
1304
|
sess = onnxruntime.InferenceSession(
|
|
1272
|
-
proto.SerializeToString(), options, providers=providers
|
|
1305
|
+
model_file or proto.SerializeToString(), options, providers=providers
|
|
1273
1306
|
)
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1307
|
+
for inp in inputs:
|
|
1308
|
+
feeds = make_feeds(proto, inp, use_numpy=True, copy=True)
|
|
1309
|
+
if verbose:
|
|
1310
|
+
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
|
|
1311
|
+
got = sess.run(None, feeds)
|
|
1312
|
+
gots.append(got)
|
|
1277
1313
|
else:
|
|
1278
|
-
feeds = make_feeds(proto, inputs, copy=True)
|
|
1279
1314
|
if verbose:
|
|
1280
1315
|
print(f"[{vname}] create InferenceSessionForTorch")
|
|
1281
1316
|
sess = InferenceSessionForTorch(proto, **kwargs)
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1317
|
+
for inp in inputs:
|
|
1318
|
+
feeds = make_feeds(proto, inp, copy=True)
|
|
1319
|
+
if verbose:
|
|
1320
|
+
print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
|
|
1321
|
+
got = sess.run(None, feeds)
|
|
1322
|
+
gots.append(got)
|
|
1285
1323
|
if verbose:
|
|
1286
1324
|
print(f"[{vname}] compute expected values")
|
|
1287
1325
|
|
|
1288
1326
|
if expected is None:
|
|
1289
1327
|
if copy_inputs:
|
|
1290
|
-
expected =
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1328
|
+
expected = [
|
|
1329
|
+
(
|
|
1330
|
+
model(*copy.deepcopy(inp))
|
|
1331
|
+
if isinstance(inp, tuple)
|
|
1332
|
+
else model(**copy.deepcopy(inp))
|
|
1333
|
+
)
|
|
1334
|
+
for inp in inputs
|
|
1335
|
+
]
|
|
1295
1336
|
else:
|
|
1296
|
-
expected =
|
|
1337
|
+
expected = [
|
|
1338
|
+
model(*inp) if isinstance(inp, tuple) else model(**inp) for inp in inputs
|
|
1339
|
+
]
|
|
1297
1340
|
|
|
1298
1341
|
if verbose:
|
|
1299
1342
|
print(f"[{vname}] expected {string_type(expected, **kws)}")
|
|
@@ -1306,47 +1349,50 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1306
1349
|
import torch
|
|
1307
1350
|
|
|
1308
1351
|
ep = torch.export.load(ep)
|
|
1309
|
-
|
|
1352
|
+
|
|
1310
1353
|
ep_model = ep.module() # type: ignore[union-attr]
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1354
|
+
for expe, inp, got in zip(expected, inputs, gots):
|
|
1355
|
+
ep_inputs = copy.deepcopy(inp) if copy_inputs else inp
|
|
1356
|
+
ep_expected = (
|
|
1357
|
+
ep_model(*copy.deepcopy(ep_inputs))
|
|
1358
|
+
if isinstance(ep_inputs, tuple)
|
|
1359
|
+
else ep_model(**copy.deepcopy(ep_inputs))
|
|
1360
|
+
)
|
|
1361
|
+
if verbose:
|
|
1362
|
+
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
|
|
1363
|
+
ep_diff = max_diff(expe, ep_expected, hist=[0.1, 0.01])
|
|
1364
|
+
if verbose:
|
|
1365
|
+
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
|
|
1366
|
+
assert (
|
|
1367
|
+
isinstance(ep_diff["abs"], float)
|
|
1368
|
+
and isinstance(ep_diff["rel"], float)
|
|
1369
|
+
and not numpy.isnan(ep_diff["abs"])
|
|
1370
|
+
and ep_diff["abs"] <= atol
|
|
1371
|
+
and not numpy.isnan(ep_diff["rel"])
|
|
1372
|
+
and ep_diff["rel"] <= rtol
|
|
1373
|
+
), (
|
|
1374
|
+
f"discrepancies in {test_name!r} between the exported program "
|
|
1375
|
+
f"and the exported model diff={string_diff(ep_diff)}"
|
|
1376
|
+
)
|
|
1377
|
+
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
|
|
1378
|
+
if verbose:
|
|
1379
|
+
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
|
|
1380
|
+
|
|
1381
|
+
for expe, got in zip(expected, gots):
|
|
1382
|
+
diff = max_diff(expe, got, flatten=True, hist=[0.1, 0.01])
|
|
1319
1383
|
if verbose:
|
|
1320
|
-
print(f"[{vname}]
|
|
1384
|
+
print(f"[{vname}] diff {string_diff(diff)}")
|
|
1321
1385
|
assert (
|
|
1322
|
-
isinstance(
|
|
1323
|
-
and isinstance(
|
|
1324
|
-
and not numpy.isnan(
|
|
1325
|
-
and
|
|
1326
|
-
and not numpy.isnan(
|
|
1327
|
-
and
|
|
1386
|
+
isinstance(diff["abs"], float)
|
|
1387
|
+
and isinstance(diff["rel"], float)
|
|
1388
|
+
and not numpy.isnan(diff["abs"])
|
|
1389
|
+
and diff["abs"] <= atol
|
|
1390
|
+
and not numpy.isnan(diff["rel"])
|
|
1391
|
+
and diff["rel"] <= rtol
|
|
1328
1392
|
), (
|
|
1329
|
-
f"discrepancies in {test_name!r} between the
|
|
1330
|
-
f"
|
|
1393
|
+
f"discrepancies in {test_name!r} between the model and "
|
|
1394
|
+
f"the onnx model diff={string_diff(diff)}"
|
|
1331
1395
|
)
|
|
1332
|
-
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
|
|
1333
|
-
if verbose:
|
|
1334
|
-
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
|
|
1335
|
-
|
|
1336
|
-
diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01])
|
|
1337
|
-
if verbose:
|
|
1338
|
-
print(f"[{vname}] diff {string_diff(diff)}")
|
|
1339
|
-
assert (
|
|
1340
|
-
isinstance(diff["abs"], float)
|
|
1341
|
-
and isinstance(diff["rel"], float)
|
|
1342
|
-
and not numpy.isnan(diff["abs"])
|
|
1343
|
-
and diff["abs"] <= atol
|
|
1344
|
-
and not numpy.isnan(diff["rel"])
|
|
1345
|
-
and diff["rel"] <= rtol
|
|
1346
|
-
), (
|
|
1347
|
-
f"discrepancies in {test_name!r} between the model and "
|
|
1348
|
-
f"the onnx model diff={string_diff(diff)}"
|
|
1349
|
-
)
|
|
1350
1396
|
|
|
1351
1397
|
def _debug(self):
|
|
1352
1398
|
"Tells if DEBUG=1 is set up."
|