onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1290 @@
1
+ """
2
+ The module contains the main class ``ExtTestCase`` which adds
3
+ specific functionalities to this project.
4
+ """
5
+
6
+ import copy
7
+ import glob
8
+ import itertools
9
+ import logging
10
+ import os
11
+ import re
12
+ import sys
13
+ import unittest
14
+ import warnings
15
+ from contextlib import redirect_stderr, redirect_stdout
16
+ from io import StringIO
17
+ from timeit import Timer
18
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
19
+ import numpy
20
+ from numpy.testing import assert_allclose
21
+
22
+ BOOLEAN_VALUES = (1, "1", True, "True", "true", "TRUE")
23
+
24
+
25
+ def is_azure() -> bool:
26
+ """Tells if the job is running on Azure DevOps."""
27
+ return os.environ.get("AZURE_HTTP_USER_AGENT", "undefined") != "undefined"
28
+
29
+
30
+ def is_windows() -> bool:
31
+ return sys.platform == "win32"
32
+
33
+
34
+ def is_apple() -> bool:
35
+ return sys.platform == "darwin"
36
+
37
+
38
+ def is_linux() -> bool:
39
+ return sys.platform == "linux"
40
+
41
+
42
+ def skipif_ci_windows(msg) -> Callable:
43
+ """Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
44
+ if is_windows() and is_azure():
45
+ msg = f"Test does not work on azure pipeline (Windows). {msg}"
46
+ return unittest.skip(msg)
47
+ return lambda x: x
48
+
49
+
50
+ def skipif_ci_linux(msg) -> Callable:
51
+ """Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Linux`."""
52
+ if is_linux() and is_azure():
53
+ msg = f"Takes too long (Linux). {msg}"
54
+ return unittest.skip(msg)
55
+ return lambda x: x
56
+
57
+
58
+ def skipif_ci_apple(msg) -> Callable:
59
+ """Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
60
+ if is_apple() and is_azure():
61
+ msg = f"Test does not work on azure pipeline (Apple). {msg}"
62
+ return unittest.skip(msg)
63
+ return lambda x: x
64
+
65
+
66
+ def unit_test_going():
67
+ """
68
+ Enables a flag telling the script is running while testing it.
69
+ Avois unit tests to be very long.
70
+ """
71
+ going = int(os.environ.get("UNITTEST_GOING", 0))
72
+ return going == 1
73
+
74
+
75
+ def ignore_warnings(warns: List[Warning]) -> Callable:
76
+ """
77
+ Catches warnings.
78
+
79
+ :param warns: warnings to ignore
80
+ """
81
+ if not isinstance(warns, (tuple, list)):
82
+ warns = (warns,)
83
+ new_list = []
84
+ for w in warns:
85
+ if w == "TracerWarning":
86
+ from torch.jit import TracerWarning
87
+
88
+ new_list.append(TracerWarning)
89
+ else:
90
+ new_list.append(w)
91
+ warns = tuple(new_list)
92
+
93
+ def wrapper(fct):
94
+ if warns is None:
95
+ raise AssertionError(f"warns cannot be None for '{fct}'.")
96
+
97
+ def call_f(self):
98
+ with warnings.catch_warnings():
99
+ warnings.simplefilter("ignore", warns)
100
+ return fct(self)
101
+
102
+ try: # noqa: SIM105
103
+ call_f.__name__ = fct.__name__
104
+ except AttributeError:
105
+ pass
106
+ return call_f
107
+
108
+ return wrapper
109
+
110
+
111
+ def ignore_errors(errors: Union[Exception, Tuple[Exception]]) -> Callable:
112
+ """
113
+ Catches exception, skip the test if the error is expected sometimes.
114
+
115
+ :param errors: errors to ignore
116
+ """
117
+
118
+ def wrapper(fct):
119
+ if errors is None:
120
+ raise AssertionError(f"errors cannot be None for '{fct}'.")
121
+
122
+ def call_f(self):
123
+ try:
124
+ return fct(self)
125
+ except errors as e:
126
+ raise unittest.SkipTest( # noqa: B904
127
+ f"expecting error {e.__class__.__name__}: {e}"
128
+ )
129
+
130
+ try: # noqa: SIM105
131
+ call_f.__name__ = fct.__name__
132
+ except AttributeError:
133
+ pass
134
+ return call_f
135
+
136
+ return wrapper
137
+
138
+
139
+ def hide_stdout(f: Optional[Callable] = None) -> Callable:
140
+ """
141
+ Catches warnings, hides standard output.
142
+ The function may be disabled by setting ``UNHIDE=1``
143
+ before running the unit test.
144
+
145
+ :param f: the function is called with the stdout as an argument
146
+ """
147
+
148
+ def wrapper(fct):
149
+ def call_f(self):
150
+ if os.environ.get("UNHIDE", ""):
151
+ fct(self)
152
+ return
153
+ st = StringIO()
154
+ with redirect_stdout(st), warnings.catch_warnings():
155
+ warnings.simplefilter("ignore", (UserWarning, DeprecationWarning))
156
+ try:
157
+ fct(self)
158
+ except AssertionError as e:
159
+ if "torch is not recent enough, file" in str(e):
160
+ raise unittest.SkipTest(str(e)) # noqa: B904
161
+ raise
162
+ if f is not None:
163
+ f(st.getvalue())
164
+ return None
165
+
166
+ try: # noqa: SIM105
167
+ call_f.__name__ = fct.__name__
168
+ except AttributeError:
169
+ pass
170
+ return call_f
171
+
172
+ return wrapper
173
+
174
+
175
+ def long_test(msg: str = "") -> Callable:
176
+ """Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
177
+ if os.environ.get("LONGTEST", "0") in ("0", 0, False, "False", "false"):
178
+ msg = f"Skipped (set LONGTEST=1 to run it. {msg}"
179
+ return unittest.skip(msg)
180
+ return lambda x: x
181
+
182
+
183
+ def never_test(msg: str = "") -> Callable:
184
+ """Skips a unit test."""
185
+ if os.environ.get("NEVERTEST", "0") in ("0", 0, False, "False", "false"):
186
+ msg = f"Skipped (set NEVERTEST=1 to run it. {msg}"
187
+ return unittest.skip(msg)
188
+ return lambda x: x
189
+
190
+
191
+ def measure_time(
192
+ stmt: Union[str, Callable],
193
+ context: Optional[Dict[str, Any]] = None,
194
+ repeat: int = 10,
195
+ number: int = 50,
196
+ warmup: int = 1,
197
+ div_by_number: bool = True,
198
+ max_time: Optional[float] = None,
199
+ ) -> Dict[str, Union[str, int, float]]:
200
+ """
201
+ Measures a statement and returns the results as a dictionary.
202
+
203
+ :param stmt: string or callable
204
+ :param context: variable to know in a dictionary
205
+ :param repeat: average over *repeat* experiment
206
+ :param number: number of executions in one row
207
+ :param warmup: number of iteration to do before starting the
208
+ real measurement
209
+ :param div_by_number: divide by the number of executions
210
+ :param max_time: execute the statement until the total goes
211
+ beyond this time (approximately), *repeat* is ignored,
212
+ *div_by_number* must be set to True
213
+ :return: dictionary
214
+
215
+ .. runpython::
216
+ :showcode:
217
+
218
+ from pprint import pprint
219
+ from math import cos
220
+ from onnx_diagnostic.ext_test_case import measure_time
221
+
222
+ res = measure_time(lambda: cos(0.5))
223
+ pprint(res)
224
+
225
+ See `Timer.repeat <https://docs.python.org/3/library/
226
+ timeit.html?timeit.Timer.repeat>`_
227
+ for a better understanding of parameter *repeat* and *number*.
228
+ The function returns a duration corresponding to
229
+ *number* times the execution of the main statement.
230
+ """
231
+ if not callable(stmt) and not isinstance(stmt, str):
232
+ raise TypeError(f"stmt is not callable or a string but is of type {type(stmt)!r}.")
233
+ if context is None:
234
+ context = {}
235
+
236
+ if isinstance(stmt, str):
237
+ tim = Timer(stmt, globals=context)
238
+ else:
239
+ tim = Timer(stmt)
240
+
241
+ if warmup > 0:
242
+ warmup_time = tim.timeit(warmup)
243
+ else:
244
+ warmup_time = 0
245
+
246
+ if max_time is not None:
247
+ if not div_by_number:
248
+ raise ValueError("div_by_number must be set to True of max_time is defined.")
249
+ i = 1
250
+ total_time = 0.0
251
+ results = []
252
+ while True:
253
+ for j in (1, 2):
254
+ number = i * j
255
+ time_taken = tim.timeit(number)
256
+ results.append((number, time_taken))
257
+ total_time += time_taken
258
+ if total_time >= max_time:
259
+ break
260
+ if total_time >= max_time:
261
+ break
262
+ ratio = (max_time - total_time) / total_time
263
+ ratio = max(ratio, 1)
264
+ i = int(i * ratio)
265
+
266
+ res = numpy.array(results)
267
+ tw = res[:, 0].sum()
268
+ ttime = res[:, 1].sum()
269
+ mean = ttime / tw
270
+ ave = res[:, 1] / res[:, 0]
271
+ dev = (((ave - mean) ** 2 * res[:, 0]).sum() / tw) ** 0.5
272
+ mes = dict(
273
+ average=mean,
274
+ deviation=dev,
275
+ min_exec=numpy.min(ave),
276
+ max_exec=numpy.max(ave),
277
+ repeat=1,
278
+ number=tw,
279
+ ttime=ttime,
280
+ )
281
+ else:
282
+ res = numpy.array(tim.repeat(repeat=repeat, number=number))
283
+ if div_by_number:
284
+ res /= number
285
+
286
+ mean = numpy.mean(res)
287
+ dev = numpy.mean(res**2)
288
+ dev = (dev - mean**2) ** 0.5
289
+ mes = dict(
290
+ average=mean,
291
+ deviation=dev,
292
+ min_exec=numpy.min(res),
293
+ max_exec=numpy.max(res),
294
+ repeat=repeat,
295
+ number=number,
296
+ ttime=res.sum(),
297
+ )
298
+
299
+ if "values" in context:
300
+ if hasattr(context["values"], "shape"):
301
+ mes["size"] = context["values"].shape[0]
302
+ else:
303
+ mes["size"] = len(context["values"])
304
+ else:
305
+ mes["context_size"] = sys.getsizeof(context)
306
+ mes["warmup_time"] = warmup_time
307
+ return mes
308
+
309
+
310
+ def statistics_on_folder(
311
+ folder: Union[str, List[str]],
312
+ pattern: str = ".*[.]((py|rst))$",
313
+ aggregation: int = 0,
314
+ ) -> List[Dict[str, Union[int, float, str]]]:
315
+ """
316
+ Computes statistics on files in a folder.
317
+
318
+ :param folder: folder or folders to investigate
319
+ :param pattern: file pattern
320
+ :param aggregation: show the first subfolders
321
+ :return: list of dictionaries
322
+
323
+ .. runpython::
324
+ :showcode:
325
+ :toggle:
326
+
327
+ import os
328
+ import pprint
329
+ from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__
330
+
331
+ pprint.pprint(statistics_on_folder(os.path.dirname(__file__)))
332
+
333
+ Aggregated:
334
+
335
+ .. runpython::
336
+ :showcode:
337
+ :toggle:
338
+
339
+ import os
340
+ import pprint
341
+ from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__
342
+
343
+ pprint.pprint(statistics_on_folder(os.path.dirname(__file__), aggregation=1))
344
+ """
345
+ if isinstance(folder, list):
346
+ rows = []
347
+ for fold in folder:
348
+ last = fold.replace("\\", "/").split("/")[-1]
349
+ r = statistics_on_folder(
350
+ fold, pattern=pattern, aggregation=max(aggregation - 1, 0)
351
+ )
352
+ if aggregation == 0:
353
+ rows.extend(r)
354
+ continue
355
+ for line in r:
356
+ line["dir"] = os.path.join(last, line["dir"])
357
+ rows.extend(r)
358
+ return rows
359
+
360
+ rows = []
361
+ reg = re.compile(pattern)
362
+ for name in glob.glob("**/*", root_dir=folder, recursive=True):
363
+ if not reg.match(name):
364
+ continue
365
+ if os.path.isdir(os.path.join(folder, name)):
366
+ continue
367
+ n = name.replace("\\", "/")
368
+ spl = n.split("/")
369
+ level = len(spl)
370
+ stat = statistics_on_file(os.path.join(folder, name))
371
+ stat["name"] = name
372
+ if aggregation <= 0:
373
+ rows.append(stat)
374
+ continue
375
+ spl = os.path.dirname(name).replace("\\", "/").split("/")
376
+ level = "/".join(spl[:aggregation])
377
+ stat["dir"] = level
378
+ rows.append(stat)
379
+ return rows
380
+
381
+
382
+ def get_figure(ax):
383
+ """Returns the figure of a matplotlib figure."""
384
+ if hasattr(ax, "get_figure"):
385
+ return ax.get_figure()
386
+ if len(ax.shape) == 0:
387
+ return ax.get_figure()
388
+ if len(ax.shape) == 1:
389
+ return ax[0].get_figure()
390
+ if len(ax.shape) == 2:
391
+ return ax[0, 0].get_figure()
392
+ raise RuntimeError(f"Unexpected shape {ax.shape} for axis.")
393
+
394
+
395
+ def has_cuda() -> bool:
396
+ """Returns ``torch.cuda.device_count() > 0``."""
397
+ import torch
398
+
399
+ return torch.cuda.device_count() > 0
400
+
401
+
402
+ def requires_python(version: Tuple[int, ...], msg: str = ""):
403
+ """
404
+ Skips a test if python is too old.
405
+
406
+ :param msg: to overwrite the message
407
+ :param version: minimum version
408
+ """
409
+ if sys.version_info[: len(version)] < version:
410
+ return unittest.skip(msg or f"python not recent enough {sys.version_info} < {version}")
411
+ return lambda x: x
412
+
413
+
414
+ def requires_cuda(msg: str = "", version: str = "", memory: int = 0):
415
+ """
416
+ Skips a test if cuda is not available.
417
+
418
+ :param msg: to overwrite the message
419
+ :param version: minimum version
420
+ :param memory: minimum number of Gb to run the test
421
+ """
422
+ import torch
423
+
424
+ if torch.cuda.device_count() == 0:
425
+ msg = msg or "only runs on CUDA but torch does not have it"
426
+ return unittest.skip(msg or "cuda not installed")
427
+ if version:
428
+ import packaging.versions as pv
429
+
430
+ if pv.Version(torch.version.cuda) < pv.Version(version):
431
+ msg = msg or f"CUDA older than {version}"
432
+ return unittest.skip(msg or f"cuda not recent enough {torch.version.cuda} < {version}")
433
+
434
+ if memory:
435
+ m = torch.cuda.get_device_properties(0).total_memory / 2**30
436
+ if m < memory:
437
+ msg = msg or f"available memory is not enough {m} < {memory} (Gb)"
438
+ return unittest.skip(msg)
439
+
440
+ return lambda x: x
441
+
442
+
443
+ def requires_zoo(msg: str = "") -> Callable:
444
+ """Skips a unit test if environment variable ZOO is not equal to 1."""
445
+ var = os.environ.get("ZOO", "0") in BOOLEAN_VALUES
446
+
447
+ if not var:
448
+ msg = f"ZOO not set up or != 1. {msg}"
449
+ return unittest.skip(msg or "zoo not installed")
450
+ return lambda x: x
451
+
452
+
453
+ def requires_sklearn(version: str, msg: str = "") -> Callable:
454
+ """Skips a unit test if :epkg:`scikit-learn` is not recent enough."""
455
+ import packaging.version as pv
456
+ import sklearn
457
+
458
+ if pv.Version(sklearn.__version__) < pv.Version(version):
459
+ msg = f"scikit-learn version {sklearn.__version__} < {version}: {msg}"
460
+ return unittest.skip(msg)
461
+ return lambda x: x
462
+
463
+
464
+ def requires_experimental(version: str = "0.0.0", msg: str = "") -> Callable:
465
+ """Skips a unit test if :epkg:`experimental-experiment` is not recent enough."""
466
+ import packaging.version as pv
467
+
468
+ try:
469
+ import experimental_experiment
470
+ except ImportError:
471
+ msg = f"experimental-experiment not installed: {msg}"
472
+ return unittest.skip(msg)
473
+
474
+ if pv.Version(experimental_experiment.__version__) < pv.Version(version):
475
+ msg = (
476
+ f"experimental-experiment version "
477
+ f"{experimental_experiment.__version__} < {version}: {msg}"
478
+ )
479
+ return unittest.skip(msg)
480
+ return lambda x: x
481
+
482
+
483
+ def has_torch(version: str) -> bool:
484
+ "Returns True if torch transformers is higher."
485
+ import packaging.version as pv
486
+ import torch
487
+
488
+ return pv.Version(torch.__version__) >= pv.Version(version)
489
+
490
+
491
+ def has_transformers(version: str) -> bool:
492
+ "Returns True if transformers version is higher."
493
+ import packaging.version as pv
494
+ import transformers
495
+
496
+ return pv.Version(transformers.__version__) >= pv.Version(version)
497
+
498
+
499
+ def requires_torch(version: str, msg: str = "") -> Callable:
500
+ """Skips a unit test if :epkg:`pytorch` is not recent enough."""
501
+ import packaging.version as pv
502
+ import torch
503
+
504
+ if pv.Version(torch.__version__) < pv.Version(version):
505
+ msg = f"torch version {torch.__version__} < {version}: {msg}"
506
+ return unittest.skip(msg)
507
+ return lambda x: x
508
+
509
+
510
+ def requires_numpy(version: str, msg: str = "") -> Callable:
511
+ """Skips a unit test if :epkg:`numpy` is not recent enough."""
512
+ import packaging.version as pv
513
+ import numpy
514
+
515
+ if pv.Version(numpy.__version__) < pv.Version(version):
516
+ msg = f"numpy version {numpy.__version__} < {version}: {msg}"
517
+ return unittest.skip(msg)
518
+ return lambda x: x
519
+
520
+
521
+ def requires_transformers(
522
+ version: str, msg: str = "", or_older_than: Optional[str] = None
523
+ ) -> Callable:
524
+ """Skips a unit test if :epkg:`transformers` is not recent enough."""
525
+ import packaging.version as pv
526
+
527
+ try:
528
+ import transformers
529
+ except ImportError:
530
+ msg = f"diffusers not installed {msg}"
531
+ return unittest.skip(msg)
532
+
533
+ v = pv.Version(transformers.__version__)
534
+ if v < pv.Version(version):
535
+ msg = f"transformers version {transformers.__version__} < {version}: {msg}"
536
+ return unittest.skip(msg)
537
+ if or_older_than and v > pv.Version(or_older_than):
538
+ msg = (
539
+ f"transformers version {or_older_than} < "
540
+ f"{transformers.__version__} < {version}: {msg}"
541
+ )
542
+ return unittest.skip(msg)
543
+ return lambda x: x
544
+
545
+
546
+ def requires_diffusers(
547
+ version: str, msg: str = "", or_older_than: Optional[str] = None
548
+ ) -> Callable:
549
+ """Skips a unit test if :epkg:`transformers` is not recent enough."""
550
+ import packaging.version as pv
551
+
552
+ try:
553
+ import diffusers
554
+ except ImportError:
555
+ msg = f"diffusers not installed {msg}"
556
+ return unittest.skip(msg)
557
+
558
+ v = pv.Version(diffusers.__version__)
559
+ if v < pv.Version(version):
560
+ msg = f"diffusers version {diffusers.__version__} < {version} {msg}"
561
+ return unittest.skip(msg)
562
+ if or_older_than and v > pv.Version(or_older_than):
563
+ msg = (
564
+ f"diffusers version {or_older_than} < "
565
+ f"{diffusers.__version__} < {version} {msg}"
566
+ )
567
+ return unittest.skip(msg)
568
+ return lambda x: x
569
+
570
+
571
+ def requires_onnxscript(version: str, msg: str = "") -> Callable:
572
+ """Skips a unit test if :epkg:`onnxscript` is not recent enough."""
573
+ import packaging.version as pv
574
+ import onnxscript
575
+
576
+ if not hasattr(onnxscript, "__version__"):
577
+ # development version
578
+ return lambda x: x
579
+
580
+ if pv.Version(onnxscript.__version__) < pv.Version(version):
581
+ msg = f"onnxscript version {onnxscript.__version__} < {version}: {msg}"
582
+ return unittest.skip(msg)
583
+ return lambda x: x
584
+
585
+
586
+ def has_onnxscript(version: str, msg: str = "") -> Callable:
587
+ """Skips a unit test if :epkg:`onnxscript` is not recent enough."""
588
+ import packaging.version as pv
589
+ import onnxscript
590
+
591
+ if not hasattr(onnxscript, "__version__"):
592
+ # development version
593
+ return True
594
+
595
+ if pv.Version(onnxscript.__version__) < pv.Version(version):
596
+ msg = f"onnxscript version {onnxscript.__version__} < {version}: {msg}"
597
+ return False
598
+ return True
599
+
600
+
601
+ def requires_onnxruntime(version: str, msg: str = "") -> Callable:
602
+ """Skips a unit test if :epkg:`onnxruntime` is not recent enough."""
603
+ import packaging.version as pv
604
+ import onnxruntime
605
+
606
+ if pv.Version(onnxruntime.__version__) < pv.Version(version):
607
+ msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}"
608
+ return unittest.skip(msg)
609
+ return lambda x: x
610
+
611
+
612
+ def has_onnxruntime_training(push_back_batch: bool = False):
613
+ """Tells if onnxruntime_training is installed."""
614
+ try:
615
+ from onnxruntime import training
616
+ except ImportError:
617
+ # onnxruntime not training
618
+ training = None
619
+ if training is None:
620
+ return False
621
+
622
+ if push_back_batch:
623
+ try:
624
+ from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
625
+ except ImportError:
626
+ return False
627
+
628
+ if not hasattr(OrtValueVector, "push_back_batch"):
629
+ return False
630
+ return True
631
+
632
+
633
+ def has_onnxruntime_genai():
634
+ """Tells if onnxruntime_genai is installed."""
635
+ try:
636
+ import onnxruntime_genai # noqa: F401
637
+
638
+ return True
639
+ except ImportError:
640
+ # onnxruntime not training
641
+ return False
642
+
643
+
644
+ def requires_onnxruntime_training(
645
+ push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
646
+ ) -> Callable:
647
+ """Skips a unit test if :epkg:`onnxruntime` is not onnxruntime_training."""
648
+ try:
649
+ from onnxruntime import training
650
+ except ImportError:
651
+ # onnxruntime not training
652
+ training = None
653
+ if training is None:
654
+ msg = msg or "onnxruntime_training is not installed"
655
+ return unittest.skip(msg)
656
+
657
+ if push_back_batch:
658
+ try:
659
+ from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
660
+ except ImportError:
661
+ msg = msg or "OrtValue has no method push_back_batch"
662
+ return unittest.skip(msg)
663
+
664
+ if not hasattr(OrtValueVector, "push_back_batch"):
665
+ msg = msg or "OrtValue has no method push_back_batch"
666
+ return unittest.skip(msg)
667
+ if ortmodule:
668
+ try:
669
+ import onnxruntime.training.ortmodule # noqa: F401
670
+ except (AttributeError, ImportError):
671
+ msg = msg or "ortmodule is missing in onnxruntime-training"
672
+ return unittest.skip(msg)
673
+ return lambda x: x
674
+
675
+
676
+ def requires_onnx(version: str, msg: str = "") -> Callable:
677
+ """Skips a unit test if :epkg:`onnx` is not recent enough."""
678
+ import packaging.version as pv
679
+ import onnx
680
+
681
+ if pv.Version(onnx.__version__) < pv.Version(version):
682
+ msg = f"onnx version {onnx.__version__} < {version}: {msg}"
683
+ return unittest.skip(msg)
684
+ return lambda x: x
685
+
686
+
687
+ def requires_onnx_array_api(version: str, msg: str = "") -> Callable:
688
+ """Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
689
+ import packaging.version as pv
690
+ import onnx_array_api
691
+
692
+ if pv.Version(onnx_array_api.__version__) < pv.Version(version):
693
+ msg = f"onnx-array-api version {onnx_array_api.__version__} < {version}: {msg}"
694
+ return unittest.skip(msg)
695
+ return lambda x: x
696
+
697
+
698
+ def statistics_on_file(filename: str) -> Dict[str, Union[int, float, str]]:
699
+ """
700
+ Computes statistics on a file.
701
+
702
+ .. runpython::
703
+ :showcode:
704
+
705
+ import pprint
706
+ from onnx_diagnostic.ext_test_case import statistics_on_file, __file__
707
+
708
+ pprint.pprint(statistics_on_file(__file__))
709
+ """
710
+ assert os.path.exists(filename), f"File {filename!r} does not exists."
711
+
712
+ ext = os.path.splitext(filename)[-1]
713
+ if ext not in {".py", ".rst", ".md", ".txt"}:
714
+ size = os.stat(filename).st_size
715
+ return {"size": size}
716
+ alpha = set("abcdefghijklmnopqrstuvwxyz0123456789")
717
+ with open(filename, "r", encoding="utf-8") as f:
718
+ n_line = 0
719
+ n_ch = 0
720
+ for line in f.readlines():
721
+ s = line.strip("\n\r\t ")
722
+ if s:
723
+ n_ch += len(s.replace(" ", ""))
724
+ ch = set(s.lower()) & alpha
725
+ if ch:
726
+ # It avoid counting line with only a bracket, a comma.
727
+ n_line += 1
728
+
729
+ stat = dict(lines=n_line, chars=n_ch, ext=ext)
730
+ if ext != ".py":
731
+ return stat
732
+ # add statistics on python syntax?
733
+ return stat
734
+
735
+
736
+ class ExtTestCase(unittest.TestCase):
737
+ """
738
+ Inherits from :class:`unittest.TestCase` and adds specific comprison
739
+ functions and other helper.
740
+ """
741
+
742
+ _warns: List[Tuple[str, int, Warning]] = []
743
+ _todos: List[Tuple[Callable, str]] = []
744
+
745
+ @property
746
+ def verbose(self):
747
+ "Returns the the value of environment variable ``VERBOSE``."
748
+ return int(os.environ.get("VERBOSE", "0"))
749
+
750
+ @classmethod
751
+ def setUpClass(cls):
752
+ logger = logging.getLogger("onnxscript.optimizer.constant_folding")
753
+ logger.setLevel(logging.ERROR)
754
+ unittest.TestCase.setUpClass()
755
+
756
+ @classmethod
757
+ def tearDownClass(cls):
758
+ for name, line, w in cls._warns:
759
+ warnings.warn(f"\n{name}:{line}: {type(w)}\n {w!s}", stacklevel=2)
760
+ if not cls._todos:
761
+ return
762
+ for f, msg in cls._todos:
763
+ sys.stderr.write(f"TODO {cls.__name__}::{f.__name__}: {msg}\n")
764
+
765
+ @classmethod
766
+ def todo(cls, f: Callable, msg: str):
767
+ "Adds a todo printed when all test are run."
768
+ cls._todos.append((f, msg))
769
+
770
+ @classmethod
771
+ def ort(cls):
772
+ import onnxruntime
773
+
774
+ return onnxruntime
775
+
776
+ @classmethod
777
+ def to_onnx(self, *args, **kwargs):
778
+ from experimental_experiment.torch_interpreter import to_onnx
779
+
780
+ return to_onnx(*args, **kwargs)
781
+
782
+ def print_model(self, model: "ModelProto"): # noqa: F821
783
+ "Prints a ModelProto"
784
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
785
+
786
+ print(pretty_onnx(model))
787
+
788
+ def print_onnx(self, model: "ModelProto"): # noqa: F821
789
+ "Prints a ModelProto"
790
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
791
+
792
+ print(pretty_onnx(model))
793
+
794
+ def get_dump_file(self, name: str, folder: Optional[str] = None) -> str:
795
+ """Returns a filename to dump a model."""
796
+ if folder is None:
797
+ folder = "dump_test"
798
+ if folder and not os.path.exists(folder):
799
+ os.mkdir(folder)
800
+ return os.path.join(folder, name)
801
+
802
+ def get_dump_folder(self, folder: str) -> str:
803
+ """Returns a folder."""
804
+ folder = os.path.join("dump_test", folder)
805
+ if not os.path.exists(folder):
806
+ os.makedirs(folder)
807
+ return folder
808
+
809
+ def dump_onnx(
810
+ self,
811
+ name: str,
812
+ proto: Any,
813
+ folder: Optional[str] = None,
814
+ ) -> str:
815
+ """Dumps an onnx file."""
816
+ fullname = self.get_dump_file(name, folder=folder)
817
+ with open(fullname, "wb") as f:
818
+ f.write(proto.SerializeToString())
819
+ return fullname
820
+
821
+ def assertExists(self, name):
822
+ """Checks the existing of a file."""
823
+ if not os.path.exists(name):
824
+ raise AssertionError(f"File or folder {name!r} does not exists.")
825
+
826
+ def assertGreaterOrEqual(self, a, b, msg=None):
827
+ """In the name"""
828
+ if a < b:
829
+ return AssertionError(f"{a} < {b}, a not greater or equal than b\n{msg or ''}")
830
+
831
+ def assertInOr(self, tofind: Tuple[str, ...], text: str, msg: str = ""):
832
+ for tof in tofind:
833
+ if tof in text:
834
+ return
835
+ raise AssertionError(
836
+ msg or f"Unable to find one string in the list {tofind!r} in\n--\n{text}"
837
+ )
838
+
839
+ def assertIn(self, tofind: str, text: str, msg: str = ""):
840
+ if tofind in text:
841
+ return
842
+ raise AssertionError(
843
+ msg or f"Unable to find the list of strings {tofind!r} in\n--\n{text}"
844
+ )
845
+
846
+ def assertHasAttr(self, obj: Any, name: str):
847
+ assert hasattr(
848
+ obj, name
849
+ ), f"Unable to find attribute {name!r} in object type {type(obj)}"
850
+
851
+ def assertSetContained(self, set1, set2):
852
+ "Checks that ``set1`` is contained in ``set2``."
853
+ set1 = set(set1)
854
+ set2 = set(set2)
855
+ if set1 & set2 != set1:
856
+ raise AssertionError(f"Set {set2} does not contain set {set1}.")
857
+
858
+ def assertEqualArrays(
859
+ self,
860
+ expected: Sequence[numpy.ndarray],
861
+ value: Sequence[numpy.ndarray],
862
+ atol: float = 0,
863
+ rtol: float = 0,
864
+ msg: Optional[str] = None,
865
+ ):
866
+ """In the name"""
867
+ self.assertEqual(len(expected), len(value))
868
+ for a, b in zip(expected, value):
869
+ self.assertEqualArray(a, b, atol=atol, rtol=rtol)
870
+
871
+ def assertEqualArray(
872
+ self,
873
+ expected: Any,
874
+ value: Any,
875
+ atol: float = 0,
876
+ rtol: float = 0,
877
+ msg: Optional[str] = None,
878
+ ):
879
+ """In the name"""
880
+ if hasattr(expected, "detach") and hasattr(value, "detach"):
881
+ if msg:
882
+ try:
883
+ self.assertEqual(expected.dtype, value.dtype)
884
+ except AssertionError as e:
885
+ raise AssertionError(msg) from e
886
+ try:
887
+ self.assertEqual(expected.shape, value.shape)
888
+ except AssertionError as e:
889
+ raise AssertionError(msg) from e
890
+ else:
891
+ self.assertEqual(expected.dtype, value.dtype)
892
+ self.assertEqual(expected.shape, value.shape)
893
+
894
+ import torch
895
+
896
+ try:
897
+ torch.testing.assert_close(value, expected, atol=atol, rtol=rtol)
898
+ except AssertionError as e:
899
+ expected_max = torch.abs(expected).max()
900
+ expected_value = torch.abs(value).max()
901
+ rows = [
902
+ f"{msg}\n{e}" if msg else str(e),
903
+ f"expected max value={expected_max}",
904
+ f"expected computed value={expected_value}",
905
+ ]
906
+ raise AssertionError("\n".join(rows)) # noqa: B904
907
+ return
908
+
909
+ from .helpers.torch_helper import to_numpy
910
+
911
+ if hasattr(expected, "detach"):
912
+ expected = to_numpy(expected.detach().cpu())
913
+ if hasattr(value, "detach"):
914
+ value = to_numpy(value.detach().cpu())
915
+ if msg:
916
+ try:
917
+ self.assertEqual(expected.dtype, value.dtype)
918
+ except AssertionError as e:
919
+ raise AssertionError(msg) from e
920
+ try:
921
+ self.assertEqual(expected.shape, value.shape)
922
+ except AssertionError as e:
923
+ raise AssertionError(msg) from e
924
+ else:
925
+ self.assertEqual(expected.dtype, value.dtype)
926
+ self.assertEqual(expected.shape, value.shape)
927
+
928
+ try:
929
+ assert_allclose(desired=expected, actual=value, atol=atol, rtol=rtol)
930
+ except AssertionError as e:
931
+ expected_max = numpy.abs(expected).max()
932
+ expected_value = numpy.abs(value).max()
933
+ te = expected.astype(int) if expected.dtype == numpy.bool_ else expected
934
+ tv = value.astype(int) if value.dtype == numpy.bool_ else value
935
+ rows = [
936
+ f"{msg}\n{e}" if msg else str(e),
937
+ f"expected max value={expected_max}",
938
+ f"expected computed value={expected_value}\n",
939
+ f"ratio={te / tv}\ndiff={te - tv}",
940
+ ]
941
+ raise AssertionError("\n".join(rows)) # noqa: B904
942
+
943
+ def assertEqualDataFrame(self, d1, d2, **kwargs):
944
+ """
945
+ Checks that two dataframes are equal.
946
+ Calls :func:`pandas.testing.assert_frame_equal`.
947
+ """
948
+ from pandas.testing import assert_frame_equal
949
+
950
+ assert_frame_equal(d1, d2, **kwargs)
951
+
952
+ def assertEqualTrue(self, value: Any, msg: str = ""):
953
+ if value is True:
954
+ return
955
+ raise AssertionError(msg or f"value is not True: {value!r}")
956
+
957
+ def assertEqual(self, expected: Any, value: Any, msg: str = ""):
958
+ """Overwrites the error message to get a more explicit message about what is what."""
959
+ if msg:
960
+ super().assertEqual(expected, value, msg)
961
+ else:
962
+ try:
963
+ super().assertEqual(expected, value)
964
+ except AssertionError as e:
965
+ raise AssertionError( # noqa: B904
966
+ f"expected is {expected!r}, value is {value!r}\n{e}"
967
+ )
968
+
969
+ def assertEqualAny(
970
+ self, expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str = ""
971
+ ):
972
+ if expected.__class__.__name__ == "BaseModelOutput":
973
+ self.assertEqual(type(expected), type(value), msg=msg)
974
+ self.assertEqual(len(expected), len(value), msg=msg)
975
+ self.assertEqual(list(expected), list(value), msg=msg) # checks the order
976
+ self.assertEqualAny(
977
+ {k: v for k, v in expected.items()}, # noqa: C416
978
+ {k: v for k, v in value.items()}, # noqa: C416
979
+ atol=atol,
980
+ rtol=rtol,
981
+ msg=msg,
982
+ )
983
+ elif isinstance(expected, (tuple, list, dict)):
984
+ self.assertIsInstance(value, type(expected), msg=msg)
985
+ self.assertEqual(len(expected), len(value), msg=msg)
986
+ if isinstance(expected, dict):
987
+ for k in expected:
988
+ self.assertIn(k, value, msg=msg)
989
+ self.assertEqualAny(expected[k], value[k], msg=msg, atol=atol, rtol=rtol)
990
+ else:
991
+ for e, g in zip(expected, value):
992
+ self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
993
+ elif expected.__class__.__name__ in (
994
+ "DynamicCache",
995
+ "SlidingWindowCache",
996
+ "HybridCache",
997
+ ):
998
+ self.assertEqual(type(expected), type(value), msg=msg)
999
+ atts = ["key_cache", "value_cache"]
1000
+ self.assertEqualAny(
1001
+ {k: expected.__dict__.get(k, None) for k in atts},
1002
+ {k: value.__dict__.get(k, None) for k in atts},
1003
+ atol=atol,
1004
+ rtol=rtol,
1005
+ )
1006
+ elif expected.__class__.__name__ == "StaticCache":
1007
+ self.assertEqual(type(expected), type(value), msg=msg)
1008
+ self.assertEqual(expected.max_cache_len, value.max_cache_len)
1009
+ atts = ["key_cache", "value_cache"]
1010
+ self.assertEqualAny(
1011
+ {k: expected.__dict__.get(k, None) for k in atts},
1012
+ {k: value.__dict__.get(k, None) for k in atts},
1013
+ atol=atol,
1014
+ rtol=rtol,
1015
+ )
1016
+ elif expected.__class__.__name__ == "EncoderDecoderCache":
1017
+ self.assertEqual(type(expected), type(value), msg=msg)
1018
+ atts = ["self_attention_cache", "cross_attention_cache"]
1019
+ self.assertEqualAny(
1020
+ {k: expected.__dict__.get(k, None) for k in atts},
1021
+ {k: value.__dict__.get(k, None) for k in atts},
1022
+ atol=atol,
1023
+ rtol=rtol,
1024
+ )
1025
+ elif isinstance(expected, (int, float, str)):
1026
+ self.assertEqual(expected, value, msg=msg)
1027
+ elif hasattr(expected, "shape"):
1028
+ self.assertEqual(type(expected), type(value), msg=msg)
1029
+ self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
1030
+ elif expected.__class__.__name__ in ("Dim", "_Dim", "_DimHintType"):
1031
+ self.assertEqual(type(expected), type(value), msg=msg)
1032
+ self.assertEqual(expected.__name__, value.__name__, msg=msg)
1033
+ elif expected is None:
1034
+ self.assertEqual(expected, value, msg=msg)
1035
+ else:
1036
+ raise AssertionError(
1037
+ f"Comparison not implemented for types {type(expected)} and {type(value)}"
1038
+ )
1039
+
1040
+ def assertEqualArrayAny(
1041
+ self, expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str = ""
1042
+ ):
1043
+ if isinstance(expected, (tuple, list, dict)):
1044
+ self.assertIsInstance(value, type(expected), msg=msg)
1045
+ self.assertEqual(len(expected), len(value), msg=msg)
1046
+ if isinstance(expected, dict):
1047
+ for k in expected:
1048
+ self.assertIn(k, value, msg=msg)
1049
+ self.assertEqualArrayAny(
1050
+ expected[k], value[k], msg=msg, atol=atol, rtol=rtol
1051
+ )
1052
+ else:
1053
+ excs = []
1054
+ for i, (e, g) in enumerate(zip(expected, value)):
1055
+ try:
1056
+ self.assertEqualArrayAny(e, g, msg=msg, atol=atol, rtol=rtol)
1057
+ except AssertionError as e:
1058
+ excs.append(f"Error at position {i} due to {e}")
1059
+ if excs:
1060
+ msg_ = "\n".join(excs)
1061
+ msg = f"{msg}\n{msg_}" if msg else msg_
1062
+ raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
1063
+ elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
1064
+ atts = {"key_cache", "value_cache"}
1065
+ self.assertEqualArrayAny(
1066
+ {k: expected.__dict__.get(k, None) for k in atts},
1067
+ {k: value.__dict__.get(k, None) for k in atts},
1068
+ atol=atol,
1069
+ rtol=rtol,
1070
+ )
1071
+ elif isinstance(expected, (int, float, str)):
1072
+ self.assertEqual(expected, value, msg=msg)
1073
+ elif hasattr(expected, "shape"):
1074
+ self.assertEqual(type(expected), type(value), msg=msg)
1075
+ self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
1076
+ elif expected is None:
1077
+ assert value is None, f"Expected is None but value is of type {type(value)}"
1078
+ else:
1079
+ raise AssertionError(
1080
+ f"Comparison not implemented for types {type(expected)} and {type(value)}"
1081
+ )
1082
+
1083
+ def assertAlmostEqual(
1084
+ self,
1085
+ expected: numpy.ndarray,
1086
+ value: numpy.ndarray,
1087
+ atol: float = 0,
1088
+ rtol: float = 0,
1089
+ ):
1090
+ """In the name"""
1091
+ if not isinstance(expected, numpy.ndarray):
1092
+ expected = numpy.array(expected)
1093
+ if not isinstance(value, numpy.ndarray):
1094
+ value = numpy.array(value).astype(expected.dtype)
1095
+ self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
1096
+
1097
+ def check_ort(self, onx: "onnx.ModelProto") -> bool: # noqa: F821
1098
+ from onnxruntime import InferenceSession
1099
+
1100
+ return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
1101
+
1102
+ def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
1103
+ """In the name"""
1104
+ try:
1105
+ fct()
1106
+ except exc_type as e:
1107
+ if not isinstance(e, exc_type):
1108
+ raise AssertionError(f"Unexpected exception {type(e)!r}.") # noqa: B904
1109
+ if msg is not None and msg not in str(e):
1110
+ raise AssertionError(f"Unexpected exception message {e!r}.") # noqa: B904
1111
+ return
1112
+ raise AssertionError("No exception was raised.") # noqa: B904
1113
+
1114
+ def assertEmpty(self, value: Any):
1115
+ """In the name"""
1116
+ if value is None:
1117
+ return
1118
+ if not value:
1119
+ return
1120
+ raise AssertionError(f"value is not empty: {value!r}.")
1121
+
1122
+ def assertNotEmpty(self, value: Any):
1123
+ """In the name"""
1124
+ if value is None:
1125
+ raise AssertionError(f"value is empty: {value!r}.")
1126
+ if isinstance(value, (list, dict, tuple, set)):
1127
+ if not value:
1128
+ raise AssertionError(f"value is empty: {value!r}.")
1129
+
1130
+ def assertStartsWith(self, prefix: str, full: str):
1131
+ """In the name"""
1132
+ if not full.startswith(prefix):
1133
+ raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
1134
+
1135
+ def assertEndsWith(self, suffix: str, full: str):
1136
+ """In the name"""
1137
+ if not full.endswith(suffix):
1138
+ raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
1139
+
1140
+ def capture(self, fct: Callable):
1141
+ """
1142
+ Runs a function and capture standard output and error.
1143
+
1144
+ :param fct: function to run
1145
+ :return: result of *fct*, output, error
1146
+ """
1147
+ sout = StringIO()
1148
+ serr = StringIO()
1149
+ with redirect_stdout(sout), redirect_stderr(serr):
1150
+ try:
1151
+ res = fct()
1152
+ except Exception as e:
1153
+ raise AssertionError(
1154
+ f"function {fct} failed, stdout="
1155
+ f"\n{sout.getvalue()}\n---\nstderr=\n{serr.getvalue()}"
1156
+ ) from e
1157
+ return res, sout.getvalue(), serr.getvalue()
1158
+
1159
+ def tryCall(
1160
+ self, fct: Callable, msg: Optional[str] = None, none_if: Optional[str] = None
1161
+ ) -> Optional[Any]:
1162
+ """
1163
+ Calls the function, catch any error.
1164
+
1165
+ :param fct: function to call
1166
+ :param msg: error message to display if failing
1167
+ :param none_if: returns None if this substring is found in the error message
1168
+ :return: output of *fct*
1169
+ """
1170
+ try:
1171
+ return fct()
1172
+ except Exception as e:
1173
+ if none_if is not None and none_if in str(e):
1174
+ return None
1175
+ if msg is None:
1176
+ raise
1177
+ raise AssertionError(msg) from e
1178
+
1179
+ def assert_onnx_disc(
1180
+ self,
1181
+ test_name: str,
1182
+ proto: "onnx.ModelProto", # noqa: F821
1183
+ model: "torch.nn.Module", # noqa: F821
1184
+ inputs: Union[Tuple[Any], Dict[str, Any]],
1185
+ verbose: int = 0,
1186
+ atol: float = 1e-5,
1187
+ rtol: float = 1e-3,
1188
+ copy_inputs: bool = True,
1189
+ expected: Optional[Any] = None,
1190
+ use_ort: bool = False,
1191
+ **kwargs,
1192
+ ):
1193
+ """
1194
+ Checks for discrepancies.
1195
+ Runs the onnx models, computes expected outputs, in that order.
1196
+ The inputs may be modified by this functions if the torch model
1197
+ modifies them inplace.
1198
+
1199
+ :param test_name: test name, dumps the model if not empty
1200
+ :param proto: onnx model
1201
+ :param model: torch model
1202
+ :param inputs: inputs
1203
+ :param verbose: verbosity
1204
+ :param atol: absolute tolerance
1205
+ :param rtol: relative tolerance
1206
+ :param expected: expected values
1207
+ :param copy_inputs: to copy the inputs
1208
+ :param use_ort: use :class:`onnxruntime.InferenceSession`
1209
+ :param kwargs: arguments sent to
1210
+ :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
1211
+ """
1212
+ from .helpers import string_type, string_diff, max_diff
1213
+ from .helpers.rt_helper import make_feeds
1214
+ from .helpers.ort_session import InferenceSessionForTorch
1215
+
1216
+ kws = dict(with_shape=True, with_min_max=verbose > 1)
1217
+ if verbose:
1218
+ vname = test_name or "assert_onnx_disc"
1219
+ if test_name:
1220
+ name = f"{test_name}.onnx"
1221
+ print(f"[{vname}] save the onnx model into {name!r}")
1222
+ name = self.dump_onnx(name, proto)
1223
+ print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
1224
+ if verbose:
1225
+ print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
1226
+ if use_ort:
1227
+ feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
1228
+ if verbose:
1229
+ print(f"[{vname}] feeds {string_type(feeds, **kws)}")
1230
+ import onnxruntime
1231
+
1232
+ sess = onnxruntime.InferenceSession(
1233
+ proto.SerializeToString(), providers=["CPUExecutionProvider"]
1234
+ )
1235
+ got = sess.run(None, feeds)
1236
+ else:
1237
+ feeds = make_feeds(proto, inputs, copy=True)
1238
+ if verbose:
1239
+ print(f"[{vname}] feeds {string_type(feeds, **kws)}")
1240
+ sess = InferenceSessionForTorch(proto, **kwargs)
1241
+ got = sess.run(None, feeds)
1242
+ if verbose:
1243
+ print(f"[{vname}] compute expected values")
1244
+ if expected is None:
1245
+ if copy_inputs:
1246
+ expected = (
1247
+ model(*copy.deepcopy(inputs))
1248
+ if isinstance(inputs, tuple)
1249
+ else model(**copy.deepcopy(inputs))
1250
+ )
1251
+ else:
1252
+ expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
1253
+ if verbose:
1254
+ print(f"[{vname}] expected {string_type(expected, **kws)}")
1255
+ print(f"[{vname}] obtained {string_type(got, **kws)}")
1256
+ diff = max_diff(expected, got, flatten=True)
1257
+ if verbose:
1258
+ print(f"[{vname}] diff {string_diff(diff)}")
1259
+ assert (
1260
+ isinstance(diff["abs"], float)
1261
+ and isinstance(diff["rel"], float)
1262
+ and not numpy.isnan(diff["abs"])
1263
+ and diff["abs"] <= atol
1264
+ and not numpy.isnan(diff["rel"])
1265
+ and diff["rel"] <= rtol
1266
+ ), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
1267
+
1268
+ def _debug(self):
1269
+ "Tells if DEBUG=1 is set up."
1270
+ return os.environ.get("DEBUG") in BOOLEAN_VALUES
1271
+
1272
+ def string_type(self, *args, **kwargs):
1273
+ from .helpers import string_type
1274
+
1275
+ return string_type(*args, **kwargs)
1276
+
1277
+ def subloop(self, *args, verbose: int = 0):
1278
+ "Loops over elements and calls :meth:`unittests.TestCase.subTest`."
1279
+ if len(args) == 1:
1280
+ for it in args[0]:
1281
+ with self.subTest(case=it):
1282
+ if verbose:
1283
+ print(f"[subloop] it={it!r}")
1284
+ yield it
1285
+ else:
1286
+ for it in itertools.product(*args):
1287
+ with self.subTest(case=it):
1288
+ if verbose:
1289
+ print(f"[subloop] it={it!r}")
1290
+ yield it