onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +118 -5
- onnx_diagnostic/export/control_flow.py +214 -0
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +118 -25
- onnx_diagnostic/helpers/cache_helper.py +218 -204
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +92 -26
- onnx_diagnostic/helpers/log_helper.py +26 -4
- onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +115 -16
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/rt_helper.py +547 -0
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +108 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
|
+
import onnx
|
|
5
|
+
import torch
|
|
6
|
+
from ..helpers import max_diff
|
|
7
|
+
from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
|
|
8
|
+
from ..reference import OnnxruntimeEvaluator
|
|
9
|
+
|
|
10
|
+
TUPLE_TENSORS = Tuple[torch.Tensor, ...]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def is_exporting() -> bool:
|
|
14
|
+
"""
|
|
15
|
+
Returns :func:`torch.compiler.is_exporting` or
|
|
16
|
+
:func:`torch.compiler.is_compiling`.
|
|
17
|
+
Changes ``_TEST_EXPORT`` to make it trigger.
|
|
18
|
+
"""
|
|
19
|
+
return torch.compiler.is_exporting() or torch.compiler.is_compiling()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class VerifyResult:
|
|
24
|
+
"""
|
|
25
|
+
Outputs of method :meth:`verify
|
|
26
|
+
<onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx.verify>`.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
eager_outputs: TUPLE_TENSORS
|
|
30
|
+
onnx_outputs: TUPLE_TENSORS
|
|
31
|
+
diffs: Tuple[Dict[str, float], ...]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class EagerDirectReplacementWithOnnx:
|
|
35
|
+
"""
|
|
36
|
+
Replaces a piece of code by another one written in ONNX
|
|
37
|
+
at export time. The function inserts a custom operator
|
|
38
|
+
and links it to the eager_fn
|
|
39
|
+
|
|
40
|
+
:param eager_fn: the code it replaces, it must be given in order to be able
|
|
41
|
+
to execute the torch.fx.Graph the exporter produces
|
|
42
|
+
:param shape_fn: the function produces dummy outputs with the shapes
|
|
43
|
+
the exporter can use for the next operators in the graph
|
|
44
|
+
:param function_proto: instances of ``onnx.FunctionProto``,
|
|
45
|
+
its domain must be ``onnx_plug``
|
|
46
|
+
:param n_inputs: number of inputs of the function, if not given,
|
|
47
|
+
the class will infer it from eager_fn signature,
|
|
48
|
+
only tensors must be counted
|
|
49
|
+
:param n_outputs: same for the number of outputs,
|
|
50
|
+
only tensors must be counted
|
|
51
|
+
:param name: the name of the custom op, the function name if not specified
|
|
52
|
+
:param kwargs: constants parameters with their default values
|
|
53
|
+
:param verbose: verbose level
|
|
54
|
+
|
|
55
|
+
Here is an example:
|
|
56
|
+
|
|
57
|
+
.. runpython::
|
|
58
|
+
:showcode:
|
|
59
|
+
|
|
60
|
+
import onnx.helper as oh
|
|
61
|
+
import torch
|
|
62
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
63
|
+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
64
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
65
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def demo_customsub(x, y):
|
|
69
|
+
return x - y
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def demo_customsub_shape(x, y):
|
|
73
|
+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def make_function_proto():
|
|
77
|
+
return oh.make_function(
|
|
78
|
+
"onnx_plug",
|
|
79
|
+
"demo_customsub",
|
|
80
|
+
["x", "y"],
|
|
81
|
+
["z"],
|
|
82
|
+
[oh.make_node("Sub", ["x", "y"], ["z"])],
|
|
83
|
+
opset_imports=[oh.make_opsetid("", 22)],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class Model(torch.nn.Module):
|
|
88
|
+
def forward(self, x):
|
|
89
|
+
y = x.sum(axis=1, keepdim=True)
|
|
90
|
+
d = torch.ops.onnx_plug.demo_customsub(x, y)
|
|
91
|
+
return torch.abs(d)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
replacements = [
|
|
95
|
+
EagerDirectReplacementWithOnnx(
|
|
96
|
+
demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
|
|
97
|
+
)
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
x = torch.randn((3, 4), dtype=torch.float32)
|
|
101
|
+
model = Model()
|
|
102
|
+
ds = ({0: "d1", 1: "d2"},)
|
|
103
|
+
|
|
104
|
+
# The exported program shows a custom op.
|
|
105
|
+
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
|
|
106
|
+
print("ep")
|
|
107
|
+
|
|
108
|
+
# As the exporter knows how the replace this custom op.
|
|
109
|
+
# Let's export.
|
|
110
|
+
|
|
111
|
+
onx = to_onnx(
|
|
112
|
+
model,
|
|
113
|
+
(x,),
|
|
114
|
+
dynamic_shapes=ds,
|
|
115
|
+
exporter="custom",
|
|
116
|
+
onnx_plugs=replacements,
|
|
117
|
+
target_opset=22,
|
|
118
|
+
inline=False,
|
|
119
|
+
).model_proto
|
|
120
|
+
|
|
121
|
+
print(pretty_onnx(onx))
|
|
122
|
+
|
|
123
|
+
# And with :func:`torch.onnx.export`:
|
|
124
|
+
|
|
125
|
+
onx = to_onnx(
|
|
126
|
+
model,
|
|
127
|
+
(x,),
|
|
128
|
+
dynamic_shapes=ds,
|
|
129
|
+
exporter="onnx-dynamo",
|
|
130
|
+
onnx_plugs=replacements,
|
|
131
|
+
target_opset=22,
|
|
132
|
+
inline=False,
|
|
133
|
+
).model_proto
|
|
134
|
+
|
|
135
|
+
print(pretty_onnx(onx))
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
141
|
+
shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
142
|
+
function_proto: onnx.FunctionProto,
|
|
143
|
+
n_inputs: Optional[int] = None,
|
|
144
|
+
n_outputs: Optional[int] = None,
|
|
145
|
+
name: Optional[str] = None,
|
|
146
|
+
kwargs: Optional[Dict[str, Union[int, float]]] = None,
|
|
147
|
+
verbose: int = 0,
|
|
148
|
+
):
|
|
149
|
+
assert isinstance(
|
|
150
|
+
function_proto, onnx.FunctionProto
|
|
151
|
+
), f"Unexpected type {type(function_proto)} for function_proto"
|
|
152
|
+
assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
|
|
153
|
+
assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}"
|
|
154
|
+
self.eager_fn = eager_fn
|
|
155
|
+
self.shape_fn = shape_fn
|
|
156
|
+
self.function_proto = function_proto
|
|
157
|
+
self.n_inputs = n_inputs
|
|
158
|
+
self.n_outputs = n_outputs
|
|
159
|
+
self.name = name or (
|
|
160
|
+
eager_fn.__name__
|
|
161
|
+
if "<" not in eager_fn.__name__
|
|
162
|
+
else eager_fn.__qualname__.replace("<locals>", "L")
|
|
163
|
+
.replace("<lambda>", "l")
|
|
164
|
+
.replace(".", "_")
|
|
165
|
+
)
|
|
166
|
+
self.kwargs = kwargs or {}
|
|
167
|
+
assert all(isinstance(v, (int, float)) for v in self.kwargs.values()), (
|
|
168
|
+
f"Only int or floats are allowed for kwargs={kwargs}, one of them "
|
|
169
|
+
f"does not respect that constraint."
|
|
170
|
+
)
|
|
171
|
+
sig = inspect.signature(self.eager_fn)
|
|
172
|
+
params = list(sig.parameters)
|
|
173
|
+
assert (
|
|
174
|
+
len(params) >= n_inputs
|
|
175
|
+
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
|
|
176
|
+
assert n_inputs == len(function_proto.input), (
|
|
177
|
+
f"Input mismatch n_inputs={n_inputs} but "
|
|
178
|
+
f"function_proto.input={function_proto.input}"
|
|
179
|
+
)
|
|
180
|
+
assert n_outputs == len(function_proto.output), (
|
|
181
|
+
f"Output mismatch n_outputs={n_outputs} but "
|
|
182
|
+
f"function_proto.output={function_proto.output}"
|
|
183
|
+
)
|
|
184
|
+
assert (
|
|
185
|
+
function_proto.domain == self.domain
|
|
186
|
+
), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
|
|
187
|
+
self.args_name = [p for p in params if p not in self.kwargs]
|
|
188
|
+
self.kwargs_name = [p for p in params if p in self.kwargs]
|
|
189
|
+
self.verbose = verbose
|
|
190
|
+
self.custom_op = self._register()
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def domain(self) -> str:
|
|
194
|
+
"Returns the onnx domain."
|
|
195
|
+
return "onnx_plug"
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def target_name(self) -> str:
|
|
199
|
+
"Returns the target name (see in the exported program)."
|
|
200
|
+
return f"{self.domain}::{self.name}"
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def torch_op(self) -> Callable:
|
|
204
|
+
"Returns ``torch.ops.onny_plug.<name>``."
|
|
205
|
+
return getattr(getattr(torch.ops, self.domain), self.name).default
|
|
206
|
+
|
|
207
|
+
def __call__(self, *args, **kwargs):
|
|
208
|
+
"""Calls eager_fn or shape_fn if the model is being exported."""
|
|
209
|
+
if is_exporting():
|
|
210
|
+
return self.torch_op(*args)
|
|
211
|
+
return self.eager_fn(*args, **kwargs)
|
|
212
|
+
|
|
213
|
+
def _register(self):
|
|
214
|
+
"""Registers the custom op."""
|
|
215
|
+
input_args = [f"Tensor {p}" for p in self.args_name]
|
|
216
|
+
for p in self.kwargs_name:
|
|
217
|
+
val = self.kwargs[p]
|
|
218
|
+
if isinstance(val, int):
|
|
219
|
+
input_args.append(f"int {p}={val}")
|
|
220
|
+
elif isinstance(val, float):
|
|
221
|
+
input_args.append(f"float {p}={val}")
|
|
222
|
+
else:
|
|
223
|
+
raise NotImplementedError(
|
|
224
|
+
f"kwargs {p!r} has a default value of unsupported type {type(val)}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
inputs = ", ".join(input_args)
|
|
228
|
+
schema = f"({inputs}) -> Tensor"
|
|
229
|
+
if self.n_outputs > 1:
|
|
230
|
+
schema += "[]"
|
|
231
|
+
if self.verbose:
|
|
232
|
+
print(
|
|
233
|
+
f"[EagerDirectReplacementWithOnnx._register] "
|
|
234
|
+
f"'torch.ops.{self.domain}.{self.name}"
|
|
235
|
+
)
|
|
236
|
+
print(f"[EagerDirectReplacementWithOnnx._register] schema={schema}")
|
|
237
|
+
custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn)
|
|
238
|
+
custom_def.register_kernel(None)(self.eager_fn)
|
|
239
|
+
custom_def._abstract_fn = self.shape_fn
|
|
240
|
+
|
|
241
|
+
def verify(
|
|
242
|
+
self,
|
|
243
|
+
*args,
|
|
244
|
+
engine: Optional[Callable] = None,
|
|
245
|
+
dump_onnx_model: Optional[str] = None,
|
|
246
|
+
**kwargs,
|
|
247
|
+
) -> VerifyResult:
|
|
248
|
+
"""
|
|
249
|
+
Verifies that the eager mode is equivalent to the onnx function given
|
|
250
|
+
as a replacements. This function evaluates `eager_fn`, checks that the shapes
|
|
251
|
+
are equivalent to the ones given by `shape_fn`, and finally evaluates the
|
|
252
|
+
onnx translation if the previous did not fail.
|
|
253
|
+
|
|
254
|
+
:param args: function inputs
|
|
255
|
+
:param kwargs: arguments for eager_fn
|
|
256
|
+
:param engine: by default an instance of
|
|
257
|
+
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
|
|
258
|
+
:param dump_onnx_model: to dump the onnx model used to verify
|
|
259
|
+
eager and onnx produce the same results
|
|
260
|
+
:param kwargs: additional arguments to the function
|
|
261
|
+
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
|
|
262
|
+
"""
|
|
263
|
+
expected = self.eager_fn(*args, **kwargs)
|
|
264
|
+
shapes = self.shape_fn(*args, **kwargs)
|
|
265
|
+
if isinstance(expected, torch.Tensor):
|
|
266
|
+
expected = (expected,)
|
|
267
|
+
assert isinstance(shapes, torch.Tensor), (
|
|
268
|
+
f"eager_fn={self.eager_fn} returns a Tensor but shape_fn={self.shape_fn} "
|
|
269
|
+
f"returns a {type(shapes)}"
|
|
270
|
+
)
|
|
271
|
+
shapes = (shapes,)
|
|
272
|
+
assert isinstance(expected, tuple) and isinstance(shapes, tuple), (
|
|
273
|
+
f"eager_fn={self.eager_fn} returns a {type(expected)} "
|
|
274
|
+
f"and shape_fn={self.shape_fn} returns a {type(shapes)}"
|
|
275
|
+
)
|
|
276
|
+
assert len(expected) and len(shapes), (
|
|
277
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn} "
|
|
278
|
+
f"do not return the same number of tensors."
|
|
279
|
+
)
|
|
280
|
+
for i, (e, s) in enumerate(zip(expected, shapes)):
|
|
281
|
+
assert e.dtype == s.dtype, (
|
|
282
|
+
f"Type mismatch {e.dtype} != {s.dtype} for output {i}, "
|
|
283
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
|
|
284
|
+
)
|
|
285
|
+
assert e.shape == s.shape, (
|
|
286
|
+
f"Type mismatch {e.shape} != {s.shape} for output {i}, "
|
|
287
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Now the ONNX execution.
|
|
291
|
+
assert engine is None, f"Not implemented yet with engine={engine!r}"
|
|
292
|
+
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
293
|
+
sess = OnnxruntimeEvaluator(
|
|
294
|
+
self.function_proto,
|
|
295
|
+
whole=True,
|
|
296
|
+
dump_onnx_model=dump_onnx_model,
|
|
297
|
+
function_kwargs=kws,
|
|
298
|
+
)
|
|
299
|
+
feeds = dict(zip(sess.input_names, ags))
|
|
300
|
+
got = sess.run(None, feeds)
|
|
301
|
+
diffs = tuple(max_diff(e, g, hist=[0.1, 0.01]) for e, g in zip(expected, got))
|
|
302
|
+
return VerifyResult(eager_outputs=expected, onnx_outputs=tuple(got), diffs=diffs) # type: ignore[arg-type]
|
|
303
|
+
|
|
304
|
+
def _make_args_kwargs(self, *args, **kwargs):
|
|
305
|
+
ags = args[: len(self.args_name)]
|
|
306
|
+
kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
|
|
307
|
+
kws.update(kwargs)
|
|
308
|
+
return ags, kws
|
|
309
|
+
|
|
310
|
+
def custom_converter(
|
|
311
|
+
self,
|
|
312
|
+
) -> Callable:
|
|
313
|
+
"""
|
|
314
|
+
Returns a function which
|
|
315
|
+
converts a custom ops found in the fx graph into ONNX
|
|
316
|
+
following the API of the custom exporter.
|
|
317
|
+
The converter adds a custom op and registers the local function.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def converter(
|
|
321
|
+
g: Any, # GraphBuilder
|
|
322
|
+
sts: Optional[Dict[str, Any]],
|
|
323
|
+
outputs: List[str],
|
|
324
|
+
*args,
|
|
325
|
+
**kwargs,
|
|
326
|
+
) -> Any:
|
|
327
|
+
if not g.has_local_function(
|
|
328
|
+
self.function_proto.name, domain=self.function_proto.domain
|
|
329
|
+
):
|
|
330
|
+
g.add_function(self.function_proto)
|
|
331
|
+
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
332
|
+
res = g.make_node(
|
|
333
|
+
self.function_proto.name,
|
|
334
|
+
ags,
|
|
335
|
+
outputs,
|
|
336
|
+
domain=self.function_proto.domain,
|
|
337
|
+
name=self.target_name,
|
|
338
|
+
**kws,
|
|
339
|
+
)
|
|
340
|
+
if not sts:
|
|
341
|
+
new_shapes = self.shape_fn(*args)
|
|
342
|
+
if not isinstance(new_shapes, tuple):
|
|
343
|
+
new_shapes = (new_shapes,)
|
|
344
|
+
for sh, o in zip(new_shapes, outputs):
|
|
345
|
+
g.set_type(o, torch_dtype_to_onnx_dtype(sh.dtype))
|
|
346
|
+
g.set_shape(o, sh.shape)
|
|
347
|
+
return res
|
|
348
|
+
|
|
349
|
+
return converter
|
|
350
|
+
|
|
351
|
+
def onnx_dynamo_converter(self) -> Callable:
|
|
352
|
+
"""
|
|
353
|
+
Returns a function which
|
|
354
|
+
which converts a custom ops found in the fx graph into ONNX
|
|
355
|
+
following the API of :func:`torch.onnx.export`.
|
|
356
|
+
"""
|
|
357
|
+
import onnxscript
|
|
358
|
+
|
|
359
|
+
onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1)
|
|
360
|
+
schema = onnx_plug_op[self.function_proto.name]
|
|
361
|
+
if schema is None:
|
|
362
|
+
all_types = [
|
|
363
|
+
"tensor(float)",
|
|
364
|
+
"tensor(float16)",
|
|
365
|
+
"tensor(bfloat16)",
|
|
366
|
+
"tensor(double)",
|
|
367
|
+
"tensor(int64)",
|
|
368
|
+
"tensor(int32)",
|
|
369
|
+
]
|
|
370
|
+
type_constraints = []
|
|
371
|
+
for i in range(self.n_inputs):
|
|
372
|
+
type_constraints.append((f"T{i}", all_types, ""))
|
|
373
|
+
for i in range(self.n_outputs):
|
|
374
|
+
type_constraints.append((f"U{i}", all_types, ""))
|
|
375
|
+
schema = onnx.defs.OpSchema(
|
|
376
|
+
self.function_proto.name,
|
|
377
|
+
self.function_proto.domain,
|
|
378
|
+
1,
|
|
379
|
+
inputs=[
|
|
380
|
+
onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
|
|
381
|
+
for i in range(self.n_inputs)
|
|
382
|
+
],
|
|
383
|
+
outputs=[
|
|
384
|
+
onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
|
|
385
|
+
for i in range(self.n_outputs)
|
|
386
|
+
],
|
|
387
|
+
type_constraints=type_constraints,
|
|
388
|
+
)
|
|
389
|
+
onnx.defs.register_schema(schema)
|
|
390
|
+
op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
|
|
391
|
+
|
|
392
|
+
def converter(*cargs, **ckwargs):
|
|
393
|
+
ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
|
|
394
|
+
return op(*ags, n_outputs=self.n_outputs, **kws)
|
|
395
|
+
|
|
396
|
+
return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -9,6 +9,7 @@ import itertools
|
|
|
9
9
|
import logging
|
|
10
10
|
import os
|
|
11
11
|
import re
|
|
12
|
+
import shutil
|
|
12
13
|
import sys
|
|
13
14
|
import unittest
|
|
14
15
|
import warnings
|
|
@@ -63,7 +64,7 @@ def skipif_ci_apple(msg) -> Callable:
|
|
|
63
64
|
return lambda x: x
|
|
64
65
|
|
|
65
66
|
|
|
66
|
-
def unit_test_going():
|
|
67
|
+
def unit_test_going() -> bool:
|
|
67
68
|
"""
|
|
68
69
|
Enables a flag telling the script is running while testing it.
|
|
69
70
|
Avois unit tests to be very long.
|
|
@@ -147,7 +148,7 @@ def hide_stdout(f: Optional[Callable] = None) -> Callable:
|
|
|
147
148
|
|
|
148
149
|
def wrapper(fct):
|
|
149
150
|
def call_f(self):
|
|
150
|
-
if os.environ.get("UNHIDE", ""):
|
|
151
|
+
if os.environ.get("UNHIDE", "") in (1, "1", "True", "true"):
|
|
151
152
|
fct(self)
|
|
152
153
|
return
|
|
153
154
|
st = StringIO()
|
|
@@ -742,8 +743,15 @@ class ExtTestCase(unittest.TestCase):
|
|
|
742
743
|
_warns: List[Tuple[str, int, Warning]] = []
|
|
743
744
|
_todos: List[Tuple[Callable, str]] = []
|
|
744
745
|
|
|
746
|
+
def unit_test_going(self) -> bool:
|
|
747
|
+
"""
|
|
748
|
+
Enables a flag telling the script is running while testing it.
|
|
749
|
+
Avois unit tests to be very long.
|
|
750
|
+
"""
|
|
751
|
+
return unit_test_going()
|
|
752
|
+
|
|
745
753
|
@property
|
|
746
|
-
def verbose(self):
|
|
754
|
+
def verbose(self) -> int:
|
|
747
755
|
"Returns the the value of environment variable ``VERBOSE``."
|
|
748
756
|
return int(os.environ.get("VERBOSE", "0"))
|
|
749
757
|
|
|
@@ -768,13 +776,13 @@ class ExtTestCase(unittest.TestCase):
|
|
|
768
776
|
cls._todos.append((f, msg))
|
|
769
777
|
|
|
770
778
|
@classmethod
|
|
771
|
-
def ort(cls):
|
|
779
|
+
def ort(cls) -> unittest.__class__:
|
|
772
780
|
import onnxruntime
|
|
773
781
|
|
|
774
782
|
return onnxruntime
|
|
775
783
|
|
|
776
784
|
@classmethod
|
|
777
|
-
def to_onnx(self, *args, **kwargs):
|
|
785
|
+
def to_onnx(self, *args, **kwargs) -> "ModelProto": # noqa: F821
|
|
778
786
|
from experimental_experiment.torch_interpreter import to_onnx
|
|
779
787
|
|
|
780
788
|
return to_onnx(*args, **kwargs)
|
|
@@ -806,12 +814,16 @@ class ExtTestCase(unittest.TestCase):
|
|
|
806
814
|
os.makedirs(folder)
|
|
807
815
|
return folder
|
|
808
816
|
|
|
809
|
-
def
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
817
|
+
def clean_dump(self, folder: str = "dump_test"):
|
|
818
|
+
"""Cleans this folder."""
|
|
819
|
+
for item in os.listdir(folder):
|
|
820
|
+
item_path = os.path.join(folder, item)
|
|
821
|
+
if os.path.isfile(item_path) or os.path.islink(item_path):
|
|
822
|
+
os.remove(item_path)
|
|
823
|
+
elif os.path.isdir(item_path):
|
|
824
|
+
shutil.rmtree(item_path)
|
|
825
|
+
|
|
826
|
+
def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
|
|
815
827
|
"""Dumps an onnx file."""
|
|
816
828
|
fullname = self.get_dump_file(name, folder=folder)
|
|
817
829
|
with open(fullname, "wb") as f:
|
|
@@ -1094,10 +1106,15 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1094
1106
|
value = numpy.array(value).astype(expected.dtype)
|
|
1095
1107
|
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
|
|
1096
1108
|
|
|
1097
|
-
def check_ort(
|
|
1109
|
+
def check_ort(
|
|
1110
|
+
self, onx: "onnx.ModelProto" # noqa: F821
|
|
1111
|
+
) -> "onnxruntime.InferenceSession": # noqa: F821
|
|
1098
1112
|
from onnxruntime import InferenceSession
|
|
1099
1113
|
|
|
1100
|
-
return InferenceSession(
|
|
1114
|
+
return InferenceSession(
|
|
1115
|
+
onx if isinstance(onx, str) else onx.SerializeToString(),
|
|
1116
|
+
providers=["CPUExecutionProvider"],
|
|
1117
|
+
)
|
|
1101
1118
|
|
|
1102
1119
|
def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
|
|
1103
1120
|
"""In the name"""
|
|
@@ -1137,7 +1154,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1137
1154
|
if not full.endswith(suffix):
|
|
1138
1155
|
raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
|
|
1139
1156
|
|
|
1140
|
-
def capture(self, fct: Callable):
|
|
1157
|
+
def capture(self, fct: Callable) -> Tuple[Any, str, str]:
|
|
1141
1158
|
"""
|
|
1142
1159
|
Runs a function and capture standard output and error.
|
|
1143
1160
|
|
|
@@ -1188,6 +1205,8 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1188
1205
|
copy_inputs: bool = True,
|
|
1189
1206
|
expected: Optional[Any] = None,
|
|
1190
1207
|
use_ort: bool = False,
|
|
1208
|
+
ort_optimized_graph: bool = False,
|
|
1209
|
+
ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
|
|
1191
1210
|
**kwargs,
|
|
1192
1211
|
):
|
|
1193
1212
|
"""
|
|
@@ -1206,6 +1225,8 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1206
1225
|
:param expected: expected values
|
|
1207
1226
|
:param copy_inputs: to copy the inputs
|
|
1208
1227
|
:param use_ort: use :class:`onnxruntime.InferenceSession`
|
|
1228
|
+
:param ort_optimized_graph: dumps the optimized onnxruntime graph
|
|
1229
|
+
:param ep: exported program (or saved exported program)
|
|
1209
1230
|
:param kwargs: arguments sent to
|
|
1210
1231
|
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
|
|
1211
1232
|
"""
|
|
@@ -1214,33 +1235,56 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1214
1235
|
from .helpers.ort_session import InferenceSessionForTorch
|
|
1215
1236
|
|
|
1216
1237
|
kws = dict(with_shape=True, with_min_max=verbose > 1)
|
|
1217
|
-
|
|
1218
|
-
vname = test_name or "assert_onnx_disc"
|
|
1238
|
+
vname = test_name or "assert_onnx_disc"
|
|
1219
1239
|
if test_name:
|
|
1240
|
+
import onnx
|
|
1241
|
+
|
|
1220
1242
|
name = f"{test_name}.onnx"
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1243
|
+
if verbose:
|
|
1244
|
+
print(f"[{vname}] save the onnx model into {name!r}")
|
|
1245
|
+
if isinstance(proto, str):
|
|
1246
|
+
name = proto
|
|
1247
|
+
proto = onnx.load(name)
|
|
1248
|
+
elif not self.unit_test_going():
|
|
1249
|
+
assert isinstance(
|
|
1250
|
+
proto, onnx.ModelProto
|
|
1251
|
+
), f"Unexpected type {type(proto)} for proto"
|
|
1252
|
+
name = self.dump_onnx(name, proto)
|
|
1253
|
+
if verbose and not self.unit_test_going():
|
|
1254
|
+
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
|
|
1224
1255
|
if verbose:
|
|
1225
1256
|
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
|
|
1257
|
+
|
|
1226
1258
|
if use_ort:
|
|
1259
|
+
assert isinstance(
|
|
1260
|
+
proto, onnx.ModelProto
|
|
1261
|
+
), f"Unexpected type {type(proto)} for proto"
|
|
1227
1262
|
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
|
|
1228
|
-
if verbose:
|
|
1229
|
-
print(f"[{vname}] feeds {string_type(feeds, **kws)}")
|
|
1230
1263
|
import onnxruntime
|
|
1231
1264
|
|
|
1265
|
+
options = onnxruntime.SessionOptions()
|
|
1266
|
+
if ort_optimized_graph:
|
|
1267
|
+
options.optimized_model_filepath = f"{name}.optort.onnx"
|
|
1268
|
+
providers = kwargs.get("providers", ["CPUExecutionProvider"])
|
|
1269
|
+
if verbose:
|
|
1270
|
+
print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
|
|
1232
1271
|
sess = onnxruntime.InferenceSession(
|
|
1233
|
-
proto.SerializeToString(), providers=
|
|
1272
|
+
proto.SerializeToString(), options, providers=providers
|
|
1234
1273
|
)
|
|
1274
|
+
if verbose:
|
|
1275
|
+
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
|
|
1235
1276
|
got = sess.run(None, feeds)
|
|
1236
1277
|
else:
|
|
1237
1278
|
feeds = make_feeds(proto, inputs, copy=True)
|
|
1238
1279
|
if verbose:
|
|
1239
|
-
print(f"[{vname}]
|
|
1280
|
+
print(f"[{vname}] create InferenceSessionForTorch")
|
|
1240
1281
|
sess = InferenceSessionForTorch(proto, **kwargs)
|
|
1282
|
+
if verbose:
|
|
1283
|
+
print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
|
|
1241
1284
|
got = sess.run(None, feeds)
|
|
1242
1285
|
if verbose:
|
|
1243
1286
|
print(f"[{vname}] compute expected values")
|
|
1287
|
+
|
|
1244
1288
|
if expected is None:
|
|
1245
1289
|
if copy_inputs:
|
|
1246
1290
|
expected = (
|
|
@@ -1250,10 +1294,46 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1250
1294
|
)
|
|
1251
1295
|
else:
|
|
1252
1296
|
expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
|
|
1297
|
+
|
|
1253
1298
|
if verbose:
|
|
1254
1299
|
print(f"[{vname}] expected {string_type(expected, **kws)}")
|
|
1255
1300
|
print(f"[{vname}] obtained {string_type(got, **kws)}")
|
|
1256
|
-
|
|
1301
|
+
|
|
1302
|
+
if ep:
|
|
1303
|
+
if isinstance(ep, str):
|
|
1304
|
+
if verbose:
|
|
1305
|
+
print(f"[{vname}] load exported program {ep!r}")
|
|
1306
|
+
import torch
|
|
1307
|
+
|
|
1308
|
+
ep = torch.export.load(ep)
|
|
1309
|
+
ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
|
|
1310
|
+
ep_model = ep.module() # type: ignore[union-attr]
|
|
1311
|
+
ep_expected = (
|
|
1312
|
+
ep_model(*copy.deepcopy(ep_inputs))
|
|
1313
|
+
if isinstance(ep_inputs, tuple)
|
|
1314
|
+
else ep_model(**copy.deepcopy(ep_inputs))
|
|
1315
|
+
)
|
|
1316
|
+
if verbose:
|
|
1317
|
+
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
|
|
1318
|
+
ep_diff = max_diff(expected, ep_expected, hist=[0.1, 0.01])
|
|
1319
|
+
if verbose:
|
|
1320
|
+
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
|
|
1321
|
+
assert (
|
|
1322
|
+
isinstance(ep_diff["abs"], float)
|
|
1323
|
+
and isinstance(ep_diff["rel"], float)
|
|
1324
|
+
and not numpy.isnan(ep_diff["abs"])
|
|
1325
|
+
and ep_diff["abs"] <= atol
|
|
1326
|
+
and not numpy.isnan(ep_diff["rel"])
|
|
1327
|
+
and ep_diff["rel"] <= rtol
|
|
1328
|
+
), (
|
|
1329
|
+
f"discrepancies in {test_name!r} between the exported program "
|
|
1330
|
+
f"and the exported model diff={string_diff(ep_diff)}"
|
|
1331
|
+
)
|
|
1332
|
+
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
|
|
1333
|
+
if verbose:
|
|
1334
|
+
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
|
|
1335
|
+
|
|
1336
|
+
diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01])
|
|
1257
1337
|
if verbose:
|
|
1258
1338
|
print(f"[{vname}] diff {string_diff(diff)}")
|
|
1259
1339
|
assert (
|
|
@@ -1263,7 +1343,10 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1263
1343
|
and diff["abs"] <= atol
|
|
1264
1344
|
and not numpy.isnan(diff["rel"])
|
|
1265
1345
|
and diff["rel"] <= rtol
|
|
1266
|
-
),
|
|
1346
|
+
), (
|
|
1347
|
+
f"discrepancies in {test_name!r} between the model and "
|
|
1348
|
+
f"the onnx model diff={string_diff(diff)}"
|
|
1349
|
+
)
|
|
1267
1350
|
|
|
1268
1351
|
def _debug(self):
|
|
1269
1352
|
"Tells if DEBUG=1 is set up."
|
|
@@ -1274,6 +1357,16 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1274
1357
|
|
|
1275
1358
|
return string_type(*args, **kwargs)
|
|
1276
1359
|
|
|
1360
|
+
def max_diff(self, *args, **kwargs):
|
|
1361
|
+
from .helpers import max_diff
|
|
1362
|
+
|
|
1363
|
+
return max_diff(*args, **kwargs)
|
|
1364
|
+
|
|
1365
|
+
def use_dyn_not_str(self, *args, **kwargs):
|
|
1366
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
1367
|
+
|
|
1368
|
+
return use_dyn_not_str(*args, *kwargs)
|
|
1369
|
+
|
|
1277
1370
|
def subloop(self, *args, verbose: int = 0):
|
|
1278
1371
|
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
|
|
1279
1372
|
if len(args) == 1:
|