holobench 1.41.0__py3-none-any.whl → 1.43.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bencher/__init__.py +20 -2
- bencher/bench_cfg.py +262 -54
- bencher/bench_report.py +2 -2
- bencher/bench_runner.py +96 -10
- bencher/bencher.py +421 -89
- bencher/class_enum.py +70 -7
- bencher/example/example_dataframe.py +2 -2
- bencher/example/example_levels.py +17 -173
- bencher/example/example_pareto.py +107 -31
- bencher/example/example_rerun2.py +1 -1
- bencher/example/example_simple_bool.py +2 -2
- bencher/example/example_simple_float2d.py +6 -1
- bencher/example/example_video.py +2 -0
- bencher/example/experimental/example_hvplot_explorer.py +2 -2
- bencher/example/inputs_0D/example_0_in_1_out.py +25 -15
- bencher/example/inputs_0D/example_0_in_2_out.py +12 -3
- bencher/example/inputs_0_float/example_0_cat_in_2_out.py +88 -0
- bencher/example/inputs_0_float/example_1_cat_in_2_out.py +98 -0
- bencher/example/inputs_0_float/example_2_cat_in_2_out.py +107 -0
- bencher/example/inputs_0_float/example_3_cat_in_2_out.py +111 -0
- bencher/example/inputs_1D/example1d_common.py +48 -12
- bencher/example/inputs_1D/example_0_float_1_cat.py +33 -0
- bencher/example/inputs_1D/example_1_cat_in_2_out_repeats.py +68 -0
- bencher/example/inputs_1D/example_1_float_2_cat_repeats.py +3 -0
- bencher/example/inputs_1D/example_1_int_in_1_out.py +98 -0
- bencher/example/inputs_1D/example_1_int_in_2_out.py +101 -0
- bencher/example/inputs_1D/example_1_int_in_2_out_repeats.py +99 -0
- bencher/example/inputs_1_float/example_1_float_0_cat_in_2_out.py +117 -0
- bencher/example/inputs_1_float/example_1_float_1_cat_in_2_out.py +124 -0
- bencher/example/inputs_1_float/example_1_float_2_cat_in_2_out.py +132 -0
- bencher/example/inputs_1_float/example_1_float_3_cat_in_2_out.py +140 -0
- bencher/example/inputs_2D/example_2_cat_in_4_out_repeats.py +104 -0
- bencher/example/inputs_2_float/example_2_float_0_cat_in_2_out.py +98 -0
- bencher/example/inputs_2_float/example_2_float_1_cat_in_2_out.py +112 -0
- bencher/example/inputs_2_float/example_2_float_2_cat_in_2_out.py +122 -0
- bencher/example/inputs_2_float/example_2_float_3_cat_in_2_out.py +138 -0
- bencher/example/inputs_3_float/example_3_float_0_cat_in_2_out.py +111 -0
- bencher/example/inputs_3_float/example_3_float_1_cat_in_2_out.py +117 -0
- bencher/example/inputs_3_float/example_3_float_2_cat_in_2_out.py +124 -0
- bencher/example/inputs_3_float/example_3_float_3_cat_in_2_out.py +129 -0
- bencher/example/meta/generate_examples.py +118 -7
- bencher/example/meta/generate_meta.py +88 -40
- bencher/job.py +174 -9
- bencher/plotting/plot_filter.py +52 -17
- bencher/results/bench_result.py +117 -25
- bencher/results/bench_result_base.py +117 -8
- bencher/results/dataset_result.py +6 -200
- bencher/results/explorer_result.py +23 -0
- bencher/results/{hvplot_result.py → histogram_result.py} +3 -18
- bencher/results/holoview_results/__init__.py +0 -0
- bencher/results/holoview_results/bar_result.py +79 -0
- bencher/results/holoview_results/curve_result.py +110 -0
- bencher/results/holoview_results/distribution_result/__init__.py +0 -0
- bencher/results/holoview_results/distribution_result/box_whisker_result.py +73 -0
- bencher/results/holoview_results/distribution_result/distribution_result.py +109 -0
- bencher/results/holoview_results/distribution_result/scatter_jitter_result.py +92 -0
- bencher/results/holoview_results/distribution_result/violin_result.py +70 -0
- bencher/results/holoview_results/heatmap_result.py +319 -0
- bencher/results/holoview_results/holoview_result.py +346 -0
- bencher/results/holoview_results/line_result.py +240 -0
- bencher/results/holoview_results/scatter_result.py +107 -0
- bencher/results/holoview_results/surface_result.py +158 -0
- bencher/results/holoview_results/table_result.py +14 -0
- bencher/results/holoview_results/tabulator_result.py +20 -0
- bencher/results/optuna_result.py +30 -115
- bencher/results/video_controls.py +38 -0
- bencher/results/video_result.py +39 -36
- bencher/results/video_summary.py +2 -2
- bencher/results/{plotly_result.py → volume_result.py} +29 -8
- bencher/utils.py +175 -26
- bencher/variables/inputs.py +122 -15
- bencher/video_writer.py +2 -1
- bencher/worker_job.py +31 -3
- {holobench-1.41.0.dist-info → holobench-1.43.0.dist-info}/METADATA +24 -24
- holobench-1.43.0.dist-info/RECORD +147 -0
- bencher/example/example_levels2.py +0 -37
- bencher/example/inputs_1D/example_1_in_1_out.py +0 -62
- bencher/example/inputs_1D/example_1_in_2_out.py +0 -63
- bencher/example/inputs_1D/example_1_in_2_out_repeats.py +0 -61
- bencher/results/holoview_result.py +0 -796
- bencher/results/panel_result.py +0 -41
- holobench-1.41.0.dist-info/RECORD +0 -114
- {holobench-1.41.0.dist-info → holobench-1.43.0.dist-info}/WHEEL +0 -0
- {holobench-1.41.0.dist-info → holobench-1.43.0.dist-info}/licenses/LICENSE +0 -0
bencher/plotting/plot_filter.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Optional
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from bencher.plotting.plt_cnt_cfg import PltCntCfg
|
5
|
+
import logging
|
5
6
|
import panel as pn
|
6
7
|
|
7
8
|
|
@@ -43,9 +44,20 @@ class VarRange:
|
|
43
44
|
|
44
45
|
return lower_match and upper_match
|
45
46
|
|
46
|
-
def matches_info(self, val, name):
|
47
|
+
def matches_info(self, val: int, name: str) -> tuple[bool, str]:
|
48
|
+
"""Get matching info for a value with a descriptive name.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
val (int): A positive integer to check against the range
|
52
|
+
name (str): A descriptive name for the value being checked, used in the output string
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
tuple[bool, str]: A tuple containing:
|
56
|
+
- bool: True if the value matches the range, False otherwise
|
57
|
+
- str: A formatted string describing the match result
|
58
|
+
"""
|
47
59
|
match = self.matches(val)
|
48
|
-
info = f"{name}\t{
|
60
|
+
info = f"{name}\t{self.lower_bound}>= {val} <={self.upper_bound} is {match}"
|
49
61
|
return match, info
|
50
62
|
|
51
63
|
def __str__(self) -> str:
|
@@ -65,13 +77,21 @@ class PlotFilter:
|
|
65
77
|
input_range: VarRange = VarRange(1, None)
|
66
78
|
|
67
79
|
def matches_result(
|
68
|
-
self, plt_cnt_cfg: PltCntCfg, plot_name: str, override: bool
|
80
|
+
self, plt_cnt_cfg: PltCntCfg, plot_name: str, override: bool
|
69
81
|
) -> PlotMatchesResult:
|
70
|
-
"""Checks if the result data signature matches the type of data the plot is able to display.
|
82
|
+
"""Checks if the result data signature matches the type of data the plot is able to display.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
plt_cnt_cfg (PltCntCfg): Configuration containing counts of different plot elements
|
86
|
+
plot_name (str): Name of the plot being checked
|
87
|
+
override (bool): Whether to override filter matching rules
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
PlotMatchesResult: Object containing match results and information
|
91
|
+
"""
|
71
92
|
return PlotMatchesResult(self, plt_cnt_cfg, plot_name, override)
|
72
93
|
|
73
94
|
|
74
|
-
# @dataclass
|
75
95
|
class PlotMatchesResult:
|
76
96
|
"""Stores information about which properties match the requirements of a particular plotter"""
|
77
97
|
|
@@ -80,12 +100,20 @@ class PlotMatchesResult:
|
|
80
100
|
plot_filter: PlotFilter,
|
81
101
|
plt_cnt_cfg: PltCntCfg,
|
82
102
|
plot_name: str,
|
83
|
-
override: bool
|
84
|
-
):
|
85
|
-
|
86
|
-
matches = []
|
103
|
+
override: bool,
|
104
|
+
) -> None:
|
105
|
+
"""Initialize a PlotMatchesResult with filter matching information.
|
87
106
|
|
88
|
-
|
107
|
+
Args:
|
108
|
+
plot_filter (PlotFilter): The filter defining acceptable ranges for plot properties
|
109
|
+
plt_cnt_cfg (PltCntCfg): Configuration containing counts of different plot elements
|
110
|
+
plot_name (str): Name of the plot being checked
|
111
|
+
override (bool): Whether to override filter matching rules
|
112
|
+
"""
|
113
|
+
match_info: list[str] = []
|
114
|
+
matches: list[bool] = []
|
115
|
+
|
116
|
+
match_candidates: list[tuple[VarRange, int, str]] = [
|
89
117
|
(plot_filter.float_range, plt_cnt_cfg.float_cnt, "float"),
|
90
118
|
(plot_filter.cat_range, plt_cnt_cfg.cat_cnt, "cat"),
|
91
119
|
(plot_filter.vector_len, plt_cnt_cfg.vector_len, "vec"),
|
@@ -99,7 +127,7 @@ class PlotMatchesResult:
|
|
99
127
|
match, info = m.matches_info(cnt, name)
|
100
128
|
matches.append(match)
|
101
129
|
if not match:
|
102
|
-
match_info.append(info)
|
130
|
+
match_info.append(f"\t{info}")
|
103
131
|
if override:
|
104
132
|
match_info.append(f"override: {override}")
|
105
133
|
self.overall = True
|
@@ -107,15 +135,22 @@ class PlotMatchesResult:
|
|
107
135
|
self.overall = all(matches)
|
108
136
|
|
109
137
|
match_info.insert(0, f"plot {plot_name} matches: {self.overall}")
|
110
|
-
self.matches_info = "\n".join(match_info).strip()
|
111
|
-
self.plt_cnt_cfg = plt_cnt_cfg
|
138
|
+
self.matches_info: str = "\n".join(match_info).strip()
|
139
|
+
self.plt_cnt_cfg: PltCntCfg = plt_cnt_cfg
|
112
140
|
|
113
|
-
if self.plt_cnt_cfg.print_debug:
|
114
|
-
|
115
|
-
if not self.overall:
|
116
|
-
print(self.matches_info)
|
141
|
+
# if self.plt_cnt_cfg.print_debug:
|
142
|
+
logging.info(self.matches_info)
|
117
143
|
|
118
144
|
def to_panel(self, **kwargs) -> Optional[pn.pane.Markdown]:
|
145
|
+
"""Convert match information to a Panel Markdown pane if debug mode is enabled.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
**kwargs: Additional keyword arguments to pass to the Panel Markdown constructor
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
Optional[pn.pane.Markdown]: A Markdown pane containing match information if in debug mode,
|
152
|
+
None otherwise
|
153
|
+
"""
|
119
154
|
if self.plt_cnt_cfg.print_debug:
|
120
155
|
return pn.pane.Markdown(self.matches_info, **kwargs)
|
121
156
|
return None
|
bencher/results/bench_result.py
CHANGED
@@ -1,49 +1,129 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import List
|
2
|
+
from typing import List, Optional, Any
|
3
3
|
import panel as pn
|
4
|
+
from param import Parameter
|
4
5
|
|
5
6
|
from bencher.results.bench_result_base import EmptyContainer
|
6
7
|
from bencher.results.video_summary import VideoSummaryResult
|
7
|
-
from bencher.results.
|
8
|
-
from bencher.results.
|
9
|
-
from bencher.results.holoview_result import HoloviewResult
|
10
|
-
|
8
|
+
from bencher.results.video_result import VideoResult
|
9
|
+
from bencher.results.volume_result import VolumeResult
|
10
|
+
from bencher.results.holoview_results.holoview_result import HoloviewResult
|
11
|
+
|
12
|
+
# Updated imports for distribution result classes
|
13
|
+
from bencher.results.holoview_results.distribution_result.box_whisker_result import BoxWhiskerResult
|
14
|
+
from bencher.results.holoview_results.distribution_result.violin_result import ViolinResult
|
15
|
+
from bencher.results.holoview_results.scatter_result import ScatterResult
|
16
|
+
from bencher.results.holoview_results.distribution_result.scatter_jitter_result import (
|
17
|
+
ScatterJitterResult,
|
18
|
+
)
|
19
|
+
from bencher.results.holoview_results.bar_result import BarResult
|
20
|
+
from bencher.results.holoview_results.line_result import LineResult
|
21
|
+
from bencher.results.holoview_results.curve_result import CurveResult
|
22
|
+
from bencher.results.holoview_results.heatmap_result import HeatmapResult
|
23
|
+
from bencher.results.holoview_results.surface_result import SurfaceResult
|
24
|
+
from bencher.results.histogram_result import HistogramResult
|
25
|
+
from bencher.results.optuna_result import OptunaResult
|
11
26
|
from bencher.results.dataset_result import DataSetResult
|
12
27
|
from bencher.utils import listify
|
13
28
|
|
14
29
|
|
15
|
-
class BenchResult(
|
30
|
+
class BenchResult(
|
31
|
+
VolumeResult,
|
32
|
+
BoxWhiskerResult,
|
33
|
+
ViolinResult,
|
34
|
+
ScatterJitterResult,
|
35
|
+
ScatterResult,
|
36
|
+
LineResult,
|
37
|
+
BarResult,
|
38
|
+
HeatmapResult,
|
39
|
+
CurveResult,
|
40
|
+
SurfaceResult,
|
41
|
+
HoloviewResult,
|
42
|
+
HistogramResult,
|
43
|
+
VideoSummaryResult,
|
44
|
+
DataSetResult,
|
45
|
+
OptunaResult,
|
46
|
+
): # noqa pylint: disable=too-many-ancestors
|
16
47
|
"""Contains the results of the benchmark and has methods to cast the results to various datatypes and graphical representations"""
|
17
48
|
|
18
49
|
def __init__(self, bench_cfg) -> None:
|
19
|
-
|
50
|
+
"""Initialize a BenchResult instance.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
bench_cfg: The benchmark configuration object containing settings and result data
|
54
|
+
"""
|
55
|
+
VolumeResult.__init__(self, bench_cfg)
|
20
56
|
HoloviewResult.__init__(self, bench_cfg)
|
21
57
|
# DataSetResult.__init__(self.bench_cfg)
|
22
58
|
|
59
|
+
@classmethod
|
60
|
+
def from_existing(cls, original: BenchResult) -> BenchResult:
|
61
|
+
new_instance = cls(original.bench_cfg)
|
62
|
+
new_instance.ds = original.ds
|
63
|
+
new_instance.bench_cfg = original.bench_cfg
|
64
|
+
new_instance.plt_cnt_cfg = original.plt_cnt_cfg
|
65
|
+
return new_instance
|
66
|
+
|
67
|
+
def to(
|
68
|
+
self,
|
69
|
+
result_type: BenchResult,
|
70
|
+
result_var: Optional[Parameter] = None,
|
71
|
+
override: bool = True,
|
72
|
+
**kwargs: Any,
|
73
|
+
) -> BenchResult:
|
74
|
+
"""Return the current instance of BenchResult.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
BenchResult: The current instance of the benchmark result
|
78
|
+
"""
|
79
|
+
result_instance = result_type(self.bench_cfg)
|
80
|
+
result_instance.ds = self.ds
|
81
|
+
result_instance.plt_cnt_cfg = self.plt_cnt_cfg
|
82
|
+
result_instance.dataset_list = self.dataset_list
|
83
|
+
return result_instance.to_plot(result_var=result_var, override=override, **kwargs)
|
84
|
+
|
23
85
|
@staticmethod
|
24
|
-
def default_plot_callbacks():
|
86
|
+
def default_plot_callbacks() -> List[callable]:
|
87
|
+
"""Get the default list of plot callback functions.
|
88
|
+
|
89
|
+
These callbacks are used by default in the to_auto method if no specific
|
90
|
+
plot list is provided.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
List[callable]: A list of plotting callback functions
|
94
|
+
"""
|
25
95
|
return [
|
26
96
|
# VideoSummaryResult.to_video_summary, #quite expensive so not turned on by default
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
97
|
+
BarResult.to_plot,
|
98
|
+
BoxWhiskerResult.to_plot,
|
99
|
+
# ViolinResult.to_violin,
|
100
|
+
# ScatterJitterResult.to_plot,
|
101
|
+
CurveResult.to_plot,
|
102
|
+
LineResult.to_plot,
|
103
|
+
HeatmapResult.to_plot,
|
104
|
+
HistogramResult.to_plot,
|
105
|
+
VolumeResult.to_plot,
|
34
106
|
# PanelResult.to_video,
|
35
|
-
|
107
|
+
VideoResult.to_panes,
|
36
108
|
]
|
37
109
|
|
38
110
|
@staticmethod
|
39
|
-
def plotly_callbacks():
|
40
|
-
|
111
|
+
def plotly_callbacks() -> List[callable]:
|
112
|
+
"""Get the list of Plotly-specific callback functions.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
List[callable]: A list of Plotly-based visualization callback functions
|
116
|
+
"""
|
117
|
+
return [SurfaceResult.to_surface, VolumeResult.to_volume]
|
41
118
|
|
42
119
|
def plot(self) -> pn.panel:
|
43
|
-
"""Plots the benchresult using the plot callbacks defined by the bench run
|
120
|
+
"""Plots the benchresult using the plot callbacks defined by the bench run.
|
121
|
+
|
122
|
+
This method uses the plot_callbacks defined in the bench_cfg to generate
|
123
|
+
plots for the benchmark results.
|
44
124
|
|
45
125
|
Returns:
|
46
|
-
pn.panel: A panel representation of the results
|
126
|
+
pn.panel: A panel representation of the results, or None if no plot_callbacks defined
|
47
127
|
"""
|
48
128
|
if self.bench_cfg.plot_callbacks is not None:
|
49
129
|
return pn.Column(*[cb(self) for cb in self.bench_cfg.plot_callbacks])
|
@@ -54,8 +134,21 @@ class BenchResult(PlotlyResult, HoloviewResult, HvplotResult, VideoSummaryResult
|
|
54
134
|
plot_list: List[callable] = None,
|
55
135
|
remove_plots: List[callable] = None,
|
56
136
|
default_container=pn.Column,
|
137
|
+
override: bool = False, # false so that plots that are not supported are not shown
|
57
138
|
**kwargs,
|
58
139
|
) -> List[pn.panel]:
|
140
|
+
"""Automatically generate plots based on the provided plot callbacks.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
plot_list (List[callable], optional): List of plot callback functions to use. Defaults to None.
|
144
|
+
remove_plots (List[callable], optional): List of plot callback functions to exclude. Defaults to None.
|
145
|
+
default_container (type, optional): Default container type for the plots. Defaults to pn.Column.
|
146
|
+
override (bool, optional): Whether to override unsupported plots. Defaults to False.
|
147
|
+
**kwargs: Additional keyword arguments for plot configuration.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
List[pn.panel]: A list of panel objects containing the generated plots.
|
151
|
+
"""
|
59
152
|
self.plt_cnt_cfg.print_debug = False
|
60
153
|
plot_list = listify(plot_list)
|
61
154
|
remove_plots = listify(remove_plots)
|
@@ -73,7 +166,7 @@ class BenchResult(PlotlyResult, HoloviewResult, HvplotResult, VideoSummaryResult
|
|
73
166
|
if self.plt_cnt_cfg.print_debug:
|
74
167
|
print(f"checking: {plot_callback.__name__}")
|
75
168
|
# the callbacks are passed from the static class definition, so self needs to be passed before the plotting callback can be called
|
76
|
-
row.append(plot_callback(self, **kwargs))
|
169
|
+
row.append(plot_callback(self, override=override, **kwargs))
|
77
170
|
|
78
171
|
self.plt_cnt_cfg.print_debug = True
|
79
172
|
if len(row.pane) == 0:
|
@@ -83,15 +176,14 @@ class BenchResult(PlotlyResult, HoloviewResult, HvplotResult, VideoSummaryResult
|
|
83
176
|
return row.pane
|
84
177
|
|
85
178
|
def to_auto_plots(self, **kwargs) -> pn.panel:
|
86
|
-
"""Given the dataset result of a benchmark run, automatically
|
179
|
+
"""Given the dataset result of a benchmark run, automatically deduce how to plot the data based on the types of variables that were sampled.
|
87
180
|
|
88
181
|
Args:
|
89
|
-
|
182
|
+
**kwargs: Additional keyword arguments for plot configuration.
|
90
183
|
|
91
184
|
Returns:
|
92
|
-
pn.
|
185
|
+
pn.panel: A panel containing plot results.
|
93
186
|
"""
|
94
|
-
|
95
187
|
plot_cols = pn.Column()
|
96
188
|
plot_cols.append(self.to_sweep_summary(name="Plots View"))
|
97
189
|
plot_cols.append(self.to_auto(**kwargs))
|
@@ -6,6 +6,8 @@ from param import Parameter
|
|
6
6
|
import holoviews as hv
|
7
7
|
from functools import partial
|
8
8
|
import panel as pn
|
9
|
+
import numpy as np
|
10
|
+
from textwrap import wrap
|
9
11
|
|
10
12
|
from bencher.utils import int_to_col, color_tuple_to_css, callable_name
|
11
13
|
|
@@ -14,7 +16,6 @@ from bencher.variables.inputs import with_level
|
|
14
16
|
|
15
17
|
from bencher.variables.results import OptDir
|
16
18
|
from copy import deepcopy
|
17
|
-
from bencher.results.optuna_result import OptunaResult
|
18
19
|
from bencher.variables.results import ResultVar
|
19
20
|
from bencher.plotting.plot_filter import VarRange, PlotFilter
|
20
21
|
from bencher.utils import listify
|
@@ -25,6 +26,13 @@ from bencher.results.composable_container.composable_container_panel import (
|
|
25
26
|
ComposableContainerPanel,
|
26
27
|
)
|
27
28
|
|
29
|
+
from collections import defaultdict
|
30
|
+
|
31
|
+
import pandas as pd
|
32
|
+
|
33
|
+
from bencher.bench_cfg import BenchCfg
|
34
|
+
from bencher.plotting.plt_cnt_cfg import PltCntCfg
|
35
|
+
|
28
36
|
# todo add plugins
|
29
37
|
# https://gist.github.com/dorneanu/cce1cd6711969d581873a88e0257e312
|
30
38
|
# https://kaleidoescape.github.io/decorated-plugins/
|
@@ -52,7 +60,87 @@ class EmptyContainer:
|
|
52
60
|
return self.pane if len(self.pane) > 0 else None
|
53
61
|
|
54
62
|
|
55
|
-
|
63
|
+
def convert_dataset_bool_dims_to_str(dataset: xr.Dataset) -> xr.Dataset:
|
64
|
+
"""Given a dataarray that contains boolean coordinates, convert them to strings so that holoviews loads the data properly
|
65
|
+
|
66
|
+
Args:
|
67
|
+
dataarray (xr.DataArray): dataarray with boolean coordinates
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
xr.DataArray: dataarray with boolean coordinates converted to strings
|
71
|
+
"""
|
72
|
+
bool_coords = {}
|
73
|
+
for c in dataset.coords:
|
74
|
+
if dataset.coords[c].dtype == bool:
|
75
|
+
bool_coords[c] = [str(vals) for vals in dataset.coords[c].values]
|
76
|
+
|
77
|
+
if len(bool_coords) > 0:
|
78
|
+
return dataset.assign_coords(bool_coords)
|
79
|
+
return dataset
|
80
|
+
|
81
|
+
|
82
|
+
class BenchResultBase:
|
83
|
+
def __init__(self, bench_cfg: BenchCfg) -> None:
|
84
|
+
self.bench_cfg = bench_cfg
|
85
|
+
# self.wrap_long_time_labels(bench_cfg) # todo remove
|
86
|
+
self.ds = xr.Dataset()
|
87
|
+
self.object_index = []
|
88
|
+
self.hmaps = defaultdict(dict)
|
89
|
+
self.result_hmaps = bench_cfg.result_hmaps
|
90
|
+
self.studies = []
|
91
|
+
self.plt_cnt_cfg = PltCntCfg()
|
92
|
+
self.plot_inputs = []
|
93
|
+
self.dataset_list = []
|
94
|
+
|
95
|
+
# self.width=600/
|
96
|
+
# self.height=600
|
97
|
+
|
98
|
+
# bench_res.objects.append(rv)
|
99
|
+
# bench_res.reference_index = len(bench_res.objects)
|
100
|
+
|
101
|
+
def to_xarray(self) -> xr.Dataset:
|
102
|
+
return self.ds
|
103
|
+
|
104
|
+
def setup_object_index(self):
|
105
|
+
self.object_index = []
|
106
|
+
|
107
|
+
def to_pandas(self, reset_index=True) -> pd.DataFrame:
|
108
|
+
"""Get the xarray results as a pandas dataframe
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
pd.DataFrame: The xarray results array as a pandas dataframe
|
112
|
+
"""
|
113
|
+
ds = self.to_xarray().to_dataframe()
|
114
|
+
return ds.reset_index() if reset_index else ds
|
115
|
+
|
116
|
+
def wrap_long_time_labels(self, bench_cfg):
|
117
|
+
"""Takes a benchCfg and wraps any index labels that are too long to be plotted easily
|
118
|
+
|
119
|
+
Args:
|
120
|
+
bench_cfg (BenchCfg):
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
BenchCfg: updated config with wrapped labels
|
124
|
+
"""
|
125
|
+
if bench_cfg.over_time:
|
126
|
+
if self.ds.coords["over_time"].dtype == np.datetime64:
|
127
|
+
# plotly catastrophically fails to plot anything with the default long string representation of time, so convert to a shorter time representation
|
128
|
+
self.ds.coords["over_time"] = [
|
129
|
+
pd.to_datetime(t).strftime("%d-%m-%y %H-%M-%S")
|
130
|
+
for t in self.ds.coords["over_time"].values
|
131
|
+
]
|
132
|
+
# wrap very long time event labels because otherwise the graphs are unreadable
|
133
|
+
if bench_cfg.time_event is not None:
|
134
|
+
self.ds.coords["over_time"] = [
|
135
|
+
"\n".join(wrap(t, 20)) for t in self.ds.coords["over_time"].values
|
136
|
+
]
|
137
|
+
return bench_cfg
|
138
|
+
|
139
|
+
def post_setup(self):
|
140
|
+
self.plt_cnt_cfg = PltCntCfg.generate_plt_cnt_cfg(self.bench_cfg)
|
141
|
+
self.bench_cfg = self.wrap_long_time_labels(self.bench_cfg)
|
142
|
+
self.ds = convert_dataset_bool_dims_to_str(self.ds)
|
143
|
+
|
56
144
|
def result_samples(self) -> int:
|
57
145
|
"""The number of samples in the results dataframe"""
|
58
146
|
return self.ds.count()
|
@@ -80,7 +168,7 @@ class BenchResultBase(OptunaResult):
|
|
80
168
|
def to_dataset(
|
81
169
|
self,
|
82
170
|
reduce: ReduceType = ReduceType.AUTO,
|
83
|
-
result_var: ResultVar = None,
|
171
|
+
result_var: ResultVar | str = None,
|
84
172
|
level: int = None,
|
85
173
|
) -> xr.Dataset:
|
86
174
|
"""Generate a summarised xarray dataset.
|
@@ -97,7 +185,15 @@ class BenchResultBase(OptunaResult):
|
|
97
185
|
ds_out = self.ds.copy()
|
98
186
|
|
99
187
|
if result_var is not None:
|
100
|
-
|
188
|
+
if isinstance(result_var, Parameter):
|
189
|
+
var_name = result_var.name
|
190
|
+
elif isinstance(result_var, str):
|
191
|
+
var_name = result_var
|
192
|
+
else:
|
193
|
+
raise TypeError(
|
194
|
+
f"Unsupported type for result_var: {type(result_var)}. Expected Parameter or str."
|
195
|
+
)
|
196
|
+
ds_out = ds_out[var_name].to_dataset(name=var_name)
|
101
197
|
|
102
198
|
def rename_ds(dataset: xr.Dataset, suffix: str):
|
103
199
|
# var_name =
|
@@ -219,9 +315,6 @@ class BenchResultBase(OptunaResult):
|
|
219
315
|
def describe_sweep(self):
|
220
316
|
return self.bench_cfg.describe_sweep()
|
221
317
|
|
222
|
-
def get_best_holomap(self, name: str = None):
|
223
|
-
return self.get_hmap(name)[self.get_best_trial_params(True)]
|
224
|
-
|
225
318
|
def get_hmap(self, name: str = None):
|
226
319
|
try:
|
227
320
|
if name is None:
|
@@ -362,7 +455,7 @@ class BenchResultBase(OptunaResult):
|
|
362
455
|
pane_collection: pn.pane = None,
|
363
456
|
override=False,
|
364
457
|
**kwargs,
|
365
|
-
):
|
458
|
+
) -> Optional[pn.panel]:
|
366
459
|
plot_filter = PlotFilter(
|
367
460
|
float_range=float_range,
|
368
461
|
cat_range=cat_range,
|
@@ -558,3 +651,19 @@ class BenchResultBase(OptunaResult):
|
|
558
651
|
|
559
652
|
def to_description(self, width: int = 800) -> pn.pane.Markdown:
|
560
653
|
return self.bench_cfg.to_description(width)
|
654
|
+
|
655
|
+
def set_plot_size(self, **kwargs) -> dict:
|
656
|
+
if "width" not in kwargs:
|
657
|
+
if self.bench_cfg.plot_size is not None:
|
658
|
+
kwargs["width"] = self.bench_cfg.plot_size
|
659
|
+
# specific width overrides general size
|
660
|
+
if self.bench_cfg.plot_width is not None:
|
661
|
+
kwargs["width"] = self.bench_cfg.plot_width
|
662
|
+
|
663
|
+
if "height" not in kwargs:
|
664
|
+
if self.bench_cfg.plot_size is not None:
|
665
|
+
kwargs["height"] = self.bench_cfg.plot_size
|
666
|
+
# specific height overrides general size
|
667
|
+
if self.bench_cfg.plot_height is not None:
|
668
|
+
kwargs["height"] = self.bench_cfg.plot_height
|
669
|
+
return kwargs
|