onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .helper import flatten_object, max_diff, string_diff, string_sig, string_type
|
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import glob
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
import zipfile
|
|
6
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas
|
|
9
|
+
|
|
10
|
+
BUCKET_SCALES_VALUES = np.array(
|
|
11
|
+
[-np.inf, -20, -10, -5, -2, 0, 2, 5, 10, 20, 100, 200, 300, 400, np.inf], dtype=float
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
BUCKET_SCALES = BUCKET_SCALES_VALUES / 100 + 1
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def mann_kendall(series: Sequence[float], threshold: float = 0.5):
|
|
19
|
+
"""
|
|
20
|
+
Computes the test of Mann-Kendall.
|
|
21
|
+
|
|
22
|
+
:param series: series
|
|
23
|
+
:param threshold: 1.96 is the usual value, 0.5 means a short timeseries
|
|
24
|
+
``(0, 1, 2, 3, 4)`` has a significant trend
|
|
25
|
+
:return: trend (-1, 0, +1), test value
|
|
26
|
+
|
|
27
|
+
.. math::
|
|
28
|
+
|
|
29
|
+
S =\\sum_{i=1}^{n}\\sum_{j=i+1}^{n} sign(x_j - x_i)
|
|
30
|
+
|
|
31
|
+
where the function *sign* is:
|
|
32
|
+
|
|
33
|
+
.. math::
|
|
34
|
+
|
|
35
|
+
sign(x) = \\left\\{ \\begin{array}{l} -1 if x < 0 \\\\ 0 if x = 0 \\\\ +1 otherwise
|
|
36
|
+
\\end{array} \\right.
|
|
37
|
+
|
|
38
|
+
And:
|
|
39
|
+
|
|
40
|
+
.. math::
|
|
41
|
+
|
|
42
|
+
Var(S)= \\frac{n(n-1)(2n+5) - \\sum_t t(t-1)(2t+5)}{18}
|
|
43
|
+
"""
|
|
44
|
+
aseries = np.asarray(series)
|
|
45
|
+
stat = 0
|
|
46
|
+
n = len(aseries)
|
|
47
|
+
var = n * (n - 1) * (2 * n + 5)
|
|
48
|
+
for i in range(n - 1):
|
|
49
|
+
stat += np.sign(aseries[i + 1 :] - aseries[i]).sum()
|
|
50
|
+
var = var**0.5
|
|
51
|
+
test = (stat + (1 if stat < 0 else (0 if stat == 0 else -1))) / var
|
|
52
|
+
trend = np.sign(test) if np.abs(test) > threshold else 0
|
|
53
|
+
return trend, test
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def breaking_last_point(series: Sequence[float], threshold: float = 1.2):
|
|
57
|
+
"""
|
|
58
|
+
Assuming a timeseries is constant, we check the last value
|
|
59
|
+
is not an outlier.
|
|
60
|
+
|
|
61
|
+
:param series: series
|
|
62
|
+
:return: significant change (-1, 0, +1), test value
|
|
63
|
+
"""
|
|
64
|
+
signal = np.asarray(series)
|
|
65
|
+
if not np.issubdtype(signal.dtype, np.number):
|
|
66
|
+
return 0, np.nan
|
|
67
|
+
assert len(signal.shape) == 1, f"Unexpected signal shape={signal.shape}, signal={signal}"
|
|
68
|
+
if signal.shape[0] <= 2:
|
|
69
|
+
return 0, 0
|
|
70
|
+
|
|
71
|
+
has_value = ~(np.isnan(signal).all()) and ~(np.isinf(signal).all())
|
|
72
|
+
if np.isnan(signal[-1]) or np.isinf(signal[-1]):
|
|
73
|
+
return (-1, np.inf) if has_value else (0, 0)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
m = np.mean(signal[:-1])
|
|
77
|
+
except (TypeError, ValueError):
|
|
78
|
+
# Not a numerical type
|
|
79
|
+
return 0, np.nan
|
|
80
|
+
|
|
81
|
+
if np.isnan(m) or np.isinf(m):
|
|
82
|
+
return (1, np.inf) if np.isinf(signal[-2]) or np.isnan(signal[-2]) else (0, 0)
|
|
83
|
+
v = np.std(signal[:-1])
|
|
84
|
+
if v == 0:
|
|
85
|
+
test = signal[-1] - m
|
|
86
|
+
assert not np.isnan(
|
|
87
|
+
test
|
|
88
|
+
), f"Unexpected test value, test={test}, signal={signal}, m={m}, v={v}"
|
|
89
|
+
trend = np.sign(test)
|
|
90
|
+
return trend, trend
|
|
91
|
+
test = (signal[-1] - m) / v
|
|
92
|
+
assert not np.isnan(
|
|
93
|
+
test
|
|
94
|
+
), f"Unexpected test value, test={test}, signal={signal}, m={m}, v={v}"
|
|
95
|
+
trend = np.sign(test) if np.abs(test) > threshold else 0
|
|
96
|
+
return trend, test
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def filter_data(
|
|
100
|
+
df: pandas.DataFrame,
|
|
101
|
+
filter_in: Optional[str] = None,
|
|
102
|
+
filter_out: Optional[str] = None,
|
|
103
|
+
verbose: int = 0,
|
|
104
|
+
) -> pandas.DataFrame:
|
|
105
|
+
"""
|
|
106
|
+
Argument `filter` follows the syntax
|
|
107
|
+
``<column1>:<fmt1>//<column2>:<fmt2>``.
|
|
108
|
+
|
|
109
|
+
The format is the following:
|
|
110
|
+
|
|
111
|
+
* a value or a set of values separated by ``;``
|
|
112
|
+
"""
|
|
113
|
+
if not filter_in and not filter_out:
|
|
114
|
+
return df
|
|
115
|
+
|
|
116
|
+
def _f(fmt):
|
|
117
|
+
cond = {}
|
|
118
|
+
if isinstance(fmt, str):
|
|
119
|
+
cols = fmt.split("//")
|
|
120
|
+
for c in cols:
|
|
121
|
+
assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}, cols={cols!r}"
|
|
122
|
+
spl = c.split(":")
|
|
123
|
+
assert (
|
|
124
|
+
len(spl) == 2
|
|
125
|
+
), f"Unexpected value {c!r} in fmt={fmt!r}, spl={spl}, cols={cols}"
|
|
126
|
+
name, fil = spl
|
|
127
|
+
cond[name] = set(fil.split(";"))
|
|
128
|
+
return cond
|
|
129
|
+
|
|
130
|
+
if filter_in:
|
|
131
|
+
cond = _f(filter_in)
|
|
132
|
+
assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_in!r}"
|
|
133
|
+
for k, v in cond.items():
|
|
134
|
+
if k not in df.columns:
|
|
135
|
+
continue
|
|
136
|
+
if verbose:
|
|
137
|
+
print(
|
|
138
|
+
f"[_filter_data] filter in column {k!r}, "
|
|
139
|
+
f"values {v!r} among {set(df[k].astype(str))}"
|
|
140
|
+
)
|
|
141
|
+
df = df[df[k].astype(str).isin(v)]
|
|
142
|
+
|
|
143
|
+
if filter_out:
|
|
144
|
+
cond = _f(filter_out)
|
|
145
|
+
assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_out!r}"
|
|
146
|
+
for k, v in cond.items():
|
|
147
|
+
if k not in df.columns:
|
|
148
|
+
continue
|
|
149
|
+
if verbose:
|
|
150
|
+
print(
|
|
151
|
+
f"[_filter_data] filter out column {k!r}, "
|
|
152
|
+
f"values {v!r} among {set(df[k].astype(str))}"
|
|
153
|
+
)
|
|
154
|
+
df = df[~df[k].astype(str).isin(v)]
|
|
155
|
+
return df
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def enumerate_csv_files(
|
|
159
|
+
data: Union[
|
|
160
|
+
pandas.DataFrame, List[Union[str, Tuple[str, str]]], str, Tuple[str, str, str, str]
|
|
161
|
+
],
|
|
162
|
+
verbose: int = 0,
|
|
163
|
+
filtering: Optional[Callable[[str], bool]] = None,
|
|
164
|
+
) -> Iterator[Union[pandas.DataFrame, str, Tuple[str, str, str, str]]]:
|
|
165
|
+
"""
|
|
166
|
+
Enumerates files considered for the aggregation.
|
|
167
|
+
Only csv files are considered.
|
|
168
|
+
If a zip file is given, the function digs into the zip files and
|
|
169
|
+
loops over csv candidates.
|
|
170
|
+
|
|
171
|
+
:param data: dataframe with the raw data or a file or list of files
|
|
172
|
+
:param vrbose: verbosity
|
|
173
|
+
:param filtering: function to filter in or out files in zip files,
|
|
174
|
+
must return true to keep the file, false to skip it.
|
|
175
|
+
:return: a generator yielding tuples with the filename, date, full path and zip file
|
|
176
|
+
|
|
177
|
+
data can contains:
|
|
178
|
+
* a dataframe
|
|
179
|
+
* a string for a filename, zip or csv
|
|
180
|
+
* a list of string
|
|
181
|
+
* a tuple
|
|
182
|
+
"""
|
|
183
|
+
if not isinstance(data, list):
|
|
184
|
+
data = [data]
|
|
185
|
+
for itn, filename in enumerate(data):
|
|
186
|
+
if isinstance(filename, pandas.DataFrame):
|
|
187
|
+
if verbose:
|
|
188
|
+
print(f"[enumerate_csv_files] data[{itn}] is a dataframe")
|
|
189
|
+
yield filename
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
if isinstance(filename, tuple):
|
|
193
|
+
# A file in a zipfile
|
|
194
|
+
if verbose:
|
|
195
|
+
print(f"[enumerate_csv_files] data[{itn}] is {filename!r}")
|
|
196
|
+
yield filename
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
if os.path.exists(filename):
|
|
200
|
+
ext = os.path.splitext(filename)[-1]
|
|
201
|
+
if ext == ".csv":
|
|
202
|
+
# We check the first line is ok.
|
|
203
|
+
if verbose:
|
|
204
|
+
print(f"[enumerate_csv_files] data[{itn}] is a csv file: {filename!r}]")
|
|
205
|
+
dt = datetime.datetime.fromtimestamp(os.stat(filename).st_mtime)
|
|
206
|
+
du = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
207
|
+
yield (os.path.split(filename)[-1], du, filename, "")
|
|
208
|
+
continue
|
|
209
|
+
|
|
210
|
+
if ext == ".zip":
|
|
211
|
+
if verbose:
|
|
212
|
+
print(f"[enumerate_csv_files] data[{itn}] is a zip file: {filename!r}]")
|
|
213
|
+
zf = zipfile.ZipFile(filename, "r")
|
|
214
|
+
for ii, info in enumerate(zf.infolist()):
|
|
215
|
+
name = info.filename
|
|
216
|
+
if filtering is None:
|
|
217
|
+
ext = os.path.splitext(name)[-1]
|
|
218
|
+
if ext != ".csv":
|
|
219
|
+
continue
|
|
220
|
+
elif not filtering(name):
|
|
221
|
+
continue
|
|
222
|
+
if verbose:
|
|
223
|
+
print(
|
|
224
|
+
f"[enumerate_csv_files] data[{itn}][{ii}] is a csv file: {name!r}]"
|
|
225
|
+
)
|
|
226
|
+
with zf.open(name) as zzf:
|
|
227
|
+
first_line = zzf.readline()
|
|
228
|
+
if b"," not in first_line:
|
|
229
|
+
continue
|
|
230
|
+
yield (
|
|
231
|
+
os.path.split(name)[-1],
|
|
232
|
+
"%04d-%02d-%02d %02d:%02d:%02d" % info.date_time,
|
|
233
|
+
name,
|
|
234
|
+
filename,
|
|
235
|
+
)
|
|
236
|
+
zf.close()
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
raise AssertionError(f"Unexpected format {filename!r}, cannot read it.")
|
|
240
|
+
|
|
241
|
+
# filename is a pattern.
|
|
242
|
+
found = glob.glob(filename)
|
|
243
|
+
if verbose and not found:
|
|
244
|
+
print(f"[enumerate_csv_files] unable to find file in {filename!r}")
|
|
245
|
+
for ii, f in enumerate(found):
|
|
246
|
+
if verbose:
|
|
247
|
+
print(f"[enumerate_csv_files] data[{itn}][{ii}] {f!r} from {filename!r}")
|
|
248
|
+
yield from enumerate_csv_files(f, verbose=verbose, filtering=filtering)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def open_dataframe(
|
|
252
|
+
data: Union[str, Tuple[str, str, str, str], pandas.DataFrame],
|
|
253
|
+
) -> pandas.DataFrame:
|
|
254
|
+
"""
|
|
255
|
+
Opens a filename defined by function
|
|
256
|
+
:func:`onnx_diagnostic.helpers._log_helper.enumerate_csv_files`.
|
|
257
|
+
|
|
258
|
+
:param data: a dataframe, a filename, a tuple indicating the file is coming
|
|
259
|
+
from a zip file
|
|
260
|
+
:return: a dataframe
|
|
261
|
+
"""
|
|
262
|
+
if isinstance(data, pandas.DataFrame):
|
|
263
|
+
return data
|
|
264
|
+
if isinstance(data, str):
|
|
265
|
+
df = pandas.read_csv(data, low_memory=False)
|
|
266
|
+
df["RAWFILENAME"] = data
|
|
267
|
+
return df
|
|
268
|
+
if isinstance(data, tuple):
|
|
269
|
+
if not data[-1]:
|
|
270
|
+
df = pandas.read_csv(data[2], low_memory=False)
|
|
271
|
+
df["RAWFILENAME"] = data[2]
|
|
272
|
+
return df
|
|
273
|
+
zf = zipfile.ZipFile(data[-1])
|
|
274
|
+
with zf.open(data[2]) as f:
|
|
275
|
+
df = pandas.read_csv(f, low_memory=False)
|
|
276
|
+
df["RAWFILENAME"] = f"{data[-1]}/{data[2]}"
|
|
277
|
+
zf.close()
|
|
278
|
+
return df
|
|
279
|
+
|
|
280
|
+
raise ValueError(f"Unexpected value for data: {data!r}")
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def align_dataframe_with(
|
|
284
|
+
df: pandas.DataFrame, baseline: pandas.DataFrame, fill_value: float = 0
|
|
285
|
+
) -> Optional[pandas.DataFrame]:
|
|
286
|
+
"""
|
|
287
|
+
Modifies the first dataframe *df* to get the exact same number of columns and rows.
|
|
288
|
+
They must share the same levels on both axes. Empty cells are filled with 0.
|
|
289
|
+
We only keep the numerical columns. The function return None if the output is empty.
|
|
290
|
+
"""
|
|
291
|
+
df = df.select_dtypes(include="number")
|
|
292
|
+
if df.shape[1] == 0:
|
|
293
|
+
return None
|
|
294
|
+
bool_cols = list(df.select_dtypes(include="bool").columns)
|
|
295
|
+
if bool_cols:
|
|
296
|
+
df[bool_cols] = df[bool_cols].astype(int)
|
|
297
|
+
assert (
|
|
298
|
+
df.columns.names == baseline.columns.names or df.index.names == baseline.index.names
|
|
299
|
+
), (
|
|
300
|
+
f"Levels mismatch, expected index.names={baseline.index.names}, "
|
|
301
|
+
f"expected columns.names={baseline.columns.names}, "
|
|
302
|
+
f"got index.names={df.index.names}, "
|
|
303
|
+
f"got columns.names={df.columns.names}"
|
|
304
|
+
)
|
|
305
|
+
dtypes = set(df[c].dtype for c in df.columns)
|
|
306
|
+
assert all(np.issubdtype(dt, np.number) for dt in dtypes), (
|
|
307
|
+
f"All columns in the first dataframe are expected to share "
|
|
308
|
+
f"the same type or be at least numerical but got {dtypes}\n{df}"
|
|
309
|
+
)
|
|
310
|
+
common_index = df.index.intersection(baseline.index)
|
|
311
|
+
cp = pandas.DataFrame(float(fill_value), index=baseline.index, columns=baseline.columns)
|
|
312
|
+
for c in df.columns:
|
|
313
|
+
if c not in cp.columns or not np.issubdtype(df[c].dtype, np.number):
|
|
314
|
+
continue
|
|
315
|
+
cp.loc[common_index, c] = df.loc[common_index, c].astype(cp[c].dtype)
|
|
316
|
+
return cp
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def apply_excel_style(
|
|
320
|
+
filename_or_writer: Any,
|
|
321
|
+
f_highlights: Optional[ # type: ignore[name-defined]
|
|
322
|
+
Dict[str, Callable[[Any], "CubeViewDef.HighLightKind"]] # noqa: F821
|
|
323
|
+
] = None,
|
|
324
|
+
time_mask_view: Optional[Dict[str, pandas.DataFrame]] = None,
|
|
325
|
+
verbose: int = 0,
|
|
326
|
+
):
|
|
327
|
+
"""
|
|
328
|
+
Applies styles on all sheets in a file unless the sheet is too big.
|
|
329
|
+
|
|
330
|
+
:param filename_or_writer: filename, modified inplace
|
|
331
|
+
:param f_highlight: color function to apply, one per sheet
|
|
332
|
+
:param time_mask_view: if specified, it contains dataframe with the same shape
|
|
333
|
+
and values in {-1, 0, +1} which indicates if a value is unexpectedly lower (-1)
|
|
334
|
+
or higher (+1), it changes the color of the background then.
|
|
335
|
+
:param verbosity: progress loop
|
|
336
|
+
"""
|
|
337
|
+
from openpyxl import load_workbook
|
|
338
|
+
from openpyxl.styles import Alignment
|
|
339
|
+
from openpyxl.utils import get_column_letter
|
|
340
|
+
from openpyxl.styles import Font, PatternFill
|
|
341
|
+
from .log_helper import CubeViewDef
|
|
342
|
+
|
|
343
|
+
if isinstance(filename_or_writer, str):
|
|
344
|
+
workbook = load_workbook(filename_or_writer)
|
|
345
|
+
save = True
|
|
346
|
+
else:
|
|
347
|
+
workbook = filename_or_writer.book
|
|
348
|
+
save = False
|
|
349
|
+
|
|
350
|
+
mask_low = PatternFill(fgColor="AAAAF0", fill_type="solid")
|
|
351
|
+
mask_high = PatternFill(fgColor="F0AAAA", fill_type="solid")
|
|
352
|
+
|
|
353
|
+
left = Alignment(horizontal="left")
|
|
354
|
+
left_shrink = Alignment(horizontal="left", shrink_to_fit=True)
|
|
355
|
+
right = Alignment(horizontal="right")
|
|
356
|
+
font_colors = {
|
|
357
|
+
CubeViewDef.HighLightKind.GREEN: Font(color="00AA00"),
|
|
358
|
+
CubeViewDef.HighLightKind.RED: Font(color="FF0000"),
|
|
359
|
+
}
|
|
360
|
+
if verbose:
|
|
361
|
+
from tqdm import tqdm
|
|
362
|
+
|
|
363
|
+
sheet_names = tqdm(list(workbook.sheetnames))
|
|
364
|
+
else:
|
|
365
|
+
sheet_names = workbook.sheetnames
|
|
366
|
+
for name in sheet_names:
|
|
367
|
+
if time_mask_view and name in time_mask_view:
|
|
368
|
+
mask = time_mask_view[name]
|
|
369
|
+
with pandas.ExcelWriter(io.BytesIO(), engine="openpyxl") as mask_writer:
|
|
370
|
+
mask.to_excel(mask_writer, sheet_name=name)
|
|
371
|
+
sheet_mask = mask_writer.sheets[name]
|
|
372
|
+
else:
|
|
373
|
+
sheet_mask = None
|
|
374
|
+
|
|
375
|
+
f_highlight = f_highlights.get(name, None) if f_highlights else None
|
|
376
|
+
sheet = workbook[name]
|
|
377
|
+
n_rows = sheet.max_row
|
|
378
|
+
n_cols = sheet.max_column
|
|
379
|
+
if n_rows * n_cols > 2**16 or n_rows > 2**13:
|
|
380
|
+
# Too big.
|
|
381
|
+
continue
|
|
382
|
+
co: Dict[int, int] = {}
|
|
383
|
+
sizes: Dict[int, int] = {}
|
|
384
|
+
cols = set()
|
|
385
|
+
for i in range(1, n_rows + 1):
|
|
386
|
+
for j, cell in enumerate(sheet[i]):
|
|
387
|
+
if j > n_cols:
|
|
388
|
+
break
|
|
389
|
+
cols.add(cell.column)
|
|
390
|
+
if isinstance(cell.value, float):
|
|
391
|
+
co[j] = co.get(j, 0) + 1
|
|
392
|
+
elif isinstance(cell.value, str):
|
|
393
|
+
sizes[cell.column] = max(sizes.get(cell.column, 0), len(cell.value))
|
|
394
|
+
|
|
395
|
+
for k, v in sizes.items():
|
|
396
|
+
c = get_column_letter(k)
|
|
397
|
+
sheet.column_dimensions[c].width = min(max(8, v), 30)
|
|
398
|
+
for k in cols:
|
|
399
|
+
if k not in sizes:
|
|
400
|
+
c = get_column_letter(k)
|
|
401
|
+
sheet.column_dimensions[c].width = 15
|
|
402
|
+
|
|
403
|
+
for i in range(1, n_rows + 1):
|
|
404
|
+
for j, cell in enumerate(sheet[i]):
|
|
405
|
+
if j > n_cols:
|
|
406
|
+
break
|
|
407
|
+
if isinstance(cell.value, pandas.Timestamp):
|
|
408
|
+
cell.alignment = right
|
|
409
|
+
dt = cell.value.to_pydatetime()
|
|
410
|
+
cell.value = dt
|
|
411
|
+
cell.number_format = (
|
|
412
|
+
"YYYY-MM-DD"
|
|
413
|
+
if (
|
|
414
|
+
dt.hour == 0
|
|
415
|
+
and dt.minute == 0
|
|
416
|
+
and dt.second == 0
|
|
417
|
+
and dt.microsecond == 0
|
|
418
|
+
)
|
|
419
|
+
else "YYYY-MM-DD 00:00:00"
|
|
420
|
+
)
|
|
421
|
+
elif isinstance(cell.value, (float, int)):
|
|
422
|
+
cell.alignment = right
|
|
423
|
+
x = abs(cell.value)
|
|
424
|
+
if int(x) == x:
|
|
425
|
+
cell.number_format = "0"
|
|
426
|
+
elif x > 5000:
|
|
427
|
+
cell.number_format = "# ##0"
|
|
428
|
+
elif x >= 500:
|
|
429
|
+
cell.number_format = "0.0"
|
|
430
|
+
elif x >= 50:
|
|
431
|
+
cell.number_format = "0.00"
|
|
432
|
+
elif x >= 5:
|
|
433
|
+
cell.number_format = "0.000"
|
|
434
|
+
elif x > 0.5:
|
|
435
|
+
cell.number_format = "0.0000"
|
|
436
|
+
elif x > 0.005:
|
|
437
|
+
cell.number_format = "0.00000"
|
|
438
|
+
else:
|
|
439
|
+
cell.number_format = "0.000E+00"
|
|
440
|
+
if f_highlight:
|
|
441
|
+
h = f_highlight(cell.value)
|
|
442
|
+
if h in font_colors:
|
|
443
|
+
cell.font = font_colors[h]
|
|
444
|
+
elif isinstance(cell.value, str) and len(cell.value) > 70:
|
|
445
|
+
cell.alignment = left_shrink
|
|
446
|
+
else:
|
|
447
|
+
cell.alignment = left
|
|
448
|
+
if f_highlight:
|
|
449
|
+
h = f_highlight(cell.value)
|
|
450
|
+
if h in font_colors:
|
|
451
|
+
cell.font = font_colors[h]
|
|
452
|
+
|
|
453
|
+
if sheet_mask is not None:
|
|
454
|
+
for i in range(1, n_rows + 1):
|
|
455
|
+
for j, (cell, cell_mask) in enumerate(zip(sheet[i], sheet_mask[i])):
|
|
456
|
+
if j > n_cols:
|
|
457
|
+
break
|
|
458
|
+
if cell_mask.value not in (1, -1):
|
|
459
|
+
continue
|
|
460
|
+
cell.fill = mask_low if cell_mask.value < 0 else mask_high
|
|
461
|
+
|
|
462
|
+
if save:
|
|
463
|
+
workbook.save(filename_or_writer)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
from argparse import ArgumentParser, Namespace
|
|
3
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def check_cuda_availability():
|
|
7
|
+
"""
|
|
8
|
+
Checks if CUDA is available without pytorch or onnxruntime.
|
|
9
|
+
Calls `nvidia-smi`.
|
|
10
|
+
"""
|
|
11
|
+
try:
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
return torch.cuda.device_count() > 0
|
|
15
|
+
except ImportError:
|
|
16
|
+
pass
|
|
17
|
+
try:
|
|
18
|
+
result = subprocess.run(["nvidia-smi"], capture_output=True, text=True)
|
|
19
|
+
return result.returncode == 0
|
|
20
|
+
except FileNotFoundError:
|
|
21
|
+
return False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_parsed_args(
|
|
25
|
+
name: str,
|
|
26
|
+
scenarios: Optional[Dict[str, str]] = None,
|
|
27
|
+
description: Optional[str] = None,
|
|
28
|
+
epilog: Optional[str] = None,
|
|
29
|
+
number: int = 10,
|
|
30
|
+
repeat: int = 10,
|
|
31
|
+
warmup: int = 5,
|
|
32
|
+
sleep: float = 0.1,
|
|
33
|
+
tries: int = 2,
|
|
34
|
+
expose: Optional[str] = None,
|
|
35
|
+
new_args: Optional[List[str]] = None,
|
|
36
|
+
**kwargs: Dict[str, Tuple[Union[int, str, float], str]],
|
|
37
|
+
) -> Namespace:
|
|
38
|
+
"""
|
|
39
|
+
Returns parsed arguments for examples in this package.
|
|
40
|
+
|
|
41
|
+
:param name: script name
|
|
42
|
+
:param scenarios: list of available scenarios
|
|
43
|
+
:param description: parser description
|
|
44
|
+
:param epilog: text at the end of the parser
|
|
45
|
+
:param number: default value for number parameter
|
|
46
|
+
:param repeat: default value for repeat parameter
|
|
47
|
+
:param warmup: default value for warmup parameter
|
|
48
|
+
:param sleep: default value for sleep parameter
|
|
49
|
+
:param expose: if empty, keeps all the parameters,
|
|
50
|
+
if not None, only publish kwargs contains, otherwise the list
|
|
51
|
+
of parameters to publish separated by a comma
|
|
52
|
+
:param new_args: args to consider or None to take `sys.args`
|
|
53
|
+
:param kwargs: additional parameters,
|
|
54
|
+
example: `n_trees=(10, "number of trees to train")`
|
|
55
|
+
:return: parser
|
|
56
|
+
"""
|
|
57
|
+
if description is None:
|
|
58
|
+
description = f"Available options for {name}.py."
|
|
59
|
+
if epilog is None:
|
|
60
|
+
epilog = ""
|
|
61
|
+
parser = ArgumentParser(prog=name, description=description, epilog=epilog)
|
|
62
|
+
if expose is not None:
|
|
63
|
+
to_publish = set(expose.split(",")) if expose else set()
|
|
64
|
+
if scenarios is not None:
|
|
65
|
+
rows = ", ".join(f"{k}: {v}" for k, v in scenarios.items())
|
|
66
|
+
parser.add_argument("-s", "--scenario", help=f"Available scenarios: {rows}.")
|
|
67
|
+
if not to_publish or "number" in to_publish:
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"-n",
|
|
70
|
+
"--number",
|
|
71
|
+
help=f"number of executions to measure, default is {number}",
|
|
72
|
+
type=int,
|
|
73
|
+
default=number,
|
|
74
|
+
)
|
|
75
|
+
if not to_publish or "repeat" in to_publish:
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"-r",
|
|
78
|
+
"--repeat",
|
|
79
|
+
help=f"number of times to repeat the measure, default is {repeat}",
|
|
80
|
+
type=int,
|
|
81
|
+
default=repeat,
|
|
82
|
+
)
|
|
83
|
+
if not to_publish or "warmup" in to_publish:
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"-w",
|
|
86
|
+
"--warmup",
|
|
87
|
+
help=f"number of times to repeat the measure, default is {warmup}",
|
|
88
|
+
type=int,
|
|
89
|
+
default=warmup,
|
|
90
|
+
)
|
|
91
|
+
if not to_publish or "sleep" in to_publish:
|
|
92
|
+
parser.add_argument(
|
|
93
|
+
"-S",
|
|
94
|
+
"--sleep",
|
|
95
|
+
help=f"sleeping time between two configurations, default is {sleep}",
|
|
96
|
+
type=float,
|
|
97
|
+
default=sleep,
|
|
98
|
+
)
|
|
99
|
+
if not to_publish or "tries" in to_publish:
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"-t",
|
|
102
|
+
"--tries",
|
|
103
|
+
help=f"number of tries for each configurations, default is {tries}",
|
|
104
|
+
type=int,
|
|
105
|
+
default=tries,
|
|
106
|
+
)
|
|
107
|
+
for k, v in kwargs.items():
|
|
108
|
+
parser.add_argument(
|
|
109
|
+
f"--{k}",
|
|
110
|
+
help=f"{v[1]}, default is {v[0]}",
|
|
111
|
+
type=type(v[0]),
|
|
112
|
+
default=v[0],
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
res = parser.parse_args(args=new_args)
|
|
116
|
+
update: Dict[str, Union[int, float]] = {}
|
|
117
|
+
for k, v in res.__dict__.items():
|
|
118
|
+
try:
|
|
119
|
+
vi = int(v)
|
|
120
|
+
update[k] = vi
|
|
121
|
+
continue
|
|
122
|
+
except (ValueError, TypeError):
|
|
123
|
+
pass
|
|
124
|
+
try:
|
|
125
|
+
vf = float(v)
|
|
126
|
+
update[k] = vf
|
|
127
|
+
continue
|
|
128
|
+
except (ValueError, TypeError):
|
|
129
|
+
pass
|
|
130
|
+
if update:
|
|
131
|
+
res.__dict__.update(update)
|
|
132
|
+
return res
|