onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,450 @@
1
+ import itertools
2
+ import multiprocessing
3
+ import os
4
+ import platform
5
+ import re
6
+ import subprocess
7
+ import sys
8
+ import time
9
+ import warnings
10
+ from argparse import Namespace
11
+ from datetime import datetime
12
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
13
+
14
+
15
+ _DEFAULT_STRING_LIMIT = 2000
16
+
17
+
18
+ class BenchmarkError(RuntimeError):
19
+ pass
20
+
21
+
22
+ def _clean_string(s: str) -> str:
23
+ cleaned = [c for c in s if 32 <= ord(c) < 127 and c not in {","}]
24
+ return "".join(cleaned)
25
+
26
+
27
+ def get_processor_name():
28
+ """Returns the processor name."""
29
+ if platform.system() in ("Windows", "Darwin"):
30
+ return platform.processor()
31
+ if platform.system() == "Linux":
32
+ command = "cat /proc/cpuinfo"
33
+ all_info = subprocess.check_output(command, shell=True).decode().strip()
34
+ for line in all_info.split("\n"):
35
+ if "model name" in line:
36
+ return re.sub(".*model name.*:", "", line, count=1, flags=0).strip()
37
+ # fails
38
+ # if platform.system() == "Darwin":
39
+ # os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
40
+ # command = "sysctl -n machdep.cpu.brand_string"
41
+ # return subprocess.check_output(command).strip()
42
+
43
+ raise AssertionError("get_process_name not implemented on this platform.")
44
+
45
+
46
+ def get_machine(
47
+ capability_as_str: bool = True,
48
+ ) -> Dict[str, Union[str, int, float, Tuple[int, int]]]:
49
+ """Returns the machine specifications."""
50
+ arch = platform.architecture()
51
+ config: Dict[str, Union[str, int, float, Tuple[int, int]]] = dict(
52
+ machine=str(platform.machine()),
53
+ architecture=(
54
+ "/".join(str(_) for _ in arch) if isinstance(arch, (list, tuple)) else str(arch)
55
+ ),
56
+ processor=str(platform.processor()),
57
+ version=str(sys.version).split()[0],
58
+ cpu=int(multiprocessing.cpu_count()),
59
+ executable=str(sys.executable),
60
+ processor_name=get_processor_name(),
61
+ system=str(platform.system()),
62
+ )
63
+ try:
64
+ import torch.cuda
65
+ except ImportError:
66
+ return config
67
+
68
+ config["has_cuda"] = bool(torch.cuda.device_count() > 0)
69
+ if config["has_cuda"]:
70
+ config["capability"] = (
71
+ ".".join(map(str, torch.cuda.get_device_capability(0)))
72
+ if capability_as_str
73
+ else torch.cuda.get_device_capability(0)
74
+ )
75
+ config["device_name"] = str(torch.cuda.get_device_name(0))
76
+ return config
77
+
78
+
79
+ def _cmd_line(script_name: str, **kwargs: Dict[str, Union[str, int, float]]) -> List[str]:
80
+ args = [sys.executable, "-m", script_name]
81
+ for k, v in kwargs.items():
82
+ if v is None:
83
+ continue
84
+ args.append(f"--{k}")
85
+ args.append(str(v))
86
+ return args
87
+
88
+
89
+ def _extract_metrics(text: str) -> Dict[str, Union[str, int, float]]:
90
+ reg = re.compile(":(.*?),(.*.?);")
91
+ res = reg.findall(text)
92
+ if len(res) == 0:
93
+ return {}
94
+ kw = dict(res)
95
+ new_kw: Dict[str, Any] = {}
96
+ for k, w in kw.items():
97
+ assert isinstance(k, str) and isinstance(
98
+ w, str
99
+ ), f"Unexpected type for k={k!r}, types={type(k)}, {type(w)})."
100
+ assert "\n" not in w, f"Unexpected multi-line value for k={k!r}, value is\n{w}"
101
+ if not (
102
+ "err" in k.lower()
103
+ or k
104
+ in {
105
+ "onnx_output_names",
106
+ "onnx_input_names",
107
+ "filename",
108
+ "time_latency_t_detail",
109
+ "time_latency_t_qu",
110
+ "time_latency_t_qu_10t",
111
+ "time_latency_eager_t_detail",
112
+ "time_latency_eager_t_qu",
113
+ "time_latency_eager_t_qu_10t",
114
+ }
115
+ or len(w) < 500
116
+ ):
117
+ warnings.warn(
118
+ f"Unexpected long value for model={kw.get('model_name', '?')}, "
119
+ f"k={k!r}, value has length {len(w)} is\n{w}",
120
+ stacklevel=2,
121
+ )
122
+ continue
123
+ try:
124
+ wi = int(w)
125
+ new_kw[k] = wi
126
+ continue
127
+ except ValueError:
128
+ pass
129
+ try:
130
+ wf = float(w)
131
+ new_kw[k] = wf
132
+ continue
133
+ except ValueError:
134
+ pass
135
+ new_kw[k] = w
136
+ return new_kw
137
+
138
+
139
+ def _make_prefix(script_name: str, index: int) -> str:
140
+ name = os.path.splitext(script_name)[0]
141
+ return f"{name}_dort_c{index}_"
142
+
143
+
144
+ def _cmd_string(s: str) -> str:
145
+ if s == "":
146
+ return '""'
147
+ return s.replace('"', '\\"')
148
+
149
+
150
+ def run_benchmark(
151
+ script_name: str,
152
+ configs: List[Dict[str, Union[str, int, float]]],
153
+ verbose: int = 0,
154
+ stop_if_exception: bool = True,
155
+ dump: bool = False,
156
+ temp_output_data: Optional[str] = None,
157
+ dump_std: Optional[str] = None,
158
+ start: int = 0,
159
+ summary: Optional[Callable] = None,
160
+ timeout: int = 600,
161
+ missing: Optional[Dict[str, Union[str, Callable]]] = None,
162
+ ) -> List[Dict[str, Union[str, int, float]]]:
163
+ """
164
+ Runs a script multiple times and extract information from the output
165
+ following the pattern ``:<metric>,<value>;``.
166
+
167
+ :param script_name: python script to run
168
+ :param configs: list of execution to do
169
+ :param stop_if_exception: stop if one experiment failed, otherwise continue
170
+ :param verbose: use tqdm to follow the progress
171
+ :param dump: dump onnx file, sets variable ONNXRT_DUMP_PATH
172
+ :param temp_output_data: to save the data after every run to avoid losing data
173
+ :param dump_std: dumps stdout and stderr in this folder
174
+ :param start: start at this iteration
175
+ :param summary: function to call on the temporary data and the final data
176
+ :param timeout: timeout for the subprocesses
177
+ :param missing: populate with this missing value if not found
178
+ :return: values
179
+ """
180
+ assert (
181
+ temp_output_data is None or "temp" in temp_output_data
182
+ ), f"Unexpected value for {temp_output_data!r}"
183
+ assert configs, f"No configuration was given (script_name={script_name!r})"
184
+ if verbose:
185
+ from tqdm import tqdm
186
+
187
+ loop = tqdm(configs)
188
+ else:
189
+ loop = configs
190
+
191
+ data: List[Dict[str, Union[str, int, float]]] = []
192
+ for iter_loop, config in enumerate(loop):
193
+ if iter_loop < start:
194
+ continue
195
+ if hasattr(loop, "set_description"):
196
+ for c in ["name", "model"]:
197
+ if c not in config:
198
+ continue
199
+ loop.set_description(f"[{config[c]}]")
200
+ break
201
+ cmd = _cmd_line(script_name, **config)
202
+ begin = time.perf_counter()
203
+
204
+ if dump:
205
+ os.environ["ONNXRT_DUMP_PATH"] = _make_prefix(script_name, iter_loop)
206
+ else:
207
+ os.environ["ONNXRT_DUMP_PATH"] = ""
208
+ if verbose > 3:
209
+ print(f"[run_benchmark] cmd={cmd if isinstance(cmd, str) else ' '.join(cmd)}")
210
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
211
+ timeout_error = ""
212
+ try:
213
+ res = p.communicate(timeout=timeout)
214
+ except subprocess.TimeoutExpired as e:
215
+ # see https://docs.python.org/3/library/subprocess.html#subprocess.Popen.communicate
216
+ timeout_error = str(e)
217
+ if verbose:
218
+ print(f"[run_benchmark] timeout {e} for cmd={cmd}")
219
+ p.terminate()
220
+ try:
221
+ # Use communicate with a timeout to prevent hanging
222
+ res = p.communicate(timeout=10)
223
+ except subprocess.TimeoutExpired:
224
+ # Force kill if terminate doesn't work
225
+ if verbose:
226
+ print(f"[run_benchmark] force killing cmd={cmd}")
227
+ p.kill()
228
+ res = p.communicate()
229
+ out, err = res
230
+ sout = out.decode("utf-8", errors="ignore")
231
+ serr = err.decode("utf-8", errors="ignore")
232
+
233
+ if dump_std:
234
+ if dump_std and not os.path.exists(dump_std):
235
+ os.makedirs(dump_std)
236
+ root = os.path.split(script_name)[-1].split(".")[-1]
237
+ filename = os.path.join(dump_std, f"{root}.{iter_loop}")
238
+ filename_out = f"{filename}.stdout"
239
+ filename_err = f"{filename}.stderr"
240
+ if out.strip(b"\n \r\t"):
241
+ with open(filename_out, "w") as f:
242
+ f.write(sout)
243
+ if err.strip(b"\n \r\t"):
244
+ with open(filename_err, "w") as f:
245
+ f.write(serr)
246
+ else:
247
+ filename_out, filename_err = None, None
248
+
249
+ if "ONNXRuntimeError" in serr or "ONNXRuntimeError" in sout:
250
+ if stop_if_exception:
251
+ raise RuntimeError(
252
+ f"Unable to continue with config {config} due to the "
253
+ f"following error\n{serr}"
254
+ f"\n----OUTPUT--\n{sout}"
255
+ )
256
+
257
+ metrics = _extract_metrics(sout)
258
+ if len(metrics) == 0:
259
+ if stop_if_exception:
260
+ raise BenchmarkError(
261
+ f"Unable (2) to continue with config {config}, no metric was "
262
+ f"collected.\n--ERROR--\n{serr}\n--OUTPUT--\n{sout}"
263
+ )
264
+ else:
265
+ metrics = {}
266
+ metrics.update(config)
267
+ if filename_out and os.path.exists(filename_out):
268
+ if "model_name" in metrics:
269
+ assert isinstance(
270
+ metrics["model_name"], str
271
+ ), f"unexpected type {type(metrics['model_name'])}"
272
+ new_name = f"{filename_out}.{_clean_string(metrics['model_name'])}"
273
+ os.rename(filename_out, new_name)
274
+ filename_out = new_name
275
+ metrics["file.stdout"] = filename_out
276
+ if filename_err and os.path.exists(filename_err):
277
+ if "model_name" in metrics:
278
+ assert isinstance(
279
+ metrics["model_name"], str
280
+ ), f"unexpected type {type(metrics['model_name'])}"
281
+ new_name = f"{filename_err}.{_clean_string(metrics['model_name'])}"
282
+ os.rename(filename_err, new_name)
283
+ filename_err = new_name
284
+ metrics["file.stderr"] = filename_err
285
+ metrics["DATE"] = f"{datetime.now():%Y-%m-%d}"
286
+ metrics["ITER"] = str(iter_loop)
287
+ metrics["TIME_ITER"] = time.perf_counter() - begin
288
+ metrics["ERROR"] = _clean_string(serr)
289
+ metrics["ERR_stdout"] = _clean_string(sout)
290
+ if metrics["ERROR"]:
291
+ metrics["ERR_std"] = metrics["ERROR"]
292
+ assert isinstance(
293
+ metrics["ERROR"], str
294
+ ), f"unexpected type {type(metrics['ERROR'])}"
295
+ if "CUDA out of memory" in metrics["ERROR"]:
296
+ metrics["ERR_CUDA_OOM"] = 1
297
+ if "Cannot access gated repo for url" in metrics["ERROR"]:
298
+ metrics["ERR_ACCESS"] = 1
299
+ if timeout_error:
300
+ metrics["ERR_timeout"] = _clean_string(timeout_error)
301
+ metrics["OUTPUT"] = _clean_string(sout)
302
+ for k, v in config.items():
303
+ metrics[f"config_{k}"] = str(v).replace("\n", " ")
304
+ if missing:
305
+ update_missing = {}
306
+ for k, v in missing.items():
307
+ if k not in metrics:
308
+ if isinstance(v, str):
309
+ update_missing[k] = v
310
+ continue
311
+ if callable(v):
312
+ update_missing.update(v(missing, config))
313
+ continue
314
+ raise AssertionError(
315
+ f"Unable to interpret {type(v)} for k={k!r}, config={config!r}"
316
+ )
317
+ if update_missing:
318
+ metrics.update(update_missing)
319
+ metrics["CMD"] = f"[{' '.join(map(_cmd_string, cmd))}]"
320
+ data.append(metrics)
321
+ if verbose > 5:
322
+ print(f"--------------- ITER={iter_loop} in {metrics['TIME_ITER']}")
323
+ print("--------------- ERROR")
324
+ print(serr)
325
+ if verbose >= 10:
326
+ print("--------------- OUTPUT")
327
+ print(sout)
328
+ if temp_output_data:
329
+ df = make_dataframe_from_benchmark_data(data, detailed=False)
330
+ if verbose > 2:
331
+ print(f"Prints out the results into file {temp_output_data!r}")
332
+ fold, _ = os.path.split(temp_output_data)
333
+ # fold could be empty string
334
+ if fold and not os.path.exists(fold):
335
+ os.makedirs(fold)
336
+ df.to_csv(temp_output_data, index=False, errors="ignore")
337
+ try:
338
+ df.to_excel(temp_output_data + ".xlsx", index=False)
339
+ except Exception:
340
+ continue
341
+ if summary:
342
+ fn = f"{temp_output_data}.summary-partial.xlsx"
343
+ if verbose > 2:
344
+ print(f"Prints out the results into file {fn!r}")
345
+ summary(df, excel_output=fn, exc=False)
346
+
347
+ return data
348
+
349
+
350
+ def multi_run(kwargs: Namespace) -> bool:
351
+ """Checks if multiple values were sent for one argument."""
352
+ return any(isinstance(v, str) and "," in v for v in kwargs.__dict__.values())
353
+
354
+
355
+ def make_configs(
356
+ kwargs: Union[Namespace, Dict[str, Any]],
357
+ drop: Optional[Set[str]] = None,
358
+ replace: Optional[Dict[str, str]] = None,
359
+ last: Optional[List[str]] = None,
360
+ filter_function: Optional[Callable[[Dict[str, Union[str, int, float]]], bool]] = None,
361
+ ) -> List[Dict[str, Union[str, int, float]]]:
362
+ """
363
+ Creates all the configurations based on the command line arguments.
364
+
365
+ :param kwargs: parameters the command line,
366
+ every value having a comma means multiple values,
367
+ it multiplies the number of configurations to try by the number of comma
368
+ separated values
369
+ :param drop: keys to drop in kwargs if specified
370
+ :param replace: values to replace for a particular key
371
+ :param last: to change the order of the loop created the configuration,
372
+ if ``last == ["part"]`` and ``kwargs[part] == "0,1"``,
373
+ then configuration where ``part==0`` is always followed by a configuration
374
+ having ``part==1``
375
+ :param filter_function: function taking a configuration and returning True
376
+ if it is must be kept
377
+ :return: list of configurations
378
+ """
379
+ kwargs_ = kwargs if isinstance(kwargs, dict) else kwargs.__dict__
380
+ args = []
381
+ slast = set(last) if last else set()
382
+ for k, v in kwargs_.items():
383
+ if (drop and k in drop) or k in slast:
384
+ continue
385
+ if replace and k in replace:
386
+ v = replace[k]
387
+ if isinstance(v, str):
388
+ args.append([(k, s) for s in v.split(",")])
389
+ else:
390
+ args.append([(k, v)])
391
+ if last:
392
+ for k in last:
393
+ if k not in kwargs_:
394
+ continue
395
+ v = kwargs[k] # type: ignore
396
+ if isinstance(v, str):
397
+ args.append([(k, s) for s in v.split(",")])
398
+ else:
399
+ args.append([(k, v)])
400
+
401
+ configs = list(itertools.product(*args))
402
+ confs: List[Dict[str, Union[int, float, str]]] = [dict(c) for c in configs]
403
+ if filter_function:
404
+ confs = [c for c in confs if filter_function(c)]
405
+ return confs
406
+
407
+
408
+ def make_dataframe_from_benchmark_data(
409
+ data: List[Dict], detailed: bool = True, string_limit: int = _DEFAULT_STRING_LIMIT
410
+ ) -> Any:
411
+ """
412
+ Creates a dataframe from the received data.
413
+
414
+ :param data: list of dictionaries for every run
415
+ :param detailed: remove multi line and long values
416
+ :param string_limit: truncate the strings
417
+ :return: dataframe
418
+ """
419
+ import pandas
420
+
421
+ if detailed:
422
+ return pandas.DataFrame(data)
423
+
424
+ new_data = []
425
+ for d in data:
426
+ g = {}
427
+ for k, v in d.items():
428
+ if not isinstance(v, str):
429
+ g[k] = v
430
+ continue
431
+ v = v.replace("\n", " -- ").replace(",", "_")
432
+ if len(v) > string_limit:
433
+ v = v[:string_limit] + "..."
434
+ g[k] = v
435
+ new_data.append(g)
436
+ df = pandas.DataFrame(new_data)
437
+ sorted_columns = sorted(df.columns)
438
+ if "_index" in sorted_columns:
439
+ set_cols = set(df.columns)
440
+ addition = {"_index", "CMD", "OUTPUT", "ERROR"} & set_cols
441
+ new_columns = []
442
+ if "_index" in addition:
443
+ new_columns.append("_index")
444
+ new_columns.extend([i for i in sorted_columns if i not in addition])
445
+ for c in ["ERROR", "OUTPUT", "CMD"]:
446
+ if c in addition:
447
+ new_columns.append(c)
448
+ sorted_columns = new_columns
449
+
450
+ return df[sorted_columns].copy()