onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +66 -8
- onnx_diagnostic/ext_test_case.py +2 -0
- onnx_diagnostic/helpers/_log_helper.py +461 -0
- onnx_diagnostic/helpers/cache_helper.py +250 -15
- onnx_diagnostic/helpers/helper.py +146 -10
- onnx_diagnostic/helpers/log_helper.py +404 -315
- onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
- onnx_diagnostic/helpers/onnx_helper.py +13 -7
- onnx_diagnostic/helpers/torch_helper.py +33 -11
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/feature_extraction.py +86 -5
- onnx_diagnostic/tasks/image_text_to_text.py +260 -56
- onnx_diagnostic/tasks/mask_generation.py +139 -0
- onnx_diagnostic/tasks/text2text_generation.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +6 -2
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
- onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
- onnx_diagnostic/torch_models/validate.py +26 -3
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -483,6 +483,12 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
483
483
|
parser.add_argument(
|
|
484
484
|
"--warmup", default=0, type=int, help="number of times to run the model to do warmup"
|
|
485
485
|
)
|
|
486
|
+
parser.add_argument(
|
|
487
|
+
"--outnames",
|
|
488
|
+
help="This comma separated list defines the output names "
|
|
489
|
+
"the onnx exporter should use.",
|
|
490
|
+
default="",
|
|
491
|
+
)
|
|
486
492
|
return parser
|
|
487
493
|
|
|
488
494
|
|
|
@@ -542,6 +548,9 @@ def _cmd_validate(argv: List[Any]):
|
|
|
542
548
|
repeat=args.repeat,
|
|
543
549
|
warmup=args.warmup,
|
|
544
550
|
inputs2=args.inputs2,
|
|
551
|
+
output_names=(
|
|
552
|
+
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
553
|
+
),
|
|
545
554
|
)
|
|
546
555
|
print("")
|
|
547
556
|
print("-- summary --")
|
|
@@ -645,6 +654,27 @@ def _cmd_stats(argv: List[Any]):
|
|
|
645
654
|
print("done.")
|
|
646
655
|
|
|
647
656
|
|
|
657
|
+
class _ParseNamedDict(argparse.Action):
|
|
658
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
659
|
+
assert ":" in values, f"':' missing from {values!r}"
|
|
660
|
+
namespace_key, rest = values.split(":", 1)
|
|
661
|
+
pairs = rest.split(",")
|
|
662
|
+
inner_dict = {}
|
|
663
|
+
|
|
664
|
+
for pair in pairs:
|
|
665
|
+
if "=" not in pair:
|
|
666
|
+
raise argparse.ArgumentError(self, f"Expected '=' in pair '{pair}'")
|
|
667
|
+
key, value = pair.split("=", 1)
|
|
668
|
+
inner_dict[key] = value
|
|
669
|
+
assert inner_dict, f"Unable to parse {rest!r} into a dictionary"
|
|
670
|
+
if not hasattr(namespace, self.dest) or getattr(namespace, self.dest) is None:
|
|
671
|
+
setattr(namespace, self.dest, {})
|
|
672
|
+
assert isinstance(
|
|
673
|
+
getattr(namespace, self.dest), dict
|
|
674
|
+
), f"Unexpected type for namespace.{self.dest}={getattr(namespace, self.dest)}"
|
|
675
|
+
getattr(namespace, self.dest).update({namespace_key: inner_dict})
|
|
676
|
+
|
|
677
|
+
|
|
648
678
|
def get_parser_agg() -> ArgumentParser:
|
|
649
679
|
parser = ArgumentParser(
|
|
650
680
|
prog="agg",
|
|
@@ -653,13 +683,23 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
653
683
|
Aggregates statistics coming from benchmarks.
|
|
654
684
|
Every run is a row. Every row is indexed by some keys,
|
|
655
685
|
and produces values. Every row has a date.
|
|
686
|
+
The data can come any csv files produces by benchmarks,
|
|
687
|
+
it can concatenates many csv files, or csv files inside zip files.
|
|
688
|
+
It produces an excel file with many tabs, one per view.
|
|
656
689
|
"""
|
|
657
690
|
),
|
|
658
691
|
epilog=textwrap.dedent(
|
|
659
692
|
"""
|
|
660
|
-
examples
|
|
693
|
+
examples:
|
|
661
694
|
|
|
662
695
|
python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1
|
|
696
|
+
python -m onnx_diagnostic agg agg.xlsx raw/*.zip raw/*.csv -v 1 \\
|
|
697
|
+
--no-raw --keep-last-date --filter-out "exporter:test-exporter"
|
|
698
|
+
|
|
699
|
+
Another to create timeseries:
|
|
700
|
+
|
|
701
|
+
python -m onnx_diagnostic agg history.xlsx raw/*.csv -v 1 --no-raw \\
|
|
702
|
+
--no-recent
|
|
663
703
|
"""
|
|
664
704
|
),
|
|
665
705
|
formatter_class=RawTextHelpFormatter,
|
|
@@ -737,7 +777,15 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
737
777
|
"--views",
|
|
738
778
|
default="agg-suite,agg-all,disc,speedup,time,time_export,err,cmd,"
|
|
739
779
|
"bucket-speedup,raw-short,counts,peak-gpu,onnx",
|
|
740
|
-
help=
|
|
780
|
+
help=textwrap.dedent(
|
|
781
|
+
"""
|
|
782
|
+
Views to add to the output files. Each view becomes a tab.
|
|
783
|
+
A view is defined by its name, among
|
|
784
|
+
agg-suite, agg-all, disc, speedup, time, time_export, err,
|
|
785
|
+
cmd, bucket-speedup, raw-short, counts, peak-gpu, onnx.
|
|
786
|
+
Their definition is part of class CubeLogsPerformance.
|
|
787
|
+
"""
|
|
788
|
+
),
|
|
741
789
|
)
|
|
742
790
|
parser.add_argument(
|
|
743
791
|
"--csv",
|
|
@@ -757,16 +805,24 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
757
805
|
help="adds a filter to filter out data, syntax is\n"
|
|
758
806
|
'``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
|
|
759
807
|
)
|
|
808
|
+
parser.add_argument(
|
|
809
|
+
"--sbs",
|
|
810
|
+
help=textwrap.dedent(
|
|
811
|
+
"""
|
|
812
|
+
Defines an exporter to compare to another, there must be at least
|
|
813
|
+
two arguments defined with --sbs. Example:
|
|
814
|
+
--sbs dynamo:exporter=onnx-dynamo,opt=ir,attn_impl=eager
|
|
815
|
+
--sbs custom:exporter=custom,opt=default,attn_impl=eager
|
|
816
|
+
"""
|
|
817
|
+
),
|
|
818
|
+
action=_ParseNamedDict,
|
|
819
|
+
)
|
|
760
820
|
return parser
|
|
761
821
|
|
|
762
822
|
|
|
763
823
|
def _cmd_agg(argv: List[Any]):
|
|
764
|
-
from .helpers.
|
|
765
|
-
|
|
766
|
-
open_dataframe,
|
|
767
|
-
enumerate_csv_files,
|
|
768
|
-
filter_data,
|
|
769
|
-
)
|
|
824
|
+
from .helpers._log_helper import open_dataframe, enumerate_csv_files, filter_data
|
|
825
|
+
from .helpers.log_helper import CubeLogsPerformance
|
|
770
826
|
|
|
771
827
|
parser = get_parser_agg()
|
|
772
828
|
args = parser.parse_args(argv[1:])
|
|
@@ -812,6 +868,8 @@ def _cmd_agg(argv: List[Any]):
|
|
|
812
868
|
verbose=args.verbose,
|
|
813
869
|
csv=args.csv.split(","),
|
|
814
870
|
raw=args.raw,
|
|
871
|
+
time_mask=True,
|
|
872
|
+
sbs=args.sbs,
|
|
815
873
|
)
|
|
816
874
|
if args.verbose:
|
|
817
875
|
print(f"Wrote {args.output!r}")
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -1058,6 +1058,8 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1058
1058
|
elif hasattr(expected, "shape"):
|
|
1059
1059
|
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1060
1060
|
self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
|
|
1061
|
+
elif expected is None:
|
|
1062
|
+
assert value is None, f"Expected is None but value is of type {type(value)}"
|
|
1061
1063
|
else:
|
|
1062
1064
|
raise AssertionError(
|
|
1063
1065
|
f"Comparison not implemented for types {type(expected)} and {type(value)}"
|
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import glob
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
import zipfile
|
|
6
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas
|
|
9
|
+
|
|
10
|
+
BUCKET_SCALES_VALUES = np.array(
|
|
11
|
+
[-np.inf, -20, -10, -5, -2, 0, 2, 5, 10, 20, 100, 200, 300, 400, np.inf], dtype=float
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
BUCKET_SCALES = BUCKET_SCALES_VALUES / 100 + 1
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def mann_kendall(series: Sequence[float], threshold: float = 0.5):
|
|
19
|
+
"""
|
|
20
|
+
Computes the test of Mann-Kendall.
|
|
21
|
+
|
|
22
|
+
:param series: series
|
|
23
|
+
:param threshold: 1.96 is the usual value, 0.5 means a short timeseries
|
|
24
|
+
``(0, 1, 2, 3, 4)`` has a significant trend
|
|
25
|
+
:return: trend (-1, 0, +1), test value
|
|
26
|
+
|
|
27
|
+
.. math::
|
|
28
|
+
|
|
29
|
+
S =\\sum_{i=1}^{n}\\sum_{j=i+1}^{n} sign(x_j - x_i)
|
|
30
|
+
|
|
31
|
+
where the function *sign* is:
|
|
32
|
+
|
|
33
|
+
.. math::
|
|
34
|
+
|
|
35
|
+
sign(x) = \\left\\{ \\begin{array}{l} -1 if x < 0 \\\\ 0 if x = 0 \\\\ +1 otherwise
|
|
36
|
+
\\end{array} \\right.
|
|
37
|
+
|
|
38
|
+
And:
|
|
39
|
+
|
|
40
|
+
.. math::
|
|
41
|
+
|
|
42
|
+
Var(S)= \\frac{n(n-1)(2n+5) - \\sum_t t(t-1)(2t+5)}{18}
|
|
43
|
+
"""
|
|
44
|
+
aseries = np.asarray(series)
|
|
45
|
+
stat = 0
|
|
46
|
+
n = len(aseries)
|
|
47
|
+
var = n * (n - 1) * (2 * n + 5)
|
|
48
|
+
for i in range(n - 1):
|
|
49
|
+
stat += np.sign(aseries[i + 1 :] - aseries[i]).sum()
|
|
50
|
+
var = var**0.5
|
|
51
|
+
test = (stat + (1 if stat < 0 else (0 if stat == 0 else -1))) / var
|
|
52
|
+
trend = np.sign(test) if np.abs(test) > threshold else 0
|
|
53
|
+
return trend, test
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def breaking_last_point(series: Sequence[float], threshold: float = 1.2):
|
|
57
|
+
"""
|
|
58
|
+
Assuming a timeseries is constant, we check the last value
|
|
59
|
+
is not an outlier.
|
|
60
|
+
|
|
61
|
+
:param series: series
|
|
62
|
+
:return: significant change (-1, 0, +1), test value
|
|
63
|
+
"""
|
|
64
|
+
signal = np.asarray(series)
|
|
65
|
+
if not np.issubdtype(signal.dtype, np.number):
|
|
66
|
+
return 0, np.nan
|
|
67
|
+
assert len(signal.shape) == 1, f"Unexpected signal shape={signal.shape}, signal={signal}"
|
|
68
|
+
if signal.shape[0] <= 2:
|
|
69
|
+
return 0, 0
|
|
70
|
+
|
|
71
|
+
has_value = ~(np.isnan(signal).all()) and ~(np.isinf(signal).all())
|
|
72
|
+
if np.isnan(signal[-1]) or np.isinf(signal[-1]):
|
|
73
|
+
return (-1, np.inf) if has_value else (0, 0)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
m = np.mean(signal[:-1])
|
|
77
|
+
except (TypeError, ValueError):
|
|
78
|
+
# Not a numerical type
|
|
79
|
+
return 0, np.nan
|
|
80
|
+
|
|
81
|
+
if np.isnan(m) or np.isinf(m):
|
|
82
|
+
return (1, np.inf) if np.isinf(signal[-2]) or np.isnan(signal[-2]) else (0, 0)
|
|
83
|
+
v = np.std(signal[:-1])
|
|
84
|
+
if v == 0:
|
|
85
|
+
test = signal[-1] - m
|
|
86
|
+
assert not np.isnan(
|
|
87
|
+
test
|
|
88
|
+
), f"Unexpected test value, test={test}, signal={signal}, m={m}, v={v}"
|
|
89
|
+
trend = np.sign(test)
|
|
90
|
+
return trend, trend
|
|
91
|
+
test = (signal[-1] - m) / v
|
|
92
|
+
assert not np.isnan(
|
|
93
|
+
test
|
|
94
|
+
), f"Unexpected test value, test={test}, signal={signal}, m={m}, v={v}"
|
|
95
|
+
trend = np.sign(test) if np.abs(test) > threshold else 0
|
|
96
|
+
return trend, test
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def filter_data(
|
|
100
|
+
df: pandas.DataFrame,
|
|
101
|
+
filter_in: Optional[str] = None,
|
|
102
|
+
filter_out: Optional[str] = None,
|
|
103
|
+
verbose: int = 0,
|
|
104
|
+
) -> pandas.DataFrame:
|
|
105
|
+
"""
|
|
106
|
+
Argument `filter` follows the syntax
|
|
107
|
+
``<column1>:<fmt1>//<column2>:<fmt2>``.
|
|
108
|
+
|
|
109
|
+
The format is the following:
|
|
110
|
+
|
|
111
|
+
* a value or a set of values separated by ``;``
|
|
112
|
+
"""
|
|
113
|
+
if not filter_in and not filter_out:
|
|
114
|
+
return df
|
|
115
|
+
|
|
116
|
+
def _f(fmt):
|
|
117
|
+
cond = {}
|
|
118
|
+
if isinstance(fmt, str):
|
|
119
|
+
cols = fmt.split("//")
|
|
120
|
+
for c in cols:
|
|
121
|
+
assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}"
|
|
122
|
+
spl = c.split(":")
|
|
123
|
+
assert len(spl) == 2, f"Unexpected value {c!r} in fmt={fmt!r}"
|
|
124
|
+
name, fil = spl
|
|
125
|
+
cond[name] = set(fil.split(";"))
|
|
126
|
+
return cond
|
|
127
|
+
|
|
128
|
+
if filter_in:
|
|
129
|
+
cond = _f(filter_in)
|
|
130
|
+
assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_in!r}"
|
|
131
|
+
for k, v in cond.items():
|
|
132
|
+
if k not in df.columns:
|
|
133
|
+
continue
|
|
134
|
+
if verbose:
|
|
135
|
+
print(
|
|
136
|
+
f"[_filter_data] filter in column {k!r}, "
|
|
137
|
+
f"values {v!r} among {set(df[k].astype(str))}"
|
|
138
|
+
)
|
|
139
|
+
df = df[df[k].astype(str).isin(v)]
|
|
140
|
+
|
|
141
|
+
if filter_out:
|
|
142
|
+
cond = _f(filter_out)
|
|
143
|
+
assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_out!r}"
|
|
144
|
+
for k, v in cond.items():
|
|
145
|
+
if k not in df.columns:
|
|
146
|
+
continue
|
|
147
|
+
if verbose:
|
|
148
|
+
print(
|
|
149
|
+
f"[_filter_data] filter out column {k!r}, "
|
|
150
|
+
f"values {v!r} among {set(df[k].astype(str))}"
|
|
151
|
+
)
|
|
152
|
+
df = df[~df[k].astype(str).isin(v)]
|
|
153
|
+
return df
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def enumerate_csv_files(
|
|
157
|
+
data: Union[
|
|
158
|
+
pandas.DataFrame, List[Union[str, Tuple[str, str]]], str, Tuple[str, str, str, str]
|
|
159
|
+
],
|
|
160
|
+
verbose: int = 0,
|
|
161
|
+
filtering: Optional[Callable[[str], bool]] = None,
|
|
162
|
+
) -> Iterator[Union[pandas.DataFrame, str, Tuple[str, str, str, str]]]:
|
|
163
|
+
"""
|
|
164
|
+
Enumerates files considered for the aggregation.
|
|
165
|
+
Only csv files are considered.
|
|
166
|
+
If a zip file is given, the function digs into the zip files and
|
|
167
|
+
loops over csv candidates.
|
|
168
|
+
|
|
169
|
+
:param data: dataframe with the raw data or a file or list of files
|
|
170
|
+
:param vrbose: verbosity
|
|
171
|
+
:param filtering: function to filter in or out files in zip files,
|
|
172
|
+
must return true to keep the file, false to skip it.
|
|
173
|
+
:return: a generator yielding tuples with the filename, date, full path and zip file
|
|
174
|
+
|
|
175
|
+
data can contains:
|
|
176
|
+
* a dataframe
|
|
177
|
+
* a string for a filename, zip or csv
|
|
178
|
+
* a list of string
|
|
179
|
+
* a tuple
|
|
180
|
+
"""
|
|
181
|
+
if not isinstance(data, list):
|
|
182
|
+
data = [data]
|
|
183
|
+
for itn, filename in enumerate(data):
|
|
184
|
+
if isinstance(filename, pandas.DataFrame):
|
|
185
|
+
if verbose:
|
|
186
|
+
print(f"[enumerate_csv_files] data[{itn}] is a dataframe")
|
|
187
|
+
yield filename
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
if isinstance(filename, tuple):
|
|
191
|
+
# A file in a zipfile
|
|
192
|
+
if verbose:
|
|
193
|
+
print(f"[enumerate_csv_files] data[{itn}] is {filename!r}")
|
|
194
|
+
yield filename
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
if os.path.exists(filename):
|
|
198
|
+
ext = os.path.splitext(filename)[-1]
|
|
199
|
+
if ext == ".csv":
|
|
200
|
+
# We check the first line is ok.
|
|
201
|
+
if verbose:
|
|
202
|
+
print(f"[enumerate_csv_files] data[{itn}] is a csv file: {filename!r}]")
|
|
203
|
+
dt = datetime.datetime.fromtimestamp(os.stat(filename).st_mtime)
|
|
204
|
+
du = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
205
|
+
yield (os.path.split(filename)[-1], du, filename, "")
|
|
206
|
+
continue
|
|
207
|
+
|
|
208
|
+
if ext == ".zip":
|
|
209
|
+
if verbose:
|
|
210
|
+
print(f"[enumerate_csv_files] data[{itn}] is a zip file: {filename!r}]")
|
|
211
|
+
zf = zipfile.ZipFile(filename, "r")
|
|
212
|
+
for ii, info in enumerate(zf.infolist()):
|
|
213
|
+
name = info.filename
|
|
214
|
+
if filtering is None:
|
|
215
|
+
ext = os.path.splitext(name)[-1]
|
|
216
|
+
if ext != ".csv":
|
|
217
|
+
continue
|
|
218
|
+
elif not filtering(name):
|
|
219
|
+
continue
|
|
220
|
+
if verbose:
|
|
221
|
+
print(
|
|
222
|
+
f"[enumerate_csv_files] data[{itn}][{ii}] is a csv file: {name!r}]"
|
|
223
|
+
)
|
|
224
|
+
with zf.open(name) as zzf:
|
|
225
|
+
first_line = zzf.readline()
|
|
226
|
+
if b"," not in first_line:
|
|
227
|
+
continue
|
|
228
|
+
yield (
|
|
229
|
+
os.path.split(name)[-1],
|
|
230
|
+
"%04d-%02d-%02d %02d:%02d:%02d" % info.date_time,
|
|
231
|
+
name,
|
|
232
|
+
filename,
|
|
233
|
+
)
|
|
234
|
+
zf.close()
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
raise AssertionError(f"Unexpected format {filename!r}, cannot read it.")
|
|
238
|
+
|
|
239
|
+
# filename is a pattern.
|
|
240
|
+
found = glob.glob(filename)
|
|
241
|
+
if verbose and not found:
|
|
242
|
+
print(f"[enumerate_csv_files] unable to find file in {filename!r}")
|
|
243
|
+
for ii, f in enumerate(found):
|
|
244
|
+
if verbose:
|
|
245
|
+
print(f"[enumerate_csv_files] data[{itn}][{ii}] {f!r} from {filename!r}")
|
|
246
|
+
yield from enumerate_csv_files(f, verbose=verbose, filtering=filtering)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def open_dataframe(
|
|
250
|
+
data: Union[str, Tuple[str, str, str, str], pandas.DataFrame],
|
|
251
|
+
) -> pandas.DataFrame:
|
|
252
|
+
"""
|
|
253
|
+
Opens a filename defined by function
|
|
254
|
+
:func:`onnx_diagnostic.helpers._log_helper.enumerate_csv_files`.
|
|
255
|
+
|
|
256
|
+
:param data: a dataframe, a filename, a tuple indicating the file is coming
|
|
257
|
+
from a zip file
|
|
258
|
+
:return: a dataframe
|
|
259
|
+
"""
|
|
260
|
+
if isinstance(data, pandas.DataFrame):
|
|
261
|
+
return data
|
|
262
|
+
if isinstance(data, str):
|
|
263
|
+
df = pandas.read_csv(data, low_memory=False)
|
|
264
|
+
df["RAWFILENAME"] = data
|
|
265
|
+
return df
|
|
266
|
+
if isinstance(data, tuple):
|
|
267
|
+
if not data[-1]:
|
|
268
|
+
df = pandas.read_csv(data[2], low_memory=False)
|
|
269
|
+
df["RAWFILENAME"] = data[2]
|
|
270
|
+
return df
|
|
271
|
+
zf = zipfile.ZipFile(data[-1])
|
|
272
|
+
with zf.open(data[2]) as f:
|
|
273
|
+
df = pandas.read_csv(f, low_memory=False)
|
|
274
|
+
df["RAWFILENAME"] = f"{data[-1]}/{data[2]}"
|
|
275
|
+
zf.close()
|
|
276
|
+
return df
|
|
277
|
+
|
|
278
|
+
raise ValueError(f"Unexpected value for data: {data!r}")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def align_dataframe_with(
|
|
282
|
+
df: pandas.DataFrame, baseline: pandas.DataFrame, fill_value: float = 0
|
|
283
|
+
) -> Optional[pandas.DataFrame]:
|
|
284
|
+
"""
|
|
285
|
+
Modifies the first dataframe *df* to get the exact same number of columns and rows.
|
|
286
|
+
They must share the same levels on both axes. Empty cells are filled with 0.
|
|
287
|
+
We only keep the numerical columns. The function return None if the output is empty.
|
|
288
|
+
"""
|
|
289
|
+
df = df.select_dtypes(include="number")
|
|
290
|
+
if df.shape[1] == 0:
|
|
291
|
+
return None
|
|
292
|
+
bool_cols = list(df.select_dtypes(include="bool").columns)
|
|
293
|
+
if bool_cols:
|
|
294
|
+
df[bool_cols] = df[bool_cols].astype(int)
|
|
295
|
+
assert (
|
|
296
|
+
df.columns.names == baseline.columns.names or df.index.names == baseline.index.names
|
|
297
|
+
), (
|
|
298
|
+
f"Levels mismatch, expected index.names={baseline.index.names}, "
|
|
299
|
+
f"expected columns.names={baseline.columns.names}, "
|
|
300
|
+
f"got index.names={df.index.names}, "
|
|
301
|
+
f"got columns.names={df.columns.names}"
|
|
302
|
+
)
|
|
303
|
+
dtypes = set(df[c].dtype for c in df.columns)
|
|
304
|
+
assert all(np.issubdtype(dt, np.number) for dt in dtypes), (
|
|
305
|
+
f"All columns in the first dataframe are expected to share "
|
|
306
|
+
f"the same type or be at least numerical but got {dtypes}\n{df}"
|
|
307
|
+
)
|
|
308
|
+
common_index = df.index.intersection(baseline.index)
|
|
309
|
+
cp = pandas.DataFrame(float(fill_value), index=baseline.index, columns=baseline.columns)
|
|
310
|
+
for c in df.columns:
|
|
311
|
+
if c not in cp.columns or not np.issubdtype(df[c].dtype, np.number):
|
|
312
|
+
continue
|
|
313
|
+
cp.loc[common_index, c] = df.loc[common_index, c].astype(cp[c].dtype)
|
|
314
|
+
return cp
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def apply_excel_style(
|
|
318
|
+
filename_or_writer: Any,
|
|
319
|
+
f_highlights: Optional[ # type: ignore[name-defined]
|
|
320
|
+
Dict[str, Callable[[Any], "CubeViewDef.HighLightKind"]] # noqa: F821
|
|
321
|
+
] = None,
|
|
322
|
+
time_mask_view: Optional[Dict[str, pandas.DataFrame]] = None,
|
|
323
|
+
verbose: int = 0,
|
|
324
|
+
):
|
|
325
|
+
"""
|
|
326
|
+
Applies styles on all sheets in a file unless the sheet is too big.
|
|
327
|
+
|
|
328
|
+
:param filename_or_writer: filename, modified inplace
|
|
329
|
+
:param f_highlight: color function to apply, one per sheet
|
|
330
|
+
:param time_mask_view: if specified, it contains dataframe with the same shape
|
|
331
|
+
and values in {-1, 0, +1} which indicates if a value is unexpectedly lower (-1)
|
|
332
|
+
or higher (+1), it changes the color of the background then.
|
|
333
|
+
:param verbosity: progress loop
|
|
334
|
+
"""
|
|
335
|
+
from openpyxl import load_workbook
|
|
336
|
+
from openpyxl.styles import Alignment
|
|
337
|
+
from openpyxl.utils import get_column_letter
|
|
338
|
+
from openpyxl.styles import Font, PatternFill
|
|
339
|
+
from .log_helper import CubeViewDef
|
|
340
|
+
|
|
341
|
+
if isinstance(filename_or_writer, str):
|
|
342
|
+
workbook = load_workbook(filename_or_writer)
|
|
343
|
+
save = True
|
|
344
|
+
else:
|
|
345
|
+
workbook = filename_or_writer.book
|
|
346
|
+
save = False
|
|
347
|
+
|
|
348
|
+
mask_low = PatternFill(fgColor="AAAAF0", fill_type="solid")
|
|
349
|
+
mask_high = PatternFill(fgColor="F0AAAA", fill_type="solid")
|
|
350
|
+
|
|
351
|
+
left = Alignment(horizontal="left")
|
|
352
|
+
left_shrink = Alignment(horizontal="left", shrink_to_fit=True)
|
|
353
|
+
right = Alignment(horizontal="right")
|
|
354
|
+
font_colors = {
|
|
355
|
+
CubeViewDef.HighLightKind.GREEN: Font(color="00AA00"),
|
|
356
|
+
CubeViewDef.HighLightKind.RED: Font(color="FF0000"),
|
|
357
|
+
}
|
|
358
|
+
if verbose:
|
|
359
|
+
from tqdm import tqdm
|
|
360
|
+
|
|
361
|
+
sheet_names = tqdm(list(workbook.sheetnames))
|
|
362
|
+
else:
|
|
363
|
+
sheet_names = workbook.sheetnames
|
|
364
|
+
for name in sheet_names:
|
|
365
|
+
if time_mask_view and name in time_mask_view:
|
|
366
|
+
mask = time_mask_view[name]
|
|
367
|
+
with pandas.ExcelWriter(io.BytesIO(), engine="openpyxl") as mask_writer:
|
|
368
|
+
mask.to_excel(mask_writer, sheet_name=name)
|
|
369
|
+
sheet_mask = mask_writer.sheets[name]
|
|
370
|
+
else:
|
|
371
|
+
sheet_mask = None
|
|
372
|
+
|
|
373
|
+
f_highlight = f_highlights.get(name, None) if f_highlights else None
|
|
374
|
+
sheet = workbook[name]
|
|
375
|
+
n_rows = sheet.max_row
|
|
376
|
+
n_cols = sheet.max_column
|
|
377
|
+
if n_rows * n_cols > 2**16 or n_rows > 2**13:
|
|
378
|
+
# Too big.
|
|
379
|
+
continue
|
|
380
|
+
co: Dict[int, int] = {}
|
|
381
|
+
sizes: Dict[int, int] = {}
|
|
382
|
+
cols = set()
|
|
383
|
+
for i in range(1, n_rows + 1):
|
|
384
|
+
for j, cell in enumerate(sheet[i]):
|
|
385
|
+
if j > n_cols:
|
|
386
|
+
break
|
|
387
|
+
cols.add(cell.column)
|
|
388
|
+
if isinstance(cell.value, float):
|
|
389
|
+
co[j] = co.get(j, 0) + 1
|
|
390
|
+
elif isinstance(cell.value, str):
|
|
391
|
+
sizes[cell.column] = max(sizes.get(cell.column, 0), len(cell.value))
|
|
392
|
+
|
|
393
|
+
for k, v in sizes.items():
|
|
394
|
+
c = get_column_letter(k)
|
|
395
|
+
sheet.column_dimensions[c].width = min(max(8, v), 30)
|
|
396
|
+
for k in cols:
|
|
397
|
+
if k not in sizes:
|
|
398
|
+
c = get_column_letter(k)
|
|
399
|
+
sheet.column_dimensions[c].width = 15
|
|
400
|
+
|
|
401
|
+
for i in range(1, n_rows + 1):
|
|
402
|
+
for j, cell in enumerate(sheet[i]):
|
|
403
|
+
if j > n_cols:
|
|
404
|
+
break
|
|
405
|
+
if isinstance(cell.value, pandas.Timestamp):
|
|
406
|
+
cell.alignment = right
|
|
407
|
+
dt = cell.value.to_pydatetime()
|
|
408
|
+
cell.value = dt
|
|
409
|
+
cell.number_format = (
|
|
410
|
+
"YYYY-MM-DD"
|
|
411
|
+
if (
|
|
412
|
+
dt.hour == 0
|
|
413
|
+
and dt.minute == 0
|
|
414
|
+
and dt.second == 0
|
|
415
|
+
and dt.microsecond == 0
|
|
416
|
+
)
|
|
417
|
+
else "YYYY-MM-DD 00:00:00"
|
|
418
|
+
)
|
|
419
|
+
elif isinstance(cell.value, (float, int)):
|
|
420
|
+
cell.alignment = right
|
|
421
|
+
x = abs(cell.value)
|
|
422
|
+
if int(x) == x:
|
|
423
|
+
cell.number_format = "0"
|
|
424
|
+
elif x > 5000:
|
|
425
|
+
cell.number_format = "# ##0"
|
|
426
|
+
elif x >= 500:
|
|
427
|
+
cell.number_format = "0.0"
|
|
428
|
+
elif x >= 50:
|
|
429
|
+
cell.number_format = "0.00"
|
|
430
|
+
elif x >= 5:
|
|
431
|
+
cell.number_format = "0.000"
|
|
432
|
+
elif x > 0.5:
|
|
433
|
+
cell.number_format = "0.0000"
|
|
434
|
+
elif x > 0.005:
|
|
435
|
+
cell.number_format = "0.00000"
|
|
436
|
+
else:
|
|
437
|
+
cell.number_format = "0.000E+00"
|
|
438
|
+
if f_highlight:
|
|
439
|
+
h = f_highlight(cell.value)
|
|
440
|
+
if h in font_colors:
|
|
441
|
+
cell.font = font_colors[h]
|
|
442
|
+
elif isinstance(cell.value, str) and len(cell.value) > 70:
|
|
443
|
+
cell.alignment = left_shrink
|
|
444
|
+
else:
|
|
445
|
+
cell.alignment = left
|
|
446
|
+
if f_highlight:
|
|
447
|
+
h = f_highlight(cell.value)
|
|
448
|
+
if h in font_colors:
|
|
449
|
+
cell.font = font_colors[h]
|
|
450
|
+
|
|
451
|
+
if sheet_mask is not None:
|
|
452
|
+
for i in range(1, n_rows + 1):
|
|
453
|
+
for j, (cell, cell_mask) in enumerate(zip(sheet[i], sheet_mask[i])):
|
|
454
|
+
if j > n_cols:
|
|
455
|
+
break
|
|
456
|
+
if cell_mask.value not in (1, -1):
|
|
457
|
+
continue
|
|
458
|
+
cell.fill = mask_low if cell_mask.value < 0 else mask_high
|
|
459
|
+
|
|
460
|
+
if save:
|
|
461
|
+
workbook.save(filename_or_writer)
|