onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1290 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The module contains the main class ``ExtTestCase`` which adds
|
|
3
|
+
specific functionalities to this project.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import copy
|
|
7
|
+
import glob
|
|
8
|
+
import itertools
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import re
|
|
12
|
+
import sys
|
|
13
|
+
import unittest
|
|
14
|
+
import warnings
|
|
15
|
+
from contextlib import redirect_stderr, redirect_stdout
|
|
16
|
+
from io import StringIO
|
|
17
|
+
from timeit import Timer
|
|
18
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
19
|
+
import numpy
|
|
20
|
+
from numpy.testing import assert_allclose
|
|
21
|
+
|
|
22
|
+
BOOLEAN_VALUES = (1, "1", True, "True", "true", "TRUE")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_azure() -> bool:
|
|
26
|
+
"""Tells if the job is running on Azure DevOps."""
|
|
27
|
+
return os.environ.get("AZURE_HTTP_USER_AGENT", "undefined") != "undefined"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def is_windows() -> bool:
|
|
31
|
+
return sys.platform == "win32"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def is_apple() -> bool:
|
|
35
|
+
return sys.platform == "darwin"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def is_linux() -> bool:
|
|
39
|
+
return sys.platform == "linux"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def skipif_ci_windows(msg) -> Callable:
|
|
43
|
+
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
|
|
44
|
+
if is_windows() and is_azure():
|
|
45
|
+
msg = f"Test does not work on azure pipeline (Windows). {msg}"
|
|
46
|
+
return unittest.skip(msg)
|
|
47
|
+
return lambda x: x
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def skipif_ci_linux(msg) -> Callable:
|
|
51
|
+
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Linux`."""
|
|
52
|
+
if is_linux() and is_azure():
|
|
53
|
+
msg = f"Takes too long (Linux). {msg}"
|
|
54
|
+
return unittest.skip(msg)
|
|
55
|
+
return lambda x: x
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def skipif_ci_apple(msg) -> Callable:
|
|
59
|
+
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
|
|
60
|
+
if is_apple() and is_azure():
|
|
61
|
+
msg = f"Test does not work on azure pipeline (Apple). {msg}"
|
|
62
|
+
return unittest.skip(msg)
|
|
63
|
+
return lambda x: x
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def unit_test_going():
|
|
67
|
+
"""
|
|
68
|
+
Enables a flag telling the script is running while testing it.
|
|
69
|
+
Avois unit tests to be very long.
|
|
70
|
+
"""
|
|
71
|
+
going = int(os.environ.get("UNITTEST_GOING", 0))
|
|
72
|
+
return going == 1
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def ignore_warnings(warns: List[Warning]) -> Callable:
|
|
76
|
+
"""
|
|
77
|
+
Catches warnings.
|
|
78
|
+
|
|
79
|
+
:param warns: warnings to ignore
|
|
80
|
+
"""
|
|
81
|
+
if not isinstance(warns, (tuple, list)):
|
|
82
|
+
warns = (warns,)
|
|
83
|
+
new_list = []
|
|
84
|
+
for w in warns:
|
|
85
|
+
if w == "TracerWarning":
|
|
86
|
+
from torch.jit import TracerWarning
|
|
87
|
+
|
|
88
|
+
new_list.append(TracerWarning)
|
|
89
|
+
else:
|
|
90
|
+
new_list.append(w)
|
|
91
|
+
warns = tuple(new_list)
|
|
92
|
+
|
|
93
|
+
def wrapper(fct):
|
|
94
|
+
if warns is None:
|
|
95
|
+
raise AssertionError(f"warns cannot be None for '{fct}'.")
|
|
96
|
+
|
|
97
|
+
def call_f(self):
|
|
98
|
+
with warnings.catch_warnings():
|
|
99
|
+
warnings.simplefilter("ignore", warns)
|
|
100
|
+
return fct(self)
|
|
101
|
+
|
|
102
|
+
try: # noqa: SIM105
|
|
103
|
+
call_f.__name__ = fct.__name__
|
|
104
|
+
except AttributeError:
|
|
105
|
+
pass
|
|
106
|
+
return call_f
|
|
107
|
+
|
|
108
|
+
return wrapper
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def ignore_errors(errors: Union[Exception, Tuple[Exception]]) -> Callable:
|
|
112
|
+
"""
|
|
113
|
+
Catches exception, skip the test if the error is expected sometimes.
|
|
114
|
+
|
|
115
|
+
:param errors: errors to ignore
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def wrapper(fct):
|
|
119
|
+
if errors is None:
|
|
120
|
+
raise AssertionError(f"errors cannot be None for '{fct}'.")
|
|
121
|
+
|
|
122
|
+
def call_f(self):
|
|
123
|
+
try:
|
|
124
|
+
return fct(self)
|
|
125
|
+
except errors as e:
|
|
126
|
+
raise unittest.SkipTest( # noqa: B904
|
|
127
|
+
f"expecting error {e.__class__.__name__}: {e}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
try: # noqa: SIM105
|
|
131
|
+
call_f.__name__ = fct.__name__
|
|
132
|
+
except AttributeError:
|
|
133
|
+
pass
|
|
134
|
+
return call_f
|
|
135
|
+
|
|
136
|
+
return wrapper
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def hide_stdout(f: Optional[Callable] = None) -> Callable:
|
|
140
|
+
"""
|
|
141
|
+
Catches warnings, hides standard output.
|
|
142
|
+
The function may be disabled by setting ``UNHIDE=1``
|
|
143
|
+
before running the unit test.
|
|
144
|
+
|
|
145
|
+
:param f: the function is called with the stdout as an argument
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def wrapper(fct):
|
|
149
|
+
def call_f(self):
|
|
150
|
+
if os.environ.get("UNHIDE", ""):
|
|
151
|
+
fct(self)
|
|
152
|
+
return
|
|
153
|
+
st = StringIO()
|
|
154
|
+
with redirect_stdout(st), warnings.catch_warnings():
|
|
155
|
+
warnings.simplefilter("ignore", (UserWarning, DeprecationWarning))
|
|
156
|
+
try:
|
|
157
|
+
fct(self)
|
|
158
|
+
except AssertionError as e:
|
|
159
|
+
if "torch is not recent enough, file" in str(e):
|
|
160
|
+
raise unittest.SkipTest(str(e)) # noqa: B904
|
|
161
|
+
raise
|
|
162
|
+
if f is not None:
|
|
163
|
+
f(st.getvalue())
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
try: # noqa: SIM105
|
|
167
|
+
call_f.__name__ = fct.__name__
|
|
168
|
+
except AttributeError:
|
|
169
|
+
pass
|
|
170
|
+
return call_f
|
|
171
|
+
|
|
172
|
+
return wrapper
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def long_test(msg: str = "") -> Callable:
|
|
176
|
+
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
|
|
177
|
+
if os.environ.get("LONGTEST", "0") in ("0", 0, False, "False", "false"):
|
|
178
|
+
msg = f"Skipped (set LONGTEST=1 to run it. {msg}"
|
|
179
|
+
return unittest.skip(msg)
|
|
180
|
+
return lambda x: x
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def never_test(msg: str = "") -> Callable:
|
|
184
|
+
"""Skips a unit test."""
|
|
185
|
+
if os.environ.get("NEVERTEST", "0") in ("0", 0, False, "False", "false"):
|
|
186
|
+
msg = f"Skipped (set NEVERTEST=1 to run it. {msg}"
|
|
187
|
+
return unittest.skip(msg)
|
|
188
|
+
return lambda x: x
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def measure_time(
|
|
192
|
+
stmt: Union[str, Callable],
|
|
193
|
+
context: Optional[Dict[str, Any]] = None,
|
|
194
|
+
repeat: int = 10,
|
|
195
|
+
number: int = 50,
|
|
196
|
+
warmup: int = 1,
|
|
197
|
+
div_by_number: bool = True,
|
|
198
|
+
max_time: Optional[float] = None,
|
|
199
|
+
) -> Dict[str, Union[str, int, float]]:
|
|
200
|
+
"""
|
|
201
|
+
Measures a statement and returns the results as a dictionary.
|
|
202
|
+
|
|
203
|
+
:param stmt: string or callable
|
|
204
|
+
:param context: variable to know in a dictionary
|
|
205
|
+
:param repeat: average over *repeat* experiment
|
|
206
|
+
:param number: number of executions in one row
|
|
207
|
+
:param warmup: number of iteration to do before starting the
|
|
208
|
+
real measurement
|
|
209
|
+
:param div_by_number: divide by the number of executions
|
|
210
|
+
:param max_time: execute the statement until the total goes
|
|
211
|
+
beyond this time (approximately), *repeat* is ignored,
|
|
212
|
+
*div_by_number* must be set to True
|
|
213
|
+
:return: dictionary
|
|
214
|
+
|
|
215
|
+
.. runpython::
|
|
216
|
+
:showcode:
|
|
217
|
+
|
|
218
|
+
from pprint import pprint
|
|
219
|
+
from math import cos
|
|
220
|
+
from onnx_diagnostic.ext_test_case import measure_time
|
|
221
|
+
|
|
222
|
+
res = measure_time(lambda: cos(0.5))
|
|
223
|
+
pprint(res)
|
|
224
|
+
|
|
225
|
+
See `Timer.repeat <https://docs.python.org/3/library/
|
|
226
|
+
timeit.html?timeit.Timer.repeat>`_
|
|
227
|
+
for a better understanding of parameter *repeat* and *number*.
|
|
228
|
+
The function returns a duration corresponding to
|
|
229
|
+
*number* times the execution of the main statement.
|
|
230
|
+
"""
|
|
231
|
+
if not callable(stmt) and not isinstance(stmt, str):
|
|
232
|
+
raise TypeError(f"stmt is not callable or a string but is of type {type(stmt)!r}.")
|
|
233
|
+
if context is None:
|
|
234
|
+
context = {}
|
|
235
|
+
|
|
236
|
+
if isinstance(stmt, str):
|
|
237
|
+
tim = Timer(stmt, globals=context)
|
|
238
|
+
else:
|
|
239
|
+
tim = Timer(stmt)
|
|
240
|
+
|
|
241
|
+
if warmup > 0:
|
|
242
|
+
warmup_time = tim.timeit(warmup)
|
|
243
|
+
else:
|
|
244
|
+
warmup_time = 0
|
|
245
|
+
|
|
246
|
+
if max_time is not None:
|
|
247
|
+
if not div_by_number:
|
|
248
|
+
raise ValueError("div_by_number must be set to True of max_time is defined.")
|
|
249
|
+
i = 1
|
|
250
|
+
total_time = 0.0
|
|
251
|
+
results = []
|
|
252
|
+
while True:
|
|
253
|
+
for j in (1, 2):
|
|
254
|
+
number = i * j
|
|
255
|
+
time_taken = tim.timeit(number)
|
|
256
|
+
results.append((number, time_taken))
|
|
257
|
+
total_time += time_taken
|
|
258
|
+
if total_time >= max_time:
|
|
259
|
+
break
|
|
260
|
+
if total_time >= max_time:
|
|
261
|
+
break
|
|
262
|
+
ratio = (max_time - total_time) / total_time
|
|
263
|
+
ratio = max(ratio, 1)
|
|
264
|
+
i = int(i * ratio)
|
|
265
|
+
|
|
266
|
+
res = numpy.array(results)
|
|
267
|
+
tw = res[:, 0].sum()
|
|
268
|
+
ttime = res[:, 1].sum()
|
|
269
|
+
mean = ttime / tw
|
|
270
|
+
ave = res[:, 1] / res[:, 0]
|
|
271
|
+
dev = (((ave - mean) ** 2 * res[:, 0]).sum() / tw) ** 0.5
|
|
272
|
+
mes = dict(
|
|
273
|
+
average=mean,
|
|
274
|
+
deviation=dev,
|
|
275
|
+
min_exec=numpy.min(ave),
|
|
276
|
+
max_exec=numpy.max(ave),
|
|
277
|
+
repeat=1,
|
|
278
|
+
number=tw,
|
|
279
|
+
ttime=ttime,
|
|
280
|
+
)
|
|
281
|
+
else:
|
|
282
|
+
res = numpy.array(tim.repeat(repeat=repeat, number=number))
|
|
283
|
+
if div_by_number:
|
|
284
|
+
res /= number
|
|
285
|
+
|
|
286
|
+
mean = numpy.mean(res)
|
|
287
|
+
dev = numpy.mean(res**2)
|
|
288
|
+
dev = (dev - mean**2) ** 0.5
|
|
289
|
+
mes = dict(
|
|
290
|
+
average=mean,
|
|
291
|
+
deviation=dev,
|
|
292
|
+
min_exec=numpy.min(res),
|
|
293
|
+
max_exec=numpy.max(res),
|
|
294
|
+
repeat=repeat,
|
|
295
|
+
number=number,
|
|
296
|
+
ttime=res.sum(),
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if "values" in context:
|
|
300
|
+
if hasattr(context["values"], "shape"):
|
|
301
|
+
mes["size"] = context["values"].shape[0]
|
|
302
|
+
else:
|
|
303
|
+
mes["size"] = len(context["values"])
|
|
304
|
+
else:
|
|
305
|
+
mes["context_size"] = sys.getsizeof(context)
|
|
306
|
+
mes["warmup_time"] = warmup_time
|
|
307
|
+
return mes
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def statistics_on_folder(
|
|
311
|
+
folder: Union[str, List[str]],
|
|
312
|
+
pattern: str = ".*[.]((py|rst))$",
|
|
313
|
+
aggregation: int = 0,
|
|
314
|
+
) -> List[Dict[str, Union[int, float, str]]]:
|
|
315
|
+
"""
|
|
316
|
+
Computes statistics on files in a folder.
|
|
317
|
+
|
|
318
|
+
:param folder: folder or folders to investigate
|
|
319
|
+
:param pattern: file pattern
|
|
320
|
+
:param aggregation: show the first subfolders
|
|
321
|
+
:return: list of dictionaries
|
|
322
|
+
|
|
323
|
+
.. runpython::
|
|
324
|
+
:showcode:
|
|
325
|
+
:toggle:
|
|
326
|
+
|
|
327
|
+
import os
|
|
328
|
+
import pprint
|
|
329
|
+
from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__
|
|
330
|
+
|
|
331
|
+
pprint.pprint(statistics_on_folder(os.path.dirname(__file__)))
|
|
332
|
+
|
|
333
|
+
Aggregated:
|
|
334
|
+
|
|
335
|
+
.. runpython::
|
|
336
|
+
:showcode:
|
|
337
|
+
:toggle:
|
|
338
|
+
|
|
339
|
+
import os
|
|
340
|
+
import pprint
|
|
341
|
+
from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__
|
|
342
|
+
|
|
343
|
+
pprint.pprint(statistics_on_folder(os.path.dirname(__file__), aggregation=1))
|
|
344
|
+
"""
|
|
345
|
+
if isinstance(folder, list):
|
|
346
|
+
rows = []
|
|
347
|
+
for fold in folder:
|
|
348
|
+
last = fold.replace("\\", "/").split("/")[-1]
|
|
349
|
+
r = statistics_on_folder(
|
|
350
|
+
fold, pattern=pattern, aggregation=max(aggregation - 1, 0)
|
|
351
|
+
)
|
|
352
|
+
if aggregation == 0:
|
|
353
|
+
rows.extend(r)
|
|
354
|
+
continue
|
|
355
|
+
for line in r:
|
|
356
|
+
line["dir"] = os.path.join(last, line["dir"])
|
|
357
|
+
rows.extend(r)
|
|
358
|
+
return rows
|
|
359
|
+
|
|
360
|
+
rows = []
|
|
361
|
+
reg = re.compile(pattern)
|
|
362
|
+
for name in glob.glob("**/*", root_dir=folder, recursive=True):
|
|
363
|
+
if not reg.match(name):
|
|
364
|
+
continue
|
|
365
|
+
if os.path.isdir(os.path.join(folder, name)):
|
|
366
|
+
continue
|
|
367
|
+
n = name.replace("\\", "/")
|
|
368
|
+
spl = n.split("/")
|
|
369
|
+
level = len(spl)
|
|
370
|
+
stat = statistics_on_file(os.path.join(folder, name))
|
|
371
|
+
stat["name"] = name
|
|
372
|
+
if aggregation <= 0:
|
|
373
|
+
rows.append(stat)
|
|
374
|
+
continue
|
|
375
|
+
spl = os.path.dirname(name).replace("\\", "/").split("/")
|
|
376
|
+
level = "/".join(spl[:aggregation])
|
|
377
|
+
stat["dir"] = level
|
|
378
|
+
rows.append(stat)
|
|
379
|
+
return rows
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def get_figure(ax):
|
|
383
|
+
"""Returns the figure of a matplotlib figure."""
|
|
384
|
+
if hasattr(ax, "get_figure"):
|
|
385
|
+
return ax.get_figure()
|
|
386
|
+
if len(ax.shape) == 0:
|
|
387
|
+
return ax.get_figure()
|
|
388
|
+
if len(ax.shape) == 1:
|
|
389
|
+
return ax[0].get_figure()
|
|
390
|
+
if len(ax.shape) == 2:
|
|
391
|
+
return ax[0, 0].get_figure()
|
|
392
|
+
raise RuntimeError(f"Unexpected shape {ax.shape} for axis.")
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def has_cuda() -> bool:
|
|
396
|
+
"""Returns ``torch.cuda.device_count() > 0``."""
|
|
397
|
+
import torch
|
|
398
|
+
|
|
399
|
+
return torch.cuda.device_count() > 0
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def requires_python(version: Tuple[int, ...], msg: str = ""):
|
|
403
|
+
"""
|
|
404
|
+
Skips a test if python is too old.
|
|
405
|
+
|
|
406
|
+
:param msg: to overwrite the message
|
|
407
|
+
:param version: minimum version
|
|
408
|
+
"""
|
|
409
|
+
if sys.version_info[: len(version)] < version:
|
|
410
|
+
return unittest.skip(msg or f"python not recent enough {sys.version_info} < {version}")
|
|
411
|
+
return lambda x: x
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def requires_cuda(msg: str = "", version: str = "", memory: int = 0):
|
|
415
|
+
"""
|
|
416
|
+
Skips a test if cuda is not available.
|
|
417
|
+
|
|
418
|
+
:param msg: to overwrite the message
|
|
419
|
+
:param version: minimum version
|
|
420
|
+
:param memory: minimum number of Gb to run the test
|
|
421
|
+
"""
|
|
422
|
+
import torch
|
|
423
|
+
|
|
424
|
+
if torch.cuda.device_count() == 0:
|
|
425
|
+
msg = msg or "only runs on CUDA but torch does not have it"
|
|
426
|
+
return unittest.skip(msg or "cuda not installed")
|
|
427
|
+
if version:
|
|
428
|
+
import packaging.versions as pv
|
|
429
|
+
|
|
430
|
+
if pv.Version(torch.version.cuda) < pv.Version(version):
|
|
431
|
+
msg = msg or f"CUDA older than {version}"
|
|
432
|
+
return unittest.skip(msg or f"cuda not recent enough {torch.version.cuda} < {version}")
|
|
433
|
+
|
|
434
|
+
if memory:
|
|
435
|
+
m = torch.cuda.get_device_properties(0).total_memory / 2**30
|
|
436
|
+
if m < memory:
|
|
437
|
+
msg = msg or f"available memory is not enough {m} < {memory} (Gb)"
|
|
438
|
+
return unittest.skip(msg)
|
|
439
|
+
|
|
440
|
+
return lambda x: x
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def requires_zoo(msg: str = "") -> Callable:
|
|
444
|
+
"""Skips a unit test if environment variable ZOO is not equal to 1."""
|
|
445
|
+
var = os.environ.get("ZOO", "0") in BOOLEAN_VALUES
|
|
446
|
+
|
|
447
|
+
if not var:
|
|
448
|
+
msg = f"ZOO not set up or != 1. {msg}"
|
|
449
|
+
return unittest.skip(msg or "zoo not installed")
|
|
450
|
+
return lambda x: x
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def requires_sklearn(version: str, msg: str = "") -> Callable:
|
|
454
|
+
"""Skips a unit test if :epkg:`scikit-learn` is not recent enough."""
|
|
455
|
+
import packaging.version as pv
|
|
456
|
+
import sklearn
|
|
457
|
+
|
|
458
|
+
if pv.Version(sklearn.__version__) < pv.Version(version):
|
|
459
|
+
msg = f"scikit-learn version {sklearn.__version__} < {version}: {msg}"
|
|
460
|
+
return unittest.skip(msg)
|
|
461
|
+
return lambda x: x
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def requires_experimental(version: str = "0.0.0", msg: str = "") -> Callable:
|
|
465
|
+
"""Skips a unit test if :epkg:`experimental-experiment` is not recent enough."""
|
|
466
|
+
import packaging.version as pv
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
import experimental_experiment
|
|
470
|
+
except ImportError:
|
|
471
|
+
msg = f"experimental-experiment not installed: {msg}"
|
|
472
|
+
return unittest.skip(msg)
|
|
473
|
+
|
|
474
|
+
if pv.Version(experimental_experiment.__version__) < pv.Version(version):
|
|
475
|
+
msg = (
|
|
476
|
+
f"experimental-experiment version "
|
|
477
|
+
f"{experimental_experiment.__version__} < {version}: {msg}"
|
|
478
|
+
)
|
|
479
|
+
return unittest.skip(msg)
|
|
480
|
+
return lambda x: x
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def has_torch(version: str) -> bool:
|
|
484
|
+
"Returns True if torch transformers is higher."
|
|
485
|
+
import packaging.version as pv
|
|
486
|
+
import torch
|
|
487
|
+
|
|
488
|
+
return pv.Version(torch.__version__) >= pv.Version(version)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def has_transformers(version: str) -> bool:
|
|
492
|
+
"Returns True if transformers version is higher."
|
|
493
|
+
import packaging.version as pv
|
|
494
|
+
import transformers
|
|
495
|
+
|
|
496
|
+
return pv.Version(transformers.__version__) >= pv.Version(version)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def requires_torch(version: str, msg: str = "") -> Callable:
|
|
500
|
+
"""Skips a unit test if :epkg:`pytorch` is not recent enough."""
|
|
501
|
+
import packaging.version as pv
|
|
502
|
+
import torch
|
|
503
|
+
|
|
504
|
+
if pv.Version(torch.__version__) < pv.Version(version):
|
|
505
|
+
msg = f"torch version {torch.__version__} < {version}: {msg}"
|
|
506
|
+
return unittest.skip(msg)
|
|
507
|
+
return lambda x: x
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def requires_numpy(version: str, msg: str = "") -> Callable:
|
|
511
|
+
"""Skips a unit test if :epkg:`numpy` is not recent enough."""
|
|
512
|
+
import packaging.version as pv
|
|
513
|
+
import numpy
|
|
514
|
+
|
|
515
|
+
if pv.Version(numpy.__version__) < pv.Version(version):
|
|
516
|
+
msg = f"numpy version {numpy.__version__} < {version}: {msg}"
|
|
517
|
+
return unittest.skip(msg)
|
|
518
|
+
return lambda x: x
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def requires_transformers(
|
|
522
|
+
version: str, msg: str = "", or_older_than: Optional[str] = None
|
|
523
|
+
) -> Callable:
|
|
524
|
+
"""Skips a unit test if :epkg:`transformers` is not recent enough."""
|
|
525
|
+
import packaging.version as pv
|
|
526
|
+
|
|
527
|
+
try:
|
|
528
|
+
import transformers
|
|
529
|
+
except ImportError:
|
|
530
|
+
msg = f"diffusers not installed {msg}"
|
|
531
|
+
return unittest.skip(msg)
|
|
532
|
+
|
|
533
|
+
v = pv.Version(transformers.__version__)
|
|
534
|
+
if v < pv.Version(version):
|
|
535
|
+
msg = f"transformers version {transformers.__version__} < {version}: {msg}"
|
|
536
|
+
return unittest.skip(msg)
|
|
537
|
+
if or_older_than and v > pv.Version(or_older_than):
|
|
538
|
+
msg = (
|
|
539
|
+
f"transformers version {or_older_than} < "
|
|
540
|
+
f"{transformers.__version__} < {version}: {msg}"
|
|
541
|
+
)
|
|
542
|
+
return unittest.skip(msg)
|
|
543
|
+
return lambda x: x
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def requires_diffusers(
|
|
547
|
+
version: str, msg: str = "", or_older_than: Optional[str] = None
|
|
548
|
+
) -> Callable:
|
|
549
|
+
"""Skips a unit test if :epkg:`transformers` is not recent enough."""
|
|
550
|
+
import packaging.version as pv
|
|
551
|
+
|
|
552
|
+
try:
|
|
553
|
+
import diffusers
|
|
554
|
+
except ImportError:
|
|
555
|
+
msg = f"diffusers not installed {msg}"
|
|
556
|
+
return unittest.skip(msg)
|
|
557
|
+
|
|
558
|
+
v = pv.Version(diffusers.__version__)
|
|
559
|
+
if v < pv.Version(version):
|
|
560
|
+
msg = f"diffusers version {diffusers.__version__} < {version} {msg}"
|
|
561
|
+
return unittest.skip(msg)
|
|
562
|
+
if or_older_than and v > pv.Version(or_older_than):
|
|
563
|
+
msg = (
|
|
564
|
+
f"diffusers version {or_older_than} < "
|
|
565
|
+
f"{diffusers.__version__} < {version} {msg}"
|
|
566
|
+
)
|
|
567
|
+
return unittest.skip(msg)
|
|
568
|
+
return lambda x: x
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def requires_onnxscript(version: str, msg: str = "") -> Callable:
|
|
572
|
+
"""Skips a unit test if :epkg:`onnxscript` is not recent enough."""
|
|
573
|
+
import packaging.version as pv
|
|
574
|
+
import onnxscript
|
|
575
|
+
|
|
576
|
+
if not hasattr(onnxscript, "__version__"):
|
|
577
|
+
# development version
|
|
578
|
+
return lambda x: x
|
|
579
|
+
|
|
580
|
+
if pv.Version(onnxscript.__version__) < pv.Version(version):
|
|
581
|
+
msg = f"onnxscript version {onnxscript.__version__} < {version}: {msg}"
|
|
582
|
+
return unittest.skip(msg)
|
|
583
|
+
return lambda x: x
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def has_onnxscript(version: str, msg: str = "") -> Callable:
|
|
587
|
+
"""Skips a unit test if :epkg:`onnxscript` is not recent enough."""
|
|
588
|
+
import packaging.version as pv
|
|
589
|
+
import onnxscript
|
|
590
|
+
|
|
591
|
+
if not hasattr(onnxscript, "__version__"):
|
|
592
|
+
# development version
|
|
593
|
+
return True
|
|
594
|
+
|
|
595
|
+
if pv.Version(onnxscript.__version__) < pv.Version(version):
|
|
596
|
+
msg = f"onnxscript version {onnxscript.__version__} < {version}: {msg}"
|
|
597
|
+
return False
|
|
598
|
+
return True
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def requires_onnxruntime(version: str, msg: str = "") -> Callable:
|
|
602
|
+
"""Skips a unit test if :epkg:`onnxruntime` is not recent enough."""
|
|
603
|
+
import packaging.version as pv
|
|
604
|
+
import onnxruntime
|
|
605
|
+
|
|
606
|
+
if pv.Version(onnxruntime.__version__) < pv.Version(version):
|
|
607
|
+
msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}"
|
|
608
|
+
return unittest.skip(msg)
|
|
609
|
+
return lambda x: x
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def has_onnxruntime_training(push_back_batch: bool = False):
|
|
613
|
+
"""Tells if onnxruntime_training is installed."""
|
|
614
|
+
try:
|
|
615
|
+
from onnxruntime import training
|
|
616
|
+
except ImportError:
|
|
617
|
+
# onnxruntime not training
|
|
618
|
+
training = None
|
|
619
|
+
if training is None:
|
|
620
|
+
return False
|
|
621
|
+
|
|
622
|
+
if push_back_batch:
|
|
623
|
+
try:
|
|
624
|
+
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
|
|
625
|
+
except ImportError:
|
|
626
|
+
return False
|
|
627
|
+
|
|
628
|
+
if not hasattr(OrtValueVector, "push_back_batch"):
|
|
629
|
+
return False
|
|
630
|
+
return True
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def has_onnxruntime_genai():
|
|
634
|
+
"""Tells if onnxruntime_genai is installed."""
|
|
635
|
+
try:
|
|
636
|
+
import onnxruntime_genai # noqa: F401
|
|
637
|
+
|
|
638
|
+
return True
|
|
639
|
+
except ImportError:
|
|
640
|
+
# onnxruntime not training
|
|
641
|
+
return False
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def requires_onnxruntime_training(
|
|
645
|
+
push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
|
|
646
|
+
) -> Callable:
|
|
647
|
+
"""Skips a unit test if :epkg:`onnxruntime` is not onnxruntime_training."""
|
|
648
|
+
try:
|
|
649
|
+
from onnxruntime import training
|
|
650
|
+
except ImportError:
|
|
651
|
+
# onnxruntime not training
|
|
652
|
+
training = None
|
|
653
|
+
if training is None:
|
|
654
|
+
msg = msg or "onnxruntime_training is not installed"
|
|
655
|
+
return unittest.skip(msg)
|
|
656
|
+
|
|
657
|
+
if push_back_batch:
|
|
658
|
+
try:
|
|
659
|
+
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
|
|
660
|
+
except ImportError:
|
|
661
|
+
msg = msg or "OrtValue has no method push_back_batch"
|
|
662
|
+
return unittest.skip(msg)
|
|
663
|
+
|
|
664
|
+
if not hasattr(OrtValueVector, "push_back_batch"):
|
|
665
|
+
msg = msg or "OrtValue has no method push_back_batch"
|
|
666
|
+
return unittest.skip(msg)
|
|
667
|
+
if ortmodule:
|
|
668
|
+
try:
|
|
669
|
+
import onnxruntime.training.ortmodule # noqa: F401
|
|
670
|
+
except (AttributeError, ImportError):
|
|
671
|
+
msg = msg or "ortmodule is missing in onnxruntime-training"
|
|
672
|
+
return unittest.skip(msg)
|
|
673
|
+
return lambda x: x
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def requires_onnx(version: str, msg: str = "") -> Callable:
|
|
677
|
+
"""Skips a unit test if :epkg:`onnx` is not recent enough."""
|
|
678
|
+
import packaging.version as pv
|
|
679
|
+
import onnx
|
|
680
|
+
|
|
681
|
+
if pv.Version(onnx.__version__) < pv.Version(version):
|
|
682
|
+
msg = f"onnx version {onnx.__version__} < {version}: {msg}"
|
|
683
|
+
return unittest.skip(msg)
|
|
684
|
+
return lambda x: x
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
def requires_onnx_array_api(version: str, msg: str = "") -> Callable:
|
|
688
|
+
"""Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
|
|
689
|
+
import packaging.version as pv
|
|
690
|
+
import onnx_array_api
|
|
691
|
+
|
|
692
|
+
if pv.Version(onnx_array_api.__version__) < pv.Version(version):
|
|
693
|
+
msg = f"onnx-array-api version {onnx_array_api.__version__} < {version}: {msg}"
|
|
694
|
+
return unittest.skip(msg)
|
|
695
|
+
return lambda x: x
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def statistics_on_file(filename: str) -> Dict[str, Union[int, float, str]]:
|
|
699
|
+
"""
|
|
700
|
+
Computes statistics on a file.
|
|
701
|
+
|
|
702
|
+
.. runpython::
|
|
703
|
+
:showcode:
|
|
704
|
+
|
|
705
|
+
import pprint
|
|
706
|
+
from onnx_diagnostic.ext_test_case import statistics_on_file, __file__
|
|
707
|
+
|
|
708
|
+
pprint.pprint(statistics_on_file(__file__))
|
|
709
|
+
"""
|
|
710
|
+
assert os.path.exists(filename), f"File {filename!r} does not exists."
|
|
711
|
+
|
|
712
|
+
ext = os.path.splitext(filename)[-1]
|
|
713
|
+
if ext not in {".py", ".rst", ".md", ".txt"}:
|
|
714
|
+
size = os.stat(filename).st_size
|
|
715
|
+
return {"size": size}
|
|
716
|
+
alpha = set("abcdefghijklmnopqrstuvwxyz0123456789")
|
|
717
|
+
with open(filename, "r", encoding="utf-8") as f:
|
|
718
|
+
n_line = 0
|
|
719
|
+
n_ch = 0
|
|
720
|
+
for line in f.readlines():
|
|
721
|
+
s = line.strip("\n\r\t ")
|
|
722
|
+
if s:
|
|
723
|
+
n_ch += len(s.replace(" ", ""))
|
|
724
|
+
ch = set(s.lower()) & alpha
|
|
725
|
+
if ch:
|
|
726
|
+
# It avoid counting line with only a bracket, a comma.
|
|
727
|
+
n_line += 1
|
|
728
|
+
|
|
729
|
+
stat = dict(lines=n_line, chars=n_ch, ext=ext)
|
|
730
|
+
if ext != ".py":
|
|
731
|
+
return stat
|
|
732
|
+
# add statistics on python syntax?
|
|
733
|
+
return stat
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
class ExtTestCase(unittest.TestCase):
|
|
737
|
+
"""
|
|
738
|
+
Inherits from :class:`unittest.TestCase` and adds specific comprison
|
|
739
|
+
functions and other helper.
|
|
740
|
+
"""
|
|
741
|
+
|
|
742
|
+
_warns: List[Tuple[str, int, Warning]] = []
|
|
743
|
+
_todos: List[Tuple[Callable, str]] = []
|
|
744
|
+
|
|
745
|
+
@property
|
|
746
|
+
def verbose(self):
|
|
747
|
+
"Returns the the value of environment variable ``VERBOSE``."
|
|
748
|
+
return int(os.environ.get("VERBOSE", "0"))
|
|
749
|
+
|
|
750
|
+
@classmethod
|
|
751
|
+
def setUpClass(cls):
|
|
752
|
+
logger = logging.getLogger("onnxscript.optimizer.constant_folding")
|
|
753
|
+
logger.setLevel(logging.ERROR)
|
|
754
|
+
unittest.TestCase.setUpClass()
|
|
755
|
+
|
|
756
|
+
@classmethod
|
|
757
|
+
def tearDownClass(cls):
|
|
758
|
+
for name, line, w in cls._warns:
|
|
759
|
+
warnings.warn(f"\n{name}:{line}: {type(w)}\n {w!s}", stacklevel=2)
|
|
760
|
+
if not cls._todos:
|
|
761
|
+
return
|
|
762
|
+
for f, msg in cls._todos:
|
|
763
|
+
sys.stderr.write(f"TODO {cls.__name__}::{f.__name__}: {msg}\n")
|
|
764
|
+
|
|
765
|
+
@classmethod
|
|
766
|
+
def todo(cls, f: Callable, msg: str):
|
|
767
|
+
"Adds a todo printed when all test are run."
|
|
768
|
+
cls._todos.append((f, msg))
|
|
769
|
+
|
|
770
|
+
@classmethod
|
|
771
|
+
def ort(cls):
|
|
772
|
+
import onnxruntime
|
|
773
|
+
|
|
774
|
+
return onnxruntime
|
|
775
|
+
|
|
776
|
+
@classmethod
|
|
777
|
+
def to_onnx(self, *args, **kwargs):
|
|
778
|
+
from experimental_experiment.torch_interpreter import to_onnx
|
|
779
|
+
|
|
780
|
+
return to_onnx(*args, **kwargs)
|
|
781
|
+
|
|
782
|
+
def print_model(self, model: "ModelProto"): # noqa: F821
|
|
783
|
+
"Prints a ModelProto"
|
|
784
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
785
|
+
|
|
786
|
+
print(pretty_onnx(model))
|
|
787
|
+
|
|
788
|
+
def print_onnx(self, model: "ModelProto"): # noqa: F821
|
|
789
|
+
"Prints a ModelProto"
|
|
790
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
791
|
+
|
|
792
|
+
print(pretty_onnx(model))
|
|
793
|
+
|
|
794
|
+
def get_dump_file(self, name: str, folder: Optional[str] = None) -> str:
|
|
795
|
+
"""Returns a filename to dump a model."""
|
|
796
|
+
if folder is None:
|
|
797
|
+
folder = "dump_test"
|
|
798
|
+
if folder and not os.path.exists(folder):
|
|
799
|
+
os.mkdir(folder)
|
|
800
|
+
return os.path.join(folder, name)
|
|
801
|
+
|
|
802
|
+
def get_dump_folder(self, folder: str) -> str:
|
|
803
|
+
"""Returns a folder."""
|
|
804
|
+
folder = os.path.join("dump_test", folder)
|
|
805
|
+
if not os.path.exists(folder):
|
|
806
|
+
os.makedirs(folder)
|
|
807
|
+
return folder
|
|
808
|
+
|
|
809
|
+
def dump_onnx(
|
|
810
|
+
self,
|
|
811
|
+
name: str,
|
|
812
|
+
proto: Any,
|
|
813
|
+
folder: Optional[str] = None,
|
|
814
|
+
) -> str:
|
|
815
|
+
"""Dumps an onnx file."""
|
|
816
|
+
fullname = self.get_dump_file(name, folder=folder)
|
|
817
|
+
with open(fullname, "wb") as f:
|
|
818
|
+
f.write(proto.SerializeToString())
|
|
819
|
+
return fullname
|
|
820
|
+
|
|
821
|
+
def assertExists(self, name):
|
|
822
|
+
"""Checks the existing of a file."""
|
|
823
|
+
if not os.path.exists(name):
|
|
824
|
+
raise AssertionError(f"File or folder {name!r} does not exists.")
|
|
825
|
+
|
|
826
|
+
def assertGreaterOrEqual(self, a, b, msg=None):
|
|
827
|
+
"""In the name"""
|
|
828
|
+
if a < b:
|
|
829
|
+
return AssertionError(f"{a} < {b}, a not greater or equal than b\n{msg or ''}")
|
|
830
|
+
|
|
831
|
+
def assertInOr(self, tofind: Tuple[str, ...], text: str, msg: str = ""):
|
|
832
|
+
for tof in tofind:
|
|
833
|
+
if tof in text:
|
|
834
|
+
return
|
|
835
|
+
raise AssertionError(
|
|
836
|
+
msg or f"Unable to find one string in the list {tofind!r} in\n--\n{text}"
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
def assertIn(self, tofind: str, text: str, msg: str = ""):
|
|
840
|
+
if tofind in text:
|
|
841
|
+
return
|
|
842
|
+
raise AssertionError(
|
|
843
|
+
msg or f"Unable to find the list of strings {tofind!r} in\n--\n{text}"
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
def assertHasAttr(self, obj: Any, name: str):
|
|
847
|
+
assert hasattr(
|
|
848
|
+
obj, name
|
|
849
|
+
), f"Unable to find attribute {name!r} in object type {type(obj)}"
|
|
850
|
+
|
|
851
|
+
def assertSetContained(self, set1, set2):
|
|
852
|
+
"Checks that ``set1`` is contained in ``set2``."
|
|
853
|
+
set1 = set(set1)
|
|
854
|
+
set2 = set(set2)
|
|
855
|
+
if set1 & set2 != set1:
|
|
856
|
+
raise AssertionError(f"Set {set2} does not contain set {set1}.")
|
|
857
|
+
|
|
858
|
+
def assertEqualArrays(
|
|
859
|
+
self,
|
|
860
|
+
expected: Sequence[numpy.ndarray],
|
|
861
|
+
value: Sequence[numpy.ndarray],
|
|
862
|
+
atol: float = 0,
|
|
863
|
+
rtol: float = 0,
|
|
864
|
+
msg: Optional[str] = None,
|
|
865
|
+
):
|
|
866
|
+
"""In the name"""
|
|
867
|
+
self.assertEqual(len(expected), len(value))
|
|
868
|
+
for a, b in zip(expected, value):
|
|
869
|
+
self.assertEqualArray(a, b, atol=atol, rtol=rtol)
|
|
870
|
+
|
|
871
|
+
def assertEqualArray(
|
|
872
|
+
self,
|
|
873
|
+
expected: Any,
|
|
874
|
+
value: Any,
|
|
875
|
+
atol: float = 0,
|
|
876
|
+
rtol: float = 0,
|
|
877
|
+
msg: Optional[str] = None,
|
|
878
|
+
):
|
|
879
|
+
"""In the name"""
|
|
880
|
+
if hasattr(expected, "detach") and hasattr(value, "detach"):
|
|
881
|
+
if msg:
|
|
882
|
+
try:
|
|
883
|
+
self.assertEqual(expected.dtype, value.dtype)
|
|
884
|
+
except AssertionError as e:
|
|
885
|
+
raise AssertionError(msg) from e
|
|
886
|
+
try:
|
|
887
|
+
self.assertEqual(expected.shape, value.shape)
|
|
888
|
+
except AssertionError as e:
|
|
889
|
+
raise AssertionError(msg) from e
|
|
890
|
+
else:
|
|
891
|
+
self.assertEqual(expected.dtype, value.dtype)
|
|
892
|
+
self.assertEqual(expected.shape, value.shape)
|
|
893
|
+
|
|
894
|
+
import torch
|
|
895
|
+
|
|
896
|
+
try:
|
|
897
|
+
torch.testing.assert_close(value, expected, atol=atol, rtol=rtol)
|
|
898
|
+
except AssertionError as e:
|
|
899
|
+
expected_max = torch.abs(expected).max()
|
|
900
|
+
expected_value = torch.abs(value).max()
|
|
901
|
+
rows = [
|
|
902
|
+
f"{msg}\n{e}" if msg else str(e),
|
|
903
|
+
f"expected max value={expected_max}",
|
|
904
|
+
f"expected computed value={expected_value}",
|
|
905
|
+
]
|
|
906
|
+
raise AssertionError("\n".join(rows)) # noqa: B904
|
|
907
|
+
return
|
|
908
|
+
|
|
909
|
+
from .helpers.torch_helper import to_numpy
|
|
910
|
+
|
|
911
|
+
if hasattr(expected, "detach"):
|
|
912
|
+
expected = to_numpy(expected.detach().cpu())
|
|
913
|
+
if hasattr(value, "detach"):
|
|
914
|
+
value = to_numpy(value.detach().cpu())
|
|
915
|
+
if msg:
|
|
916
|
+
try:
|
|
917
|
+
self.assertEqual(expected.dtype, value.dtype)
|
|
918
|
+
except AssertionError as e:
|
|
919
|
+
raise AssertionError(msg) from e
|
|
920
|
+
try:
|
|
921
|
+
self.assertEqual(expected.shape, value.shape)
|
|
922
|
+
except AssertionError as e:
|
|
923
|
+
raise AssertionError(msg) from e
|
|
924
|
+
else:
|
|
925
|
+
self.assertEqual(expected.dtype, value.dtype)
|
|
926
|
+
self.assertEqual(expected.shape, value.shape)
|
|
927
|
+
|
|
928
|
+
try:
|
|
929
|
+
assert_allclose(desired=expected, actual=value, atol=atol, rtol=rtol)
|
|
930
|
+
except AssertionError as e:
|
|
931
|
+
expected_max = numpy.abs(expected).max()
|
|
932
|
+
expected_value = numpy.abs(value).max()
|
|
933
|
+
te = expected.astype(int) if expected.dtype == numpy.bool_ else expected
|
|
934
|
+
tv = value.astype(int) if value.dtype == numpy.bool_ else value
|
|
935
|
+
rows = [
|
|
936
|
+
f"{msg}\n{e}" if msg else str(e),
|
|
937
|
+
f"expected max value={expected_max}",
|
|
938
|
+
f"expected computed value={expected_value}\n",
|
|
939
|
+
f"ratio={te / tv}\ndiff={te - tv}",
|
|
940
|
+
]
|
|
941
|
+
raise AssertionError("\n".join(rows)) # noqa: B904
|
|
942
|
+
|
|
943
|
+
def assertEqualDataFrame(self, d1, d2, **kwargs):
|
|
944
|
+
"""
|
|
945
|
+
Checks that two dataframes are equal.
|
|
946
|
+
Calls :func:`pandas.testing.assert_frame_equal`.
|
|
947
|
+
"""
|
|
948
|
+
from pandas.testing import assert_frame_equal
|
|
949
|
+
|
|
950
|
+
assert_frame_equal(d1, d2, **kwargs)
|
|
951
|
+
|
|
952
|
+
def assertEqualTrue(self, value: Any, msg: str = ""):
|
|
953
|
+
if value is True:
|
|
954
|
+
return
|
|
955
|
+
raise AssertionError(msg or f"value is not True: {value!r}")
|
|
956
|
+
|
|
957
|
+
def assertEqual(self, expected: Any, value: Any, msg: str = ""):
|
|
958
|
+
"""Overwrites the error message to get a more explicit message about what is what."""
|
|
959
|
+
if msg:
|
|
960
|
+
super().assertEqual(expected, value, msg)
|
|
961
|
+
else:
|
|
962
|
+
try:
|
|
963
|
+
super().assertEqual(expected, value)
|
|
964
|
+
except AssertionError as e:
|
|
965
|
+
raise AssertionError( # noqa: B904
|
|
966
|
+
f"expected is {expected!r}, value is {value!r}\n{e}"
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
def assertEqualAny(
|
|
970
|
+
self, expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str = ""
|
|
971
|
+
):
|
|
972
|
+
if expected.__class__.__name__ == "BaseModelOutput":
|
|
973
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
974
|
+
self.assertEqual(len(expected), len(value), msg=msg)
|
|
975
|
+
self.assertEqual(list(expected), list(value), msg=msg) # checks the order
|
|
976
|
+
self.assertEqualAny(
|
|
977
|
+
{k: v for k, v in expected.items()}, # noqa: C416
|
|
978
|
+
{k: v for k, v in value.items()}, # noqa: C416
|
|
979
|
+
atol=atol,
|
|
980
|
+
rtol=rtol,
|
|
981
|
+
msg=msg,
|
|
982
|
+
)
|
|
983
|
+
elif isinstance(expected, (tuple, list, dict)):
|
|
984
|
+
self.assertIsInstance(value, type(expected), msg=msg)
|
|
985
|
+
self.assertEqual(len(expected), len(value), msg=msg)
|
|
986
|
+
if isinstance(expected, dict):
|
|
987
|
+
for k in expected:
|
|
988
|
+
self.assertIn(k, value, msg=msg)
|
|
989
|
+
self.assertEqualAny(expected[k], value[k], msg=msg, atol=atol, rtol=rtol)
|
|
990
|
+
else:
|
|
991
|
+
for e, g in zip(expected, value):
|
|
992
|
+
self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
|
|
993
|
+
elif expected.__class__.__name__ in (
|
|
994
|
+
"DynamicCache",
|
|
995
|
+
"SlidingWindowCache",
|
|
996
|
+
"HybridCache",
|
|
997
|
+
):
|
|
998
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
999
|
+
atts = ["key_cache", "value_cache"]
|
|
1000
|
+
self.assertEqualAny(
|
|
1001
|
+
{k: expected.__dict__.get(k, None) for k in atts},
|
|
1002
|
+
{k: value.__dict__.get(k, None) for k in atts},
|
|
1003
|
+
atol=atol,
|
|
1004
|
+
rtol=rtol,
|
|
1005
|
+
)
|
|
1006
|
+
elif expected.__class__.__name__ == "StaticCache":
|
|
1007
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1008
|
+
self.assertEqual(expected.max_cache_len, value.max_cache_len)
|
|
1009
|
+
atts = ["key_cache", "value_cache"]
|
|
1010
|
+
self.assertEqualAny(
|
|
1011
|
+
{k: expected.__dict__.get(k, None) for k in atts},
|
|
1012
|
+
{k: value.__dict__.get(k, None) for k in atts},
|
|
1013
|
+
atol=atol,
|
|
1014
|
+
rtol=rtol,
|
|
1015
|
+
)
|
|
1016
|
+
elif expected.__class__.__name__ == "EncoderDecoderCache":
|
|
1017
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1018
|
+
atts = ["self_attention_cache", "cross_attention_cache"]
|
|
1019
|
+
self.assertEqualAny(
|
|
1020
|
+
{k: expected.__dict__.get(k, None) for k in atts},
|
|
1021
|
+
{k: value.__dict__.get(k, None) for k in atts},
|
|
1022
|
+
atol=atol,
|
|
1023
|
+
rtol=rtol,
|
|
1024
|
+
)
|
|
1025
|
+
elif isinstance(expected, (int, float, str)):
|
|
1026
|
+
self.assertEqual(expected, value, msg=msg)
|
|
1027
|
+
elif hasattr(expected, "shape"):
|
|
1028
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1029
|
+
self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
|
|
1030
|
+
elif expected.__class__.__name__ in ("Dim", "_Dim", "_DimHintType"):
|
|
1031
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1032
|
+
self.assertEqual(expected.__name__, value.__name__, msg=msg)
|
|
1033
|
+
elif expected is None:
|
|
1034
|
+
self.assertEqual(expected, value, msg=msg)
|
|
1035
|
+
else:
|
|
1036
|
+
raise AssertionError(
|
|
1037
|
+
f"Comparison not implemented for types {type(expected)} and {type(value)}"
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
def assertEqualArrayAny(
|
|
1041
|
+
self, expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str = ""
|
|
1042
|
+
):
|
|
1043
|
+
if isinstance(expected, (tuple, list, dict)):
|
|
1044
|
+
self.assertIsInstance(value, type(expected), msg=msg)
|
|
1045
|
+
self.assertEqual(len(expected), len(value), msg=msg)
|
|
1046
|
+
if isinstance(expected, dict):
|
|
1047
|
+
for k in expected:
|
|
1048
|
+
self.assertIn(k, value, msg=msg)
|
|
1049
|
+
self.assertEqualArrayAny(
|
|
1050
|
+
expected[k], value[k], msg=msg, atol=atol, rtol=rtol
|
|
1051
|
+
)
|
|
1052
|
+
else:
|
|
1053
|
+
excs = []
|
|
1054
|
+
for i, (e, g) in enumerate(zip(expected, value)):
|
|
1055
|
+
try:
|
|
1056
|
+
self.assertEqualArrayAny(e, g, msg=msg, atol=atol, rtol=rtol)
|
|
1057
|
+
except AssertionError as e:
|
|
1058
|
+
excs.append(f"Error at position {i} due to {e}")
|
|
1059
|
+
if excs:
|
|
1060
|
+
msg_ = "\n".join(excs)
|
|
1061
|
+
msg = f"{msg}\n{msg_}" if msg else msg_
|
|
1062
|
+
raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
|
|
1063
|
+
elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
|
|
1064
|
+
atts = {"key_cache", "value_cache"}
|
|
1065
|
+
self.assertEqualArrayAny(
|
|
1066
|
+
{k: expected.__dict__.get(k, None) for k in atts},
|
|
1067
|
+
{k: value.__dict__.get(k, None) for k in atts},
|
|
1068
|
+
atol=atol,
|
|
1069
|
+
rtol=rtol,
|
|
1070
|
+
)
|
|
1071
|
+
elif isinstance(expected, (int, float, str)):
|
|
1072
|
+
self.assertEqual(expected, value, msg=msg)
|
|
1073
|
+
elif hasattr(expected, "shape"):
|
|
1074
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1075
|
+
self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
|
|
1076
|
+
elif expected is None:
|
|
1077
|
+
assert value is None, f"Expected is None but value is of type {type(value)}"
|
|
1078
|
+
else:
|
|
1079
|
+
raise AssertionError(
|
|
1080
|
+
f"Comparison not implemented for types {type(expected)} and {type(value)}"
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
def assertAlmostEqual(
|
|
1084
|
+
self,
|
|
1085
|
+
expected: numpy.ndarray,
|
|
1086
|
+
value: numpy.ndarray,
|
|
1087
|
+
atol: float = 0,
|
|
1088
|
+
rtol: float = 0,
|
|
1089
|
+
):
|
|
1090
|
+
"""In the name"""
|
|
1091
|
+
if not isinstance(expected, numpy.ndarray):
|
|
1092
|
+
expected = numpy.array(expected)
|
|
1093
|
+
if not isinstance(value, numpy.ndarray):
|
|
1094
|
+
value = numpy.array(value).astype(expected.dtype)
|
|
1095
|
+
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
|
|
1096
|
+
|
|
1097
|
+
def check_ort(self, onx: "onnx.ModelProto") -> bool: # noqa: F821
|
|
1098
|
+
from onnxruntime import InferenceSession
|
|
1099
|
+
|
|
1100
|
+
return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
|
|
1101
|
+
|
|
1102
|
+
def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
|
|
1103
|
+
"""In the name"""
|
|
1104
|
+
try:
|
|
1105
|
+
fct()
|
|
1106
|
+
except exc_type as e:
|
|
1107
|
+
if not isinstance(e, exc_type):
|
|
1108
|
+
raise AssertionError(f"Unexpected exception {type(e)!r}.") # noqa: B904
|
|
1109
|
+
if msg is not None and msg not in str(e):
|
|
1110
|
+
raise AssertionError(f"Unexpected exception message {e!r}.") # noqa: B904
|
|
1111
|
+
return
|
|
1112
|
+
raise AssertionError("No exception was raised.") # noqa: B904
|
|
1113
|
+
|
|
1114
|
+
def assertEmpty(self, value: Any):
|
|
1115
|
+
"""In the name"""
|
|
1116
|
+
if value is None:
|
|
1117
|
+
return
|
|
1118
|
+
if not value:
|
|
1119
|
+
return
|
|
1120
|
+
raise AssertionError(f"value is not empty: {value!r}.")
|
|
1121
|
+
|
|
1122
|
+
def assertNotEmpty(self, value: Any):
|
|
1123
|
+
"""In the name"""
|
|
1124
|
+
if value is None:
|
|
1125
|
+
raise AssertionError(f"value is empty: {value!r}.")
|
|
1126
|
+
if isinstance(value, (list, dict, tuple, set)):
|
|
1127
|
+
if not value:
|
|
1128
|
+
raise AssertionError(f"value is empty: {value!r}.")
|
|
1129
|
+
|
|
1130
|
+
def assertStartsWith(self, prefix: str, full: str):
|
|
1131
|
+
"""In the name"""
|
|
1132
|
+
if not full.startswith(prefix):
|
|
1133
|
+
raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
|
|
1134
|
+
|
|
1135
|
+
def assertEndsWith(self, suffix: str, full: str):
|
|
1136
|
+
"""In the name"""
|
|
1137
|
+
if not full.endswith(suffix):
|
|
1138
|
+
raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
|
|
1139
|
+
|
|
1140
|
+
def capture(self, fct: Callable):
|
|
1141
|
+
"""
|
|
1142
|
+
Runs a function and capture standard output and error.
|
|
1143
|
+
|
|
1144
|
+
:param fct: function to run
|
|
1145
|
+
:return: result of *fct*, output, error
|
|
1146
|
+
"""
|
|
1147
|
+
sout = StringIO()
|
|
1148
|
+
serr = StringIO()
|
|
1149
|
+
with redirect_stdout(sout), redirect_stderr(serr):
|
|
1150
|
+
try:
|
|
1151
|
+
res = fct()
|
|
1152
|
+
except Exception as e:
|
|
1153
|
+
raise AssertionError(
|
|
1154
|
+
f"function {fct} failed, stdout="
|
|
1155
|
+
f"\n{sout.getvalue()}\n---\nstderr=\n{serr.getvalue()}"
|
|
1156
|
+
) from e
|
|
1157
|
+
return res, sout.getvalue(), serr.getvalue()
|
|
1158
|
+
|
|
1159
|
+
def tryCall(
|
|
1160
|
+
self, fct: Callable, msg: Optional[str] = None, none_if: Optional[str] = None
|
|
1161
|
+
) -> Optional[Any]:
|
|
1162
|
+
"""
|
|
1163
|
+
Calls the function, catch any error.
|
|
1164
|
+
|
|
1165
|
+
:param fct: function to call
|
|
1166
|
+
:param msg: error message to display if failing
|
|
1167
|
+
:param none_if: returns None if this substring is found in the error message
|
|
1168
|
+
:return: output of *fct*
|
|
1169
|
+
"""
|
|
1170
|
+
try:
|
|
1171
|
+
return fct()
|
|
1172
|
+
except Exception as e:
|
|
1173
|
+
if none_if is not None and none_if in str(e):
|
|
1174
|
+
return None
|
|
1175
|
+
if msg is None:
|
|
1176
|
+
raise
|
|
1177
|
+
raise AssertionError(msg) from e
|
|
1178
|
+
|
|
1179
|
+
def assert_onnx_disc(
|
|
1180
|
+
self,
|
|
1181
|
+
test_name: str,
|
|
1182
|
+
proto: "onnx.ModelProto", # noqa: F821
|
|
1183
|
+
model: "torch.nn.Module", # noqa: F821
|
|
1184
|
+
inputs: Union[Tuple[Any], Dict[str, Any]],
|
|
1185
|
+
verbose: int = 0,
|
|
1186
|
+
atol: float = 1e-5,
|
|
1187
|
+
rtol: float = 1e-3,
|
|
1188
|
+
copy_inputs: bool = True,
|
|
1189
|
+
expected: Optional[Any] = None,
|
|
1190
|
+
use_ort: bool = False,
|
|
1191
|
+
**kwargs,
|
|
1192
|
+
):
|
|
1193
|
+
"""
|
|
1194
|
+
Checks for discrepancies.
|
|
1195
|
+
Runs the onnx models, computes expected outputs, in that order.
|
|
1196
|
+
The inputs may be modified by this functions if the torch model
|
|
1197
|
+
modifies them inplace.
|
|
1198
|
+
|
|
1199
|
+
:param test_name: test name, dumps the model if not empty
|
|
1200
|
+
:param proto: onnx model
|
|
1201
|
+
:param model: torch model
|
|
1202
|
+
:param inputs: inputs
|
|
1203
|
+
:param verbose: verbosity
|
|
1204
|
+
:param atol: absolute tolerance
|
|
1205
|
+
:param rtol: relative tolerance
|
|
1206
|
+
:param expected: expected values
|
|
1207
|
+
:param copy_inputs: to copy the inputs
|
|
1208
|
+
:param use_ort: use :class:`onnxruntime.InferenceSession`
|
|
1209
|
+
:param kwargs: arguments sent to
|
|
1210
|
+
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
|
|
1211
|
+
"""
|
|
1212
|
+
from .helpers import string_type, string_diff, max_diff
|
|
1213
|
+
from .helpers.rt_helper import make_feeds
|
|
1214
|
+
from .helpers.ort_session import InferenceSessionForTorch
|
|
1215
|
+
|
|
1216
|
+
kws = dict(with_shape=True, with_min_max=verbose > 1)
|
|
1217
|
+
if verbose:
|
|
1218
|
+
vname = test_name or "assert_onnx_disc"
|
|
1219
|
+
if test_name:
|
|
1220
|
+
name = f"{test_name}.onnx"
|
|
1221
|
+
print(f"[{vname}] save the onnx model into {name!r}")
|
|
1222
|
+
name = self.dump_onnx(name, proto)
|
|
1223
|
+
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
|
|
1224
|
+
if verbose:
|
|
1225
|
+
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
|
|
1226
|
+
if use_ort:
|
|
1227
|
+
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
|
|
1228
|
+
if verbose:
|
|
1229
|
+
print(f"[{vname}] feeds {string_type(feeds, **kws)}")
|
|
1230
|
+
import onnxruntime
|
|
1231
|
+
|
|
1232
|
+
sess = onnxruntime.InferenceSession(
|
|
1233
|
+
proto.SerializeToString(), providers=["CPUExecutionProvider"]
|
|
1234
|
+
)
|
|
1235
|
+
got = sess.run(None, feeds)
|
|
1236
|
+
else:
|
|
1237
|
+
feeds = make_feeds(proto, inputs, copy=True)
|
|
1238
|
+
if verbose:
|
|
1239
|
+
print(f"[{vname}] feeds {string_type(feeds, **kws)}")
|
|
1240
|
+
sess = InferenceSessionForTorch(proto, **kwargs)
|
|
1241
|
+
got = sess.run(None, feeds)
|
|
1242
|
+
if verbose:
|
|
1243
|
+
print(f"[{vname}] compute expected values")
|
|
1244
|
+
if expected is None:
|
|
1245
|
+
if copy_inputs:
|
|
1246
|
+
expected = (
|
|
1247
|
+
model(*copy.deepcopy(inputs))
|
|
1248
|
+
if isinstance(inputs, tuple)
|
|
1249
|
+
else model(**copy.deepcopy(inputs))
|
|
1250
|
+
)
|
|
1251
|
+
else:
|
|
1252
|
+
expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
|
|
1253
|
+
if verbose:
|
|
1254
|
+
print(f"[{vname}] expected {string_type(expected, **kws)}")
|
|
1255
|
+
print(f"[{vname}] obtained {string_type(got, **kws)}")
|
|
1256
|
+
diff = max_diff(expected, got, flatten=True)
|
|
1257
|
+
if verbose:
|
|
1258
|
+
print(f"[{vname}] diff {string_diff(diff)}")
|
|
1259
|
+
assert (
|
|
1260
|
+
isinstance(diff["abs"], float)
|
|
1261
|
+
and isinstance(diff["rel"], float)
|
|
1262
|
+
and not numpy.isnan(diff["abs"])
|
|
1263
|
+
and diff["abs"] <= atol
|
|
1264
|
+
and not numpy.isnan(diff["rel"])
|
|
1265
|
+
and diff["rel"] <= rtol
|
|
1266
|
+
), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
|
|
1267
|
+
|
|
1268
|
+
def _debug(self):
|
|
1269
|
+
"Tells if DEBUG=1 is set up."
|
|
1270
|
+
return os.environ.get("DEBUG") in BOOLEAN_VALUES
|
|
1271
|
+
|
|
1272
|
+
def string_type(self, *args, **kwargs):
|
|
1273
|
+
from .helpers import string_type
|
|
1274
|
+
|
|
1275
|
+
return string_type(*args, **kwargs)
|
|
1276
|
+
|
|
1277
|
+
def subloop(self, *args, verbose: int = 0):
|
|
1278
|
+
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
|
|
1279
|
+
if len(args) == 1:
|
|
1280
|
+
for it in args[0]:
|
|
1281
|
+
with self.subTest(case=it):
|
|
1282
|
+
if verbose:
|
|
1283
|
+
print(f"[subloop] it={it!r}")
|
|
1284
|
+
yield it
|
|
1285
|
+
else:
|
|
1286
|
+
for it in itertools.product(*args):
|
|
1287
|
+
with self.subTest(case=it):
|
|
1288
|
+
if verbose:
|
|
1289
|
+
print(f"[subloop] it={it!r}")
|
|
1290
|
+
yield it
|