onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +91 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +3 -3
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +92 -23
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +90 -26
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +103 -1
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +103 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
|
+
import onnx
|
|
5
|
+
import torch
|
|
6
|
+
from ..helpers import max_diff
|
|
7
|
+
from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
|
|
8
|
+
from ..reference import OnnxruntimeEvaluator
|
|
9
|
+
|
|
10
|
+
TUPLE_TENSORS = Tuple[torch.Tensor, ...]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def is_exporting() -> bool:
|
|
14
|
+
"""
|
|
15
|
+
Returns :func:`torch.compiler.is_exporting` or
|
|
16
|
+
:func:`torch.compiler.is_compiling`.
|
|
17
|
+
Changes ``_TEST_EXPORT`` to make it trigger.
|
|
18
|
+
"""
|
|
19
|
+
return torch.compiler.is_exporting() or torch.compiler.is_compiling()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class VerifyResult:
|
|
24
|
+
"""
|
|
25
|
+
Outputs of method :meth:`verify
|
|
26
|
+
<onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx.verify>`.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
eager_outputs: TUPLE_TENSORS
|
|
30
|
+
onnx_outputs: TUPLE_TENSORS
|
|
31
|
+
diffs: Tuple[Dict[str, float], ...]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class EagerDirectReplacementWithOnnx:
|
|
35
|
+
"""
|
|
36
|
+
Replaces a piece of code by another one written in ONNX
|
|
37
|
+
at export time. The function inserts a custom operator
|
|
38
|
+
and links it to the eager_fn
|
|
39
|
+
|
|
40
|
+
:param eager_fn: the code it replaces, it must be given in order to be able
|
|
41
|
+
to execute the torch.fx.Graph the exporter produces
|
|
42
|
+
:param shape_fn: the function produces dummy outputs with the shapes
|
|
43
|
+
the exporter can use for the next operators in the graph
|
|
44
|
+
:param function_proto: instances of ``onnx.FunctionProto``,
|
|
45
|
+
its domain must be ``onnx_plug``
|
|
46
|
+
:param n_inputs: number of inputs of the function, if not given,
|
|
47
|
+
the class will infer it from eager_fn signature,
|
|
48
|
+
only tensors must be counted
|
|
49
|
+
:param n_outputs: same for the number of outputs,
|
|
50
|
+
only tensors must be counted
|
|
51
|
+
:param name: the name of the custom op, the function name if not specified
|
|
52
|
+
:param kwargs: constants parameters with their default values
|
|
53
|
+
:param verbose: verbose level
|
|
54
|
+
|
|
55
|
+
Here is an example:
|
|
56
|
+
|
|
57
|
+
.. runpython::
|
|
58
|
+
:showcode:
|
|
59
|
+
|
|
60
|
+
import onnx.helper as oh
|
|
61
|
+
import torch
|
|
62
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
63
|
+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
64
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
65
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def demo_customsub(x, y):
|
|
69
|
+
return x - y
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def demo_customsub_shape(x, y):
|
|
73
|
+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def make_function_proto():
|
|
77
|
+
return oh.make_function(
|
|
78
|
+
"onnx_plug",
|
|
79
|
+
"demo_customsub",
|
|
80
|
+
["x", "y"],
|
|
81
|
+
["z"],
|
|
82
|
+
[oh.make_node("Sub", ["x", "y"], ["z"])],
|
|
83
|
+
opset_imports=[oh.make_opsetid("", 22)],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class Model(torch.nn.Module):
|
|
88
|
+
def forward(self, x):
|
|
89
|
+
y = x.sum(axis=1, keepdim=True)
|
|
90
|
+
d = torch.ops.onnx_plug.demo_customsub(x, y)
|
|
91
|
+
return torch.abs(d)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
replacements = [
|
|
95
|
+
EagerDirectReplacementWithOnnx(
|
|
96
|
+
demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
|
|
97
|
+
)
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
x = torch.randn((3, 4), dtype=torch.float32)
|
|
101
|
+
model = Model()
|
|
102
|
+
ds = ({0: "d1", 1: "d2"},)
|
|
103
|
+
|
|
104
|
+
# The exported program shows a custom op.
|
|
105
|
+
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
|
|
106
|
+
print("ep")
|
|
107
|
+
|
|
108
|
+
# As the exporter knows how the replace this custom op.
|
|
109
|
+
# Let's export.
|
|
110
|
+
|
|
111
|
+
onx = to_onnx(
|
|
112
|
+
model,
|
|
113
|
+
(x,),
|
|
114
|
+
dynamic_shapes=ds,
|
|
115
|
+
exporter="custom",
|
|
116
|
+
onnx_plugs=replacements,
|
|
117
|
+
target_opset=22,
|
|
118
|
+
inline=False,
|
|
119
|
+
).model_proto
|
|
120
|
+
|
|
121
|
+
print(pretty_onnx(onx))
|
|
122
|
+
|
|
123
|
+
# And with :func:`torch.onnx.export`:
|
|
124
|
+
|
|
125
|
+
onx = to_onnx(
|
|
126
|
+
model,
|
|
127
|
+
(x,),
|
|
128
|
+
dynamic_shapes=ds,
|
|
129
|
+
exporter="onnx-dynamo",
|
|
130
|
+
onnx_plugs=replacements,
|
|
131
|
+
target_opset=22,
|
|
132
|
+
inline=False,
|
|
133
|
+
).model_proto
|
|
134
|
+
|
|
135
|
+
print(pretty_onnx(onx))
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
141
|
+
shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
|
|
142
|
+
function_proto: onnx.FunctionProto,
|
|
143
|
+
n_inputs: Optional[int] = None,
|
|
144
|
+
n_outputs: Optional[int] = None,
|
|
145
|
+
name: Optional[str] = None,
|
|
146
|
+
kwargs: Optional[Dict[str, Union[int, float]]] = None,
|
|
147
|
+
verbose: int = 0,
|
|
148
|
+
):
|
|
149
|
+
assert isinstance(
|
|
150
|
+
function_proto, onnx.FunctionProto
|
|
151
|
+
), f"Unexpected type {type(function_proto)} for function_proto"
|
|
152
|
+
assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
|
|
153
|
+
assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}"
|
|
154
|
+
self.eager_fn = eager_fn
|
|
155
|
+
self.shape_fn = shape_fn
|
|
156
|
+
self.function_proto = function_proto
|
|
157
|
+
self.n_inputs = n_inputs
|
|
158
|
+
self.n_outputs = n_outputs
|
|
159
|
+
self.name = name or (
|
|
160
|
+
eager_fn.__name__
|
|
161
|
+
if "<" not in eager_fn.__name__
|
|
162
|
+
else eager_fn.__qualname__.replace("<locals>", "L")
|
|
163
|
+
.replace("<lambda>", "l")
|
|
164
|
+
.replace(".", "_")
|
|
165
|
+
)
|
|
166
|
+
self.kwargs = kwargs or {}
|
|
167
|
+
assert all(isinstance(v, (int, float)) for v in self.kwargs.values()), (
|
|
168
|
+
f"Only int or floats are allowed for kwargs={kwargs}, one of them "
|
|
169
|
+
f"does not respect that constraint."
|
|
170
|
+
)
|
|
171
|
+
sig = inspect.signature(self.eager_fn)
|
|
172
|
+
params = list(sig.parameters)
|
|
173
|
+
assert (
|
|
174
|
+
len(params) >= n_inputs
|
|
175
|
+
), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
|
|
176
|
+
assert n_inputs == len(function_proto.input), (
|
|
177
|
+
f"Input mismatch n_inputs={n_inputs} but "
|
|
178
|
+
f"function_proto.input={function_proto.input}"
|
|
179
|
+
)
|
|
180
|
+
assert n_outputs == len(function_proto.output), (
|
|
181
|
+
f"Output mismatch n_outputs={n_outputs} but "
|
|
182
|
+
f"function_proto.output={function_proto.output}"
|
|
183
|
+
)
|
|
184
|
+
assert (
|
|
185
|
+
function_proto.domain == self.domain
|
|
186
|
+
), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
|
|
187
|
+
self.args_name = [p for p in params if p not in self.kwargs]
|
|
188
|
+
self.kwargs_name = [p for p in params if p in self.kwargs]
|
|
189
|
+
self.verbose = verbose
|
|
190
|
+
self.custom_op = self._register()
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def domain(self) -> str:
|
|
194
|
+
"Returns the onnx domain."
|
|
195
|
+
return "onnx_plug"
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def target_name(self) -> str:
|
|
199
|
+
"Returns the target name (see in the exported program)."
|
|
200
|
+
return f"{self.domain}::{self.name}"
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def torch_op(self) -> Callable:
|
|
204
|
+
"Returns ``torch.ops.onny_plug.<name>``."
|
|
205
|
+
return getattr(getattr(torch.ops, self.domain), self.name).default
|
|
206
|
+
|
|
207
|
+
def __call__(self, *args, **kwargs):
|
|
208
|
+
"""Calls eager_fn or shape_fn if the model is being exported."""
|
|
209
|
+
if is_exporting():
|
|
210
|
+
return self.torch_op(*args)
|
|
211
|
+
return self.eager_fn(*args, **kwargs)
|
|
212
|
+
|
|
213
|
+
def _register(self):
|
|
214
|
+
"""Registers the custom op."""
|
|
215
|
+
input_args = [f"Tensor {p}" for p in self.args_name]
|
|
216
|
+
for p in self.kwargs_name:
|
|
217
|
+
val = self.kwargs[p]
|
|
218
|
+
if isinstance(val, int):
|
|
219
|
+
input_args.append(f"int {p}={val}")
|
|
220
|
+
elif isinstance(val, float):
|
|
221
|
+
input_args.append(f"float {p}={val}")
|
|
222
|
+
else:
|
|
223
|
+
raise NotImplementedError(
|
|
224
|
+
f"kwargs {p!r} has a default value of unsupported type {type(val)}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
inputs = ", ".join(input_args)
|
|
228
|
+
schema = f"({inputs}) -> Tensor"
|
|
229
|
+
if self.n_outputs > 1:
|
|
230
|
+
schema += "[]"
|
|
231
|
+
if self.verbose:
|
|
232
|
+
print(
|
|
233
|
+
f"[EagerDirectReplacementWithOnnx._register] "
|
|
234
|
+
f"'torch.ops.{self.domain}.{self.name}"
|
|
235
|
+
)
|
|
236
|
+
print(f"[EagerDirectReplacementWithOnnx._register] schema={schema}")
|
|
237
|
+
custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn)
|
|
238
|
+
custom_def.register_kernel(None)(self.eager_fn)
|
|
239
|
+
custom_def._abstract_fn = self.shape_fn
|
|
240
|
+
|
|
241
|
+
def verify(
|
|
242
|
+
self,
|
|
243
|
+
*args,
|
|
244
|
+
engine: Optional[Callable] = None,
|
|
245
|
+
dump_onnx_model: Optional[str] = None,
|
|
246
|
+
**kwargs,
|
|
247
|
+
) -> VerifyResult:
|
|
248
|
+
"""
|
|
249
|
+
Verifies that the eager mode is equivalent to the onnx function given
|
|
250
|
+
as a replacements. This function evaluates `eager_fn`, checks that the shapes
|
|
251
|
+
are equivalent to the ones given by `shape_fn`, and finally evaluates the
|
|
252
|
+
onnx translation if the previous did not fail.
|
|
253
|
+
|
|
254
|
+
:param args: function inputs
|
|
255
|
+
:param kwargs: arguments for eager_fn
|
|
256
|
+
:param engine: by default an instance of
|
|
257
|
+
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
|
|
258
|
+
:param dump_onnx_model: to dump the onnx model used to verify
|
|
259
|
+
eager and onnx produce the same results
|
|
260
|
+
:param kwargs: additional arguments to the function
|
|
261
|
+
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
|
|
262
|
+
"""
|
|
263
|
+
expected = self.eager_fn(*args, **kwargs)
|
|
264
|
+
shapes = self.shape_fn(*args, **kwargs)
|
|
265
|
+
if isinstance(expected, torch.Tensor):
|
|
266
|
+
expected = (expected,)
|
|
267
|
+
assert isinstance(shapes, torch.Tensor), (
|
|
268
|
+
f"eager_fn={self.eager_fn} returns a Tensor but shape_fn={self.shape_fn} "
|
|
269
|
+
f"returns a {type(shapes)}"
|
|
270
|
+
)
|
|
271
|
+
shapes = (shapes,)
|
|
272
|
+
assert isinstance(expected, tuple) and isinstance(shapes, tuple), (
|
|
273
|
+
f"eager_fn={self.eager_fn} returns a {type(expected)} "
|
|
274
|
+
f"and shape_fn={self.shape_fn} returns a {type(shapes)}"
|
|
275
|
+
)
|
|
276
|
+
assert len(expected) and len(shapes), (
|
|
277
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn} "
|
|
278
|
+
f"do not return the same number of tensors."
|
|
279
|
+
)
|
|
280
|
+
for i, (e, s) in enumerate(zip(expected, shapes)):
|
|
281
|
+
assert e.dtype == s.dtype, (
|
|
282
|
+
f"Type mismatch {e.dtype} != {s.dtype} for output {i}, "
|
|
283
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
|
|
284
|
+
)
|
|
285
|
+
assert e.shape == s.shape, (
|
|
286
|
+
f"Type mismatch {e.shape} != {s.shape} for output {i}, "
|
|
287
|
+
f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Now the ONNX execution.
|
|
291
|
+
assert engine is None, f"Not implemented yet with engine={engine!r}"
|
|
292
|
+
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
293
|
+
sess = OnnxruntimeEvaluator(
|
|
294
|
+
self.function_proto,
|
|
295
|
+
whole=True,
|
|
296
|
+
dump_onnx_model=dump_onnx_model,
|
|
297
|
+
function_kwargs=kws,
|
|
298
|
+
)
|
|
299
|
+
feeds = dict(zip(sess.input_names, ags))
|
|
300
|
+
got = sess.run(None, feeds)
|
|
301
|
+
diffs = tuple(max_diff(e, g, hist=[0.1, 0.01]) for e, g in zip(expected, got))
|
|
302
|
+
return VerifyResult(eager_outputs=expected, onnx_outputs=tuple(got), diffs=diffs) # type: ignore[arg-type]
|
|
303
|
+
|
|
304
|
+
def _make_args_kwargs(self, *args, **kwargs):
|
|
305
|
+
ags = args[: len(self.args_name)]
|
|
306
|
+
kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
|
|
307
|
+
kws.update(kwargs)
|
|
308
|
+
return ags, kws
|
|
309
|
+
|
|
310
|
+
def custom_converter(
|
|
311
|
+
self,
|
|
312
|
+
) -> Callable:
|
|
313
|
+
"""
|
|
314
|
+
Returns a function which
|
|
315
|
+
converts a custom ops found in the fx graph into ONNX
|
|
316
|
+
following the API of the custom exporter.
|
|
317
|
+
The converter adds a custom op and registers the local function.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def converter(
|
|
321
|
+
g: Any, # GraphBuilder
|
|
322
|
+
sts: Optional[Dict[str, Any]],
|
|
323
|
+
outputs: List[str],
|
|
324
|
+
*args,
|
|
325
|
+
**kwargs,
|
|
326
|
+
) -> Any:
|
|
327
|
+
if not g.has_local_function(
|
|
328
|
+
self.function_proto.name, domain=self.function_proto.domain
|
|
329
|
+
):
|
|
330
|
+
g.add_function(self.function_proto)
|
|
331
|
+
ags, kws = self._make_args_kwargs(*args, **kwargs)
|
|
332
|
+
res = g.make_node(
|
|
333
|
+
self.function_proto.name,
|
|
334
|
+
ags,
|
|
335
|
+
outputs,
|
|
336
|
+
domain=self.function_proto.domain,
|
|
337
|
+
name=self.target_name,
|
|
338
|
+
**kws,
|
|
339
|
+
)
|
|
340
|
+
if not sts:
|
|
341
|
+
new_shapes = self.shape_fn(*args)
|
|
342
|
+
if not isinstance(new_shapes, tuple):
|
|
343
|
+
new_shapes = (new_shapes,)
|
|
344
|
+
for sh, o in zip(new_shapes, outputs):
|
|
345
|
+
g.set_type(o, torch_dtype_to_onnx_dtype(sh.dtype))
|
|
346
|
+
g.set_shape(o, sh.shape)
|
|
347
|
+
return res
|
|
348
|
+
|
|
349
|
+
return converter
|
|
350
|
+
|
|
351
|
+
def onnx_dynamo_converter(self) -> Callable:
|
|
352
|
+
"""
|
|
353
|
+
Returns a function which
|
|
354
|
+
which converts a custom ops found in the fx graph into ONNX
|
|
355
|
+
following the API of :func:`torch.onnx.export`.
|
|
356
|
+
"""
|
|
357
|
+
import onnxscript
|
|
358
|
+
|
|
359
|
+
onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1)
|
|
360
|
+
schema = onnx_plug_op[self.function_proto.name]
|
|
361
|
+
if schema is None:
|
|
362
|
+
all_types = [
|
|
363
|
+
"tensor(float)",
|
|
364
|
+
"tensor(float16)",
|
|
365
|
+
"tensor(bfloat16)",
|
|
366
|
+
"tensor(double)",
|
|
367
|
+
"tensor(int64)",
|
|
368
|
+
"tensor(int32)",
|
|
369
|
+
]
|
|
370
|
+
type_constraints = []
|
|
371
|
+
for i in range(self.n_inputs):
|
|
372
|
+
type_constraints.append((f"T{i}", all_types, ""))
|
|
373
|
+
for i in range(self.n_outputs):
|
|
374
|
+
type_constraints.append((f"U{i}", all_types, ""))
|
|
375
|
+
schema = onnx.defs.OpSchema(
|
|
376
|
+
self.function_proto.name,
|
|
377
|
+
self.function_proto.domain,
|
|
378
|
+
1,
|
|
379
|
+
inputs=[
|
|
380
|
+
onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
|
|
381
|
+
for i in range(self.n_inputs)
|
|
382
|
+
],
|
|
383
|
+
outputs=[
|
|
384
|
+
onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
|
|
385
|
+
for i in range(self.n_outputs)
|
|
386
|
+
],
|
|
387
|
+
type_constraints=type_constraints,
|
|
388
|
+
)
|
|
389
|
+
onnx.defs.register_schema(schema)
|
|
390
|
+
op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
|
|
391
|
+
|
|
392
|
+
def converter(*cargs, **ckwargs):
|
|
393
|
+
ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
|
|
394
|
+
return op(*ags, n_outputs=self.n_outputs, **kws)
|
|
395
|
+
|
|
396
|
+
return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -9,6 +9,7 @@ import itertools
|
|
|
9
9
|
import logging
|
|
10
10
|
import os
|
|
11
11
|
import re
|
|
12
|
+
import shutil
|
|
12
13
|
import sys
|
|
13
14
|
import unittest
|
|
14
15
|
import warnings
|
|
@@ -63,7 +64,7 @@ def skipif_ci_apple(msg) -> Callable:
|
|
|
63
64
|
return lambda x: x
|
|
64
65
|
|
|
65
66
|
|
|
66
|
-
def unit_test_going():
|
|
67
|
+
def unit_test_going() -> bool:
|
|
67
68
|
"""
|
|
68
69
|
Enables a flag telling the script is running while testing it.
|
|
69
70
|
Avois unit tests to be very long.
|
|
@@ -147,7 +148,7 @@ def hide_stdout(f: Optional[Callable] = None) -> Callable:
|
|
|
147
148
|
|
|
148
149
|
def wrapper(fct):
|
|
149
150
|
def call_f(self):
|
|
150
|
-
if os.environ.get("UNHIDE", ""):
|
|
151
|
+
if os.environ.get("UNHIDE", "") in (1, "1", "True", "true"):
|
|
151
152
|
fct(self)
|
|
152
153
|
return
|
|
153
154
|
st = StringIO()
|
|
@@ -742,8 +743,15 @@ class ExtTestCase(unittest.TestCase):
|
|
|
742
743
|
_warns: List[Tuple[str, int, Warning]] = []
|
|
743
744
|
_todos: List[Tuple[Callable, str]] = []
|
|
744
745
|
|
|
746
|
+
def unit_test_going(self) -> bool:
|
|
747
|
+
"""
|
|
748
|
+
Enables a flag telling the script is running while testing it.
|
|
749
|
+
Avois unit tests to be very long.
|
|
750
|
+
"""
|
|
751
|
+
return unit_test_going()
|
|
752
|
+
|
|
745
753
|
@property
|
|
746
|
-
def verbose(self):
|
|
754
|
+
def verbose(self) -> int:
|
|
747
755
|
"Returns the the value of environment variable ``VERBOSE``."
|
|
748
756
|
return int(os.environ.get("VERBOSE", "0"))
|
|
749
757
|
|
|
@@ -768,13 +776,13 @@ class ExtTestCase(unittest.TestCase):
|
|
|
768
776
|
cls._todos.append((f, msg))
|
|
769
777
|
|
|
770
778
|
@classmethod
|
|
771
|
-
def ort(cls):
|
|
779
|
+
def ort(cls) -> unittest.__class__:
|
|
772
780
|
import onnxruntime
|
|
773
781
|
|
|
774
782
|
return onnxruntime
|
|
775
783
|
|
|
776
784
|
@classmethod
|
|
777
|
-
def to_onnx(self, *args, **kwargs):
|
|
785
|
+
def to_onnx(self, *args, **kwargs) -> "ModelProto": # noqa: F821
|
|
778
786
|
from experimental_experiment.torch_interpreter import to_onnx
|
|
779
787
|
|
|
780
788
|
return to_onnx(*args, **kwargs)
|
|
@@ -806,12 +814,16 @@ class ExtTestCase(unittest.TestCase):
|
|
|
806
814
|
os.makedirs(folder)
|
|
807
815
|
return folder
|
|
808
816
|
|
|
809
|
-
def
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
817
|
+
def clean_dump(self, folder: str = "dump_test"):
|
|
818
|
+
"""Cleans this folder."""
|
|
819
|
+
for item in os.listdir(folder):
|
|
820
|
+
item_path = os.path.join(folder, item)
|
|
821
|
+
if os.path.isfile(item_path) or os.path.islink(item_path):
|
|
822
|
+
os.remove(item_path)
|
|
823
|
+
elif os.path.isdir(item_path):
|
|
824
|
+
shutil.rmtree(item_path)
|
|
825
|
+
|
|
826
|
+
def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
|
|
815
827
|
"""Dumps an onnx file."""
|
|
816
828
|
fullname = self.get_dump_file(name, folder=folder)
|
|
817
829
|
with open(fullname, "wb") as f:
|
|
@@ -1094,10 +1106,15 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1094
1106
|
value = numpy.array(value).astype(expected.dtype)
|
|
1095
1107
|
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
|
|
1096
1108
|
|
|
1097
|
-
def check_ort(
|
|
1109
|
+
def check_ort(
|
|
1110
|
+
self, onx: "onnx.ModelProto" # noqa: F821
|
|
1111
|
+
) -> "onnxruntime.InferenceSession": # noqa: F821
|
|
1098
1112
|
from onnxruntime import InferenceSession
|
|
1099
1113
|
|
|
1100
|
-
return InferenceSession(
|
|
1114
|
+
return InferenceSession(
|
|
1115
|
+
onx if isinstance(onx, str) else onx.SerializeToString(),
|
|
1116
|
+
providers=["CPUExecutionProvider"],
|
|
1117
|
+
)
|
|
1101
1118
|
|
|
1102
1119
|
def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
|
|
1103
1120
|
"""In the name"""
|
|
@@ -1137,7 +1154,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1137
1154
|
if not full.endswith(suffix):
|
|
1138
1155
|
raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
|
|
1139
1156
|
|
|
1140
|
-
def capture(self, fct: Callable):
|
|
1157
|
+
def capture(self, fct: Callable) -> Tuple[Any, str, str]:
|
|
1141
1158
|
"""
|
|
1142
1159
|
Runs a function and capture standard output and error.
|
|
1143
1160
|
|
|
@@ -1189,6 +1206,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1189
1206
|
expected: Optional[Any] = None,
|
|
1190
1207
|
use_ort: bool = False,
|
|
1191
1208
|
ort_optimized_graph: bool = False,
|
|
1209
|
+
ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
|
|
1192
1210
|
**kwargs,
|
|
1193
1211
|
):
|
|
1194
1212
|
"""
|
|
@@ -1208,6 +1226,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1208
1226
|
:param copy_inputs: to copy the inputs
|
|
1209
1227
|
:param use_ort: use :class:`onnxruntime.InferenceSession`
|
|
1210
1228
|
:param ort_optimized_graph: dumps the optimized onnxruntime graph
|
|
1229
|
+
:param ep: exported program (or saved exported program)
|
|
1211
1230
|
:param kwargs: arguments sent to
|
|
1212
1231
|
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
|
|
1213
1232
|
"""
|
|
@@ -1226,15 +1245,16 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1226
1245
|
if isinstance(proto, str):
|
|
1227
1246
|
name = proto
|
|
1228
1247
|
proto = onnx.load(name)
|
|
1229
|
-
|
|
1248
|
+
elif not self.unit_test_going():
|
|
1230
1249
|
assert isinstance(
|
|
1231
1250
|
proto, onnx.ModelProto
|
|
1232
1251
|
), f"Unexpected type {type(proto)} for proto"
|
|
1233
1252
|
name = self.dump_onnx(name, proto)
|
|
1234
|
-
if verbose:
|
|
1253
|
+
if verbose and not self.unit_test_going():
|
|
1235
1254
|
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
|
|
1236
1255
|
if verbose:
|
|
1237
1256
|
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
|
|
1257
|
+
|
|
1238
1258
|
if use_ort:
|
|
1239
1259
|
assert isinstance(
|
|
1240
1260
|
proto, onnx.ModelProto
|
|
@@ -1242,15 +1262,14 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1242
1262
|
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
|
|
1243
1263
|
import onnxruntime
|
|
1244
1264
|
|
|
1245
|
-
if verbose:
|
|
1246
|
-
print(f"[{vname}] create onnxruntime.InferenceSession")
|
|
1247
1265
|
options = onnxruntime.SessionOptions()
|
|
1248
1266
|
if ort_optimized_graph:
|
|
1249
1267
|
options.optimized_model_filepath = f"{name}.optort.onnx"
|
|
1268
|
+
providers = kwargs.get("providers", ["CPUExecutionProvider"])
|
|
1269
|
+
if verbose:
|
|
1270
|
+
print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
|
|
1250
1271
|
sess = onnxruntime.InferenceSession(
|
|
1251
|
-
proto.SerializeToString(),
|
|
1252
|
-
options,
|
|
1253
|
-
providers=kwargs.get("providers", ["CPUExecutionProvider"]),
|
|
1272
|
+
proto.SerializeToString(), options, providers=providers
|
|
1254
1273
|
)
|
|
1255
1274
|
if verbose:
|
|
1256
1275
|
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
|
|
@@ -1265,6 +1284,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1265
1284
|
got = sess.run(None, feeds)
|
|
1266
1285
|
if verbose:
|
|
1267
1286
|
print(f"[{vname}] compute expected values")
|
|
1287
|
+
|
|
1268
1288
|
if expected is None:
|
|
1269
1289
|
if copy_inputs:
|
|
1270
1290
|
expected = (
|
|
@@ -1274,10 +1294,46 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1274
1294
|
)
|
|
1275
1295
|
else:
|
|
1276
1296
|
expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
|
|
1297
|
+
|
|
1277
1298
|
if verbose:
|
|
1278
1299
|
print(f"[{vname}] expected {string_type(expected, **kws)}")
|
|
1279
1300
|
print(f"[{vname}] obtained {string_type(got, **kws)}")
|
|
1280
|
-
|
|
1301
|
+
|
|
1302
|
+
if ep:
|
|
1303
|
+
if isinstance(ep, str):
|
|
1304
|
+
if verbose:
|
|
1305
|
+
print(f"[{vname}] load exported program {ep!r}")
|
|
1306
|
+
import torch
|
|
1307
|
+
|
|
1308
|
+
ep = torch.export.load(ep)
|
|
1309
|
+
ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
|
|
1310
|
+
ep_model = ep.module() # type: ignore[union-attr]
|
|
1311
|
+
ep_expected = (
|
|
1312
|
+
ep_model(*copy.deepcopy(ep_inputs))
|
|
1313
|
+
if isinstance(ep_inputs, tuple)
|
|
1314
|
+
else ep_model(**copy.deepcopy(ep_inputs))
|
|
1315
|
+
)
|
|
1316
|
+
if verbose:
|
|
1317
|
+
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
|
|
1318
|
+
ep_diff = max_diff(expected, ep_expected, hist=[0.1, 0.01])
|
|
1319
|
+
if verbose:
|
|
1320
|
+
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
|
|
1321
|
+
assert (
|
|
1322
|
+
isinstance(ep_diff["abs"], float)
|
|
1323
|
+
and isinstance(ep_diff["rel"], float)
|
|
1324
|
+
and not numpy.isnan(ep_diff["abs"])
|
|
1325
|
+
and ep_diff["abs"] <= atol
|
|
1326
|
+
and not numpy.isnan(ep_diff["rel"])
|
|
1327
|
+
and ep_diff["rel"] <= rtol
|
|
1328
|
+
), (
|
|
1329
|
+
f"discrepancies in {test_name!r} between the exported program "
|
|
1330
|
+
f"and the exported model diff={string_diff(ep_diff)}"
|
|
1331
|
+
)
|
|
1332
|
+
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
|
|
1333
|
+
if verbose:
|
|
1334
|
+
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
|
|
1335
|
+
|
|
1336
|
+
diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01])
|
|
1281
1337
|
if verbose:
|
|
1282
1338
|
print(f"[{vname}] diff {string_diff(diff)}")
|
|
1283
1339
|
assert (
|
|
@@ -1287,7 +1343,10 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1287
1343
|
and diff["abs"] <= atol
|
|
1288
1344
|
and not numpy.isnan(diff["rel"])
|
|
1289
1345
|
and diff["rel"] <= rtol
|
|
1290
|
-
),
|
|
1346
|
+
), (
|
|
1347
|
+
f"discrepancies in {test_name!r} between the model and "
|
|
1348
|
+
f"the onnx model diff={string_diff(diff)}"
|
|
1349
|
+
)
|
|
1291
1350
|
|
|
1292
1351
|
def _debug(self):
|
|
1293
1352
|
"Tells if DEBUG=1 is set up."
|
|
@@ -1298,6 +1357,16 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1298
1357
|
|
|
1299
1358
|
return string_type(*args, **kwargs)
|
|
1300
1359
|
|
|
1360
|
+
def max_diff(self, *args, **kwargs):
|
|
1361
|
+
from .helpers import max_diff
|
|
1362
|
+
|
|
1363
|
+
return max_diff(*args, **kwargs)
|
|
1364
|
+
|
|
1365
|
+
def use_dyn_not_str(self, *args, **kwargs):
|
|
1366
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
1367
|
+
|
|
1368
|
+
return use_dyn_not_str(*args, *kwargs)
|
|
1369
|
+
|
|
1301
1370
|
def subloop(self, *args, verbose: int = 0):
|
|
1302
1371
|
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
|
|
1303
1372
|
if len(args) == 1:
|
|
@@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
80
80
|
start = 0
|
|
81
81
|
end = 0
|
|
82
82
|
subtrees = []
|
|
83
|
-
for subspec in spec.children_specs:
|
|
83
|
+
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
|
|
84
84
|
end += subspec.num_leaves
|
|
85
85
|
value = subspec.unflatten(flat[start:end])
|
|
86
86
|
value = flatten_unflatten_for_dynamic_shapes(
|