onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import multiprocessing
|
|
3
|
+
import os
|
|
4
|
+
import platform
|
|
5
|
+
import re
|
|
6
|
+
import subprocess
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
import warnings
|
|
10
|
+
from argparse import Namespace
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_DEFAULT_STRING_LIMIT = 2000
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BenchmarkError(RuntimeError):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _clean_string(s: str) -> str:
|
|
23
|
+
cleaned = [c for c in s if 32 <= ord(c) < 127 and c not in {","}]
|
|
24
|
+
return "".join(cleaned)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_processor_name():
|
|
28
|
+
"""Returns the processor name."""
|
|
29
|
+
if platform.system() in ("Windows", "Darwin"):
|
|
30
|
+
return platform.processor()
|
|
31
|
+
if platform.system() == "Linux":
|
|
32
|
+
command = "cat /proc/cpuinfo"
|
|
33
|
+
all_info = subprocess.check_output(command, shell=True).decode().strip()
|
|
34
|
+
for line in all_info.split("\n"):
|
|
35
|
+
if "model name" in line:
|
|
36
|
+
return re.sub(".*model name.*:", "", line, count=1, flags=0).strip()
|
|
37
|
+
# fails
|
|
38
|
+
# if platform.system() == "Darwin":
|
|
39
|
+
# os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
|
|
40
|
+
# command = "sysctl -n machdep.cpu.brand_string"
|
|
41
|
+
# return subprocess.check_output(command).strip()
|
|
42
|
+
|
|
43
|
+
raise AssertionError("get_process_name not implemented on this platform.")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_machine(
|
|
47
|
+
capability_as_str: bool = True,
|
|
48
|
+
) -> Dict[str, Union[str, int, float, Tuple[int, int]]]:
|
|
49
|
+
"""Returns the machine specifications."""
|
|
50
|
+
arch = platform.architecture()
|
|
51
|
+
config: Dict[str, Union[str, int, float, Tuple[int, int]]] = dict(
|
|
52
|
+
machine=str(platform.machine()),
|
|
53
|
+
architecture=(
|
|
54
|
+
"/".join(str(_) for _ in arch) if isinstance(arch, (list, tuple)) else str(arch)
|
|
55
|
+
),
|
|
56
|
+
processor=str(platform.processor()),
|
|
57
|
+
version=str(sys.version).split()[0],
|
|
58
|
+
cpu=int(multiprocessing.cpu_count()),
|
|
59
|
+
executable=str(sys.executable),
|
|
60
|
+
processor_name=get_processor_name(),
|
|
61
|
+
system=str(platform.system()),
|
|
62
|
+
)
|
|
63
|
+
try:
|
|
64
|
+
import torch.cuda
|
|
65
|
+
except ImportError:
|
|
66
|
+
return config
|
|
67
|
+
|
|
68
|
+
config["has_cuda"] = bool(torch.cuda.device_count() > 0)
|
|
69
|
+
if config["has_cuda"]:
|
|
70
|
+
config["capability"] = (
|
|
71
|
+
".".join(map(str, torch.cuda.get_device_capability(0)))
|
|
72
|
+
if capability_as_str
|
|
73
|
+
else torch.cuda.get_device_capability(0)
|
|
74
|
+
)
|
|
75
|
+
config["device_name"] = str(torch.cuda.get_device_name(0))
|
|
76
|
+
return config
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _cmd_line(script_name: str, **kwargs: Dict[str, Union[str, int, float]]) -> List[str]:
|
|
80
|
+
args = [sys.executable, "-m", script_name]
|
|
81
|
+
for k, v in kwargs.items():
|
|
82
|
+
if v is None:
|
|
83
|
+
continue
|
|
84
|
+
args.append(f"--{k}")
|
|
85
|
+
args.append(str(v))
|
|
86
|
+
return args
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _extract_metrics(text: str) -> Dict[str, Union[str, int, float]]:
|
|
90
|
+
reg = re.compile(":(.*?),(.*.?);")
|
|
91
|
+
res = reg.findall(text)
|
|
92
|
+
if len(res) == 0:
|
|
93
|
+
return {}
|
|
94
|
+
kw = dict(res)
|
|
95
|
+
new_kw: Dict[str, Any] = {}
|
|
96
|
+
for k, w in kw.items():
|
|
97
|
+
assert isinstance(k, str) and isinstance(
|
|
98
|
+
w, str
|
|
99
|
+
), f"Unexpected type for k={k!r}, types={type(k)}, {type(w)})."
|
|
100
|
+
assert "\n" not in w, f"Unexpected multi-line value for k={k!r}, value is\n{w}"
|
|
101
|
+
if not (
|
|
102
|
+
"err" in k.lower()
|
|
103
|
+
or k
|
|
104
|
+
in {
|
|
105
|
+
"onnx_output_names",
|
|
106
|
+
"onnx_input_names",
|
|
107
|
+
"filename",
|
|
108
|
+
"time_latency_t_detail",
|
|
109
|
+
"time_latency_t_qu",
|
|
110
|
+
"time_latency_t_qu_10t",
|
|
111
|
+
"time_latency_eager_t_detail",
|
|
112
|
+
"time_latency_eager_t_qu",
|
|
113
|
+
"time_latency_eager_t_qu_10t",
|
|
114
|
+
}
|
|
115
|
+
or len(w) < 500
|
|
116
|
+
):
|
|
117
|
+
warnings.warn(
|
|
118
|
+
f"Unexpected long value for model={kw.get('model_name', '?')}, "
|
|
119
|
+
f"k={k!r}, value has length {len(w)} is\n{w}",
|
|
120
|
+
stacklevel=2,
|
|
121
|
+
)
|
|
122
|
+
continue
|
|
123
|
+
try:
|
|
124
|
+
wi = int(w)
|
|
125
|
+
new_kw[k] = wi
|
|
126
|
+
continue
|
|
127
|
+
except ValueError:
|
|
128
|
+
pass
|
|
129
|
+
try:
|
|
130
|
+
wf = float(w)
|
|
131
|
+
new_kw[k] = wf
|
|
132
|
+
continue
|
|
133
|
+
except ValueError:
|
|
134
|
+
pass
|
|
135
|
+
new_kw[k] = w
|
|
136
|
+
return new_kw
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _make_prefix(script_name: str, index: int) -> str:
|
|
140
|
+
name = os.path.splitext(script_name)[0]
|
|
141
|
+
return f"{name}_dort_c{index}_"
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _cmd_string(s: str) -> str:
|
|
145
|
+
if s == "":
|
|
146
|
+
return '""'
|
|
147
|
+
return s.replace('"', '\\"')
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def run_benchmark(
|
|
151
|
+
script_name: str,
|
|
152
|
+
configs: List[Dict[str, Union[str, int, float]]],
|
|
153
|
+
verbose: int = 0,
|
|
154
|
+
stop_if_exception: bool = True,
|
|
155
|
+
dump: bool = False,
|
|
156
|
+
temp_output_data: Optional[str] = None,
|
|
157
|
+
dump_std: Optional[str] = None,
|
|
158
|
+
start: int = 0,
|
|
159
|
+
summary: Optional[Callable] = None,
|
|
160
|
+
timeout: int = 600,
|
|
161
|
+
missing: Optional[Dict[str, Union[str, Callable]]] = None,
|
|
162
|
+
) -> List[Dict[str, Union[str, int, float]]]:
|
|
163
|
+
"""
|
|
164
|
+
Runs a script multiple times and extract information from the output
|
|
165
|
+
following the pattern ``:<metric>,<value>;``.
|
|
166
|
+
|
|
167
|
+
:param script_name: python script to run
|
|
168
|
+
:param configs: list of execution to do
|
|
169
|
+
:param stop_if_exception: stop if one experiment failed, otherwise continue
|
|
170
|
+
:param verbose: use tqdm to follow the progress
|
|
171
|
+
:param dump: dump onnx file, sets variable ONNXRT_DUMP_PATH
|
|
172
|
+
:param temp_output_data: to save the data after every run to avoid losing data
|
|
173
|
+
:param dump_std: dumps stdout and stderr in this folder
|
|
174
|
+
:param start: start at this iteration
|
|
175
|
+
:param summary: function to call on the temporary data and the final data
|
|
176
|
+
:param timeout: timeout for the subprocesses
|
|
177
|
+
:param missing: populate with this missing value if not found
|
|
178
|
+
:return: values
|
|
179
|
+
"""
|
|
180
|
+
assert (
|
|
181
|
+
temp_output_data is None or "temp" in temp_output_data
|
|
182
|
+
), f"Unexpected value for {temp_output_data!r}"
|
|
183
|
+
assert configs, f"No configuration was given (script_name={script_name!r})"
|
|
184
|
+
if verbose:
|
|
185
|
+
from tqdm import tqdm
|
|
186
|
+
|
|
187
|
+
loop = tqdm(configs)
|
|
188
|
+
else:
|
|
189
|
+
loop = configs
|
|
190
|
+
|
|
191
|
+
data: List[Dict[str, Union[str, int, float]]] = []
|
|
192
|
+
for iter_loop, config in enumerate(loop):
|
|
193
|
+
if iter_loop < start:
|
|
194
|
+
continue
|
|
195
|
+
if hasattr(loop, "set_description"):
|
|
196
|
+
for c in ["name", "model"]:
|
|
197
|
+
if c not in config:
|
|
198
|
+
continue
|
|
199
|
+
loop.set_description(f"[{config[c]}]")
|
|
200
|
+
break
|
|
201
|
+
cmd = _cmd_line(script_name, **config)
|
|
202
|
+
begin = time.perf_counter()
|
|
203
|
+
|
|
204
|
+
if dump:
|
|
205
|
+
os.environ["ONNXRT_DUMP_PATH"] = _make_prefix(script_name, iter_loop)
|
|
206
|
+
else:
|
|
207
|
+
os.environ["ONNXRT_DUMP_PATH"] = ""
|
|
208
|
+
if verbose > 3:
|
|
209
|
+
print(f"[run_benchmark] cmd={cmd if isinstance(cmd, str) else ' '.join(cmd)}")
|
|
210
|
+
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
211
|
+
timeout_error = ""
|
|
212
|
+
try:
|
|
213
|
+
res = p.communicate(timeout=timeout)
|
|
214
|
+
except subprocess.TimeoutExpired as e:
|
|
215
|
+
# see https://docs.python.org/3/library/subprocess.html#subprocess.Popen.communicate
|
|
216
|
+
timeout_error = str(e)
|
|
217
|
+
if verbose:
|
|
218
|
+
print(f"[run_benchmark] timeout {e} for cmd={cmd}")
|
|
219
|
+
p.terminate()
|
|
220
|
+
try:
|
|
221
|
+
# Use communicate with a timeout to prevent hanging
|
|
222
|
+
res = p.communicate(timeout=10)
|
|
223
|
+
except subprocess.TimeoutExpired:
|
|
224
|
+
# Force kill if terminate doesn't work
|
|
225
|
+
if verbose:
|
|
226
|
+
print(f"[run_benchmark] force killing cmd={cmd}")
|
|
227
|
+
p.kill()
|
|
228
|
+
res = p.communicate()
|
|
229
|
+
out, err = res
|
|
230
|
+
sout = out.decode("utf-8", errors="ignore")
|
|
231
|
+
serr = err.decode("utf-8", errors="ignore")
|
|
232
|
+
|
|
233
|
+
if dump_std:
|
|
234
|
+
if dump_std and not os.path.exists(dump_std):
|
|
235
|
+
os.makedirs(dump_std)
|
|
236
|
+
root = os.path.split(script_name)[-1].split(".")[-1]
|
|
237
|
+
filename = os.path.join(dump_std, f"{root}.{iter_loop}")
|
|
238
|
+
filename_out = f"{filename}.stdout"
|
|
239
|
+
filename_err = f"{filename}.stderr"
|
|
240
|
+
if out.strip(b"\n \r\t"):
|
|
241
|
+
with open(filename_out, "w") as f:
|
|
242
|
+
f.write(sout)
|
|
243
|
+
if err.strip(b"\n \r\t"):
|
|
244
|
+
with open(filename_err, "w") as f:
|
|
245
|
+
f.write(serr)
|
|
246
|
+
else:
|
|
247
|
+
filename_out, filename_err = None, None
|
|
248
|
+
|
|
249
|
+
if "ONNXRuntimeError" in serr or "ONNXRuntimeError" in sout:
|
|
250
|
+
if stop_if_exception:
|
|
251
|
+
raise RuntimeError(
|
|
252
|
+
f"Unable to continue with config {config} due to the "
|
|
253
|
+
f"following error\n{serr}"
|
|
254
|
+
f"\n----OUTPUT--\n{sout}"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
metrics = _extract_metrics(sout)
|
|
258
|
+
if len(metrics) == 0:
|
|
259
|
+
if stop_if_exception:
|
|
260
|
+
raise BenchmarkError(
|
|
261
|
+
f"Unable (2) to continue with config {config}, no metric was "
|
|
262
|
+
f"collected.\n--ERROR--\n{serr}\n--OUTPUT--\n{sout}"
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
metrics = {}
|
|
266
|
+
metrics.update(config)
|
|
267
|
+
if filename_out and os.path.exists(filename_out):
|
|
268
|
+
if "model_name" in metrics:
|
|
269
|
+
assert isinstance(
|
|
270
|
+
metrics["model_name"], str
|
|
271
|
+
), f"unexpected type {type(metrics['model_name'])}"
|
|
272
|
+
new_name = f"{filename_out}.{_clean_string(metrics['model_name'])}"
|
|
273
|
+
os.rename(filename_out, new_name)
|
|
274
|
+
filename_out = new_name
|
|
275
|
+
metrics["file.stdout"] = filename_out
|
|
276
|
+
if filename_err and os.path.exists(filename_err):
|
|
277
|
+
if "model_name" in metrics:
|
|
278
|
+
assert isinstance(
|
|
279
|
+
metrics["model_name"], str
|
|
280
|
+
), f"unexpected type {type(metrics['model_name'])}"
|
|
281
|
+
new_name = f"{filename_err}.{_clean_string(metrics['model_name'])}"
|
|
282
|
+
os.rename(filename_err, new_name)
|
|
283
|
+
filename_err = new_name
|
|
284
|
+
metrics["file.stderr"] = filename_err
|
|
285
|
+
metrics["DATE"] = f"{datetime.now():%Y-%m-%d}"
|
|
286
|
+
metrics["ITER"] = str(iter_loop)
|
|
287
|
+
metrics["TIME_ITER"] = time.perf_counter() - begin
|
|
288
|
+
metrics["ERROR"] = _clean_string(serr)
|
|
289
|
+
metrics["ERR_stdout"] = _clean_string(sout)
|
|
290
|
+
if metrics["ERROR"]:
|
|
291
|
+
metrics["ERR_std"] = metrics["ERROR"]
|
|
292
|
+
assert isinstance(
|
|
293
|
+
metrics["ERROR"], str
|
|
294
|
+
), f"unexpected type {type(metrics['ERROR'])}"
|
|
295
|
+
if "CUDA out of memory" in metrics["ERROR"]:
|
|
296
|
+
metrics["ERR_CUDA_OOM"] = 1
|
|
297
|
+
if "Cannot access gated repo for url" in metrics["ERROR"]:
|
|
298
|
+
metrics["ERR_ACCESS"] = 1
|
|
299
|
+
if timeout_error:
|
|
300
|
+
metrics["ERR_timeout"] = _clean_string(timeout_error)
|
|
301
|
+
metrics["OUTPUT"] = _clean_string(sout)
|
|
302
|
+
for k, v in config.items():
|
|
303
|
+
metrics[f"config_{k}"] = str(v).replace("\n", " ")
|
|
304
|
+
if missing:
|
|
305
|
+
update_missing = {}
|
|
306
|
+
for k, v in missing.items():
|
|
307
|
+
if k not in metrics:
|
|
308
|
+
if isinstance(v, str):
|
|
309
|
+
update_missing[k] = v
|
|
310
|
+
continue
|
|
311
|
+
if callable(v):
|
|
312
|
+
update_missing.update(v(missing, config))
|
|
313
|
+
continue
|
|
314
|
+
raise AssertionError(
|
|
315
|
+
f"Unable to interpret {type(v)} for k={k!r}, config={config!r}"
|
|
316
|
+
)
|
|
317
|
+
if update_missing:
|
|
318
|
+
metrics.update(update_missing)
|
|
319
|
+
metrics["CMD"] = f"[{' '.join(map(_cmd_string, cmd))}]"
|
|
320
|
+
data.append(metrics)
|
|
321
|
+
if verbose > 5:
|
|
322
|
+
print(f"--------------- ITER={iter_loop} in {metrics['TIME_ITER']}")
|
|
323
|
+
print("--------------- ERROR")
|
|
324
|
+
print(serr)
|
|
325
|
+
if verbose >= 10:
|
|
326
|
+
print("--------------- OUTPUT")
|
|
327
|
+
print(sout)
|
|
328
|
+
if temp_output_data:
|
|
329
|
+
df = make_dataframe_from_benchmark_data(data, detailed=False)
|
|
330
|
+
if verbose > 2:
|
|
331
|
+
print(f"Prints out the results into file {temp_output_data!r}")
|
|
332
|
+
fold, _ = os.path.split(temp_output_data)
|
|
333
|
+
# fold could be empty string
|
|
334
|
+
if fold and not os.path.exists(fold):
|
|
335
|
+
os.makedirs(fold)
|
|
336
|
+
df.to_csv(temp_output_data, index=False, errors="ignore")
|
|
337
|
+
try:
|
|
338
|
+
df.to_excel(temp_output_data + ".xlsx", index=False)
|
|
339
|
+
except Exception:
|
|
340
|
+
continue
|
|
341
|
+
if summary:
|
|
342
|
+
fn = f"{temp_output_data}.summary-partial.xlsx"
|
|
343
|
+
if verbose > 2:
|
|
344
|
+
print(f"Prints out the results into file {fn!r}")
|
|
345
|
+
summary(df, excel_output=fn, exc=False)
|
|
346
|
+
|
|
347
|
+
return data
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def multi_run(kwargs: Namespace) -> bool:
|
|
351
|
+
"""Checks if multiple values were sent for one argument."""
|
|
352
|
+
return any(isinstance(v, str) and "," in v for v in kwargs.__dict__.values())
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def make_configs(
|
|
356
|
+
kwargs: Union[Namespace, Dict[str, Any]],
|
|
357
|
+
drop: Optional[Set[str]] = None,
|
|
358
|
+
replace: Optional[Dict[str, str]] = None,
|
|
359
|
+
last: Optional[List[str]] = None,
|
|
360
|
+
filter_function: Optional[Callable[[Dict[str, Union[str, int, float]]], bool]] = None,
|
|
361
|
+
) -> List[Dict[str, Union[str, int, float]]]:
|
|
362
|
+
"""
|
|
363
|
+
Creates all the configurations based on the command line arguments.
|
|
364
|
+
|
|
365
|
+
:param kwargs: parameters the command line,
|
|
366
|
+
every value having a comma means multiple values,
|
|
367
|
+
it multiplies the number of configurations to try by the number of comma
|
|
368
|
+
separated values
|
|
369
|
+
:param drop: keys to drop in kwargs if specified
|
|
370
|
+
:param replace: values to replace for a particular key
|
|
371
|
+
:param last: to change the order of the loop created the configuration,
|
|
372
|
+
if ``last == ["part"]`` and ``kwargs[part] == "0,1"``,
|
|
373
|
+
then configuration where ``part==0`` is always followed by a configuration
|
|
374
|
+
having ``part==1``
|
|
375
|
+
:param filter_function: function taking a configuration and returning True
|
|
376
|
+
if it is must be kept
|
|
377
|
+
:return: list of configurations
|
|
378
|
+
"""
|
|
379
|
+
kwargs_ = kwargs if isinstance(kwargs, dict) else kwargs.__dict__
|
|
380
|
+
args = []
|
|
381
|
+
slast = set(last) if last else set()
|
|
382
|
+
for k, v in kwargs_.items():
|
|
383
|
+
if (drop and k in drop) or k in slast:
|
|
384
|
+
continue
|
|
385
|
+
if replace and k in replace:
|
|
386
|
+
v = replace[k]
|
|
387
|
+
if isinstance(v, str):
|
|
388
|
+
args.append([(k, s) for s in v.split(",")])
|
|
389
|
+
else:
|
|
390
|
+
args.append([(k, v)])
|
|
391
|
+
if last:
|
|
392
|
+
for k in last:
|
|
393
|
+
if k not in kwargs_:
|
|
394
|
+
continue
|
|
395
|
+
v = kwargs[k] # type: ignore
|
|
396
|
+
if isinstance(v, str):
|
|
397
|
+
args.append([(k, s) for s in v.split(",")])
|
|
398
|
+
else:
|
|
399
|
+
args.append([(k, v)])
|
|
400
|
+
|
|
401
|
+
configs = list(itertools.product(*args))
|
|
402
|
+
confs: List[Dict[str, Union[int, float, str]]] = [dict(c) for c in configs]
|
|
403
|
+
if filter_function:
|
|
404
|
+
confs = [c for c in confs if filter_function(c)]
|
|
405
|
+
return confs
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def make_dataframe_from_benchmark_data(
|
|
409
|
+
data: List[Dict], detailed: bool = True, string_limit: int = _DEFAULT_STRING_LIMIT
|
|
410
|
+
) -> Any:
|
|
411
|
+
"""
|
|
412
|
+
Creates a dataframe from the received data.
|
|
413
|
+
|
|
414
|
+
:param data: list of dictionaries for every run
|
|
415
|
+
:param detailed: remove multi line and long values
|
|
416
|
+
:param string_limit: truncate the strings
|
|
417
|
+
:return: dataframe
|
|
418
|
+
"""
|
|
419
|
+
import pandas
|
|
420
|
+
|
|
421
|
+
if detailed:
|
|
422
|
+
return pandas.DataFrame(data)
|
|
423
|
+
|
|
424
|
+
new_data = []
|
|
425
|
+
for d in data:
|
|
426
|
+
g = {}
|
|
427
|
+
for k, v in d.items():
|
|
428
|
+
if not isinstance(v, str):
|
|
429
|
+
g[k] = v
|
|
430
|
+
continue
|
|
431
|
+
v = v.replace("\n", " -- ").replace(",", "_")
|
|
432
|
+
if len(v) > string_limit:
|
|
433
|
+
v = v[:string_limit] + "..."
|
|
434
|
+
g[k] = v
|
|
435
|
+
new_data.append(g)
|
|
436
|
+
df = pandas.DataFrame(new_data)
|
|
437
|
+
sorted_columns = sorted(df.columns)
|
|
438
|
+
if "_index" in sorted_columns:
|
|
439
|
+
set_cols = set(df.columns)
|
|
440
|
+
addition = {"_index", "CMD", "OUTPUT", "ERROR"} & set_cols
|
|
441
|
+
new_columns = []
|
|
442
|
+
if "_index" in addition:
|
|
443
|
+
new_columns.append("_index")
|
|
444
|
+
new_columns.extend([i for i in sorted_columns if i not in addition])
|
|
445
|
+
for c in ["ERROR", "OUTPUT", "CMD"]:
|
|
446
|
+
if c in addition:
|
|
447
|
+
new_columns.append(c)
|
|
448
|
+
sorted_columns = new_columns
|
|
449
|
+
|
|
450
|
+
return df[sorted_columns].copy()
|