onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1 @@
1
+ from .helper import flatten_object, max_diff, string_diff, string_sig, string_type
@@ -0,0 +1,463 @@
1
+ import datetime
2
+ import glob
3
+ import io
4
+ import os
5
+ import zipfile
6
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
7
+ import numpy as np
8
+ import pandas
9
+
10
+ BUCKET_SCALES_VALUES = np.array(
11
+ [-np.inf, -20, -10, -5, -2, 0, 2, 5, 10, 20, 100, 200, 300, 400, np.inf], dtype=float
12
+ )
13
+
14
+
15
+ BUCKET_SCALES = BUCKET_SCALES_VALUES / 100 + 1
16
+
17
+
18
+ def mann_kendall(series: Sequence[float], threshold: float = 0.5):
19
+ """
20
+ Computes the test of Mann-Kendall.
21
+
22
+ :param series: series
23
+ :param threshold: 1.96 is the usual value, 0.5 means a short timeseries
24
+ ``(0, 1, 2, 3, 4)`` has a significant trend
25
+ :return: trend (-1, 0, +1), test value
26
+
27
+ .. math::
28
+
29
+ S =\\sum_{i=1}^{n}\\sum_{j=i+1}^{n} sign(x_j - x_i)
30
+
31
+ where the function *sign* is:
32
+
33
+ .. math::
34
+
35
+ sign(x) = \\left\\{ \\begin{array}{l} -1 if x < 0 \\\\ 0 if x = 0 \\\\ +1 otherwise
36
+ \\end{array} \\right.
37
+
38
+ And:
39
+
40
+ .. math::
41
+
42
+ Var(S)= \\frac{n(n-1)(2n+5) - \\sum_t t(t-1)(2t+5)}{18}
43
+ """
44
+ aseries = np.asarray(series)
45
+ stat = 0
46
+ n = len(aseries)
47
+ var = n * (n - 1) * (2 * n + 5)
48
+ for i in range(n - 1):
49
+ stat += np.sign(aseries[i + 1 :] - aseries[i]).sum()
50
+ var = var**0.5
51
+ test = (stat + (1 if stat < 0 else (0 if stat == 0 else -1))) / var
52
+ trend = np.sign(test) if np.abs(test) > threshold else 0
53
+ return trend, test
54
+
55
+
56
+ def breaking_last_point(series: Sequence[float], threshold: float = 1.2):
57
+ """
58
+ Assuming a timeseries is constant, we check the last value
59
+ is not an outlier.
60
+
61
+ :param series: series
62
+ :return: significant change (-1, 0, +1), test value
63
+ """
64
+ signal = np.asarray(series)
65
+ if not np.issubdtype(signal.dtype, np.number):
66
+ return 0, np.nan
67
+ assert len(signal.shape) == 1, f"Unexpected signal shape={signal.shape}, signal={signal}"
68
+ if signal.shape[0] <= 2:
69
+ return 0, 0
70
+
71
+ has_value = ~(np.isnan(signal).all()) and ~(np.isinf(signal).all())
72
+ if np.isnan(signal[-1]) or np.isinf(signal[-1]):
73
+ return (-1, np.inf) if has_value else (0, 0)
74
+
75
+ try:
76
+ m = np.mean(signal[:-1])
77
+ except (TypeError, ValueError):
78
+ # Not a numerical type
79
+ return 0, np.nan
80
+
81
+ if np.isnan(m) or np.isinf(m):
82
+ return (1, np.inf) if np.isinf(signal[-2]) or np.isnan(signal[-2]) else (0, 0)
83
+ v = np.std(signal[:-1])
84
+ if v == 0:
85
+ test = signal[-1] - m
86
+ assert not np.isnan(
87
+ test
88
+ ), f"Unexpected test value, test={test}, signal={signal}, m={m}, v={v}"
89
+ trend = np.sign(test)
90
+ return trend, trend
91
+ test = (signal[-1] - m) / v
92
+ assert not np.isnan(
93
+ test
94
+ ), f"Unexpected test value, test={test}, signal={signal}, m={m}, v={v}"
95
+ trend = np.sign(test) if np.abs(test) > threshold else 0
96
+ return trend, test
97
+
98
+
99
+ def filter_data(
100
+ df: pandas.DataFrame,
101
+ filter_in: Optional[str] = None,
102
+ filter_out: Optional[str] = None,
103
+ verbose: int = 0,
104
+ ) -> pandas.DataFrame:
105
+ """
106
+ Argument `filter` follows the syntax
107
+ ``<column1>:<fmt1>//<column2>:<fmt2>``.
108
+
109
+ The format is the following:
110
+
111
+ * a value or a set of values separated by ``;``
112
+ """
113
+ if not filter_in and not filter_out:
114
+ return df
115
+
116
+ def _f(fmt):
117
+ cond = {}
118
+ if isinstance(fmt, str):
119
+ cols = fmt.split("//")
120
+ for c in cols:
121
+ assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}, cols={cols!r}"
122
+ spl = c.split(":")
123
+ assert (
124
+ len(spl) == 2
125
+ ), f"Unexpected value {c!r} in fmt={fmt!r}, spl={spl}, cols={cols}"
126
+ name, fil = spl
127
+ cond[name] = set(fil.split(";"))
128
+ return cond
129
+
130
+ if filter_in:
131
+ cond = _f(filter_in)
132
+ assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_in!r}"
133
+ for k, v in cond.items():
134
+ if k not in df.columns:
135
+ continue
136
+ if verbose:
137
+ print(
138
+ f"[_filter_data] filter in column {k!r}, "
139
+ f"values {v!r} among {set(df[k].astype(str))}"
140
+ )
141
+ df = df[df[k].astype(str).isin(v)]
142
+
143
+ if filter_out:
144
+ cond = _f(filter_out)
145
+ assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_out!r}"
146
+ for k, v in cond.items():
147
+ if k not in df.columns:
148
+ continue
149
+ if verbose:
150
+ print(
151
+ f"[_filter_data] filter out column {k!r}, "
152
+ f"values {v!r} among {set(df[k].astype(str))}"
153
+ )
154
+ df = df[~df[k].astype(str).isin(v)]
155
+ return df
156
+
157
+
158
+ def enumerate_csv_files(
159
+ data: Union[
160
+ pandas.DataFrame, List[Union[str, Tuple[str, str]]], str, Tuple[str, str, str, str]
161
+ ],
162
+ verbose: int = 0,
163
+ filtering: Optional[Callable[[str], bool]] = None,
164
+ ) -> Iterator[Union[pandas.DataFrame, str, Tuple[str, str, str, str]]]:
165
+ """
166
+ Enumerates files considered for the aggregation.
167
+ Only csv files are considered.
168
+ If a zip file is given, the function digs into the zip files and
169
+ loops over csv candidates.
170
+
171
+ :param data: dataframe with the raw data or a file or list of files
172
+ :param vrbose: verbosity
173
+ :param filtering: function to filter in or out files in zip files,
174
+ must return true to keep the file, false to skip it.
175
+ :return: a generator yielding tuples with the filename, date, full path and zip file
176
+
177
+ data can contains:
178
+ * a dataframe
179
+ * a string for a filename, zip or csv
180
+ * a list of string
181
+ * a tuple
182
+ """
183
+ if not isinstance(data, list):
184
+ data = [data]
185
+ for itn, filename in enumerate(data):
186
+ if isinstance(filename, pandas.DataFrame):
187
+ if verbose:
188
+ print(f"[enumerate_csv_files] data[{itn}] is a dataframe")
189
+ yield filename
190
+ continue
191
+
192
+ if isinstance(filename, tuple):
193
+ # A file in a zipfile
194
+ if verbose:
195
+ print(f"[enumerate_csv_files] data[{itn}] is {filename!r}")
196
+ yield filename
197
+ continue
198
+
199
+ if os.path.exists(filename):
200
+ ext = os.path.splitext(filename)[-1]
201
+ if ext == ".csv":
202
+ # We check the first line is ok.
203
+ if verbose:
204
+ print(f"[enumerate_csv_files] data[{itn}] is a csv file: {filename!r}]")
205
+ dt = datetime.datetime.fromtimestamp(os.stat(filename).st_mtime)
206
+ du = dt.strftime("%Y-%m-%d %H:%M:%S")
207
+ yield (os.path.split(filename)[-1], du, filename, "")
208
+ continue
209
+
210
+ if ext == ".zip":
211
+ if verbose:
212
+ print(f"[enumerate_csv_files] data[{itn}] is a zip file: {filename!r}]")
213
+ zf = zipfile.ZipFile(filename, "r")
214
+ for ii, info in enumerate(zf.infolist()):
215
+ name = info.filename
216
+ if filtering is None:
217
+ ext = os.path.splitext(name)[-1]
218
+ if ext != ".csv":
219
+ continue
220
+ elif not filtering(name):
221
+ continue
222
+ if verbose:
223
+ print(
224
+ f"[enumerate_csv_files] data[{itn}][{ii}] is a csv file: {name!r}]"
225
+ )
226
+ with zf.open(name) as zzf:
227
+ first_line = zzf.readline()
228
+ if b"," not in first_line:
229
+ continue
230
+ yield (
231
+ os.path.split(name)[-1],
232
+ "%04d-%02d-%02d %02d:%02d:%02d" % info.date_time,
233
+ name,
234
+ filename,
235
+ )
236
+ zf.close()
237
+ continue
238
+
239
+ raise AssertionError(f"Unexpected format {filename!r}, cannot read it.")
240
+
241
+ # filename is a pattern.
242
+ found = glob.glob(filename)
243
+ if verbose and not found:
244
+ print(f"[enumerate_csv_files] unable to find file in {filename!r}")
245
+ for ii, f in enumerate(found):
246
+ if verbose:
247
+ print(f"[enumerate_csv_files] data[{itn}][{ii}] {f!r} from {filename!r}")
248
+ yield from enumerate_csv_files(f, verbose=verbose, filtering=filtering)
249
+
250
+
251
+ def open_dataframe(
252
+ data: Union[str, Tuple[str, str, str, str], pandas.DataFrame],
253
+ ) -> pandas.DataFrame:
254
+ """
255
+ Opens a filename defined by function
256
+ :func:`onnx_diagnostic.helpers._log_helper.enumerate_csv_files`.
257
+
258
+ :param data: a dataframe, a filename, a tuple indicating the file is coming
259
+ from a zip file
260
+ :return: a dataframe
261
+ """
262
+ if isinstance(data, pandas.DataFrame):
263
+ return data
264
+ if isinstance(data, str):
265
+ df = pandas.read_csv(data, low_memory=False)
266
+ df["RAWFILENAME"] = data
267
+ return df
268
+ if isinstance(data, tuple):
269
+ if not data[-1]:
270
+ df = pandas.read_csv(data[2], low_memory=False)
271
+ df["RAWFILENAME"] = data[2]
272
+ return df
273
+ zf = zipfile.ZipFile(data[-1])
274
+ with zf.open(data[2]) as f:
275
+ df = pandas.read_csv(f, low_memory=False)
276
+ df["RAWFILENAME"] = f"{data[-1]}/{data[2]}"
277
+ zf.close()
278
+ return df
279
+
280
+ raise ValueError(f"Unexpected value for data: {data!r}")
281
+
282
+
283
+ def align_dataframe_with(
284
+ df: pandas.DataFrame, baseline: pandas.DataFrame, fill_value: float = 0
285
+ ) -> Optional[pandas.DataFrame]:
286
+ """
287
+ Modifies the first dataframe *df* to get the exact same number of columns and rows.
288
+ They must share the same levels on both axes. Empty cells are filled with 0.
289
+ We only keep the numerical columns. The function return None if the output is empty.
290
+ """
291
+ df = df.select_dtypes(include="number")
292
+ if df.shape[1] == 0:
293
+ return None
294
+ bool_cols = list(df.select_dtypes(include="bool").columns)
295
+ if bool_cols:
296
+ df[bool_cols] = df[bool_cols].astype(int)
297
+ assert (
298
+ df.columns.names == baseline.columns.names or df.index.names == baseline.index.names
299
+ ), (
300
+ f"Levels mismatch, expected index.names={baseline.index.names}, "
301
+ f"expected columns.names={baseline.columns.names}, "
302
+ f"got index.names={df.index.names}, "
303
+ f"got columns.names={df.columns.names}"
304
+ )
305
+ dtypes = set(df[c].dtype for c in df.columns)
306
+ assert all(np.issubdtype(dt, np.number) for dt in dtypes), (
307
+ f"All columns in the first dataframe are expected to share "
308
+ f"the same type or be at least numerical but got {dtypes}\n{df}"
309
+ )
310
+ common_index = df.index.intersection(baseline.index)
311
+ cp = pandas.DataFrame(float(fill_value), index=baseline.index, columns=baseline.columns)
312
+ for c in df.columns:
313
+ if c not in cp.columns or not np.issubdtype(df[c].dtype, np.number):
314
+ continue
315
+ cp.loc[common_index, c] = df.loc[common_index, c].astype(cp[c].dtype)
316
+ return cp
317
+
318
+
319
+ def apply_excel_style(
320
+ filename_or_writer: Any,
321
+ f_highlights: Optional[ # type: ignore[name-defined]
322
+ Dict[str, Callable[[Any], "CubeViewDef.HighLightKind"]] # noqa: F821
323
+ ] = None,
324
+ time_mask_view: Optional[Dict[str, pandas.DataFrame]] = None,
325
+ verbose: int = 0,
326
+ ):
327
+ """
328
+ Applies styles on all sheets in a file unless the sheet is too big.
329
+
330
+ :param filename_or_writer: filename, modified inplace
331
+ :param f_highlight: color function to apply, one per sheet
332
+ :param time_mask_view: if specified, it contains dataframe with the same shape
333
+ and values in {-1, 0, +1} which indicates if a value is unexpectedly lower (-1)
334
+ or higher (+1), it changes the color of the background then.
335
+ :param verbosity: progress loop
336
+ """
337
+ from openpyxl import load_workbook
338
+ from openpyxl.styles import Alignment
339
+ from openpyxl.utils import get_column_letter
340
+ from openpyxl.styles import Font, PatternFill
341
+ from .log_helper import CubeViewDef
342
+
343
+ if isinstance(filename_or_writer, str):
344
+ workbook = load_workbook(filename_or_writer)
345
+ save = True
346
+ else:
347
+ workbook = filename_or_writer.book
348
+ save = False
349
+
350
+ mask_low = PatternFill(fgColor="AAAAF0", fill_type="solid")
351
+ mask_high = PatternFill(fgColor="F0AAAA", fill_type="solid")
352
+
353
+ left = Alignment(horizontal="left")
354
+ left_shrink = Alignment(horizontal="left", shrink_to_fit=True)
355
+ right = Alignment(horizontal="right")
356
+ font_colors = {
357
+ CubeViewDef.HighLightKind.GREEN: Font(color="00AA00"),
358
+ CubeViewDef.HighLightKind.RED: Font(color="FF0000"),
359
+ }
360
+ if verbose:
361
+ from tqdm import tqdm
362
+
363
+ sheet_names = tqdm(list(workbook.sheetnames))
364
+ else:
365
+ sheet_names = workbook.sheetnames
366
+ for name in sheet_names:
367
+ if time_mask_view and name in time_mask_view:
368
+ mask = time_mask_view[name]
369
+ with pandas.ExcelWriter(io.BytesIO(), engine="openpyxl") as mask_writer:
370
+ mask.to_excel(mask_writer, sheet_name=name)
371
+ sheet_mask = mask_writer.sheets[name]
372
+ else:
373
+ sheet_mask = None
374
+
375
+ f_highlight = f_highlights.get(name, None) if f_highlights else None
376
+ sheet = workbook[name]
377
+ n_rows = sheet.max_row
378
+ n_cols = sheet.max_column
379
+ if n_rows * n_cols > 2**16 or n_rows > 2**13:
380
+ # Too big.
381
+ continue
382
+ co: Dict[int, int] = {}
383
+ sizes: Dict[int, int] = {}
384
+ cols = set()
385
+ for i in range(1, n_rows + 1):
386
+ for j, cell in enumerate(sheet[i]):
387
+ if j > n_cols:
388
+ break
389
+ cols.add(cell.column)
390
+ if isinstance(cell.value, float):
391
+ co[j] = co.get(j, 0) + 1
392
+ elif isinstance(cell.value, str):
393
+ sizes[cell.column] = max(sizes.get(cell.column, 0), len(cell.value))
394
+
395
+ for k, v in sizes.items():
396
+ c = get_column_letter(k)
397
+ sheet.column_dimensions[c].width = min(max(8, v), 30)
398
+ for k in cols:
399
+ if k not in sizes:
400
+ c = get_column_letter(k)
401
+ sheet.column_dimensions[c].width = 15
402
+
403
+ for i in range(1, n_rows + 1):
404
+ for j, cell in enumerate(sheet[i]):
405
+ if j > n_cols:
406
+ break
407
+ if isinstance(cell.value, pandas.Timestamp):
408
+ cell.alignment = right
409
+ dt = cell.value.to_pydatetime()
410
+ cell.value = dt
411
+ cell.number_format = (
412
+ "YYYY-MM-DD"
413
+ if (
414
+ dt.hour == 0
415
+ and dt.minute == 0
416
+ and dt.second == 0
417
+ and dt.microsecond == 0
418
+ )
419
+ else "YYYY-MM-DD 00:00:00"
420
+ )
421
+ elif isinstance(cell.value, (float, int)):
422
+ cell.alignment = right
423
+ x = abs(cell.value)
424
+ if int(x) == x:
425
+ cell.number_format = "0"
426
+ elif x > 5000:
427
+ cell.number_format = "# ##0"
428
+ elif x >= 500:
429
+ cell.number_format = "0.0"
430
+ elif x >= 50:
431
+ cell.number_format = "0.00"
432
+ elif x >= 5:
433
+ cell.number_format = "0.000"
434
+ elif x > 0.5:
435
+ cell.number_format = "0.0000"
436
+ elif x > 0.005:
437
+ cell.number_format = "0.00000"
438
+ else:
439
+ cell.number_format = "0.000E+00"
440
+ if f_highlight:
441
+ h = f_highlight(cell.value)
442
+ if h in font_colors:
443
+ cell.font = font_colors[h]
444
+ elif isinstance(cell.value, str) and len(cell.value) > 70:
445
+ cell.alignment = left_shrink
446
+ else:
447
+ cell.alignment = left
448
+ if f_highlight:
449
+ h = f_highlight(cell.value)
450
+ if h in font_colors:
451
+ cell.font = font_colors[h]
452
+
453
+ if sheet_mask is not None:
454
+ for i in range(1, n_rows + 1):
455
+ for j, (cell, cell_mask) in enumerate(zip(sheet[i], sheet_mask[i])):
456
+ if j > n_cols:
457
+ break
458
+ if cell_mask.value not in (1, -1):
459
+ continue
460
+ cell.fill = mask_low if cell_mask.value < 0 else mask_high
461
+
462
+ if save:
463
+ workbook.save(filename_or_writer)
@@ -0,0 +1,132 @@
1
+ import subprocess
2
+ from argparse import ArgumentParser, Namespace
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+
6
+ def check_cuda_availability():
7
+ """
8
+ Checks if CUDA is available without pytorch or onnxruntime.
9
+ Calls `nvidia-smi`.
10
+ """
11
+ try:
12
+ import torch
13
+
14
+ return torch.cuda.device_count() > 0
15
+ except ImportError:
16
+ pass
17
+ try:
18
+ result = subprocess.run(["nvidia-smi"], capture_output=True, text=True)
19
+ return result.returncode == 0
20
+ except FileNotFoundError:
21
+ return False
22
+
23
+
24
+ def get_parsed_args(
25
+ name: str,
26
+ scenarios: Optional[Dict[str, str]] = None,
27
+ description: Optional[str] = None,
28
+ epilog: Optional[str] = None,
29
+ number: int = 10,
30
+ repeat: int = 10,
31
+ warmup: int = 5,
32
+ sleep: float = 0.1,
33
+ tries: int = 2,
34
+ expose: Optional[str] = None,
35
+ new_args: Optional[List[str]] = None,
36
+ **kwargs: Dict[str, Tuple[Union[int, str, float], str]],
37
+ ) -> Namespace:
38
+ """
39
+ Returns parsed arguments for examples in this package.
40
+
41
+ :param name: script name
42
+ :param scenarios: list of available scenarios
43
+ :param description: parser description
44
+ :param epilog: text at the end of the parser
45
+ :param number: default value for number parameter
46
+ :param repeat: default value for repeat parameter
47
+ :param warmup: default value for warmup parameter
48
+ :param sleep: default value for sleep parameter
49
+ :param expose: if empty, keeps all the parameters,
50
+ if not None, only publish kwargs contains, otherwise the list
51
+ of parameters to publish separated by a comma
52
+ :param new_args: args to consider or None to take `sys.args`
53
+ :param kwargs: additional parameters,
54
+ example: `n_trees=(10, "number of trees to train")`
55
+ :return: parser
56
+ """
57
+ if description is None:
58
+ description = f"Available options for {name}.py."
59
+ if epilog is None:
60
+ epilog = ""
61
+ parser = ArgumentParser(prog=name, description=description, epilog=epilog)
62
+ if expose is not None:
63
+ to_publish = set(expose.split(",")) if expose else set()
64
+ if scenarios is not None:
65
+ rows = ", ".join(f"{k}: {v}" for k, v in scenarios.items())
66
+ parser.add_argument("-s", "--scenario", help=f"Available scenarios: {rows}.")
67
+ if not to_publish or "number" in to_publish:
68
+ parser.add_argument(
69
+ "-n",
70
+ "--number",
71
+ help=f"number of executions to measure, default is {number}",
72
+ type=int,
73
+ default=number,
74
+ )
75
+ if not to_publish or "repeat" in to_publish:
76
+ parser.add_argument(
77
+ "-r",
78
+ "--repeat",
79
+ help=f"number of times to repeat the measure, default is {repeat}",
80
+ type=int,
81
+ default=repeat,
82
+ )
83
+ if not to_publish or "warmup" in to_publish:
84
+ parser.add_argument(
85
+ "-w",
86
+ "--warmup",
87
+ help=f"number of times to repeat the measure, default is {warmup}",
88
+ type=int,
89
+ default=warmup,
90
+ )
91
+ if not to_publish or "sleep" in to_publish:
92
+ parser.add_argument(
93
+ "-S",
94
+ "--sleep",
95
+ help=f"sleeping time between two configurations, default is {sleep}",
96
+ type=float,
97
+ default=sleep,
98
+ )
99
+ if not to_publish or "tries" in to_publish:
100
+ parser.add_argument(
101
+ "-t",
102
+ "--tries",
103
+ help=f"number of tries for each configurations, default is {tries}",
104
+ type=int,
105
+ default=tries,
106
+ )
107
+ for k, v in kwargs.items():
108
+ parser.add_argument(
109
+ f"--{k}",
110
+ help=f"{v[1]}, default is {v[0]}",
111
+ type=type(v[0]),
112
+ default=v[0],
113
+ )
114
+
115
+ res = parser.parse_args(args=new_args)
116
+ update: Dict[str, Union[int, float]] = {}
117
+ for k, v in res.__dict__.items():
118
+ try:
119
+ vi = int(v)
120
+ update[k] = vi
121
+ continue
122
+ except (ValueError, TypeError):
123
+ pass
124
+ try:
125
+ vf = float(v)
126
+ update[k] = vf
127
+ continue
128
+ except (ValueError, TypeError):
129
+ pass
130
+ if update:
131
+ res.__dict__.update(update)
132
+ return res