holobench 1.41.0__py3-none-any.whl → 1.43.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. bencher/__init__.py +20 -2
  2. bencher/bench_cfg.py +262 -54
  3. bencher/bench_report.py +2 -2
  4. bencher/bench_runner.py +96 -10
  5. bencher/bencher.py +421 -89
  6. bencher/class_enum.py +70 -7
  7. bencher/example/example_dataframe.py +2 -2
  8. bencher/example/example_levels.py +17 -173
  9. bencher/example/example_pareto.py +107 -31
  10. bencher/example/example_rerun2.py +1 -1
  11. bencher/example/example_simple_bool.py +2 -2
  12. bencher/example/example_simple_float2d.py +6 -1
  13. bencher/example/example_video.py +2 -0
  14. bencher/example/experimental/example_hvplot_explorer.py +2 -2
  15. bencher/example/inputs_0D/example_0_in_1_out.py +25 -15
  16. bencher/example/inputs_0D/example_0_in_2_out.py +12 -3
  17. bencher/example/inputs_0_float/example_0_cat_in_2_out.py +88 -0
  18. bencher/example/inputs_0_float/example_1_cat_in_2_out.py +98 -0
  19. bencher/example/inputs_0_float/example_2_cat_in_2_out.py +107 -0
  20. bencher/example/inputs_0_float/example_3_cat_in_2_out.py +111 -0
  21. bencher/example/inputs_1D/example1d_common.py +48 -12
  22. bencher/example/inputs_1D/example_0_float_1_cat.py +33 -0
  23. bencher/example/inputs_1D/example_1_cat_in_2_out_repeats.py +68 -0
  24. bencher/example/inputs_1D/example_1_float_2_cat_repeats.py +3 -0
  25. bencher/example/inputs_1D/example_1_int_in_1_out.py +98 -0
  26. bencher/example/inputs_1D/example_1_int_in_2_out.py +101 -0
  27. bencher/example/inputs_1D/example_1_int_in_2_out_repeats.py +99 -0
  28. bencher/example/inputs_1_float/example_1_float_0_cat_in_2_out.py +117 -0
  29. bencher/example/inputs_1_float/example_1_float_1_cat_in_2_out.py +124 -0
  30. bencher/example/inputs_1_float/example_1_float_2_cat_in_2_out.py +132 -0
  31. bencher/example/inputs_1_float/example_1_float_3_cat_in_2_out.py +140 -0
  32. bencher/example/inputs_2D/example_2_cat_in_4_out_repeats.py +104 -0
  33. bencher/example/inputs_2_float/example_2_float_0_cat_in_2_out.py +98 -0
  34. bencher/example/inputs_2_float/example_2_float_1_cat_in_2_out.py +112 -0
  35. bencher/example/inputs_2_float/example_2_float_2_cat_in_2_out.py +122 -0
  36. bencher/example/inputs_2_float/example_2_float_3_cat_in_2_out.py +138 -0
  37. bencher/example/inputs_3_float/example_3_float_0_cat_in_2_out.py +111 -0
  38. bencher/example/inputs_3_float/example_3_float_1_cat_in_2_out.py +117 -0
  39. bencher/example/inputs_3_float/example_3_float_2_cat_in_2_out.py +124 -0
  40. bencher/example/inputs_3_float/example_3_float_3_cat_in_2_out.py +129 -0
  41. bencher/example/meta/generate_examples.py +118 -7
  42. bencher/example/meta/generate_meta.py +88 -40
  43. bencher/job.py +174 -9
  44. bencher/plotting/plot_filter.py +52 -17
  45. bencher/results/bench_result.py +117 -25
  46. bencher/results/bench_result_base.py +117 -8
  47. bencher/results/dataset_result.py +6 -200
  48. bencher/results/explorer_result.py +23 -0
  49. bencher/results/{hvplot_result.py → histogram_result.py} +3 -18
  50. bencher/results/holoview_results/__init__.py +0 -0
  51. bencher/results/holoview_results/bar_result.py +79 -0
  52. bencher/results/holoview_results/curve_result.py +110 -0
  53. bencher/results/holoview_results/distribution_result/__init__.py +0 -0
  54. bencher/results/holoview_results/distribution_result/box_whisker_result.py +73 -0
  55. bencher/results/holoview_results/distribution_result/distribution_result.py +109 -0
  56. bencher/results/holoview_results/distribution_result/scatter_jitter_result.py +92 -0
  57. bencher/results/holoview_results/distribution_result/violin_result.py +70 -0
  58. bencher/results/holoview_results/heatmap_result.py +319 -0
  59. bencher/results/holoview_results/holoview_result.py +346 -0
  60. bencher/results/holoview_results/line_result.py +240 -0
  61. bencher/results/holoview_results/scatter_result.py +107 -0
  62. bencher/results/holoview_results/surface_result.py +158 -0
  63. bencher/results/holoview_results/table_result.py +14 -0
  64. bencher/results/holoview_results/tabulator_result.py +20 -0
  65. bencher/results/optuna_result.py +30 -115
  66. bencher/results/video_controls.py +38 -0
  67. bencher/results/video_result.py +39 -36
  68. bencher/results/video_summary.py +2 -2
  69. bencher/results/{plotly_result.py → volume_result.py} +29 -8
  70. bencher/utils.py +175 -26
  71. bencher/variables/inputs.py +122 -15
  72. bencher/video_writer.py +2 -1
  73. bencher/worker_job.py +31 -3
  74. {holobench-1.41.0.dist-info → holobench-1.43.0.dist-info}/METADATA +24 -24
  75. holobench-1.43.0.dist-info/RECORD +147 -0
  76. bencher/example/example_levels2.py +0 -37
  77. bencher/example/inputs_1D/example_1_in_1_out.py +0 -62
  78. bencher/example/inputs_1D/example_1_in_2_out.py +0 -63
  79. bencher/example/inputs_1D/example_1_in_2_out_repeats.py +0 -61
  80. bencher/results/holoview_result.py +0 -796
  81. bencher/results/panel_result.py +0 -41
  82. holobench-1.41.0.dist-info/RECORD +0 -114
  83. {holobench-1.41.0.dist-info → holobench-1.43.0.dist-info}/WHEEL +0 -0
  84. {holobench-1.41.0.dist-info → holobench-1.43.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Optional
3
3
  from dataclasses import dataclass
4
4
  from bencher.plotting.plt_cnt_cfg import PltCntCfg
5
+ import logging
5
6
  import panel as pn
6
7
 
7
8
 
@@ -43,9 +44,20 @@ class VarRange:
43
44
 
44
45
  return lower_match and upper_match
45
46
 
46
- def matches_info(self, val, name):
47
+ def matches_info(self, val: int, name: str) -> tuple[bool, str]:
48
+ """Get matching info for a value with a descriptive name.
49
+
50
+ Args:
51
+ val (int): A positive integer to check against the range
52
+ name (str): A descriptive name for the value being checked, used in the output string
53
+
54
+ Returns:
55
+ tuple[bool, str]: A tuple containing:
56
+ - bool: True if the value matches the range, False otherwise
57
+ - str: A formatted string describing the match result
58
+ """
47
59
  match = self.matches(val)
48
- info = f"{name}\t{match}\t{self.lower_bound}>= {val} <={self.upper_bound}"
60
+ info = f"{name}\t{self.lower_bound}>= {val} <={self.upper_bound} is {match}"
49
61
  return match, info
50
62
 
51
63
  def __str__(self) -> str:
@@ -65,13 +77,21 @@ class PlotFilter:
65
77
  input_range: VarRange = VarRange(1, None)
66
78
 
67
79
  def matches_result(
68
- self, plt_cnt_cfg: PltCntCfg, plot_name: str, override: bool = False
80
+ self, plt_cnt_cfg: PltCntCfg, plot_name: str, override: bool
69
81
  ) -> PlotMatchesResult:
70
- """Checks if the result data signature matches the type of data the plot is able to display."""
82
+ """Checks if the result data signature matches the type of data the plot is able to display.
83
+
84
+ Args:
85
+ plt_cnt_cfg (PltCntCfg): Configuration containing counts of different plot elements
86
+ plot_name (str): Name of the plot being checked
87
+ override (bool): Whether to override filter matching rules
88
+
89
+ Returns:
90
+ PlotMatchesResult: Object containing match results and information
91
+ """
71
92
  return PlotMatchesResult(self, plt_cnt_cfg, plot_name, override)
72
93
 
73
94
 
74
- # @dataclass
75
95
  class PlotMatchesResult:
76
96
  """Stores information about which properties match the requirements of a particular plotter"""
77
97
 
@@ -80,12 +100,20 @@ class PlotMatchesResult:
80
100
  plot_filter: PlotFilter,
81
101
  plt_cnt_cfg: PltCntCfg,
82
102
  plot_name: str,
83
- override: bool = False,
84
- ):
85
- match_info = []
86
- matches = []
103
+ override: bool,
104
+ ) -> None:
105
+ """Initialize a PlotMatchesResult with filter matching information.
87
106
 
88
- match_candidates = [
107
+ Args:
108
+ plot_filter (PlotFilter): The filter defining acceptable ranges for plot properties
109
+ plt_cnt_cfg (PltCntCfg): Configuration containing counts of different plot elements
110
+ plot_name (str): Name of the plot being checked
111
+ override (bool): Whether to override filter matching rules
112
+ """
113
+ match_info: list[str] = []
114
+ matches: list[bool] = []
115
+
116
+ match_candidates: list[tuple[VarRange, int, str]] = [
89
117
  (plot_filter.float_range, plt_cnt_cfg.float_cnt, "float"),
90
118
  (plot_filter.cat_range, plt_cnt_cfg.cat_cnt, "cat"),
91
119
  (plot_filter.vector_len, plt_cnt_cfg.vector_len, "vec"),
@@ -99,7 +127,7 @@ class PlotMatchesResult:
99
127
  match, info = m.matches_info(cnt, name)
100
128
  matches.append(match)
101
129
  if not match:
102
- match_info.append(info)
130
+ match_info.append(f"\t{info}")
103
131
  if override:
104
132
  match_info.append(f"override: {override}")
105
133
  self.overall = True
@@ -107,15 +135,22 @@ class PlotMatchesResult:
107
135
  self.overall = all(matches)
108
136
 
109
137
  match_info.insert(0, f"plot {plot_name} matches: {self.overall}")
110
- self.matches_info = "\n".join(match_info).strip()
111
- self.plt_cnt_cfg = plt_cnt_cfg
138
+ self.matches_info: str = "\n".join(match_info).strip()
139
+ self.plt_cnt_cfg: PltCntCfg = plt_cnt_cfg
112
140
 
113
- if self.plt_cnt_cfg.print_debug:
114
- print(f"checking {plot_name} result: {self.overall}")
115
- if not self.overall:
116
- print(self.matches_info)
141
+ # if self.plt_cnt_cfg.print_debug:
142
+ logging.info(self.matches_info)
117
143
 
118
144
  def to_panel(self, **kwargs) -> Optional[pn.pane.Markdown]:
145
+ """Convert match information to a Panel Markdown pane if debug mode is enabled.
146
+
147
+ Args:
148
+ **kwargs: Additional keyword arguments to pass to the Panel Markdown constructor
149
+
150
+ Returns:
151
+ Optional[pn.pane.Markdown]: A Markdown pane containing match information if in debug mode,
152
+ None otherwise
153
+ """
119
154
  if self.plt_cnt_cfg.print_debug:
120
155
  return pn.pane.Markdown(self.matches_info, **kwargs)
121
156
  return None
@@ -1,49 +1,129 @@
1
1
  from __future__ import annotations
2
- from typing import List
2
+ from typing import List, Optional, Any
3
3
  import panel as pn
4
+ from param import Parameter
4
5
 
5
6
  from bencher.results.bench_result_base import EmptyContainer
6
7
  from bencher.results.video_summary import VideoSummaryResult
7
- from bencher.results.panel_result import PanelResult
8
- from bencher.results.plotly_result import PlotlyResult
9
- from bencher.results.holoview_result import HoloviewResult
10
- from bencher.results.hvplot_result import HvplotResult
8
+ from bencher.results.video_result import VideoResult
9
+ from bencher.results.volume_result import VolumeResult
10
+ from bencher.results.holoview_results.holoview_result import HoloviewResult
11
+
12
+ # Updated imports for distribution result classes
13
+ from bencher.results.holoview_results.distribution_result.box_whisker_result import BoxWhiskerResult
14
+ from bencher.results.holoview_results.distribution_result.violin_result import ViolinResult
15
+ from bencher.results.holoview_results.scatter_result import ScatterResult
16
+ from bencher.results.holoview_results.distribution_result.scatter_jitter_result import (
17
+ ScatterJitterResult,
18
+ )
19
+ from bencher.results.holoview_results.bar_result import BarResult
20
+ from bencher.results.holoview_results.line_result import LineResult
21
+ from bencher.results.holoview_results.curve_result import CurveResult
22
+ from bencher.results.holoview_results.heatmap_result import HeatmapResult
23
+ from bencher.results.holoview_results.surface_result import SurfaceResult
24
+ from bencher.results.histogram_result import HistogramResult
25
+ from bencher.results.optuna_result import OptunaResult
11
26
  from bencher.results.dataset_result import DataSetResult
12
27
  from bencher.utils import listify
13
28
 
14
29
 
15
- class BenchResult(PlotlyResult, HoloviewResult, HvplotResult, VideoSummaryResult, DataSetResult): # noqa pylint: disable=too-many-ancestors
30
+ class BenchResult(
31
+ VolumeResult,
32
+ BoxWhiskerResult,
33
+ ViolinResult,
34
+ ScatterJitterResult,
35
+ ScatterResult,
36
+ LineResult,
37
+ BarResult,
38
+ HeatmapResult,
39
+ CurveResult,
40
+ SurfaceResult,
41
+ HoloviewResult,
42
+ HistogramResult,
43
+ VideoSummaryResult,
44
+ DataSetResult,
45
+ OptunaResult,
46
+ ): # noqa pylint: disable=too-many-ancestors
16
47
  """Contains the results of the benchmark and has methods to cast the results to various datatypes and graphical representations"""
17
48
 
18
49
  def __init__(self, bench_cfg) -> None:
19
- PlotlyResult.__init__(self, bench_cfg)
50
+ """Initialize a BenchResult instance.
51
+
52
+ Args:
53
+ bench_cfg: The benchmark configuration object containing settings and result data
54
+ """
55
+ VolumeResult.__init__(self, bench_cfg)
20
56
  HoloviewResult.__init__(self, bench_cfg)
21
57
  # DataSetResult.__init__(self.bench_cfg)
22
58
 
59
+ @classmethod
60
+ def from_existing(cls, original: BenchResult) -> BenchResult:
61
+ new_instance = cls(original.bench_cfg)
62
+ new_instance.ds = original.ds
63
+ new_instance.bench_cfg = original.bench_cfg
64
+ new_instance.plt_cnt_cfg = original.plt_cnt_cfg
65
+ return new_instance
66
+
67
+ def to(
68
+ self,
69
+ result_type: BenchResult,
70
+ result_var: Optional[Parameter] = None,
71
+ override: bool = True,
72
+ **kwargs: Any,
73
+ ) -> BenchResult:
74
+ """Return the current instance of BenchResult.
75
+
76
+ Returns:
77
+ BenchResult: The current instance of the benchmark result
78
+ """
79
+ result_instance = result_type(self.bench_cfg)
80
+ result_instance.ds = self.ds
81
+ result_instance.plt_cnt_cfg = self.plt_cnt_cfg
82
+ result_instance.dataset_list = self.dataset_list
83
+ return result_instance.to_plot(result_var=result_var, override=override, **kwargs)
84
+
23
85
  @staticmethod
24
- def default_plot_callbacks():
86
+ def default_plot_callbacks() -> List[callable]:
87
+ """Get the default list of plot callback functions.
88
+
89
+ These callbacks are used by default in the to_auto method if no specific
90
+ plot list is provided.
91
+
92
+ Returns:
93
+ List[callable]: A list of plotting callback functions
94
+ """
25
95
  return [
26
96
  # VideoSummaryResult.to_video_summary, #quite expensive so not turned on by default
27
- HoloviewResult.to_bar,
28
- HoloviewResult.to_scatter_jitter,
29
- HoloviewResult.to_curve,
30
- HoloviewResult.to_line,
31
- HoloviewResult.to_heatmap,
32
- HvplotResult.to_histogram,
33
- PlotlyResult.to_volume,
97
+ BarResult.to_plot,
98
+ BoxWhiskerResult.to_plot,
99
+ # ViolinResult.to_violin,
100
+ # ScatterJitterResult.to_plot,
101
+ CurveResult.to_plot,
102
+ LineResult.to_plot,
103
+ HeatmapResult.to_plot,
104
+ HistogramResult.to_plot,
105
+ VolumeResult.to_plot,
34
106
  # PanelResult.to_video,
35
- PanelResult.to_panes,
107
+ VideoResult.to_panes,
36
108
  ]
37
109
 
38
110
  @staticmethod
39
- def plotly_callbacks():
40
- return [HoloviewResult.to_surface, PlotlyResult.to_volume]
111
+ def plotly_callbacks() -> List[callable]:
112
+ """Get the list of Plotly-specific callback functions.
113
+
114
+ Returns:
115
+ List[callable]: A list of Plotly-based visualization callback functions
116
+ """
117
+ return [SurfaceResult.to_surface, VolumeResult.to_volume]
41
118
 
42
119
  def plot(self) -> pn.panel:
43
- """Plots the benchresult using the plot callbacks defined by the bench run
120
+ """Plots the benchresult using the plot callbacks defined by the bench run.
121
+
122
+ This method uses the plot_callbacks defined in the bench_cfg to generate
123
+ plots for the benchmark results.
44
124
 
45
125
  Returns:
46
- pn.panel: A panel representation of the results
126
+ pn.panel: A panel representation of the results, or None if no plot_callbacks defined
47
127
  """
48
128
  if self.bench_cfg.plot_callbacks is not None:
49
129
  return pn.Column(*[cb(self) for cb in self.bench_cfg.plot_callbacks])
@@ -54,8 +134,21 @@ class BenchResult(PlotlyResult, HoloviewResult, HvplotResult, VideoSummaryResult
54
134
  plot_list: List[callable] = None,
55
135
  remove_plots: List[callable] = None,
56
136
  default_container=pn.Column,
137
+ override: bool = False, # false so that plots that are not supported are not shown
57
138
  **kwargs,
58
139
  ) -> List[pn.panel]:
140
+ """Automatically generate plots based on the provided plot callbacks.
141
+
142
+ Args:
143
+ plot_list (List[callable], optional): List of plot callback functions to use. Defaults to None.
144
+ remove_plots (List[callable], optional): List of plot callback functions to exclude. Defaults to None.
145
+ default_container (type, optional): Default container type for the plots. Defaults to pn.Column.
146
+ override (bool, optional): Whether to override unsupported plots. Defaults to False.
147
+ **kwargs: Additional keyword arguments for plot configuration.
148
+
149
+ Returns:
150
+ List[pn.panel]: A list of panel objects containing the generated plots.
151
+ """
59
152
  self.plt_cnt_cfg.print_debug = False
60
153
  plot_list = listify(plot_list)
61
154
  remove_plots = listify(remove_plots)
@@ -73,7 +166,7 @@ class BenchResult(PlotlyResult, HoloviewResult, HvplotResult, VideoSummaryResult
73
166
  if self.plt_cnt_cfg.print_debug:
74
167
  print(f"checking: {plot_callback.__name__}")
75
168
  # the callbacks are passed from the static class definition, so self needs to be passed before the plotting callback can be called
76
- row.append(plot_callback(self, **kwargs))
169
+ row.append(plot_callback(self, override=override, **kwargs))
77
170
 
78
171
  self.plt_cnt_cfg.print_debug = True
79
172
  if len(row.pane) == 0:
@@ -83,15 +176,14 @@ class BenchResult(PlotlyResult, HoloviewResult, HvplotResult, VideoSummaryResult
83
176
  return row.pane
84
177
 
85
178
  def to_auto_plots(self, **kwargs) -> pn.panel:
86
- """Given the dataset result of a benchmark run, automatically dedeuce how to plot the data based on the types of variables that were sampled
179
+ """Given the dataset result of a benchmark run, automatically deduce how to plot the data based on the types of variables that were sampled.
87
180
 
88
181
  Args:
89
- bench_cfg (BenchCfg): Information on how the benchmark was sampled and the resulting data
182
+ **kwargs: Additional keyword arguments for plot configuration.
90
183
 
91
184
  Returns:
92
- pn.pane: A panel containing plot results
185
+ pn.panel: A panel containing plot results.
93
186
  """
94
-
95
187
  plot_cols = pn.Column()
96
188
  plot_cols.append(self.to_sweep_summary(name="Plots View"))
97
189
  plot_cols.append(self.to_auto(**kwargs))
@@ -6,6 +6,8 @@ from param import Parameter
6
6
  import holoviews as hv
7
7
  from functools import partial
8
8
  import panel as pn
9
+ import numpy as np
10
+ from textwrap import wrap
9
11
 
10
12
  from bencher.utils import int_to_col, color_tuple_to_css, callable_name
11
13
 
@@ -14,7 +16,6 @@ from bencher.variables.inputs import with_level
14
16
 
15
17
  from bencher.variables.results import OptDir
16
18
  from copy import deepcopy
17
- from bencher.results.optuna_result import OptunaResult
18
19
  from bencher.variables.results import ResultVar
19
20
  from bencher.plotting.plot_filter import VarRange, PlotFilter
20
21
  from bencher.utils import listify
@@ -25,6 +26,13 @@ from bencher.results.composable_container.composable_container_panel import (
25
26
  ComposableContainerPanel,
26
27
  )
27
28
 
29
+ from collections import defaultdict
30
+
31
+ import pandas as pd
32
+
33
+ from bencher.bench_cfg import BenchCfg
34
+ from bencher.plotting.plt_cnt_cfg import PltCntCfg
35
+
28
36
  # todo add plugins
29
37
  # https://gist.github.com/dorneanu/cce1cd6711969d581873a88e0257e312
30
38
  # https://kaleidoescape.github.io/decorated-plugins/
@@ -52,7 +60,87 @@ class EmptyContainer:
52
60
  return self.pane if len(self.pane) > 0 else None
53
61
 
54
62
 
55
- class BenchResultBase(OptunaResult):
63
+ def convert_dataset_bool_dims_to_str(dataset: xr.Dataset) -> xr.Dataset:
64
+ """Given a dataarray that contains boolean coordinates, convert them to strings so that holoviews loads the data properly
65
+
66
+ Args:
67
+ dataarray (xr.DataArray): dataarray with boolean coordinates
68
+
69
+ Returns:
70
+ xr.DataArray: dataarray with boolean coordinates converted to strings
71
+ """
72
+ bool_coords = {}
73
+ for c in dataset.coords:
74
+ if dataset.coords[c].dtype == bool:
75
+ bool_coords[c] = [str(vals) for vals in dataset.coords[c].values]
76
+
77
+ if len(bool_coords) > 0:
78
+ return dataset.assign_coords(bool_coords)
79
+ return dataset
80
+
81
+
82
+ class BenchResultBase:
83
+ def __init__(self, bench_cfg: BenchCfg) -> None:
84
+ self.bench_cfg = bench_cfg
85
+ # self.wrap_long_time_labels(bench_cfg) # todo remove
86
+ self.ds = xr.Dataset()
87
+ self.object_index = []
88
+ self.hmaps = defaultdict(dict)
89
+ self.result_hmaps = bench_cfg.result_hmaps
90
+ self.studies = []
91
+ self.plt_cnt_cfg = PltCntCfg()
92
+ self.plot_inputs = []
93
+ self.dataset_list = []
94
+
95
+ # self.width=600/
96
+ # self.height=600
97
+
98
+ # bench_res.objects.append(rv)
99
+ # bench_res.reference_index = len(bench_res.objects)
100
+
101
+ def to_xarray(self) -> xr.Dataset:
102
+ return self.ds
103
+
104
+ def setup_object_index(self):
105
+ self.object_index = []
106
+
107
+ def to_pandas(self, reset_index=True) -> pd.DataFrame:
108
+ """Get the xarray results as a pandas dataframe
109
+
110
+ Returns:
111
+ pd.DataFrame: The xarray results array as a pandas dataframe
112
+ """
113
+ ds = self.to_xarray().to_dataframe()
114
+ return ds.reset_index() if reset_index else ds
115
+
116
+ def wrap_long_time_labels(self, bench_cfg):
117
+ """Takes a benchCfg and wraps any index labels that are too long to be plotted easily
118
+
119
+ Args:
120
+ bench_cfg (BenchCfg):
121
+
122
+ Returns:
123
+ BenchCfg: updated config with wrapped labels
124
+ """
125
+ if bench_cfg.over_time:
126
+ if self.ds.coords["over_time"].dtype == np.datetime64:
127
+ # plotly catastrophically fails to plot anything with the default long string representation of time, so convert to a shorter time representation
128
+ self.ds.coords["over_time"] = [
129
+ pd.to_datetime(t).strftime("%d-%m-%y %H-%M-%S")
130
+ for t in self.ds.coords["over_time"].values
131
+ ]
132
+ # wrap very long time event labels because otherwise the graphs are unreadable
133
+ if bench_cfg.time_event is not None:
134
+ self.ds.coords["over_time"] = [
135
+ "\n".join(wrap(t, 20)) for t in self.ds.coords["over_time"].values
136
+ ]
137
+ return bench_cfg
138
+
139
+ def post_setup(self):
140
+ self.plt_cnt_cfg = PltCntCfg.generate_plt_cnt_cfg(self.bench_cfg)
141
+ self.bench_cfg = self.wrap_long_time_labels(self.bench_cfg)
142
+ self.ds = convert_dataset_bool_dims_to_str(self.ds)
143
+
56
144
  def result_samples(self) -> int:
57
145
  """The number of samples in the results dataframe"""
58
146
  return self.ds.count()
@@ -80,7 +168,7 @@ class BenchResultBase(OptunaResult):
80
168
  def to_dataset(
81
169
  self,
82
170
  reduce: ReduceType = ReduceType.AUTO,
83
- result_var: ResultVar = None,
171
+ result_var: ResultVar | str = None,
84
172
  level: int = None,
85
173
  ) -> xr.Dataset:
86
174
  """Generate a summarised xarray dataset.
@@ -97,7 +185,15 @@ class BenchResultBase(OptunaResult):
97
185
  ds_out = self.ds.copy()
98
186
 
99
187
  if result_var is not None:
100
- ds_out = ds_out[result_var.name].to_dataset(name=result_var.name)
188
+ if isinstance(result_var, Parameter):
189
+ var_name = result_var.name
190
+ elif isinstance(result_var, str):
191
+ var_name = result_var
192
+ else:
193
+ raise TypeError(
194
+ f"Unsupported type for result_var: {type(result_var)}. Expected Parameter or str."
195
+ )
196
+ ds_out = ds_out[var_name].to_dataset(name=var_name)
101
197
 
102
198
  def rename_ds(dataset: xr.Dataset, suffix: str):
103
199
  # var_name =
@@ -219,9 +315,6 @@ class BenchResultBase(OptunaResult):
219
315
  def describe_sweep(self):
220
316
  return self.bench_cfg.describe_sweep()
221
317
 
222
- def get_best_holomap(self, name: str = None):
223
- return self.get_hmap(name)[self.get_best_trial_params(True)]
224
-
225
318
  def get_hmap(self, name: str = None):
226
319
  try:
227
320
  if name is None:
@@ -362,7 +455,7 @@ class BenchResultBase(OptunaResult):
362
455
  pane_collection: pn.pane = None,
363
456
  override=False,
364
457
  **kwargs,
365
- ):
458
+ ) -> Optional[pn.panel]:
366
459
  plot_filter = PlotFilter(
367
460
  float_range=float_range,
368
461
  cat_range=cat_range,
@@ -558,3 +651,19 @@ class BenchResultBase(OptunaResult):
558
651
 
559
652
  def to_description(self, width: int = 800) -> pn.pane.Markdown:
560
653
  return self.bench_cfg.to_description(width)
654
+
655
+ def set_plot_size(self, **kwargs) -> dict:
656
+ if "width" not in kwargs:
657
+ if self.bench_cfg.plot_size is not None:
658
+ kwargs["width"] = self.bench_cfg.plot_size
659
+ # specific width overrides general size
660
+ if self.bench_cfg.plot_width is not None:
661
+ kwargs["width"] = self.bench_cfg.plot_width
662
+
663
+ if "height" not in kwargs:
664
+ if self.bench_cfg.plot_size is not None:
665
+ kwargs["height"] = self.bench_cfg.plot_size
666
+ # specific height overrides general size
667
+ if self.bench_cfg.plot_height is not None:
668
+ kwargs["height"] = self.bench_cfg.plot_height
669
+ return kwargs