onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +66 -8
  3. onnx_diagnostic/ext_test_case.py +2 -0
  4. onnx_diagnostic/helpers/_log_helper.py +461 -0
  5. onnx_diagnostic/helpers/cache_helper.py +250 -15
  6. onnx_diagnostic/helpers/helper.py +146 -10
  7. onnx_diagnostic/helpers/log_helper.py +404 -315
  8. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  9. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  10. onnx_diagnostic/helpers/torch_helper.py +33 -11
  11. onnx_diagnostic/tasks/__init__.py +2 -0
  12. onnx_diagnostic/tasks/feature_extraction.py +86 -5
  13. onnx_diagnostic/tasks/image_text_to_text.py +260 -56
  14. onnx_diagnostic/tasks/mask_generation.py +139 -0
  15. onnx_diagnostic/tasks/text2text_generation.py +2 -2
  16. onnx_diagnostic/tasks/text_generation.py +6 -2
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
  18. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  19. onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
  21. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
  24. onnx_diagnostic/torch_models/validate.py +26 -3
  25. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
  27. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.4"
6
+ __version__ = "0.7.6"
7
7
  __author__ = "Xavier Dupré"
@@ -483,6 +483,12 @@ def get_parser_validate() -> ArgumentParser:
483
483
  parser.add_argument(
484
484
  "--warmup", default=0, type=int, help="number of times to run the model to do warmup"
485
485
  )
486
+ parser.add_argument(
487
+ "--outnames",
488
+ help="This comma separated list defines the output names "
489
+ "the onnx exporter should use.",
490
+ default="",
491
+ )
486
492
  return parser
487
493
 
488
494
 
@@ -542,6 +548,9 @@ def _cmd_validate(argv: List[Any]):
542
548
  repeat=args.repeat,
543
549
  warmup=args.warmup,
544
550
  inputs2=args.inputs2,
551
+ output_names=(
552
+ None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
553
+ ),
545
554
  )
546
555
  print("")
547
556
  print("-- summary --")
@@ -645,6 +654,27 @@ def _cmd_stats(argv: List[Any]):
645
654
  print("done.")
646
655
 
647
656
 
657
+ class _ParseNamedDict(argparse.Action):
658
+ def __call__(self, parser, namespace, values, option_string=None):
659
+ assert ":" in values, f"':' missing from {values!r}"
660
+ namespace_key, rest = values.split(":", 1)
661
+ pairs = rest.split(",")
662
+ inner_dict = {}
663
+
664
+ for pair in pairs:
665
+ if "=" not in pair:
666
+ raise argparse.ArgumentError(self, f"Expected '=' in pair '{pair}'")
667
+ key, value = pair.split("=", 1)
668
+ inner_dict[key] = value
669
+ assert inner_dict, f"Unable to parse {rest!r} into a dictionary"
670
+ if not hasattr(namespace, self.dest) or getattr(namespace, self.dest) is None:
671
+ setattr(namespace, self.dest, {})
672
+ assert isinstance(
673
+ getattr(namespace, self.dest), dict
674
+ ), f"Unexpected type for namespace.{self.dest}={getattr(namespace, self.dest)}"
675
+ getattr(namespace, self.dest).update({namespace_key: inner_dict})
676
+
677
+
648
678
  def get_parser_agg() -> ArgumentParser:
649
679
  parser = ArgumentParser(
650
680
  prog="agg",
@@ -653,13 +683,23 @@ def get_parser_agg() -> ArgumentParser:
653
683
  Aggregates statistics coming from benchmarks.
654
684
  Every run is a row. Every row is indexed by some keys,
655
685
  and produces values. Every row has a date.
686
+ The data can come any csv files produces by benchmarks,
687
+ it can concatenates many csv files, or csv files inside zip files.
688
+ It produces an excel file with many tabs, one per view.
656
689
  """
657
690
  ),
658
691
  epilog=textwrap.dedent(
659
692
  """
660
- examples:\n
693
+ examples:
661
694
 
662
695
  python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1
696
+ python -m onnx_diagnostic agg agg.xlsx raw/*.zip raw/*.csv -v 1 \\
697
+ --no-raw --keep-last-date --filter-out "exporter:test-exporter"
698
+
699
+ Another to create timeseries:
700
+
701
+ python -m onnx_diagnostic agg history.xlsx raw/*.csv -v 1 --no-raw \\
702
+ --no-recent
663
703
  """
664
704
  ),
665
705
  formatter_class=RawTextHelpFormatter,
@@ -737,7 +777,15 @@ def get_parser_agg() -> ArgumentParser:
737
777
  "--views",
738
778
  default="agg-suite,agg-all,disc,speedup,time,time_export,err,cmd,"
739
779
  "bucket-speedup,raw-short,counts,peak-gpu,onnx",
740
- help="Views to add to the output files.",
780
+ help=textwrap.dedent(
781
+ """
782
+ Views to add to the output files. Each view becomes a tab.
783
+ A view is defined by its name, among
784
+ agg-suite, agg-all, disc, speedup, time, time_export, err,
785
+ cmd, bucket-speedup, raw-short, counts, peak-gpu, onnx.
786
+ Their definition is part of class CubeLogsPerformance.
787
+ """
788
+ ),
741
789
  )
742
790
  parser.add_argument(
743
791
  "--csv",
@@ -757,16 +805,24 @@ def get_parser_agg() -> ArgumentParser:
757
805
  help="adds a filter to filter out data, syntax is\n"
758
806
  '``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
759
807
  )
808
+ parser.add_argument(
809
+ "--sbs",
810
+ help=textwrap.dedent(
811
+ """
812
+ Defines an exporter to compare to another, there must be at least
813
+ two arguments defined with --sbs. Example:
814
+ --sbs dynamo:exporter=onnx-dynamo,opt=ir,attn_impl=eager
815
+ --sbs custom:exporter=custom,opt=default,attn_impl=eager
816
+ """
817
+ ),
818
+ action=_ParseNamedDict,
819
+ )
760
820
  return parser
761
821
 
762
822
 
763
823
  def _cmd_agg(argv: List[Any]):
764
- from .helpers.log_helper import (
765
- CubeLogsPerformance,
766
- open_dataframe,
767
- enumerate_csv_files,
768
- filter_data,
769
- )
824
+ from .helpers._log_helper import open_dataframe, enumerate_csv_files, filter_data
825
+ from .helpers.log_helper import CubeLogsPerformance
770
826
 
771
827
  parser = get_parser_agg()
772
828
  args = parser.parse_args(argv[1:])
@@ -812,6 +868,8 @@ def _cmd_agg(argv: List[Any]):
812
868
  verbose=args.verbose,
813
869
  csv=args.csv.split(","),
814
870
  raw=args.raw,
871
+ time_mask=True,
872
+ sbs=args.sbs,
815
873
  )
816
874
  if args.verbose:
817
875
  print(f"Wrote {args.output!r}")
@@ -1058,6 +1058,8 @@ class ExtTestCase(unittest.TestCase):
1058
1058
  elif hasattr(expected, "shape"):
1059
1059
  self.assertEqual(type(expected), type(value), msg=msg)
1060
1060
  self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
1061
+ elif expected is None:
1062
+ assert value is None, f"Expected is None but value is of type {type(value)}"
1061
1063
  else:
1062
1064
  raise AssertionError(
1063
1065
  f"Comparison not implemented for types {type(expected)} and {type(value)}"
@@ -0,0 +1,461 @@
1
+ import datetime
2
+ import glob
3
+ import io
4
+ import os
5
+ import zipfile
6
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
7
+ import numpy as np
8
+ import pandas
9
+
10
+ BUCKET_SCALES_VALUES = np.array(
11
+ [-np.inf, -20, -10, -5, -2, 0, 2, 5, 10, 20, 100, 200, 300, 400, np.inf], dtype=float
12
+ )
13
+
14
+
15
+ BUCKET_SCALES = BUCKET_SCALES_VALUES / 100 + 1
16
+
17
+
18
+ def mann_kendall(series: Sequence[float], threshold: float = 0.5):
19
+ """
20
+ Computes the test of Mann-Kendall.
21
+
22
+ :param series: series
23
+ :param threshold: 1.96 is the usual value, 0.5 means a short timeseries
24
+ ``(0, 1, 2, 3, 4)`` has a significant trend
25
+ :return: trend (-1, 0, +1), test value
26
+
27
+ .. math::
28
+
29
+ S =\\sum_{i=1}^{n}\\sum_{j=i+1}^{n} sign(x_j - x_i)
30
+
31
+ where the function *sign* is:
32
+
33
+ .. math::
34
+
35
+ sign(x) = \\left\\{ \\begin{array}{l} -1 if x < 0 \\\\ 0 if x = 0 \\\\ +1 otherwise
36
+ \\end{array} \\right.
37
+
38
+ And:
39
+
40
+ .. math::
41
+
42
+ Var(S)= \\frac{n(n-1)(2n+5) - \\sum_t t(t-1)(2t+5)}{18}
43
+ """
44
+ aseries = np.asarray(series)
45
+ stat = 0
46
+ n = len(aseries)
47
+ var = n * (n - 1) * (2 * n + 5)
48
+ for i in range(n - 1):
49
+ stat += np.sign(aseries[i + 1 :] - aseries[i]).sum()
50
+ var = var**0.5
51
+ test = (stat + (1 if stat < 0 else (0 if stat == 0 else -1))) / var
52
+ trend = np.sign(test) if np.abs(test) > threshold else 0
53
+ return trend, test
54
+
55
+
56
+ def breaking_last_point(series: Sequence[float], threshold: float = 1.2):
57
+ """
58
+ Assuming a timeseries is constant, we check the last value
59
+ is not an outlier.
60
+
61
+ :param series: series
62
+ :return: significant change (-1, 0, +1), test value
63
+ """
64
+ signal = np.asarray(series)
65
+ if not np.issubdtype(signal.dtype, np.number):
66
+ return 0, np.nan
67
+ assert len(signal.shape) == 1, f"Unexpected signal shape={signal.shape}, signal={signal}"
68
+ if signal.shape[0] <= 2:
69
+ return 0, 0
70
+
71
+ has_value = ~(np.isnan(signal).all()) and ~(np.isinf(signal).all())
72
+ if np.isnan(signal[-1]) or np.isinf(signal[-1]):
73
+ return (-1, np.inf) if has_value else (0, 0)
74
+
75
+ try:
76
+ m = np.mean(signal[:-1])
77
+ except (TypeError, ValueError):
78
+ # Not a numerical type
79
+ return 0, np.nan
80
+
81
+ if np.isnan(m) or np.isinf(m):
82
+ return (1, np.inf) if np.isinf(signal[-2]) or np.isnan(signal[-2]) else (0, 0)
83
+ v = np.std(signal[:-1])
84
+ if v == 0:
85
+ test = signal[-1] - m
86
+ assert not np.isnan(
87
+ test
88
+ ), f"Unexpected test value, test={test}, signal={signal}, m={m}, v={v}"
89
+ trend = np.sign(test)
90
+ return trend, trend
91
+ test = (signal[-1] - m) / v
92
+ assert not np.isnan(
93
+ test
94
+ ), f"Unexpected test value, test={test}, signal={signal}, m={m}, v={v}"
95
+ trend = np.sign(test) if np.abs(test) > threshold else 0
96
+ return trend, test
97
+
98
+
99
+ def filter_data(
100
+ df: pandas.DataFrame,
101
+ filter_in: Optional[str] = None,
102
+ filter_out: Optional[str] = None,
103
+ verbose: int = 0,
104
+ ) -> pandas.DataFrame:
105
+ """
106
+ Argument `filter` follows the syntax
107
+ ``<column1>:<fmt1>//<column2>:<fmt2>``.
108
+
109
+ The format is the following:
110
+
111
+ * a value or a set of values separated by ``;``
112
+ """
113
+ if not filter_in and not filter_out:
114
+ return df
115
+
116
+ def _f(fmt):
117
+ cond = {}
118
+ if isinstance(fmt, str):
119
+ cols = fmt.split("//")
120
+ for c in cols:
121
+ assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}"
122
+ spl = c.split(":")
123
+ assert len(spl) == 2, f"Unexpected value {c!r} in fmt={fmt!r}"
124
+ name, fil = spl
125
+ cond[name] = set(fil.split(";"))
126
+ return cond
127
+
128
+ if filter_in:
129
+ cond = _f(filter_in)
130
+ assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_in!r}"
131
+ for k, v in cond.items():
132
+ if k not in df.columns:
133
+ continue
134
+ if verbose:
135
+ print(
136
+ f"[_filter_data] filter in column {k!r}, "
137
+ f"values {v!r} among {set(df[k].astype(str))}"
138
+ )
139
+ df = df[df[k].astype(str).isin(v)]
140
+
141
+ if filter_out:
142
+ cond = _f(filter_out)
143
+ assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_out!r}"
144
+ for k, v in cond.items():
145
+ if k not in df.columns:
146
+ continue
147
+ if verbose:
148
+ print(
149
+ f"[_filter_data] filter out column {k!r}, "
150
+ f"values {v!r} among {set(df[k].astype(str))}"
151
+ )
152
+ df = df[~df[k].astype(str).isin(v)]
153
+ return df
154
+
155
+
156
+ def enumerate_csv_files(
157
+ data: Union[
158
+ pandas.DataFrame, List[Union[str, Tuple[str, str]]], str, Tuple[str, str, str, str]
159
+ ],
160
+ verbose: int = 0,
161
+ filtering: Optional[Callable[[str], bool]] = None,
162
+ ) -> Iterator[Union[pandas.DataFrame, str, Tuple[str, str, str, str]]]:
163
+ """
164
+ Enumerates files considered for the aggregation.
165
+ Only csv files are considered.
166
+ If a zip file is given, the function digs into the zip files and
167
+ loops over csv candidates.
168
+
169
+ :param data: dataframe with the raw data or a file or list of files
170
+ :param vrbose: verbosity
171
+ :param filtering: function to filter in or out files in zip files,
172
+ must return true to keep the file, false to skip it.
173
+ :return: a generator yielding tuples with the filename, date, full path and zip file
174
+
175
+ data can contains:
176
+ * a dataframe
177
+ * a string for a filename, zip or csv
178
+ * a list of string
179
+ * a tuple
180
+ """
181
+ if not isinstance(data, list):
182
+ data = [data]
183
+ for itn, filename in enumerate(data):
184
+ if isinstance(filename, pandas.DataFrame):
185
+ if verbose:
186
+ print(f"[enumerate_csv_files] data[{itn}] is a dataframe")
187
+ yield filename
188
+ continue
189
+
190
+ if isinstance(filename, tuple):
191
+ # A file in a zipfile
192
+ if verbose:
193
+ print(f"[enumerate_csv_files] data[{itn}] is {filename!r}")
194
+ yield filename
195
+ continue
196
+
197
+ if os.path.exists(filename):
198
+ ext = os.path.splitext(filename)[-1]
199
+ if ext == ".csv":
200
+ # We check the first line is ok.
201
+ if verbose:
202
+ print(f"[enumerate_csv_files] data[{itn}] is a csv file: {filename!r}]")
203
+ dt = datetime.datetime.fromtimestamp(os.stat(filename).st_mtime)
204
+ du = dt.strftime("%Y-%m-%d %H:%M:%S")
205
+ yield (os.path.split(filename)[-1], du, filename, "")
206
+ continue
207
+
208
+ if ext == ".zip":
209
+ if verbose:
210
+ print(f"[enumerate_csv_files] data[{itn}] is a zip file: {filename!r}]")
211
+ zf = zipfile.ZipFile(filename, "r")
212
+ for ii, info in enumerate(zf.infolist()):
213
+ name = info.filename
214
+ if filtering is None:
215
+ ext = os.path.splitext(name)[-1]
216
+ if ext != ".csv":
217
+ continue
218
+ elif not filtering(name):
219
+ continue
220
+ if verbose:
221
+ print(
222
+ f"[enumerate_csv_files] data[{itn}][{ii}] is a csv file: {name!r}]"
223
+ )
224
+ with zf.open(name) as zzf:
225
+ first_line = zzf.readline()
226
+ if b"," not in first_line:
227
+ continue
228
+ yield (
229
+ os.path.split(name)[-1],
230
+ "%04d-%02d-%02d %02d:%02d:%02d" % info.date_time,
231
+ name,
232
+ filename,
233
+ )
234
+ zf.close()
235
+ continue
236
+
237
+ raise AssertionError(f"Unexpected format {filename!r}, cannot read it.")
238
+
239
+ # filename is a pattern.
240
+ found = glob.glob(filename)
241
+ if verbose and not found:
242
+ print(f"[enumerate_csv_files] unable to find file in {filename!r}")
243
+ for ii, f in enumerate(found):
244
+ if verbose:
245
+ print(f"[enumerate_csv_files] data[{itn}][{ii}] {f!r} from {filename!r}")
246
+ yield from enumerate_csv_files(f, verbose=verbose, filtering=filtering)
247
+
248
+
249
+ def open_dataframe(
250
+ data: Union[str, Tuple[str, str, str, str], pandas.DataFrame],
251
+ ) -> pandas.DataFrame:
252
+ """
253
+ Opens a filename defined by function
254
+ :func:`onnx_diagnostic.helpers._log_helper.enumerate_csv_files`.
255
+
256
+ :param data: a dataframe, a filename, a tuple indicating the file is coming
257
+ from a zip file
258
+ :return: a dataframe
259
+ """
260
+ if isinstance(data, pandas.DataFrame):
261
+ return data
262
+ if isinstance(data, str):
263
+ df = pandas.read_csv(data, low_memory=False)
264
+ df["RAWFILENAME"] = data
265
+ return df
266
+ if isinstance(data, tuple):
267
+ if not data[-1]:
268
+ df = pandas.read_csv(data[2], low_memory=False)
269
+ df["RAWFILENAME"] = data[2]
270
+ return df
271
+ zf = zipfile.ZipFile(data[-1])
272
+ with zf.open(data[2]) as f:
273
+ df = pandas.read_csv(f, low_memory=False)
274
+ df["RAWFILENAME"] = f"{data[-1]}/{data[2]}"
275
+ zf.close()
276
+ return df
277
+
278
+ raise ValueError(f"Unexpected value for data: {data!r}")
279
+
280
+
281
+ def align_dataframe_with(
282
+ df: pandas.DataFrame, baseline: pandas.DataFrame, fill_value: float = 0
283
+ ) -> Optional[pandas.DataFrame]:
284
+ """
285
+ Modifies the first dataframe *df* to get the exact same number of columns and rows.
286
+ They must share the same levels on both axes. Empty cells are filled with 0.
287
+ We only keep the numerical columns. The function return None if the output is empty.
288
+ """
289
+ df = df.select_dtypes(include="number")
290
+ if df.shape[1] == 0:
291
+ return None
292
+ bool_cols = list(df.select_dtypes(include="bool").columns)
293
+ if bool_cols:
294
+ df[bool_cols] = df[bool_cols].astype(int)
295
+ assert (
296
+ df.columns.names == baseline.columns.names or df.index.names == baseline.index.names
297
+ ), (
298
+ f"Levels mismatch, expected index.names={baseline.index.names}, "
299
+ f"expected columns.names={baseline.columns.names}, "
300
+ f"got index.names={df.index.names}, "
301
+ f"got columns.names={df.columns.names}"
302
+ )
303
+ dtypes = set(df[c].dtype for c in df.columns)
304
+ assert all(np.issubdtype(dt, np.number) for dt in dtypes), (
305
+ f"All columns in the first dataframe are expected to share "
306
+ f"the same type or be at least numerical but got {dtypes}\n{df}"
307
+ )
308
+ common_index = df.index.intersection(baseline.index)
309
+ cp = pandas.DataFrame(float(fill_value), index=baseline.index, columns=baseline.columns)
310
+ for c in df.columns:
311
+ if c not in cp.columns or not np.issubdtype(df[c].dtype, np.number):
312
+ continue
313
+ cp.loc[common_index, c] = df.loc[common_index, c].astype(cp[c].dtype)
314
+ return cp
315
+
316
+
317
+ def apply_excel_style(
318
+ filename_or_writer: Any,
319
+ f_highlights: Optional[ # type: ignore[name-defined]
320
+ Dict[str, Callable[[Any], "CubeViewDef.HighLightKind"]] # noqa: F821
321
+ ] = None,
322
+ time_mask_view: Optional[Dict[str, pandas.DataFrame]] = None,
323
+ verbose: int = 0,
324
+ ):
325
+ """
326
+ Applies styles on all sheets in a file unless the sheet is too big.
327
+
328
+ :param filename_or_writer: filename, modified inplace
329
+ :param f_highlight: color function to apply, one per sheet
330
+ :param time_mask_view: if specified, it contains dataframe with the same shape
331
+ and values in {-1, 0, +1} which indicates if a value is unexpectedly lower (-1)
332
+ or higher (+1), it changes the color of the background then.
333
+ :param verbosity: progress loop
334
+ """
335
+ from openpyxl import load_workbook
336
+ from openpyxl.styles import Alignment
337
+ from openpyxl.utils import get_column_letter
338
+ from openpyxl.styles import Font, PatternFill
339
+ from .log_helper import CubeViewDef
340
+
341
+ if isinstance(filename_or_writer, str):
342
+ workbook = load_workbook(filename_or_writer)
343
+ save = True
344
+ else:
345
+ workbook = filename_or_writer.book
346
+ save = False
347
+
348
+ mask_low = PatternFill(fgColor="AAAAF0", fill_type="solid")
349
+ mask_high = PatternFill(fgColor="F0AAAA", fill_type="solid")
350
+
351
+ left = Alignment(horizontal="left")
352
+ left_shrink = Alignment(horizontal="left", shrink_to_fit=True)
353
+ right = Alignment(horizontal="right")
354
+ font_colors = {
355
+ CubeViewDef.HighLightKind.GREEN: Font(color="00AA00"),
356
+ CubeViewDef.HighLightKind.RED: Font(color="FF0000"),
357
+ }
358
+ if verbose:
359
+ from tqdm import tqdm
360
+
361
+ sheet_names = tqdm(list(workbook.sheetnames))
362
+ else:
363
+ sheet_names = workbook.sheetnames
364
+ for name in sheet_names:
365
+ if time_mask_view and name in time_mask_view:
366
+ mask = time_mask_view[name]
367
+ with pandas.ExcelWriter(io.BytesIO(), engine="openpyxl") as mask_writer:
368
+ mask.to_excel(mask_writer, sheet_name=name)
369
+ sheet_mask = mask_writer.sheets[name]
370
+ else:
371
+ sheet_mask = None
372
+
373
+ f_highlight = f_highlights.get(name, None) if f_highlights else None
374
+ sheet = workbook[name]
375
+ n_rows = sheet.max_row
376
+ n_cols = sheet.max_column
377
+ if n_rows * n_cols > 2**16 or n_rows > 2**13:
378
+ # Too big.
379
+ continue
380
+ co: Dict[int, int] = {}
381
+ sizes: Dict[int, int] = {}
382
+ cols = set()
383
+ for i in range(1, n_rows + 1):
384
+ for j, cell in enumerate(sheet[i]):
385
+ if j > n_cols:
386
+ break
387
+ cols.add(cell.column)
388
+ if isinstance(cell.value, float):
389
+ co[j] = co.get(j, 0) + 1
390
+ elif isinstance(cell.value, str):
391
+ sizes[cell.column] = max(sizes.get(cell.column, 0), len(cell.value))
392
+
393
+ for k, v in sizes.items():
394
+ c = get_column_letter(k)
395
+ sheet.column_dimensions[c].width = min(max(8, v), 30)
396
+ for k in cols:
397
+ if k not in sizes:
398
+ c = get_column_letter(k)
399
+ sheet.column_dimensions[c].width = 15
400
+
401
+ for i in range(1, n_rows + 1):
402
+ for j, cell in enumerate(sheet[i]):
403
+ if j > n_cols:
404
+ break
405
+ if isinstance(cell.value, pandas.Timestamp):
406
+ cell.alignment = right
407
+ dt = cell.value.to_pydatetime()
408
+ cell.value = dt
409
+ cell.number_format = (
410
+ "YYYY-MM-DD"
411
+ if (
412
+ dt.hour == 0
413
+ and dt.minute == 0
414
+ and dt.second == 0
415
+ and dt.microsecond == 0
416
+ )
417
+ else "YYYY-MM-DD 00:00:00"
418
+ )
419
+ elif isinstance(cell.value, (float, int)):
420
+ cell.alignment = right
421
+ x = abs(cell.value)
422
+ if int(x) == x:
423
+ cell.number_format = "0"
424
+ elif x > 5000:
425
+ cell.number_format = "# ##0"
426
+ elif x >= 500:
427
+ cell.number_format = "0.0"
428
+ elif x >= 50:
429
+ cell.number_format = "0.00"
430
+ elif x >= 5:
431
+ cell.number_format = "0.000"
432
+ elif x > 0.5:
433
+ cell.number_format = "0.0000"
434
+ elif x > 0.005:
435
+ cell.number_format = "0.00000"
436
+ else:
437
+ cell.number_format = "0.000E+00"
438
+ if f_highlight:
439
+ h = f_highlight(cell.value)
440
+ if h in font_colors:
441
+ cell.font = font_colors[h]
442
+ elif isinstance(cell.value, str) and len(cell.value) > 70:
443
+ cell.alignment = left_shrink
444
+ else:
445
+ cell.alignment = left
446
+ if f_highlight:
447
+ h = f_highlight(cell.value)
448
+ if h in font_colors:
449
+ cell.font = font_colors[h]
450
+
451
+ if sheet_mask is not None:
452
+ for i in range(1, n_rows + 1):
453
+ for j, (cell, cell_mask) in enumerate(zip(sheet[i], sheet_mask[i])):
454
+ if j > n_cols:
455
+ break
456
+ if cell_mask.value not in (1, -1):
457
+ continue
458
+ cell.fill = mask_low if cell_mask.value < 0 else mask_high
459
+
460
+ if save:
461
+ workbook.save(filename_or_writer)