holobench 1.3.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bencher/__init__.py +41 -0
- bencher/bench_cfg.py +462 -0
- bencher/bench_plot_server.py +100 -0
- bencher/bench_report.py +268 -0
- bencher/bench_runner.py +136 -0
- bencher/bencher.py +805 -0
- bencher/caching.py +51 -0
- bencher/example/__init__.py +0 -0
- bencher/example/benchmark_data.py +200 -0
- bencher/example/example_all.py +45 -0
- bencher/example/example_categorical.py +99 -0
- bencher/example/example_custom_sweep.py +59 -0
- bencher/example/example_docs.py +34 -0
- bencher/example/example_float3D.py +101 -0
- bencher/example/example_float_cat.py +98 -0
- bencher/example/example_floats.py +89 -0
- bencher/example/example_floats2D.py +93 -0
- bencher/example/example_holosweep.py +104 -0
- bencher/example/example_holosweep_objects.py +111 -0
- bencher/example/example_holosweep_tap.py +144 -0
- bencher/example/example_image.py +82 -0
- bencher/example/example_levels.py +181 -0
- bencher/example/example_pareto.py +53 -0
- bencher/example/example_sample_cache.py +85 -0
- bencher/example/example_sample_cache_context.py +116 -0
- bencher/example/example_simple.py +134 -0
- bencher/example/example_simple_bool.py +34 -0
- bencher/example/example_simple_cat.py +47 -0
- bencher/example/example_simple_float.py +38 -0
- bencher/example/example_strings.py +46 -0
- bencher/example/example_time_event.py +62 -0
- bencher/example/example_video.py +124 -0
- bencher/example/example_workflow.py +189 -0
- bencher/example/experimental/example_bokeh_plotly.py +38 -0
- bencher/example/experimental/example_hover_ex.py +45 -0
- bencher/example/experimental/example_hvplot_explorer.py +39 -0
- bencher/example/experimental/example_interactive.py +75 -0
- bencher/example/experimental/example_streamnd.py +49 -0
- bencher/example/experimental/example_streams.py +36 -0
- bencher/example/experimental/example_template.py +40 -0
- bencher/example/experimental/example_updates.py +84 -0
- bencher/example/experimental/example_vector.py +84 -0
- bencher/example/meta/example_meta.py +171 -0
- bencher/example/meta/example_meta_cat.py +25 -0
- bencher/example/meta/example_meta_float.py +23 -0
- bencher/example/meta/example_meta_levels.py +26 -0
- bencher/example/optuna/example_optuna.py +78 -0
- bencher/example/shelved/example_float2D_scatter.py +109 -0
- bencher/example/shelved/example_float3D_cone.py +96 -0
- bencher/example/shelved/example_kwargs.py +63 -0
- bencher/job.py +184 -0
- bencher/optuna_conversions.py +168 -0
- bencher/plotting/__init__.py +0 -0
- bencher/plotting/plot_filter.py +110 -0
- bencher/plotting/plt_cnt_cfg.py +74 -0
- bencher/results/__init__.py +0 -0
- bencher/results/bench_result.py +80 -0
- bencher/results/bench_result_base.py +405 -0
- bencher/results/float_formatter.py +44 -0
- bencher/results/holoview_result.py +592 -0
- bencher/results/optuna_result.py +354 -0
- bencher/results/panel_result.py +113 -0
- bencher/results/plotly_result.py +65 -0
- bencher/utils.py +148 -0
- bencher/variables/inputs.py +193 -0
- bencher/variables/parametrised_sweep.py +206 -0
- bencher/variables/results.py +176 -0
- bencher/variables/sweep_base.py +167 -0
- bencher/variables/time.py +74 -0
- bencher/video_writer.py +30 -0
- bencher/worker_job.py +40 -0
- holobench-1.3.6.dist-info/METADATA +85 -0
- holobench-1.3.6.dist-info/RECORD +74 -0
- holobench-1.3.6.dist-info/WHEEL +5 -0
@@ -0,0 +1,354 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import List
|
3
|
+
from copy import deepcopy
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import optuna
|
7
|
+
import panel as pn
|
8
|
+
from collections import defaultdict
|
9
|
+
from textwrap import wrap
|
10
|
+
|
11
|
+
import pandas as pd
|
12
|
+
import xarray as xr
|
13
|
+
|
14
|
+
|
15
|
+
from optuna.visualization import (
|
16
|
+
plot_param_importances,
|
17
|
+
plot_pareto_front,
|
18
|
+
)
|
19
|
+
from bencher.utils import hmap_canonical_input
|
20
|
+
from bencher.variables.time import TimeSnapshot, TimeEvent
|
21
|
+
from bencher.bench_cfg import BenchCfg
|
22
|
+
from bencher.plotting.plt_cnt_cfg import PltCntCfg
|
23
|
+
|
24
|
+
|
25
|
+
# from bencher.results.bench_result_base import BenchResultBase
|
26
|
+
from bencher.optuna_conversions import (
|
27
|
+
sweep_var_to_optuna_dist,
|
28
|
+
summarise_trial,
|
29
|
+
param_importance,
|
30
|
+
optuna_grid_search,
|
31
|
+
summarise_optuna_study,
|
32
|
+
sweep_var_to_suggest,
|
33
|
+
)
|
34
|
+
|
35
|
+
|
36
|
+
def convert_dataset_bool_dims_to_str(dataset: xr.Dataset) -> xr.Dataset:
|
37
|
+
"""Given a dataarray that contains boolean coordinates, conver them to strings so that holoviews loads the data properly
|
38
|
+
|
39
|
+
Args:
|
40
|
+
dataarray (xr.DataArray): dataarray with boolean coordinates
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
xr.DataArray: dataarray with boolean coordinates converted to strings
|
44
|
+
"""
|
45
|
+
bool_coords = {}
|
46
|
+
for c in dataset.coords:
|
47
|
+
if dataset.coords[c].dtype == bool:
|
48
|
+
bool_coords[c] = [str(vals) for vals in dataset.coords[c].values]
|
49
|
+
|
50
|
+
if len(bool_coords) > 0:
|
51
|
+
return dataset.assign_coords(bool_coords)
|
52
|
+
return dataset
|
53
|
+
|
54
|
+
|
55
|
+
class OptunaResult:
|
56
|
+
def __init__(self, bench_cfg: BenchCfg) -> None:
|
57
|
+
self.bench_cfg = bench_cfg
|
58
|
+
# self.wrap_long_time_labels(bench_cfg) # todo remove
|
59
|
+
self.ds = xr.Dataset()
|
60
|
+
self.object_index = []
|
61
|
+
self.hmaps = defaultdict(dict)
|
62
|
+
self.result_hmaps = bench_cfg.result_hmaps
|
63
|
+
self.studies = []
|
64
|
+
self.plt_cnt_cfg = PltCntCfg()
|
65
|
+
self.plot_inputs = []
|
66
|
+
|
67
|
+
# self.width=600/
|
68
|
+
# self.height=600
|
69
|
+
|
70
|
+
# bench_res.objects.append(rv)
|
71
|
+
# bench_res.reference_index = len(bench_res.objects)
|
72
|
+
|
73
|
+
def post_setup(self):
|
74
|
+
self.plt_cnt_cfg = PltCntCfg.generate_plt_cnt_cfg(self.bench_cfg)
|
75
|
+
self.bench_cfg = self.wrap_long_time_labels(self.bench_cfg)
|
76
|
+
self.ds = convert_dataset_bool_dims_to_str(self.ds)
|
77
|
+
|
78
|
+
def to_xarray(self) -> xr.Dataset:
|
79
|
+
return self.ds
|
80
|
+
|
81
|
+
def setup_object_index(self):
|
82
|
+
self.object_index = []
|
83
|
+
|
84
|
+
def to_pandas(self, reset_index=True) -> pd.DataFrame:
|
85
|
+
"""Get the xarray results as a pandas dataframe
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
pd.DataFrame: The xarray results array as a pandas dataframe
|
89
|
+
"""
|
90
|
+
ds = self.to_xarray().to_dataframe()
|
91
|
+
if reset_index:
|
92
|
+
return ds.reset_index()
|
93
|
+
return ds
|
94
|
+
|
95
|
+
def wrap_long_time_labels(self, bench_cfg):
|
96
|
+
"""Takes a benchCfg and wraps any index labels that are too long to be plotted easily
|
97
|
+
|
98
|
+
Args:
|
99
|
+
bench_cfg (BenchCfg):
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
BenchCfg: updated config with wrapped labels
|
103
|
+
"""
|
104
|
+
if bench_cfg.over_time:
|
105
|
+
if self.ds.coords["over_time"].dtype == np.datetime64:
|
106
|
+
# plotly catastrophically fails to plot anything with the default long string representation of time, so convert to a shorter time representation
|
107
|
+
self.ds.coords["over_time"] = [
|
108
|
+
pd.to_datetime(t).strftime("%d-%m-%y %H-%M-%S")
|
109
|
+
for t in self.ds.coords["over_time"].values
|
110
|
+
]
|
111
|
+
# wrap very long time event labels because otherwise the graphs are unreadable
|
112
|
+
if bench_cfg.time_event is not None:
|
113
|
+
self.ds.coords["over_time"] = [
|
114
|
+
"\n".join(wrap(t, 20)) for t in self.ds.coords["over_time"].values
|
115
|
+
]
|
116
|
+
return bench_cfg
|
117
|
+
|
118
|
+
def to_optuna_plots(self) -> List[pn.pane.panel]:
|
119
|
+
"""Create an optuna summary from the benchmark results
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
List[pn.pane.panel]: A list of optuna plot summarising the benchmark process
|
123
|
+
"""
|
124
|
+
|
125
|
+
return self.collect_optuna_plots()
|
126
|
+
|
127
|
+
def to_optuna_from_sweep(self, bench, n_trials=30):
|
128
|
+
optu = self.to_optuna_from_results(
|
129
|
+
bench.worker, n_trials=n_trials, extra_results=bench.results
|
130
|
+
)
|
131
|
+
return summarise_optuna_study(optu)
|
132
|
+
|
133
|
+
def to_optuna_from_results(
|
134
|
+
self,
|
135
|
+
worker,
|
136
|
+
n_trials=100,
|
137
|
+
extra_results: List[OptunaResult] = None,
|
138
|
+
sampler=optuna.samplers.TPESampler(),
|
139
|
+
):
|
140
|
+
directions = []
|
141
|
+
for rv in self.bench_cfg.optuna_targets(True):
|
142
|
+
directions.append(rv.direction)
|
143
|
+
|
144
|
+
study = optuna.create_study(
|
145
|
+
sampler=sampler, directions=directions, study_name=self.bench_cfg.title
|
146
|
+
)
|
147
|
+
|
148
|
+
# add already calculated results
|
149
|
+
results_list = extra_results if extra_results is not None else [self]
|
150
|
+
for res in results_list:
|
151
|
+
if len(res.ds.sizes) > 0:
|
152
|
+
study.add_trials(res.bench_results_to_optuna_trials(True))
|
153
|
+
|
154
|
+
def wrapped(trial) -> tuple:
|
155
|
+
kwargs = {}
|
156
|
+
for iv in self.bench_cfg.input_vars:
|
157
|
+
kwargs[iv.name] = sweep_var_to_suggest(iv, trial)
|
158
|
+
result = worker(**kwargs)
|
159
|
+
output = []
|
160
|
+
for rv in self.bench_cfg.result_vars:
|
161
|
+
output.append(result[rv.name])
|
162
|
+
return tuple(output)
|
163
|
+
|
164
|
+
study.optimize(wrapped, n_trials=n_trials)
|
165
|
+
return study
|
166
|
+
|
167
|
+
def bench_results_to_optuna_trials(self, include_meta: bool = True) -> optuna.Study:
|
168
|
+
"""Convert an xarray dataset to an optuna study so optuna can further optimise or plot the statespace
|
169
|
+
|
170
|
+
Args:
|
171
|
+
bench_cfg (BenchCfg): benchmark config to convert
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
optuna.Study: optuna description of the study
|
175
|
+
"""
|
176
|
+
if include_meta:
|
177
|
+
df = self.to_pandas()
|
178
|
+
all_vars = []
|
179
|
+
for v in self.bench_cfg.all_vars:
|
180
|
+
if type(v) != TimeEvent:
|
181
|
+
all_vars.append(v)
|
182
|
+
|
183
|
+
print("All vars", all_vars)
|
184
|
+
else:
|
185
|
+
all_vars = self.bench_cfg.input_vars
|
186
|
+
# df = self.ds.
|
187
|
+
# if "repeat" in self.
|
188
|
+
# if self.bench_cfg.repeats>1:
|
189
|
+
# df = self.bench_cfg.ds.mean("repeat").to_dataframe().reset_index()
|
190
|
+
# else:
|
191
|
+
df = self.to_pandas().reset_index()
|
192
|
+
# df = self.bench_cfg.ds.mean("repeat").to_dataframe.reset_index()
|
193
|
+
# self.bench_cfg.all_vars
|
194
|
+
# del self.bench_cfg.meta_vars[1]
|
195
|
+
|
196
|
+
trials = []
|
197
|
+
distributions = {}
|
198
|
+
for i in all_vars:
|
199
|
+
distributions[i.name] = sweep_var_to_optuna_dist(i)
|
200
|
+
|
201
|
+
for row in df.iterrows():
|
202
|
+
params = {}
|
203
|
+
values = []
|
204
|
+
for i in all_vars:
|
205
|
+
if type(i) == TimeSnapshot:
|
206
|
+
if type(row[1][i.name]) == np.datetime64:
|
207
|
+
params[i.name] = row[1][i.name].timestamp()
|
208
|
+
else:
|
209
|
+
params[i.name] = row[1][i.name]
|
210
|
+
|
211
|
+
for r in self.bench_cfg.optuna_targets():
|
212
|
+
values.append(row[1][r])
|
213
|
+
|
214
|
+
trials.append(
|
215
|
+
optuna.trial.create_trial(
|
216
|
+
params=params,
|
217
|
+
distributions=distributions,
|
218
|
+
values=values,
|
219
|
+
)
|
220
|
+
)
|
221
|
+
return trials
|
222
|
+
|
223
|
+
def bench_result_to_study(self, include_meta: bool) -> optuna.Study:
|
224
|
+
trials = self.bench_results_to_optuna_trials(include_meta)
|
225
|
+
study = optuna_grid_search(self.bench_cfg)
|
226
|
+
optuna.logging.set_verbosity(optuna.logging.CRITICAL)
|
227
|
+
import warnings
|
228
|
+
|
229
|
+
# /usr/local/lib/python3.10/dist-packages/optuna/samplers/_grid.py:224: UserWarning: over_time contains a value with the type of <class 'pandas._libs.tslibs.timestamps.Timestamp'>, which is not supported by `GridSampler`. Please make sure a value is `str`, `int`, `float`, `bool` or `None` for persistent storage.
|
230
|
+
|
231
|
+
# this is not disabling the warning
|
232
|
+
warnings.filterwarnings(action="ignore", category=UserWarning)
|
233
|
+
# remove optuna gridsearch warning as we are not using their gridsearch because it has the most inexplicably terrible performance I have ever seen in my life. How can a for loop of 400 iterations start out with 100ms per loop and increase to greater than a 1000ms after 250ish iterations!?!?!??!!??!
|
234
|
+
study.add_trials(trials)
|
235
|
+
return study
|
236
|
+
|
237
|
+
def get_best_trial_params(self, canonical=False):
|
238
|
+
studies = self.bench_result_to_study(True)
|
239
|
+
out = studies.best_trials[0].params
|
240
|
+
if canonical:
|
241
|
+
return hmap_canonical_input(out)
|
242
|
+
return out
|
243
|
+
|
244
|
+
def get_pareto_front_params(self):
|
245
|
+
return [p.params for p in self.studies[0].trials]
|
246
|
+
|
247
|
+
def collect_optuna_plots(self) -> List[pn.pane.panel]:
|
248
|
+
"""Use optuna to plot various summaries of the optimisation
|
249
|
+
|
250
|
+
Args:
|
251
|
+
study (optuna.Study): The study to plot
|
252
|
+
bench_cfg (BenchCfg): Benchmark config with options used to generate the study
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
List[pn.pane.Pane]: A list of plots
|
256
|
+
"""
|
257
|
+
|
258
|
+
self.studies = [self.bench_result_to_study(True)]
|
259
|
+
titles = ["# Analysis"]
|
260
|
+
if self.bench_cfg.repeats > 1:
|
261
|
+
self.studies.append(self.bench_result_to_study(False))
|
262
|
+
titles = [
|
263
|
+
"# Parameter Importance With Repeats",
|
264
|
+
"# Parameter Importance Without Repeats",
|
265
|
+
]
|
266
|
+
|
267
|
+
study_repeats_pane = pn.Row()
|
268
|
+
for study, title in zip(self.studies, titles):
|
269
|
+
study_pane = pn.Column()
|
270
|
+
target_names = self.bench_cfg.optuna_targets()
|
271
|
+
param_str = []
|
272
|
+
|
273
|
+
study_pane.append(pn.pane.Markdown(title))
|
274
|
+
|
275
|
+
if len(target_names) > 1:
|
276
|
+
if len(target_names) <= 3:
|
277
|
+
study_pane.append(
|
278
|
+
plot_pareto_front(
|
279
|
+
study, target_names=target_names, include_dominated_trials=False
|
280
|
+
)
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
print("plotting pareto front of first 3 result variables")
|
284
|
+
study_pane.append(
|
285
|
+
plot_pareto_front(
|
286
|
+
study,
|
287
|
+
targets=lambda t: (t.values[0], t.values[1], t.values[2]),
|
288
|
+
target_names=target_names[:3],
|
289
|
+
include_dominated_trials=False,
|
290
|
+
)
|
291
|
+
)
|
292
|
+
|
293
|
+
study_pane.append(param_importance(self.bench_cfg, study))
|
294
|
+
param_str.append(
|
295
|
+
f" Number of trials on the Pareto front: {len(study.best_trials)}"
|
296
|
+
)
|
297
|
+
for t in study.best_trials:
|
298
|
+
param_str.extend(summarise_trial(t, self.bench_cfg))
|
299
|
+
|
300
|
+
else:
|
301
|
+
# cols.append(plot_optimization_history(study)) #TODO, maybe more clever when this is plotted?
|
302
|
+
|
303
|
+
# If there is only 1 parameter then there is no point is plotting relative importance. Only worth plotting if there are multiple repeats of the same value so that you can compare the parameter vs to repeat to get a sense of the how much chance affects the results
|
304
|
+
# if bench_cfg.repeats > 1 and len(bench_cfg.input_vars) > 1: #old code, not sure if its right
|
305
|
+
if len(self.bench_cfg.input_vars) > 1:
|
306
|
+
study_pane.append(plot_param_importances(study, target_name=target_names[0]))
|
307
|
+
|
308
|
+
param_str.extend(summarise_trial(study.best_trial, self.bench_cfg))
|
309
|
+
|
310
|
+
kwargs = {"height": 500, "scroll": True} if len(param_str) > 30 else {}
|
311
|
+
|
312
|
+
param_str = "\n".join(param_str)
|
313
|
+
study_pane.append(
|
314
|
+
pn.Row(pn.pane.Markdown(f"## Best Parameters\n```text\n{param_str}"), **kwargs),
|
315
|
+
)
|
316
|
+
|
317
|
+
study_repeats_pane.append(study_pane)
|
318
|
+
|
319
|
+
return study_repeats_pane
|
320
|
+
|
321
|
+
# def extract_study_to_dataset(study: optuna.Study, bench_cfg: BenchCfg) -> BenchCfg:
|
322
|
+
# """Extract an optuna study into an xarray dataset for easy plotting
|
323
|
+
|
324
|
+
# Args:
|
325
|
+
# study (optuna.Study): The result of a gridsearch
|
326
|
+
# bench_cfg (BenchCfg): Options for the grid search
|
327
|
+
|
328
|
+
# Returns:
|
329
|
+
# BenchCfg: An updated config with the results included
|
330
|
+
# """
|
331
|
+
# for t in study.trials:
|
332
|
+
# for it, rv in enumerate(bench_cfg.result_vars):
|
333
|
+
# bench_cfg.ds[rv.name].loc[t.params] = t.values[it]
|
334
|
+
# return bench_cfg
|
335
|
+
|
336
|
+
def deep(self) -> OptunaResult: # pragma: no cover
|
337
|
+
"""Return a deep copy of these results"""
|
338
|
+
return deepcopy(self)
|
339
|
+
|
340
|
+
def set_plot_size(self, **kwargs) -> dict:
|
341
|
+
if "width" not in kwargs:
|
342
|
+
if self.bench_cfg.plot_size is not None:
|
343
|
+
kwargs["width"] = self.bench_cfg.plot_size
|
344
|
+
# specific width overrrides general size
|
345
|
+
if self.bench_cfg.plot_width is not None:
|
346
|
+
kwargs["width"] = self.bench_cfg.plot_width
|
347
|
+
|
348
|
+
if "height" not in kwargs:
|
349
|
+
if self.bench_cfg.plot_size is not None:
|
350
|
+
kwargs["height"] = self.bench_cfg.plot_size
|
351
|
+
# specific height overrrides general size
|
352
|
+
if self.bench_cfg.plot_height is not None:
|
353
|
+
kwargs["height"] = self.bench_cfg.plot_height
|
354
|
+
return kwargs
|
@@ -0,0 +1,113 @@
|
|
1
|
+
from typing import Optional, Any
|
2
|
+
from pathlib import Path
|
3
|
+
from functools import partial
|
4
|
+
import panel as pn
|
5
|
+
import xarray as xr
|
6
|
+
from param import Parameter
|
7
|
+
from bencher.results.bench_result_base import BenchResultBase, ReduceType
|
8
|
+
from bencher.variables.results import (
|
9
|
+
ResultVideo,
|
10
|
+
ResultContainer,
|
11
|
+
ResultReference,
|
12
|
+
PANEL_TYPES,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
class PanelResult(BenchResultBase):
|
17
|
+
def to_video(self, **kwargs):
|
18
|
+
return self.map_plots(partial(self.to_video_multi, **kwargs))
|
19
|
+
|
20
|
+
def to_video_multi(self, result_var: Parameter, **kwargs) -> Optional[pn.Column]:
|
21
|
+
if isinstance(result_var, (ResultVideo, ResultContainer)):
|
22
|
+
vid_p = []
|
23
|
+
|
24
|
+
xr_dataset = self.to_hv_dataset(ReduceType.SQUEEZE)
|
25
|
+
|
26
|
+
def to_video_da(da, **kwargs):
|
27
|
+
if da is not None and Path(da).exists():
|
28
|
+
vid = pn.pane.Video(da, autoplay=True, **kwargs)
|
29
|
+
vid.loop = True
|
30
|
+
vid_p.append(vid)
|
31
|
+
return vid
|
32
|
+
return pn.pane.Markdown(f"video does not exist {da}")
|
33
|
+
|
34
|
+
plot_callback = partial(self.ds_to_container, container=partial(to_video_da, **kwargs))
|
35
|
+
|
36
|
+
panes = self.to_panes_multi_panel(
|
37
|
+
xr_dataset, result_var, plot_callback=plot_callback, target_dimension=0
|
38
|
+
)
|
39
|
+
|
40
|
+
def play_vid(_): # pragma: no cover
|
41
|
+
for r in vid_p:
|
42
|
+
r.paused = False
|
43
|
+
r.loop = False
|
44
|
+
|
45
|
+
def pause_vid(_): # pragma: no cover
|
46
|
+
for r in vid_p:
|
47
|
+
r.paused = True
|
48
|
+
|
49
|
+
def reset_vid(_): # pragma: no cover
|
50
|
+
for r in vid_p:
|
51
|
+
r.paused = False
|
52
|
+
r.time = 0
|
53
|
+
|
54
|
+
def loop_vid(_): # pragma: no cover
|
55
|
+
for r in vid_p:
|
56
|
+
r.paused = False
|
57
|
+
r.time = 0
|
58
|
+
r.loop = True
|
59
|
+
|
60
|
+
button_names = ["Play Videos", "Pause Videos", "Loop Videos", "Reset Videos"]
|
61
|
+
buttom_cb = [play_vid, pause_vid, reset_vid, loop_vid]
|
62
|
+
buttons = pn.Row()
|
63
|
+
|
64
|
+
for name, cb in zip(button_names, buttom_cb):
|
65
|
+
button = pn.widgets.Button(name=name)
|
66
|
+
pn.bind(cb, button, watch=True)
|
67
|
+
buttons.append(button)
|
68
|
+
|
69
|
+
return pn.Column(buttons, panes)
|
70
|
+
return None
|
71
|
+
|
72
|
+
def zero_dim_da_to_val(self, da_ds: xr.DataArray | xr.Dataset) -> Any:
|
73
|
+
# todo this is really horrible, need to improve
|
74
|
+
dim = None
|
75
|
+
if isinstance(da_ds, xr.Dataset):
|
76
|
+
dim = list(da_ds.keys())[0]
|
77
|
+
da = da_ds[dim]
|
78
|
+
else:
|
79
|
+
da = da_ds
|
80
|
+
|
81
|
+
for k in da.coords.keys():
|
82
|
+
dim = k
|
83
|
+
break
|
84
|
+
if dim is None:
|
85
|
+
return da_ds.values.squeeze().item()
|
86
|
+
return da.expand_dims(dim).values[0]
|
87
|
+
|
88
|
+
def ds_to_container(
|
89
|
+
self, dataset: xr.Dataset, result_var: Parameter, container, **kwargs
|
90
|
+
) -> Any:
|
91
|
+
val = self.zero_dim_da_to_val(dataset[result_var.name])
|
92
|
+
if isinstance(result_var, ResultReference):
|
93
|
+
ref = self.object_index[val]
|
94
|
+
val = ref.obj
|
95
|
+
if ref.container is not None:
|
96
|
+
return ref.container(val, **kwargs)
|
97
|
+
if container is not None:
|
98
|
+
return container(val, styles={"background": "white"}, **kwargs)
|
99
|
+
return val
|
100
|
+
|
101
|
+
def to_panes(
|
102
|
+
self, result_var: Parameter = None, target_dimension: int = 0, container=None, **kwargs
|
103
|
+
) -> Optional[pn.pane.panel]:
|
104
|
+
if container is None:
|
105
|
+
container = pn.pane.panel
|
106
|
+
return self.map_plot_panes(
|
107
|
+
partial(self.ds_to_container, container=container),
|
108
|
+
hv_dataset=self.to_hv_dataset(ReduceType.SQUEEZE),
|
109
|
+
target_dimension=target_dimension,
|
110
|
+
result_var=result_var,
|
111
|
+
result_types=PANEL_TYPES,
|
112
|
+
**kwargs,
|
113
|
+
)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
import panel as pn
|
2
|
+
import plotly.graph_objs as go
|
3
|
+
from typing import Optional
|
4
|
+
import xarray as xr
|
5
|
+
|
6
|
+
from param import Parameter
|
7
|
+
|
8
|
+
from bencher.plotting.plot_filter import VarRange
|
9
|
+
from bencher.results.bench_result_base import BenchResultBase, ReduceType
|
10
|
+
from bencher.variables.results import ResultVar
|
11
|
+
|
12
|
+
|
13
|
+
class PlotlyResult(BenchResultBase):
|
14
|
+
def to_volume(self, result_var: Parameter = None, **kwargs):
|
15
|
+
return self.filter(
|
16
|
+
self.to_volume_da,
|
17
|
+
float_range=VarRange(3, 3),
|
18
|
+
cat_range=VarRange(-1, 0),
|
19
|
+
reduce=ReduceType.REDUCE,
|
20
|
+
target_dimension=3,
|
21
|
+
result_var=result_var,
|
22
|
+
result_types=(ResultVar),
|
23
|
+
**kwargs,
|
24
|
+
)
|
25
|
+
|
26
|
+
def to_volume_da(
|
27
|
+
self, dataset: xr.Dataset, result_var: Parameter, width=600, height=600
|
28
|
+
) -> Optional[pn.pane.Plotly]:
|
29
|
+
"""Given a benchCfg generate a 3D surface plot
|
30
|
+
Returns:
|
31
|
+
pn.pane.Plotly: A 3d volume plot as a holoview in a pane
|
32
|
+
"""
|
33
|
+
x = self.bench_cfg.input_vars[0]
|
34
|
+
y = self.bench_cfg.input_vars[1]
|
35
|
+
z = self.bench_cfg.input_vars[2]
|
36
|
+
opacity = 0.1
|
37
|
+
meandf = dataset[result_var.name].to_dataframe().reset_index()
|
38
|
+
data = [
|
39
|
+
go.Volume(
|
40
|
+
x=meandf[x.name],
|
41
|
+
y=meandf[y.name],
|
42
|
+
z=meandf[z.name],
|
43
|
+
value=meandf[result_var.name],
|
44
|
+
isomin=meandf[result_var.name].min(),
|
45
|
+
isomax=meandf[result_var.name].max(),
|
46
|
+
opacity=opacity,
|
47
|
+
surface_count=20,
|
48
|
+
)
|
49
|
+
]
|
50
|
+
|
51
|
+
layout = go.Layout(
|
52
|
+
title=f"{result_var.name} vs ({x.name} vs {y.name} vs {z.name})",
|
53
|
+
width=width,
|
54
|
+
height=height,
|
55
|
+
margin=dict(t=50, b=50, r=50, l=50),
|
56
|
+
scene=dict(
|
57
|
+
xaxis_title=f"{x.name} [{x.units}]",
|
58
|
+
yaxis_title=f"{y.name} [{y.units}]",
|
59
|
+
zaxis_title=f"{z.name} [{z.units}]",
|
60
|
+
),
|
61
|
+
)
|
62
|
+
|
63
|
+
fig = dict(data=data, layout=layout)
|
64
|
+
|
65
|
+
return pn.pane.Plotly(fig, name="volume_plotly")
|
bencher/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
1
|
+
from collections import namedtuple
|
2
|
+
import xarray as xr
|
3
|
+
from sortedcontainers import SortedDict
|
4
|
+
import hashlib
|
5
|
+
import re
|
6
|
+
import math
|
7
|
+
from colorsys import hsv_to_rgb
|
8
|
+
from pathlib import Path
|
9
|
+
from uuid import uuid4
|
10
|
+
from functools import partial
|
11
|
+
from typing import Callable, Any
|
12
|
+
|
13
|
+
|
14
|
+
def hmap_canonical_input(dic: dict) -> tuple:
|
15
|
+
"""From a dictionary of kwargs, return a hashable representation (tuple) that is always the same for the same inputs and retains the order of the input arguments. e.g, {x=1,y=2} -> (1,2) and {y=2,x=1} -> (1,2). This is used so that keywords arguments can be hashed and converted the the tuple keys that are used for holomaps
|
16
|
+
|
17
|
+
Args:
|
18
|
+
dic (dict): dictionary with keyword arguments and values in any order
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
tuple: values of the dictionary always in the same order and hashable
|
22
|
+
"""
|
23
|
+
|
24
|
+
function_input = SortedDict(dic)
|
25
|
+
return tuple(function_input.values())
|
26
|
+
|
27
|
+
|
28
|
+
def make_namedtuple(class_name: str, **fields) -> namedtuple:
|
29
|
+
"""Convenience method for making a named tuple
|
30
|
+
|
31
|
+
Args:
|
32
|
+
class_name (str): name of the named tuple
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
namedtuple: a named tuple with the fields as values
|
36
|
+
"""
|
37
|
+
return namedtuple(class_name, fields)(*fields.values())
|
38
|
+
|
39
|
+
|
40
|
+
def get_nearest_coords(dataset: xr.Dataset, collapse_list=False, **kwargs) -> dict:
|
41
|
+
"""Given an xarray dataset and kwargs of key value pairs of coordinate values, return a dictionary of the nearest coordinate name value pair that was found in the dataset
|
42
|
+
|
43
|
+
Args:
|
44
|
+
ds (xr.Dataset): dataset
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
dict: nearest coordinate name value pair that matches the input coordinate name value pairs.
|
48
|
+
"""
|
49
|
+
|
50
|
+
selection = dataset.sel(method="nearest", **kwargs)
|
51
|
+
cd = selection.coords.to_dataset().to_dict()["coords"]
|
52
|
+
cd2 = {}
|
53
|
+
for k, v in cd.items():
|
54
|
+
cd2[k] = v["data"]
|
55
|
+
if collapse_list and isinstance(cd2[k], list):
|
56
|
+
cd2[k] = cd2[k][0] # select the first item in the list
|
57
|
+
return cd2
|
58
|
+
|
59
|
+
|
60
|
+
def get_nearest_coords1D(val: Any, coords) -> Any:
|
61
|
+
if isinstance(val, (int, float)):
|
62
|
+
return min(coords, key=lambda x_: abs(x_ - val))
|
63
|
+
return val
|
64
|
+
|
65
|
+
|
66
|
+
def hash_sha1(var: any) -> str:
|
67
|
+
"""A hash function that avoids the PYTHONHASHSEED 'feature' which returns a different hash value each time the program is run"""
|
68
|
+
return hashlib.sha1(str(var).encode("ASCII")).hexdigest()
|
69
|
+
|
70
|
+
|
71
|
+
def capitalise_words(message: str):
|
72
|
+
"""Given a string of lowercase words, capitalise them
|
73
|
+
|
74
|
+
Args:
|
75
|
+
message (str): lower case string
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
_type_: capitalised string
|
79
|
+
"""
|
80
|
+
capitalized_message = " ".join([word.capitalize() for word in message.split(" ")])
|
81
|
+
return capitalized_message
|
82
|
+
|
83
|
+
|
84
|
+
def un_camel(camel: str) -> str:
|
85
|
+
"""Given a snake_case string return a CamelCase string
|
86
|
+
|
87
|
+
Args:
|
88
|
+
camel (str): camelcase string
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
str: uncamelcased string
|
92
|
+
"""
|
93
|
+
|
94
|
+
return capitalise_words(re.sub("([a-z])([A-Z])", r"\g<1> \g<2>", camel.replace("_", " ")))
|
95
|
+
|
96
|
+
|
97
|
+
def int_to_col(int_val, sat=0.5, val=0.95, alpha=-1) -> tuple[float, float, float]:
|
98
|
+
"""Uses the golden angle to generate colors programmatically with minimum overlap between colors.
|
99
|
+
https://martin.ankerl.com/2009/12/09/how-to-create-random-colors-programmatically/
|
100
|
+
|
101
|
+
Args:
|
102
|
+
int_val (_type_): index of an object you want to color, this is mapped to hue in HSV
|
103
|
+
sat (float, optional): saturation in HSV. Defaults to 0.5.
|
104
|
+
val (float, optional): value in HSV. Defaults to 0.95.
|
105
|
+
alpha (int, optional): transparency. If -1 then only RGB is returned, if 0 or greater, RGBA is returned. Defaults to -1.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
tuple[float, float, float] | tuple[float, float, float, float]: either RGB or RGBA vector
|
109
|
+
"""
|
110
|
+
golden_ratio_conjugate = (1 + math.sqrt(5)) / 2
|
111
|
+
rgb = hsv_to_rgb(int_val * golden_ratio_conjugate, sat, val)
|
112
|
+
if alpha >= 0:
|
113
|
+
return (*rgb, alpha)
|
114
|
+
return rgb
|
115
|
+
|
116
|
+
|
117
|
+
def lerp(value, input_low: float, input_high: float, output_low: float, output_high: float):
|
118
|
+
input_low = float(input_low)
|
119
|
+
return output_low + ((float(value) - input_low) / (float(input_high) - input_low)) * (
|
120
|
+
float(output_high) - output_low
|
121
|
+
)
|
122
|
+
|
123
|
+
|
124
|
+
def color_tuple_to_css(color: tuple[float, float, float]) -> str:
|
125
|
+
return f"rgb{(color[0] * 255, color[1] * 255, color[2] * 255)}"
|
126
|
+
|
127
|
+
|
128
|
+
def gen_path(filename, folder, suffix):
|
129
|
+
path = Path(f"cachedir/{folder}/{filename}/")
|
130
|
+
path.mkdir(parents=True, exist_ok=True)
|
131
|
+
return f"{path.absolute().as_posix()}/{filename}_{uuid4()}{suffix}"
|
132
|
+
|
133
|
+
|
134
|
+
def gen_video_path(video_name: str, extension: str = ".webm") -> str:
|
135
|
+
return gen_path(video_name, "vid", extension)
|
136
|
+
|
137
|
+
|
138
|
+
def gen_image_path(image_name: str, filetype=".png") -> str:
|
139
|
+
return gen_path(image_name, "img", filetype)
|
140
|
+
|
141
|
+
|
142
|
+
def callable_name(any_callable: Callable[..., Any]) -> str:
|
143
|
+
if isinstance(any_callable, partial):
|
144
|
+
return any_callable.func.__name__
|
145
|
+
try:
|
146
|
+
return any_callable.__name__
|
147
|
+
except AttributeError:
|
148
|
+
return str(any_callable)
|