onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +2 -2
- onnx_diagnostic/_command_lines_parser.py +39 -1
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/export/dynamic_shapes.py +14 -5
- onnx_diagnostic/ext_test_case.py +15 -1
- onnx_diagnostic/helpers/args_helper.py +1 -1
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +30 -5
- onnx_diagnostic/helpers/model_builder_helper.py +349 -0
- onnx_diagnostic/helpers/rt_helper.py +69 -1
- onnx_diagnostic/helpers/torch_helper.py +2 -0
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/torch_evaluator.py +518 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
- onnx_diagnostic/tasks/__init__.py +22 -1
- onnx_diagnostic/tasks/image_classification.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +3 -3
- onnx_diagnostic/torch_export_patches/eval/__init__.py +690 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +883 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +34 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +6 -1
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +148 -28
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +91 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +117 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/test_helper.py +225 -22
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +43 -24
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import io
|
|
3
|
+
import itertools
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
|
+
import numpy as np
|
|
7
|
+
import onnx
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def discover():
|
|
11
|
+
"""
|
|
12
|
+
Discovers all model cases used to evaluate an exporter.
|
|
13
|
+
|
|
14
|
+
.. runpython::
|
|
15
|
+
:showcode:
|
|
16
|
+
|
|
17
|
+
import pprint
|
|
18
|
+
from onnx_diagnostic.torch_export_patches.eval import discover
|
|
19
|
+
|
|
20
|
+
pprint.pprint(discover())
|
|
21
|
+
"""
|
|
22
|
+
from . import model_cases
|
|
23
|
+
|
|
24
|
+
res = {}
|
|
25
|
+
for m in model_cases.__dict__.values():
|
|
26
|
+
if m is None or isinstance(m, str):
|
|
27
|
+
continue
|
|
28
|
+
if not hasattr(m, "forward"):
|
|
29
|
+
continue
|
|
30
|
+
assert m.__name__ not in res, f"Case {m.__name__!r} is duplicated."
|
|
31
|
+
assert hasattr(m, "_inputs"), f"Attribute '_inputs' is missing from class {m}"
|
|
32
|
+
assert hasattr(m, "_dynamic"), f"Attribute '_dynamic' is missing from class {m}"
|
|
33
|
+
res[m.__name__] = m
|
|
34
|
+
return res
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def evaluation(
|
|
38
|
+
exporters: Tuple[str] = (
|
|
39
|
+
"export-strict",
|
|
40
|
+
"export-nostrict",
|
|
41
|
+
"export-nostrict-decall",
|
|
42
|
+
),
|
|
43
|
+
dynamic: Tuple[bool] = (False, True),
|
|
44
|
+
cases: Optional[Union[str, Dict[str, type]]] = None,
|
|
45
|
+
verbose: int = 0,
|
|
46
|
+
quiet: bool = True,
|
|
47
|
+
) -> List[Dict[str, Any]]:
|
|
48
|
+
"""
|
|
49
|
+
Evaluates exporter for a list of cases.
|
|
50
|
+
|
|
51
|
+
:param exporters: exporters to evaluate
|
|
52
|
+
:param dynamic: evaluate static shape and dynamic shapes
|
|
53
|
+
:param cases: model cases to evaluate
|
|
54
|
+
:param verbose: verbosity
|
|
55
|
+
:param quiet: catch exception
|
|
56
|
+
:return: results, list of dictionaries
|
|
57
|
+
"""
|
|
58
|
+
if isinstance(exporters, str):
|
|
59
|
+
exporters = (exporters,)
|
|
60
|
+
if isinstance(dynamic, (bool, int)):
|
|
61
|
+
dynamic = (dynamic,)
|
|
62
|
+
|
|
63
|
+
if cases is None:
|
|
64
|
+
cases = discover()
|
|
65
|
+
elif cases in ("three", ["three"]):
|
|
66
|
+
all_cases = discover()
|
|
67
|
+
cases = dict(list(all_cases.items())[:3])
|
|
68
|
+
elif isinstance(cases, str):
|
|
69
|
+
cases = (cases,)
|
|
70
|
+
|
|
71
|
+
if isinstance(cases, (list, tuple)):
|
|
72
|
+
all_cases = discover()
|
|
73
|
+
new_cases = [] # type: ignore[var-annotated]
|
|
74
|
+
for c in cases:
|
|
75
|
+
if "*" in c or "?" in c:
|
|
76
|
+
# regex
|
|
77
|
+
reg = re.compile(c)
|
|
78
|
+
new_cases.extend(k for k in all_cases if reg.match(k))
|
|
79
|
+
else:
|
|
80
|
+
new_cases.append(c)
|
|
81
|
+
cases = {k: v for k, v in all_cases.items() if k in set(new_cases)}
|
|
82
|
+
|
|
83
|
+
sorted_cases = sorted(cases.items())
|
|
84
|
+
loop = list(itertools.product(sorted_cases, dynamic, exporters))
|
|
85
|
+
if verbose:
|
|
86
|
+
try:
|
|
87
|
+
import tqdm
|
|
88
|
+
|
|
89
|
+
loop = tqdm.tqdm(loop)
|
|
90
|
+
except ImportError:
|
|
91
|
+
|
|
92
|
+
def _loop():
|
|
93
|
+
for _ in loop:
|
|
94
|
+
print(f"[evaluation] {_}")
|
|
95
|
+
yield _
|
|
96
|
+
|
|
97
|
+
assert len(loop) > 0, f"No case to test for cases={cases!r}."
|
|
98
|
+
obs = []
|
|
99
|
+
for case, dyn, exporter in loop:
|
|
100
|
+
name, cls_model = case
|
|
101
|
+
res = run_exporter(exporter, cls_model, dyn, quiet=quiet, verbose=max(0, verbose - 1))
|
|
102
|
+
res.update(dict(name=name, dynamic=int(dyn), exporter=exporter))
|
|
103
|
+
obs.append(res)
|
|
104
|
+
return obs
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821
|
|
108
|
+
"""
|
|
109
|
+
Flatten inputs.
|
|
110
|
+
"""
|
|
111
|
+
if x is None:
|
|
112
|
+
return x
|
|
113
|
+
import torch
|
|
114
|
+
|
|
115
|
+
if isinstance(x, (list, tuple)):
|
|
116
|
+
res = []
|
|
117
|
+
for i in x:
|
|
118
|
+
if i is None or isinstance(
|
|
119
|
+
i,
|
|
120
|
+
(
|
|
121
|
+
torch.Tensor,
|
|
122
|
+
torch.SymInt,
|
|
123
|
+
torch.SymFloat,
|
|
124
|
+
int,
|
|
125
|
+
float,
|
|
126
|
+
),
|
|
127
|
+
):
|
|
128
|
+
res.append(i)
|
|
129
|
+
else:
|
|
130
|
+
res.extend(_flatten_inputs(i))
|
|
131
|
+
return tuple(res) if isinstance(x, tuple) else res
|
|
132
|
+
raise AssertionError(f"Unexpected type {type(x)} for x")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _to_numpy(x):
|
|
136
|
+
if hasattr(x, "numpy"):
|
|
137
|
+
return x.numpy()
|
|
138
|
+
if isinstance(x, int):
|
|
139
|
+
# onnxruntime does not like scalar
|
|
140
|
+
return np.array([x], dtype=np.int64)
|
|
141
|
+
if isinstance(x, float):
|
|
142
|
+
# onnxruntime does not like scalar
|
|
143
|
+
return np.array([x], dtype=np.float32)
|
|
144
|
+
if isinstance(x, list):
|
|
145
|
+
return [_to_numpy(_) for _ in x]
|
|
146
|
+
if isinstance(x, tuple):
|
|
147
|
+
return tuple(_to_numpy(_) for _ in x)
|
|
148
|
+
raise TypeError(f"Unable to convert type {type(x)}, x={x} into numpy")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _make_feeds(names, args):
|
|
152
|
+
if len(names) == len(args):
|
|
153
|
+
return {k: _to_numpy(v) for k, v in zip(names, args)}
|
|
154
|
+
if len(names) > len(args):
|
|
155
|
+
flats = _flatten_inputs(args)
|
|
156
|
+
return {k: _to_numpy(v) for k, v in zip(names, flats)}
|
|
157
|
+
from ...helpers import string_type
|
|
158
|
+
|
|
159
|
+
raise RuntimeError(
|
|
160
|
+
f"Unable to handle names={names!r} and args={string_type(args, limit=20)}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _clone(x):
|
|
165
|
+
if hasattr(x, "clone"):
|
|
166
|
+
return x.clone()
|
|
167
|
+
if isinstance(x, (int, float)):
|
|
168
|
+
return x
|
|
169
|
+
if isinstance(x, list):
|
|
170
|
+
return [_clone(_) for _ in x]
|
|
171
|
+
if isinstance(x, tuple):
|
|
172
|
+
return tuple(_clone(_) for _ in x)
|
|
173
|
+
raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy")
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _make_exporter_export(
|
|
177
|
+
exporter: str,
|
|
178
|
+
model: "torch.nn.Module", # noqa: F821
|
|
179
|
+
inputs: Tuple[Any, ...],
|
|
180
|
+
dynamic_shapes: Optional[Any] = None,
|
|
181
|
+
verbose: int = 0,
|
|
182
|
+
quiet: bool = True,
|
|
183
|
+
) -> Union[Dict, Callable]:
|
|
184
|
+
import torch
|
|
185
|
+
|
|
186
|
+
if exporter == "export-strict":
|
|
187
|
+
try:
|
|
188
|
+
if verbose >= 2:
|
|
189
|
+
exported = torch.export.export(
|
|
190
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
194
|
+
io.StringIO()
|
|
195
|
+
):
|
|
196
|
+
exported = torch.export.export(
|
|
197
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
198
|
+
)
|
|
199
|
+
except Exception as e:
|
|
200
|
+
if not quiet:
|
|
201
|
+
raise
|
|
202
|
+
return dict(error=str(e), success=0, error_step="export")
|
|
203
|
+
if verbose >= 9:
|
|
204
|
+
print("-- graph")
|
|
205
|
+
print(exported.graph)
|
|
206
|
+
return exported.module()
|
|
207
|
+
if exporter in ("export-strict-dec", "export-strict-decall"):
|
|
208
|
+
try:
|
|
209
|
+
if verbose >= 2:
|
|
210
|
+
exported = torch.export.export(
|
|
211
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
212
|
+
)
|
|
213
|
+
if verbose >= 9:
|
|
214
|
+
print("-- graph before decomposition")
|
|
215
|
+
print(exported.graph)
|
|
216
|
+
exported = (
|
|
217
|
+
exported.run_decompositions()
|
|
218
|
+
if "decall" in exporter
|
|
219
|
+
else exported.run_decompositions({})
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
223
|
+
io.StringIO()
|
|
224
|
+
):
|
|
225
|
+
exported = torch.export.export(
|
|
226
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
227
|
+
)
|
|
228
|
+
if verbose >= 9:
|
|
229
|
+
print("-- graph before decomposition")
|
|
230
|
+
print(exported.graph)
|
|
231
|
+
exported = (
|
|
232
|
+
exported.run_decompositions()
|
|
233
|
+
if "decall" in exporter
|
|
234
|
+
else exported.run_decompositions({})
|
|
235
|
+
)
|
|
236
|
+
except Exception as e:
|
|
237
|
+
if not quiet:
|
|
238
|
+
raise
|
|
239
|
+
return dict(error=str(e), success=0, error_step="export")
|
|
240
|
+
if verbose >= 9:
|
|
241
|
+
print("-- graph after decomposition")
|
|
242
|
+
print(exported.graph)
|
|
243
|
+
return exported.module()
|
|
244
|
+
if exporter == "export-nostrict":
|
|
245
|
+
try:
|
|
246
|
+
if verbose >= 2:
|
|
247
|
+
exported = torch.export.export(
|
|
248
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
252
|
+
io.StringIO()
|
|
253
|
+
):
|
|
254
|
+
exported = torch.export.export(
|
|
255
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
|
|
256
|
+
)
|
|
257
|
+
except Exception as e:
|
|
258
|
+
if not quiet:
|
|
259
|
+
raise
|
|
260
|
+
return dict(error=str(e), success=0, error_step="export")
|
|
261
|
+
if verbose >= 9:
|
|
262
|
+
print("-- graph")
|
|
263
|
+
print(exported.graph)
|
|
264
|
+
return exported.module()
|
|
265
|
+
if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
|
|
266
|
+
try:
|
|
267
|
+
if verbose >= 2:
|
|
268
|
+
exported = torch.export.export(
|
|
269
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
|
|
270
|
+
)
|
|
271
|
+
if verbose >= 9:
|
|
272
|
+
print("-- graph before decomposition")
|
|
273
|
+
print(exported.graph)
|
|
274
|
+
exported = (
|
|
275
|
+
exported.run_decompositions()
|
|
276
|
+
if "decall" in exporter
|
|
277
|
+
else exported.run_decompositions({})
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
281
|
+
io.StringIO()
|
|
282
|
+
):
|
|
283
|
+
exported = torch.export.export(
|
|
284
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
|
|
285
|
+
)
|
|
286
|
+
if verbose >= 9:
|
|
287
|
+
print("-- graph before decomposition")
|
|
288
|
+
print(exported.graph)
|
|
289
|
+
exported = (
|
|
290
|
+
exported.run_decompositions()
|
|
291
|
+
if "decall" in exporter
|
|
292
|
+
else exported.run_decompositions({})
|
|
293
|
+
)
|
|
294
|
+
except Exception as e:
|
|
295
|
+
if not quiet:
|
|
296
|
+
raise
|
|
297
|
+
return dict(error=str(e), success=0, error_step="export")
|
|
298
|
+
if verbose >= 9:
|
|
299
|
+
print("-- graph after decomposition")
|
|
300
|
+
print(exported.graph)
|
|
301
|
+
return exported.module()
|
|
302
|
+
if exporter == "export-tracing":
|
|
303
|
+
from experimental_experiment.torch_interpreter.tracing import CustomTracer
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
if verbose >= 2:
|
|
307
|
+
graph = CustomTracer().trace(model)
|
|
308
|
+
mod = torch.fx.GraphModule(model, graph)
|
|
309
|
+
else:
|
|
310
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
311
|
+
io.StringIO()
|
|
312
|
+
):
|
|
313
|
+
graph = CustomTracer().trace(model)
|
|
314
|
+
mod = torch.fx.GraphModule(model, graph)
|
|
315
|
+
except Exception as e:
|
|
316
|
+
if not quiet:
|
|
317
|
+
raise
|
|
318
|
+
return dict(error=str(e), success=0, error_step="export")
|
|
319
|
+
if verbose >= 9:
|
|
320
|
+
print("-- graph")
|
|
321
|
+
print(graph)
|
|
322
|
+
return mod
|
|
323
|
+
raise AssertionError(f"Unexpected exporter={exporter!r}")
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _make_exporter_onnx(
|
|
327
|
+
exporter: str,
|
|
328
|
+
model: "torch.nn.Module", # noqa: F821
|
|
329
|
+
inputs: Tuple[Any, ...],
|
|
330
|
+
dynamic_shapes: Optional[Any] = None,
|
|
331
|
+
verbose: int = 0,
|
|
332
|
+
quiet: bool = True,
|
|
333
|
+
) -> Union[Dict, Tuple[onnx.ModelProto, Any]]:
|
|
334
|
+
from ...helpers import string_type
|
|
335
|
+
|
|
336
|
+
if exporter.startswith("custom"):
|
|
337
|
+
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
|
|
338
|
+
|
|
339
|
+
opts = {}
|
|
340
|
+
opts["strict"] = "-nostrict" not in exporter
|
|
341
|
+
opts["fallback"] = "-fallback" in exporter
|
|
342
|
+
opts["tracing"] = "-tracing" in exporter
|
|
343
|
+
opts["jit"] = "-jit" in exporter
|
|
344
|
+
if "-dec" in exporter:
|
|
345
|
+
opts["decomposition_table"] = "all" if "-decall" in exporter else "default"
|
|
346
|
+
try:
|
|
347
|
+
if verbose >= 2:
|
|
348
|
+
onx, builder = to_onnx(
|
|
349
|
+
model,
|
|
350
|
+
inputs,
|
|
351
|
+
dynamic_shapes=dynamic_shapes,
|
|
352
|
+
export_options=ExportOptions(**opts),
|
|
353
|
+
return_builder=True,
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
357
|
+
io.StringIO()
|
|
358
|
+
):
|
|
359
|
+
onx, builder = to_onnx(
|
|
360
|
+
model,
|
|
361
|
+
inputs,
|
|
362
|
+
dynamic_shapes=dynamic_shapes,
|
|
363
|
+
export_options=ExportOptions(**opts),
|
|
364
|
+
return_builder=True,
|
|
365
|
+
)
|
|
366
|
+
except Exception as e:
|
|
367
|
+
if not quiet:
|
|
368
|
+
raise RuntimeError(
|
|
369
|
+
f"Unable to convert model={model.__class__.__name__}, "
|
|
370
|
+
f"input={string_type(inputs[0], with_shape=True)}, "
|
|
371
|
+
f"dynamic_shapes={dynamic_shapes}, "
|
|
372
|
+
f"exporter={exporter!r}"
|
|
373
|
+
) from e
|
|
374
|
+
return dict(error=str(e), success=0, error_step="export")
|
|
375
|
+
return onx, builder
|
|
376
|
+
|
|
377
|
+
if exporter == "dynamo":
|
|
378
|
+
import torch
|
|
379
|
+
|
|
380
|
+
try:
|
|
381
|
+
if verbose >= 2:
|
|
382
|
+
onx = torch.onnx.export(
|
|
383
|
+
model,
|
|
384
|
+
inputs,
|
|
385
|
+
dynamic_shapes=dynamic_shapes,
|
|
386
|
+
dynamo=True,
|
|
387
|
+
report=True,
|
|
388
|
+
).model_proto
|
|
389
|
+
else:
|
|
390
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
391
|
+
io.StringIO()
|
|
392
|
+
):
|
|
393
|
+
onx = torch.onnx.export(
|
|
394
|
+
model,
|
|
395
|
+
inputs,
|
|
396
|
+
dynamic_shapes=dynamic_shapes,
|
|
397
|
+
dynamo=True,
|
|
398
|
+
).model_proto
|
|
399
|
+
except Exception as e:
|
|
400
|
+
if not quiet:
|
|
401
|
+
raise RuntimeError(
|
|
402
|
+
f"Unable to convert model={model.__class__.__name__}, "
|
|
403
|
+
f"input={string_type(inputs[0], with_shape=True)}, "
|
|
404
|
+
f"dynamic_shapes={dynamic_shapes}, "
|
|
405
|
+
f"exporter={exporter!r}"
|
|
406
|
+
) from e
|
|
407
|
+
return dict(error=str(e), success=0, error_step="export")
|
|
408
|
+
return onx, None
|
|
409
|
+
|
|
410
|
+
if exporter == "dynamo-ir":
|
|
411
|
+
import torch
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
if verbose >= 2:
|
|
415
|
+
ep = torch.onnx.export(
|
|
416
|
+
model,
|
|
417
|
+
inputs,
|
|
418
|
+
dynamic_shapes=dynamic_shapes,
|
|
419
|
+
dynamo=True,
|
|
420
|
+
report=True,
|
|
421
|
+
)
|
|
422
|
+
ep.optimize()
|
|
423
|
+
onx = ep.model_proto
|
|
424
|
+
else:
|
|
425
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
426
|
+
io.StringIO()
|
|
427
|
+
):
|
|
428
|
+
ep = torch.onnx.export(
|
|
429
|
+
model,
|
|
430
|
+
inputs,
|
|
431
|
+
dynamic_shapes=dynamic_shapes,
|
|
432
|
+
dynamo=True,
|
|
433
|
+
)
|
|
434
|
+
ep.optimize()
|
|
435
|
+
onx = ep.model_proto
|
|
436
|
+
except Exception as e:
|
|
437
|
+
if not quiet:
|
|
438
|
+
raise RuntimeError(
|
|
439
|
+
f"Unable to convert model={model.__class__.__name__}, "
|
|
440
|
+
f"input={string_type(inputs[0], with_shape=True)}, "
|
|
441
|
+
f"dynamic_shapes={dynamic_shapes}, "
|
|
442
|
+
f"exporter={exporter!r}"
|
|
443
|
+
) from e
|
|
444
|
+
return dict(error=str(e), success=0, error_step="export")
|
|
445
|
+
return onx, None
|
|
446
|
+
raise AssertionError(f"Unexpected exporter={exporter!r}")
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def run_exporter(
|
|
450
|
+
exporter: str,
|
|
451
|
+
cls_model: type,
|
|
452
|
+
dynamic: bool = False,
|
|
453
|
+
quiet: bool = False,
|
|
454
|
+
verbose: int = 0,
|
|
455
|
+
) -> Dict[str, Any]:
|
|
456
|
+
"""
|
|
457
|
+
Runs an exporter and returns whether it fails or not.
|
|
458
|
+
|
|
459
|
+
:param exporter: exporter
|
|
460
|
+
:param cls_model: model class to create
|
|
461
|
+
:param inputs: list of inputs to try
|
|
462
|
+
:param dynamic: use dynamic shape or not
|
|
463
|
+
:param quiet: raise exception or not
|
|
464
|
+
:param verbose: verbosity
|
|
465
|
+
:return: results
|
|
466
|
+
"""
|
|
467
|
+
from onnx_diagnostic.helpers import max_diff, string_type
|
|
468
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
469
|
+
|
|
470
|
+
assert hasattr(
|
|
471
|
+
cls_model, "_inputs"
|
|
472
|
+
), f"Attribute '_inputs' is missing from class {cls_model}"
|
|
473
|
+
|
|
474
|
+
model = cls_model()
|
|
475
|
+
inputs = cls_model._inputs
|
|
476
|
+
if isinstance(inputs, tuple):
|
|
477
|
+
inputs = [inputs]
|
|
478
|
+
if dynamic:
|
|
479
|
+
assert hasattr(
|
|
480
|
+
cls_model, "_dynamic"
|
|
481
|
+
), f"Attribute '_inputs' is missing from class {cls_model}"
|
|
482
|
+
dynamic_shapes = cls_model._dynamic
|
|
483
|
+
else:
|
|
484
|
+
dynamic_shapes = None
|
|
485
|
+
|
|
486
|
+
base = dict(inputs=inputs, model=model, dynamic_shapes=dynamic_shapes)
|
|
487
|
+
|
|
488
|
+
if verbose > 0:
|
|
489
|
+
print(
|
|
490
|
+
f"[run_exporter] exporter={exporter}, model={cls_model.__name__}, "
|
|
491
|
+
f"dynamic={dynamic}, inputs={string_type(inputs, with_shape=True)}"
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
builder = None
|
|
495
|
+
onx = None
|
|
496
|
+
|
|
497
|
+
if exporter.startswith("export-"):
|
|
498
|
+
mod = _make_exporter_export(
|
|
499
|
+
exporter,
|
|
500
|
+
model,
|
|
501
|
+
inputs[0],
|
|
502
|
+
dynamic_shapes=dynamic_shapes,
|
|
503
|
+
verbose=verbose,
|
|
504
|
+
quiet=quiet,
|
|
505
|
+
)
|
|
506
|
+
if isinstance(mod, dict):
|
|
507
|
+
# something went wrong
|
|
508
|
+
return mod
|
|
509
|
+
else:
|
|
510
|
+
res = _make_exporter_onnx(
|
|
511
|
+
exporter,
|
|
512
|
+
model,
|
|
513
|
+
inputs[0],
|
|
514
|
+
dynamic_shapes=dynamic_shapes,
|
|
515
|
+
verbose=verbose,
|
|
516
|
+
quiet=quiet,
|
|
517
|
+
)
|
|
518
|
+
if isinstance(res, dict):
|
|
519
|
+
# something went wrong
|
|
520
|
+
return res
|
|
521
|
+
|
|
522
|
+
onx, builder = res
|
|
523
|
+
if verbose >= 9:
|
|
524
|
+
print("[run_exporter] onnx model")
|
|
525
|
+
print(
|
|
526
|
+
builder.pretty_text(add_fx_graph=True)
|
|
527
|
+
if builder is not None
|
|
528
|
+
else pretty_onnx(onx)
|
|
529
|
+
)
|
|
530
|
+
if verbose >= 2:
|
|
531
|
+
onnx.save(onx, f"evaluation-{model.__class__.__name__}-{dynamic}-{exporter}.onnx")
|
|
532
|
+
|
|
533
|
+
names = [i.name for i in onx.graph.input]
|
|
534
|
+
flats = _flatten_inputs(inputs[0]) if len(names) > len(inputs[0]) else inputs[0]
|
|
535
|
+
|
|
536
|
+
assert quiet or len(names) == len(flats), (
|
|
537
|
+
f"Input mismatch, inputs[0]={string_type(inputs[0])} "
|
|
538
|
+
f"inputs but names={names!r}, "
|
|
539
|
+
f"model={cls_model.__name__}, export={exporter!r}"
|
|
540
|
+
)
|
|
541
|
+
if len(names) != len(flats):
|
|
542
|
+
res = dict(
|
|
543
|
+
error=f"Input mismatch, inputs[0]={string_type(inputs[0])} "
|
|
544
|
+
f"but names={names!r}, model={cls_model.__name__}, export={exporter!r}",
|
|
545
|
+
success=0,
|
|
546
|
+
error_step="inputs",
|
|
547
|
+
)
|
|
548
|
+
res.update(base)
|
|
549
|
+
return res
|
|
550
|
+
|
|
551
|
+
import onnxruntime
|
|
552
|
+
|
|
553
|
+
try:
|
|
554
|
+
sess = onnxruntime.InferenceSession(
|
|
555
|
+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
|
|
556
|
+
)
|
|
557
|
+
except Exception as e:
|
|
558
|
+
if not quiet:
|
|
559
|
+
raise
|
|
560
|
+
res = dict(error=str(e), success=0, error_step="ort-init")
|
|
561
|
+
res.update(base)
|
|
562
|
+
return res
|
|
563
|
+
|
|
564
|
+
mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731
|
|
565
|
+
|
|
566
|
+
# we need to clone for models modifying the inputs
|
|
567
|
+
try:
|
|
568
|
+
expected = model(*_clone(inputs[0]))
|
|
569
|
+
except Exception as e:
|
|
570
|
+
if not quiet:
|
|
571
|
+
raise RuntimeError(
|
|
572
|
+
f"eager mode failed=\n{string_type(inputs[0], with_shape=True)} "
|
|
573
|
+
f"\nmodel=\n{type(model)}"
|
|
574
|
+
) from e
|
|
575
|
+
res = dict(error=str(e), success=0, error_step="eager")
|
|
576
|
+
res.update(base)
|
|
577
|
+
return res
|
|
578
|
+
try:
|
|
579
|
+
got = mod(*inputs[0])
|
|
580
|
+
except Exception as e:
|
|
581
|
+
if not quiet:
|
|
582
|
+
raise RuntimeError(
|
|
583
|
+
f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} "
|
|
584
|
+
f"\nmodel=\n{pretty_onnx(onx)}"
|
|
585
|
+
) from e
|
|
586
|
+
res = dict(error=str(e), success=0, error_step="run.0")
|
|
587
|
+
res.update(base)
|
|
588
|
+
return res
|
|
589
|
+
|
|
590
|
+
base["expected"] = expected
|
|
591
|
+
base["obtained"] = got
|
|
592
|
+
|
|
593
|
+
try:
|
|
594
|
+
disc = max_diff(expected, got)
|
|
595
|
+
except Exception as e:
|
|
596
|
+
if not quiet:
|
|
597
|
+
raise
|
|
598
|
+
res = dict(error=str(e), success=0, error_step="discrepancy")
|
|
599
|
+
res.update(base)
|
|
600
|
+
return res
|
|
601
|
+
|
|
602
|
+
if verbose >= 5 and np.isinf(disc["abs"]):
|
|
603
|
+
print("[run_exporter] comparison issues with")
|
|
604
|
+
print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
|
|
605
|
+
print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
|
|
606
|
+
print(f"-- got={string_type(got, with_shape=True, limit=20)}")
|
|
607
|
+
elif verbose >= 9:
|
|
608
|
+
print("[run_exporter] inputs and outputs")
|
|
609
|
+
print(
|
|
610
|
+
f"-- inputs="
|
|
611
|
+
f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
|
|
612
|
+
)
|
|
613
|
+
print(
|
|
614
|
+
f"-- expected="
|
|
615
|
+
f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
|
|
616
|
+
)
|
|
617
|
+
print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
|
|
618
|
+
del disc["n"]
|
|
619
|
+
del disc["sum"]
|
|
620
|
+
disc.update(
|
|
621
|
+
dict(
|
|
622
|
+
success=1 if disc["abs"] < 0.1 else 0,
|
|
623
|
+
model_cls=model.__class__,
|
|
624
|
+
exported=mod, # type: ignore[dict-item]
|
|
625
|
+
onnx=onx, # type: ignore[dict-item]
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
if disc["abs"] >= 0.1:
|
|
629
|
+
disc["error"] = "diff.0"
|
|
630
|
+
disc["error_step"] = "diff.0"
|
|
631
|
+
if verbose >= 9:
|
|
632
|
+
max_diff(expected, got, verbose=verbose)
|
|
633
|
+
else:
|
|
634
|
+
disc["success"] = 1
|
|
635
|
+
|
|
636
|
+
if dynamic and onx is not None:
|
|
637
|
+
ds = []
|
|
638
|
+
for i in onx.graph.input:
|
|
639
|
+
if i.type.tensor_type:
|
|
640
|
+
for di, dim in enumerate(i.type.tensor_type.shape.dim):
|
|
641
|
+
if dim.dim_param:
|
|
642
|
+
ds.append((i.name, di, dim.dim_param))
|
|
643
|
+
if verbose >= 2:
|
|
644
|
+
print(f"[run_exporter] dynamic dimension={ds}")
|
|
645
|
+
if not ds:
|
|
646
|
+
return dict(error="no dynamic shape", success=0, error_step="dynamic")
|
|
647
|
+
|
|
648
|
+
if dynamic and len(inputs) > 1:
|
|
649
|
+
for index, i in enumerate(inputs):
|
|
650
|
+
expected = model(*_clone(i))
|
|
651
|
+
try:
|
|
652
|
+
got = mod(*i)
|
|
653
|
+
except Exception as e:
|
|
654
|
+
if not quiet:
|
|
655
|
+
raise RuntimeError(
|
|
656
|
+
f"onnxruntime failed,\n-- feeds=\n{string_type(i, with_shape=True)} "
|
|
657
|
+
f"exporter={exporter!r}, dynamic_shapes={dynamic_shapes}"
|
|
658
|
+
f"\n-- model=\n{pretty_onnx(onx) if onx is not None else type(model)}"
|
|
659
|
+
) from e
|
|
660
|
+
return dict(error=str(e), success=0, error_step=f"run.{index}")
|
|
661
|
+
|
|
662
|
+
try:
|
|
663
|
+
d = max_diff(expected, got)
|
|
664
|
+
except Exception as e:
|
|
665
|
+
if not quiet:
|
|
666
|
+
raise
|
|
667
|
+
return dict(error=str(e), success=0, error_step=f"discrepancy.{index}")
|
|
668
|
+
|
|
669
|
+
if verbose >= 5 and np.isinf(d["abs"]):
|
|
670
|
+
print(f"[run_exporter] comparison issues iteration {index}")
|
|
671
|
+
print(f"-- inputs={string_type(i, with_shape=True)}")
|
|
672
|
+
print(f"-- expected={string_type(expected, with_shape=True)}")
|
|
673
|
+
print(f"-- got={string_type(got, with_shape=True)}")
|
|
674
|
+
elif verbose >= 9:
|
|
675
|
+
print(f"[run_exporter] inputs and outputs iteration {index}")
|
|
676
|
+
print(f"-- inputs={string_type(i, with_shape=True, with_min_max=True)}")
|
|
677
|
+
print(
|
|
678
|
+
f"-- expected={string_type(expected, with_shape=True, with_min_max=True)}"
|
|
679
|
+
)
|
|
680
|
+
print(f"-- got={string_type(got, with_shape=True, with_min_max=True)}")
|
|
681
|
+
del d["n"]
|
|
682
|
+
del d["sum"]
|
|
683
|
+
if d["abs"] >= 0.1:
|
|
684
|
+
d["error"] = f"diff.{index}"
|
|
685
|
+
d["error_step"] = f"diff.{index}"
|
|
686
|
+
d["success"] = 0
|
|
687
|
+
disc.update(d)
|
|
688
|
+
|
|
689
|
+
disc.update(base)
|
|
690
|
+
return disc
|