holobench 1.3.4__py3-none-any.whl → 1.22.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. bencher/__init__.py +4 -1
  2. bencher/bench_cfg.py +37 -34
  3. bencher/bench_plot_server.py +5 -1
  4. bencher/bench_report.py +14 -14
  5. bencher/bench_runner.py +2 -1
  6. bencher/bencher.py +87 -50
  7. bencher/class_enum.py +52 -0
  8. bencher/job.py +6 -4
  9. bencher/optuna_conversions.py +1 -1
  10. bencher/utils.py +58 -3
  11. bencher/video_writer.py +110 -6
  12. holobench-1.22.2.data/data/share/bencher/package.xml +33 -0
  13. holobench-1.22.2.dist-info/LICENSE +21 -0
  14. {holobench-1.3.4.dist-info → holobench-1.22.2.dist-info}/METADATA +39 -32
  15. holobench-1.22.2.dist-info/RECORD +20 -0
  16. {holobench-1.3.4.dist-info → holobench-1.22.2.dist-info}/WHEEL +2 -1
  17. holobench-1.22.2.dist-info/top_level.txt +1 -0
  18. bencher/example/benchmark_data.py +0 -200
  19. bencher/example/example_all.py +0 -45
  20. bencher/example/example_categorical.py +0 -99
  21. bencher/example/example_custom_sweep.py +0 -59
  22. bencher/example/example_docs.py +0 -34
  23. bencher/example/example_float3D.py +0 -101
  24. bencher/example/example_float_cat.py +0 -98
  25. bencher/example/example_floats.py +0 -89
  26. bencher/example/example_floats2D.py +0 -93
  27. bencher/example/example_holosweep.py +0 -104
  28. bencher/example/example_holosweep_objects.py +0 -111
  29. bencher/example/example_holosweep_tap.py +0 -144
  30. bencher/example/example_image.py +0 -82
  31. bencher/example/example_levels.py +0 -181
  32. bencher/example/example_pareto.py +0 -53
  33. bencher/example/example_sample_cache.py +0 -85
  34. bencher/example/example_sample_cache_context.py +0 -116
  35. bencher/example/example_simple.py +0 -134
  36. bencher/example/example_simple_bool.py +0 -34
  37. bencher/example/example_simple_cat.py +0 -47
  38. bencher/example/example_simple_float.py +0 -38
  39. bencher/example/example_strings.py +0 -46
  40. bencher/example/example_time_event.py +0 -62
  41. bencher/example/example_video.py +0 -98
  42. bencher/example/example_workflow.py +0 -189
  43. bencher/example/experimental/example_bokeh_plotly.py +0 -38
  44. bencher/example/experimental/example_hover_ex.py +0 -45
  45. bencher/example/experimental/example_hvplot_explorer.py +0 -39
  46. bencher/example/experimental/example_interactive.py +0 -75
  47. bencher/example/experimental/example_streamnd.py +0 -49
  48. bencher/example/experimental/example_streams.py +0 -36
  49. bencher/example/experimental/example_template.py +0 -40
  50. bencher/example/experimental/example_updates.py +0 -84
  51. bencher/example/experimental/example_vector.py +0 -84
  52. bencher/example/meta/example_meta.py +0 -171
  53. bencher/example/meta/example_meta_cat.py +0 -25
  54. bencher/example/meta/example_meta_float.py +0 -23
  55. bencher/example/meta/example_meta_levels.py +0 -26
  56. bencher/example/optuna/example_optuna.py +0 -78
  57. bencher/example/shelved/example_float2D_scatter.py +0 -109
  58. bencher/example/shelved/example_float3D_cone.py +0 -96
  59. bencher/example/shelved/example_kwargs.py +0 -63
  60. bencher/plotting/__init__.py +0 -0
  61. bencher/plotting/plot_filter.py +0 -110
  62. bencher/plotting/plt_cnt_cfg.py +0 -74
  63. bencher/results/__init__.py +0 -0
  64. bencher/results/bench_result.py +0 -83
  65. bencher/results/bench_result_base.py +0 -401
  66. bencher/results/float_formatter.py +0 -44
  67. bencher/results/holoview_result.py +0 -535
  68. bencher/results/optuna_result.py +0 -332
  69. bencher/results/panel_result.py +0 -113
  70. bencher/results/plotly_result.py +0 -65
  71. bencher/variables/inputs.py +0 -193
  72. bencher/variables/parametrised_sweep.py +0 -206
  73. bencher/variables/results.py +0 -176
  74. bencher/variables/sweep_base.py +0 -167
  75. bencher/variables/time.py +0 -74
  76. holobench-1.3.4.dist-info/RECORD +0 -74
  77. /bencher/example/__init__.py → /holobench-1.22.2.data/data/share/ament_index/resource_index/packages/bencher +0 -0
@@ -1,332 +0,0 @@
1
- from __future__ import annotations
2
- from typing import List
3
- import numpy as np
4
- import optuna
5
- import panel as pn
6
- from collections import defaultdict
7
- from textwrap import wrap
8
-
9
- import pandas as pd
10
- import xarray as xr
11
-
12
-
13
- from optuna.visualization import (
14
- plot_param_importances,
15
- plot_pareto_front,
16
- )
17
- from bencher.utils import hmap_canonical_input
18
- from bencher.variables.time import TimeSnapshot, TimeEvent
19
- from bencher.bench_cfg import BenchCfg
20
- from bencher.plotting.plt_cnt_cfg import PltCntCfg
21
-
22
-
23
- # from bencher.results.bench_result_base import BenchResultBase
24
- from bencher.optuna_conversions import (
25
- sweep_var_to_optuna_dist,
26
- summarise_trial,
27
- param_importance,
28
- optuna_grid_search,
29
- summarise_optuna_study,
30
- sweep_var_to_suggest,
31
- )
32
-
33
-
34
- def convert_dataset_bool_dims_to_str(dataset: xr.Dataset) -> xr.Dataset:
35
- """Given a dataarray that contains boolean coordinates, conver them to strings so that holoviews loads the data properly
36
-
37
- Args:
38
- dataarray (xr.DataArray): dataarray with boolean coordinates
39
-
40
- Returns:
41
- xr.DataArray: dataarray with boolean coordinates converted to strings
42
- """
43
- bool_coords = {}
44
- for c in dataset.coords:
45
- if dataset.coords[c].dtype == bool:
46
- bool_coords[c] = [str(vals) for vals in dataset.coords[c].values]
47
-
48
- if len(bool_coords) > 0:
49
- return dataset.assign_coords(bool_coords)
50
- return dataset
51
-
52
-
53
- class OptunaResult:
54
- def __init__(self, bench_cfg: BenchCfg) -> None:
55
- self.bench_cfg = bench_cfg
56
- # self.wrap_long_time_labels(bench_cfg) # todo remove
57
- self.ds = xr.Dataset()
58
- self.object_index = []
59
- self.hmaps = defaultdict(dict)
60
- self.result_hmaps = bench_cfg.result_hmaps
61
- self.studies = []
62
- self.plt_cnt_cfg = PltCntCfg()
63
- self.plot_inputs = []
64
-
65
- # self.width=600/
66
- # self.height=600
67
-
68
- # bench_res.objects.append(rv)
69
- # bench_res.reference_index = len(bench_res.objects)
70
-
71
- def post_setup(self):
72
- self.plt_cnt_cfg = PltCntCfg.generate_plt_cnt_cfg(self.bench_cfg)
73
- self.bench_cfg = self.wrap_long_time_labels(self.bench_cfg)
74
- self.ds = convert_dataset_bool_dims_to_str(self.ds)
75
-
76
- def to_xarray(self) -> xr.Dataset:
77
- return self.ds
78
-
79
- def setup_object_index(self):
80
- self.object_index = []
81
-
82
- def to_pandas(self, reset_index=True) -> pd.DataFrame:
83
- """Get the xarray results as a pandas dataframe
84
-
85
- Returns:
86
- pd.DataFrame: The xarray results array as a pandas dataframe
87
- """
88
- ds = self.to_xarray().to_dataframe()
89
- if reset_index:
90
- return ds.reset_index()
91
- return ds
92
-
93
- def wrap_long_time_labels(self, bench_cfg):
94
- """Takes a benchCfg and wraps any index labels that are too long to be plotted easily
95
-
96
- Args:
97
- bench_cfg (BenchCfg):
98
-
99
- Returns:
100
- BenchCfg: updated config with wrapped labels
101
- """
102
- if bench_cfg.over_time:
103
- if self.ds.coords["over_time"].dtype == np.datetime64:
104
- # plotly catastrophically fails to plot anything with the default long string representation of time, so convert to a shorter time representation
105
- self.ds.coords["over_time"] = [
106
- pd.to_datetime(t).strftime("%d-%m-%y %H-%M-%S")
107
- for t in self.ds.coords["over_time"].values
108
- ]
109
- # wrap very long time event labels because otherwise the graphs are unreadable
110
- if bench_cfg.time_event is not None:
111
- self.ds.coords["over_time"] = [
112
- "\n".join(wrap(t, 20)) for t in self.ds.coords["over_time"].values
113
- ]
114
- return bench_cfg
115
-
116
- def to_optuna_plots(self) -> List[pn.pane.panel]:
117
- """Create an optuna summary from the benchmark results
118
-
119
- Returns:
120
- List[pn.pane.panel]: A list of optuna plot summarising the benchmark process
121
- """
122
-
123
- return self.collect_optuna_plots()
124
-
125
- def to_optuna_from_sweep(self, bench, n_trials=30):
126
- optu = self.to_optuna_from_results(
127
- bench.worker, n_trials=n_trials, extra_results=bench.results
128
- )
129
- return summarise_optuna_study(optu)
130
-
131
- def to_optuna_from_results(
132
- self,
133
- worker,
134
- n_trials=100,
135
- extra_results: List[OptunaResult] = None,
136
- sampler=optuna.samplers.TPESampler(),
137
- ):
138
- directions = []
139
- for rv in self.bench_cfg.optuna_targets(True):
140
- directions.append(rv.direction)
141
-
142
- study = optuna.create_study(
143
- sampler=sampler, directions=directions, study_name=self.bench_cfg.title
144
- )
145
-
146
- # add already calculated results
147
- results_list = extra_results if extra_results is not None else [self]
148
- for res in results_list:
149
- if len(res.ds.sizes) > 0:
150
- study.add_trials(res.bench_results_to_optuna_trials(True))
151
-
152
- def wrapped(trial) -> tuple:
153
- kwargs = {}
154
- for iv in self.bench_cfg.input_vars:
155
- kwargs[iv.name] = sweep_var_to_suggest(iv, trial)
156
- result = worker(**kwargs)
157
- output = []
158
- for rv in self.bench_cfg.result_vars:
159
- output.append(result[rv.name])
160
- return tuple(output)
161
-
162
- study.optimize(wrapped, n_trials=n_trials)
163
- return study
164
-
165
- def bench_results_to_optuna_trials(self, include_meta: bool = True) -> optuna.Study:
166
- """Convert an xarray dataset to an optuna study so optuna can further optimise or plot the statespace
167
-
168
- Args:
169
- bench_cfg (BenchCfg): benchmark config to convert
170
-
171
- Returns:
172
- optuna.Study: optuna description of the study
173
- """
174
- if include_meta:
175
- df = self.to_pandas()
176
- all_vars = []
177
- for v in self.bench_cfg.all_vars:
178
- if type(v) != TimeEvent:
179
- all_vars.append(v)
180
-
181
- print("All vars", all_vars)
182
- else:
183
- all_vars = self.bench_cfg.input_vars
184
- # df = self.ds.
185
- # if "repeat" in self.
186
- # if self.bench_cfg.repeats>1:
187
- # df = self.bench_cfg.ds.mean("repeat").to_dataframe().reset_index()
188
- # else:
189
- df = self.to_pandas().reset_index()
190
- # df = self.bench_cfg.ds.mean("repeat").to_dataframe.reset_index()
191
- # self.bench_cfg.all_vars
192
- # del self.bench_cfg.meta_vars[1]
193
-
194
- trials = []
195
- distributions = {}
196
- for i in all_vars:
197
- distributions[i.name] = sweep_var_to_optuna_dist(i)
198
-
199
- for row in df.iterrows():
200
- params = {}
201
- values = []
202
- for i in all_vars:
203
- if type(i) == TimeSnapshot:
204
- if type(row[1][i.name]) == np.datetime64:
205
- params[i.name] = row[1][i.name].timestamp()
206
- else:
207
- params[i.name] = row[1][i.name]
208
-
209
- for r in self.bench_cfg.optuna_targets():
210
- values.append(row[1][r])
211
-
212
- trials.append(
213
- optuna.trial.create_trial(
214
- params=params,
215
- distributions=distributions,
216
- values=values,
217
- )
218
- )
219
- return trials
220
-
221
- def bench_result_to_study(self, include_meta: bool) -> optuna.Study:
222
- trials = self.bench_results_to_optuna_trials(include_meta)
223
- study = optuna_grid_search(self.bench_cfg)
224
- optuna.logging.set_verbosity(optuna.logging.CRITICAL)
225
- import warnings
226
-
227
- # /usr/local/lib/python3.10/dist-packages/optuna/samplers/_grid.py:224: UserWarning: over_time contains a value with the type of <class 'pandas._libs.tslibs.timestamps.Timestamp'>, which is not supported by `GridSampler`. Please make sure a value is `str`, `int`, `float`, `bool` or `None` for persistent storage.
228
-
229
- # this is not disabling the warning
230
- warnings.filterwarnings(action="ignore", category=UserWarning)
231
- # remove optuna gridsearch warning as we are not using their gridsearch because it has the most inexplicably terrible performance I have ever seen in my life. How can a for loop of 400 iterations start out with 100ms per loop and increase to greater than a 1000ms after 250ish iterations!?!?!??!!??!
232
- study.add_trials(trials)
233
- return study
234
-
235
- def get_best_trial_params(self, canonical=False):
236
- studies = self.bench_result_to_study(True)
237
- out = studies.best_trials[0].params
238
- if canonical:
239
- return hmap_canonical_input(out)
240
- return out
241
-
242
- def get_pareto_front_params(self):
243
- return [p.params for p in self.studies[0].trials]
244
-
245
- def collect_optuna_plots(self) -> List[pn.pane.panel]:
246
- """Use optuna to plot various summaries of the optimisation
247
-
248
- Args:
249
- study (optuna.Study): The study to plot
250
- bench_cfg (BenchCfg): Benchmark config with options used to generate the study
251
-
252
- Returns:
253
- List[pn.pane.Pane]: A list of plots
254
- """
255
-
256
- self.studies = [self.bench_result_to_study(True)]
257
- titles = ["# Analysis"]
258
- if self.bench_cfg.repeats > 1:
259
- self.studies.append(self.bench_result_to_study(False))
260
- titles = [
261
- "# Parameter Importance With Repeats",
262
- "# Parameter Importance Without Repeats",
263
- ]
264
-
265
- study_repeats_pane = pn.Row()
266
- for study, title in zip(self.studies, titles):
267
- study_pane = pn.Column()
268
- target_names = self.bench_cfg.optuna_targets()
269
- param_str = []
270
-
271
- study_pane.append(pn.pane.Markdown(title))
272
-
273
- if len(target_names) > 1:
274
- if len(target_names) <= 3:
275
- study_pane.append(
276
- plot_pareto_front(
277
- study, target_names=target_names, include_dominated_trials=False
278
- )
279
- )
280
- else:
281
- print("plotting pareto front of first 3 result variables")
282
- study_pane.append(
283
- plot_pareto_front(
284
- study,
285
- targets=lambda t: (t.values[0], t.values[1], t.values[2]),
286
- target_names=target_names[:3],
287
- include_dominated_trials=False,
288
- )
289
- )
290
-
291
- study_pane.append(param_importance(self.bench_cfg, study))
292
- param_str.append(
293
- f" Number of trials on the Pareto front: {len(study.best_trials)}"
294
- )
295
- for t in study.best_trials:
296
- param_str.extend(summarise_trial(t, self.bench_cfg))
297
-
298
- else:
299
- # cols.append(plot_optimization_history(study)) #TODO, maybe more clever when this is plotted?
300
-
301
- # If there is only 1 parameter then there is no point is plotting relative importance. Only worth plotting if there are multiple repeats of the same value so that you can compare the parameter vs to repeat to get a sense of the how much chance affects the results
302
- # if bench_cfg.repeats > 1 and len(bench_cfg.input_vars) > 1: #old code, not sure if its right
303
- if len(self.bench_cfg.input_vars) > 1:
304
- study_pane.append(plot_param_importances(study, target_name=target_names[0]))
305
-
306
- param_str.extend(summarise_trial(study.best_trial, self.bench_cfg))
307
-
308
- kwargs = {"height": 500, "scroll": True} if len(param_str) > 30 else {}
309
-
310
- param_str = "\n".join(param_str)
311
- study_pane.append(
312
- pn.Row(pn.pane.Markdown(f"## Best Parameters\n```text\n{param_str}"), **kwargs),
313
- )
314
-
315
- study_repeats_pane.append(study_pane)
316
-
317
- return study_repeats_pane
318
-
319
- # def extract_study_to_dataset(study: optuna.Study, bench_cfg: BenchCfg) -> BenchCfg:
320
- # """Extract an optuna study into an xarray dataset for easy plotting
321
-
322
- # Args:
323
- # study (optuna.Study): The result of a gridsearch
324
- # bench_cfg (BenchCfg): Options for the grid search
325
-
326
- # Returns:
327
- # BenchCfg: An updated config with the results included
328
- # """
329
- # for t in study.trials:
330
- # for it, rv in enumerate(bench_cfg.result_vars):
331
- # bench_cfg.ds[rv.name].loc[t.params] = t.values[it]
332
- # return bench_cfg
@@ -1,113 +0,0 @@
1
- from typing import Optional, Any
2
- from pathlib import Path
3
- from functools import partial
4
- import panel as pn
5
- import xarray as xr
6
- from param import Parameter
7
- from bencher.results.bench_result_base import BenchResultBase, ReduceType
8
- from bencher.variables.results import (
9
- ResultVideo,
10
- ResultContainer,
11
- ResultReference,
12
- PANEL_TYPES,
13
- )
14
-
15
-
16
- class PanelResult(BenchResultBase):
17
- def to_video(self, **kwargs):
18
- return self.map_plots(partial(self.to_video_multi, **kwargs))
19
-
20
- def to_video_multi(self, result_var: Parameter, **kwargs) -> Optional[pn.Column]:
21
- if isinstance(result_var, (ResultVideo, ResultContainer)):
22
- vid_p = []
23
-
24
- xr_dataset = self.to_hv_dataset(ReduceType.SQUEEZE)
25
-
26
- def to_video_da(da, **kwargs):
27
- if da is not None and Path(da).exists():
28
- vid = pn.pane.Video(da, autoplay=True, **kwargs)
29
- vid.loop = True
30
- vid_p.append(vid)
31
- return vid
32
- return pn.pane.Markdown(f"video does not exist {da}")
33
-
34
- plot_callback = partial(self.ds_to_container, container=partial(to_video_da, **kwargs))
35
-
36
- panes = self.to_panes_multi_panel(
37
- xr_dataset, result_var, plot_callback=plot_callback, target_dimension=0
38
- )
39
-
40
- def play_vid(_): # pragma: no cover
41
- for r in vid_p:
42
- r.paused = False
43
- r.loop = False
44
-
45
- def pause_vid(_): # pragma: no cover
46
- for r in vid_p:
47
- r.paused = True
48
-
49
- def reset_vid(_): # pragma: no cover
50
- for r in vid_p:
51
- r.paused = False
52
- r.time = 0
53
-
54
- def loop_vid(_): # pragma: no cover
55
- for r in vid_p:
56
- r.paused = False
57
- r.time = 0
58
- r.loop = True
59
-
60
- button_names = ["Play Videos", "Pause Videos", "Loop Videos", "Reset Videos"]
61
- buttom_cb = [play_vid, pause_vid, reset_vid, loop_vid]
62
- buttons = pn.Row()
63
-
64
- for name, cb in zip(button_names, buttom_cb):
65
- button = pn.widgets.Button(name=name)
66
- pn.bind(cb, button, watch=True)
67
- buttons.append(button)
68
-
69
- return pn.Column(buttons, panes)
70
- return None
71
-
72
- def zero_dim_da_to_val(self, da_ds: xr.DataArray | xr.Dataset) -> Any:
73
- # todo this is really horrible, need to improve
74
- dim = None
75
- if isinstance(da_ds, xr.Dataset):
76
- dim = list(da_ds.keys())[0]
77
- da = da_ds[dim]
78
- else:
79
- da = da_ds
80
-
81
- for k in da.coords.keys():
82
- dim = k
83
- break
84
- if dim is None:
85
- return da_ds.values.squeeze().item()
86
- return da.expand_dims(dim).values[0]
87
-
88
- def ds_to_container(
89
- self, dataset: xr.Dataset, result_var: Parameter, container, **kwargs
90
- ) -> Any:
91
- val = self.zero_dim_da_to_val(dataset[result_var.name])
92
- if isinstance(result_var, ResultReference):
93
- ref = self.object_index[val]
94
- val = ref.obj
95
- if ref.container is not None:
96
- return ref.container(val, **kwargs)
97
- if container is not None:
98
- return container(val, styles={"background": "white"}, **kwargs)
99
- return val
100
-
101
- def to_panes(
102
- self, result_var: Parameter = None, target_dimension: int = 0, container=None, **kwargs
103
- ) -> Optional[pn.pane.panel]:
104
- if container is None:
105
- container = pn.pane.panel
106
- return self.map_plot_panes(
107
- partial(self.ds_to_container, container=container),
108
- hv_dataset=self.to_hv_dataset(ReduceType.SQUEEZE),
109
- target_dimension=target_dimension,
110
- result_var=result_var,
111
- result_types=PANEL_TYPES,
112
- **kwargs,
113
- )
@@ -1,65 +0,0 @@
1
- import panel as pn
2
- import plotly.graph_objs as go
3
- from typing import Optional
4
- import xarray as xr
5
-
6
- from param import Parameter
7
-
8
- from bencher.plotting.plot_filter import VarRange
9
- from bencher.results.bench_result_base import BenchResultBase, ReduceType
10
- from bencher.variables.results import ResultVar
11
-
12
-
13
- class PlotlyResult(BenchResultBase):
14
- def to_volume(self, result_var: Parameter = None, **kwargs):
15
- return self.filter(
16
- self.to_volume_da,
17
- float_range=VarRange(3, 3),
18
- cat_range=VarRange(-1, 0),
19
- reduce=ReduceType.REDUCE,
20
- target_dimension=3,
21
- result_var=result_var,
22
- result_types=(ResultVar),
23
- **kwargs,
24
- )
25
-
26
- def to_volume_da(
27
- self, dataset: xr.Dataset, result_var: Parameter, width=600, height=600
28
- ) -> Optional[pn.pane.Plotly]:
29
- """Given a benchCfg generate a 3D surface plot
30
- Returns:
31
- pn.pane.Plotly: A 3d volume plot as a holoview in a pane
32
- """
33
- x = self.bench_cfg.input_vars[0]
34
- y = self.bench_cfg.input_vars[1]
35
- z = self.bench_cfg.input_vars[2]
36
- opacity = 0.1
37
- meandf = dataset[result_var.name].to_dataframe().reset_index()
38
- data = [
39
- go.Volume(
40
- x=meandf[x.name],
41
- y=meandf[y.name],
42
- z=meandf[z.name],
43
- value=meandf[result_var.name],
44
- isomin=meandf[result_var.name].min(),
45
- isomax=meandf[result_var.name].max(),
46
- opacity=opacity,
47
- surface_count=20,
48
- )
49
- ]
50
-
51
- layout = go.Layout(
52
- title=f"{result_var.name} vs ({x.name} vs {y.name} vs {z.name})",
53
- width=width,
54
- height=height,
55
- margin=dict(t=50, b=50, r=50, l=50),
56
- scene=dict(
57
- xaxis_title=f"{x.name} [{x.units}]",
58
- yaxis_title=f"{y.name} [{y.units}]",
59
- zaxis_title=f"{z.name} [{z.units}]",
60
- ),
61
- )
62
-
63
- fig = dict(data=data, layout=layout)
64
-
65
- return pn.pane.Plotly(fig, name="volume_plotly")
@@ -1,193 +0,0 @@
1
- from enum import Enum
2
- from typing import List, Any
3
-
4
- import numpy as np
5
- from param import Integer, Number, Selector
6
- from bencher.variables.sweep_base import SweepBase, shared_slots
7
-
8
-
9
- class SweepSelector(Selector, SweepBase):
10
- """A class to reprsent a parameter sweep of bools"""
11
-
12
- __slots__ = shared_slots
13
-
14
- def __init__(self, units: str = "ul", samples: int = None, samples_debug: int = 2, **params):
15
- SweepBase.__init__(self)
16
- Selector.__init__(self, **params)
17
-
18
- self.units = units
19
- if samples is None:
20
- self.samples = len(self.objects)
21
- else:
22
- self.samples = samples
23
- self.samples_debug = min(self.samples, samples_debug)
24
-
25
- def values(self, debug=False) -> List[Any]:
26
- """return all the values for a parameter sweep. If debug is true return a reduced list"""
27
- return self.indices_to_samples(self.samples_debug if debug else self.samples, self.objects)
28
-
29
-
30
- class BoolSweep(SweepSelector):
31
- """A class to reprsent a parameter sweep of bools"""
32
-
33
- def __init__(
34
- self, units: str = "ul", samples: int = None, samples_debug: int = 2, default=True, **params
35
- ):
36
- SweepSelector.__init__(
37
- self,
38
- units=units,
39
- samples=samples,
40
- samples_debug=samples_debug,
41
- default=default,
42
- objects=[True, False] if default else [False, True],
43
- **params,
44
- )
45
-
46
-
47
- class StringSweep(SweepSelector):
48
- """A class to reprsent a parameter sweep of strings"""
49
-
50
- def __init__(
51
- self,
52
- string_list: List[str],
53
- units: str = "",
54
- samples: int = None,
55
- samples_debug: int = 2,
56
- **params,
57
- ):
58
- SweepSelector.__init__(
59
- self,
60
- objects=string_list,
61
- instantiate=True,
62
- units=units,
63
- samples=samples,
64
- samples_debug=samples_debug,
65
- **params,
66
- )
67
-
68
-
69
- class EnumSweep(SweepSelector):
70
- """A class to reprsent a parameter sweep of enums"""
71
-
72
- __slots__ = shared_slots
73
-
74
- def __init__(
75
- self, enum_type: Enum | List[Enum], units="", samples=None, samples_debug=2, **params
76
- ):
77
- # The enum can either be an Enum type or a list of enums
78
- list_of_enums = isinstance(enum_type, list)
79
- selector_list = enum_type if list_of_enums else list(enum_type)
80
- SweepSelector.__init__(
81
- self,
82
- objects=selector_list,
83
- instantiate=True,
84
- units=units,
85
- samples=samples,
86
- samples_debug=samples_debug,
87
- **params,
88
- )
89
- if not list_of_enums: # Grab the docs from the enum type def
90
- self.doc = enum_type.__doc__
91
-
92
-
93
- class IntSweep(Integer, SweepBase):
94
- """A class to reprsent a parameter sweep of ints"""
95
-
96
- __slots__ = shared_slots + ["sample_values"]
97
-
98
- def __init__(self, units="ul", samples=None, samples_debug=2, sample_values=None, **params):
99
- SweepBase.__init__(self)
100
- Integer.__init__(self, **params)
101
-
102
- self.units = units
103
- self.samples_debug = samples_debug
104
-
105
- if sample_values is None:
106
- if samples is None:
107
- if self.bounds is None:
108
- raise RuntimeError("You must define bounds for integer types")
109
- self.samples = 1 + self.bounds[1] - self.bounds[0]
110
- else:
111
- self.samples = samples
112
- self.sample_values = None
113
- else:
114
- self.sample_values = sample_values
115
- self.samples = len(self.sample_values)
116
- if "default" not in params:
117
- self.default = sample_values[0]
118
-
119
- def values(self, debug=False) -> List[int]:
120
- """return all the values for a parameter sweep. If debug is true return the list"""
121
- sample_values = (
122
- self.sample_values
123
- if self.sample_values is not None
124
- else list(range(int(self.bounds[0]), int(self.bounds[1] + 1)))
125
- )
126
-
127
- return self.indices_to_samples(self.samples_debug if debug else self.samples, sample_values)
128
-
129
- ###THESE ARE COPIES OF INTEGER VALIDATION BUT ALSO ALLOW NUMPY INT TYPES
130
- def _validate_value(self, val, allow_None):
131
- if callable(val):
132
- return
133
-
134
- if allow_None and val is None:
135
- return
136
-
137
- if not isinstance(val, (int, np.integer)):
138
- raise ValueError(
139
- "Integer parameter %r must be an integer, " "not type %r." % (self.name, type(val))
140
- )
141
-
142
- ###THESE ARE COPIES OF INTEGER VALIDATION BUT ALSO ALLOW NUMPY INT TYPES
143
- def _validate_step(self, val, step):
144
- if step is not None and not isinstance(step, (int, np.integer)):
145
- raise ValueError(
146
- "Step can only be None or an " "integer value, not type %r" % type(step)
147
- )
148
-
149
-
150
- class FloatSweep(Number, SweepBase):
151
- """A class to represent a parameter sweep of floats"""
152
-
153
- __slots__ = shared_slots + ["sample_values"]
154
-
155
- def __init__(
156
- self, units="ul", samples=10, samples_debug=2, sample_values=None, step=None, **params
157
- ):
158
- SweepBase.__init__(self)
159
- Number.__init__(self, step=step, **params)
160
-
161
- self.units = units
162
- self.samples_debug = samples_debug
163
-
164
- self.sample_values = sample_values
165
-
166
- if sample_values is None:
167
- self.samples = samples
168
- else:
169
- self.samples = len(self.sample_values)
170
- if "default" not in params:
171
- self.default = sample_values[0]
172
-
173
- def values(self, debug=False) -> List[float]:
174
- """return all the values for a parameter sweep. If debug is true return a reduced list"""
175
- samps = self.samples_debug if debug else self.samples
176
- if self.sample_values is None:
177
- if self.step is None:
178
- return np.linspace(self.bounds[0], self.bounds[1], samps)
179
-
180
- return np.arange(self.bounds[0], self.bounds[1], self.step)
181
- if debug:
182
- indices = [
183
- int(i)
184
- for i in np.linspace(0, len(self.sample_values) - 1, self.samples_debug, dtype=int)
185
- ]
186
- return [self.sample_values[i] for i in indices]
187
- return self.sample_values
188
-
189
-
190
- def box(name, center, width):
191
- var = FloatSweep(default=center, bounds=(center - width, center + width))
192
- var.name = name
193
- return var