holobench 1.3.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. bencher/__init__.py +41 -0
  2. bencher/bench_cfg.py +462 -0
  3. bencher/bench_plot_server.py +100 -0
  4. bencher/bench_report.py +268 -0
  5. bencher/bench_runner.py +136 -0
  6. bencher/bencher.py +805 -0
  7. bencher/caching.py +51 -0
  8. bencher/example/__init__.py +0 -0
  9. bencher/example/benchmark_data.py +200 -0
  10. bencher/example/example_all.py +45 -0
  11. bencher/example/example_categorical.py +99 -0
  12. bencher/example/example_custom_sweep.py +59 -0
  13. bencher/example/example_docs.py +34 -0
  14. bencher/example/example_float3D.py +101 -0
  15. bencher/example/example_float_cat.py +98 -0
  16. bencher/example/example_floats.py +89 -0
  17. bencher/example/example_floats2D.py +93 -0
  18. bencher/example/example_holosweep.py +104 -0
  19. bencher/example/example_holosweep_objects.py +111 -0
  20. bencher/example/example_holosweep_tap.py +144 -0
  21. bencher/example/example_image.py +82 -0
  22. bencher/example/example_levels.py +181 -0
  23. bencher/example/example_pareto.py +53 -0
  24. bencher/example/example_sample_cache.py +85 -0
  25. bencher/example/example_sample_cache_context.py +116 -0
  26. bencher/example/example_simple.py +134 -0
  27. bencher/example/example_simple_bool.py +34 -0
  28. bencher/example/example_simple_cat.py +47 -0
  29. bencher/example/example_simple_float.py +38 -0
  30. bencher/example/example_strings.py +46 -0
  31. bencher/example/example_time_event.py +62 -0
  32. bencher/example/example_video.py +124 -0
  33. bencher/example/example_workflow.py +189 -0
  34. bencher/example/experimental/example_bokeh_plotly.py +38 -0
  35. bencher/example/experimental/example_hover_ex.py +45 -0
  36. bencher/example/experimental/example_hvplot_explorer.py +39 -0
  37. bencher/example/experimental/example_interactive.py +75 -0
  38. bencher/example/experimental/example_streamnd.py +49 -0
  39. bencher/example/experimental/example_streams.py +36 -0
  40. bencher/example/experimental/example_template.py +40 -0
  41. bencher/example/experimental/example_updates.py +84 -0
  42. bencher/example/experimental/example_vector.py +84 -0
  43. bencher/example/meta/example_meta.py +171 -0
  44. bencher/example/meta/example_meta_cat.py +25 -0
  45. bencher/example/meta/example_meta_float.py +23 -0
  46. bencher/example/meta/example_meta_levels.py +26 -0
  47. bencher/example/optuna/example_optuna.py +78 -0
  48. bencher/example/shelved/example_float2D_scatter.py +109 -0
  49. bencher/example/shelved/example_float3D_cone.py +96 -0
  50. bencher/example/shelved/example_kwargs.py +63 -0
  51. bencher/job.py +184 -0
  52. bencher/optuna_conversions.py +168 -0
  53. bencher/plotting/__init__.py +0 -0
  54. bencher/plotting/plot_filter.py +110 -0
  55. bencher/plotting/plt_cnt_cfg.py +74 -0
  56. bencher/results/__init__.py +0 -0
  57. bencher/results/bench_result.py +80 -0
  58. bencher/results/bench_result_base.py +405 -0
  59. bencher/results/float_formatter.py +44 -0
  60. bencher/results/holoview_result.py +592 -0
  61. bencher/results/optuna_result.py +354 -0
  62. bencher/results/panel_result.py +113 -0
  63. bencher/results/plotly_result.py +65 -0
  64. bencher/utils.py +148 -0
  65. bencher/variables/inputs.py +193 -0
  66. bencher/variables/parametrised_sweep.py +206 -0
  67. bencher/variables/results.py +176 -0
  68. bencher/variables/sweep_base.py +167 -0
  69. bencher/variables/time.py +74 -0
  70. bencher/video_writer.py +30 -0
  71. bencher/worker_job.py +40 -0
  72. holobench-1.3.6.dist-info/METADATA +85 -0
  73. holobench-1.3.6.dist-info/RECORD +74 -0
  74. holobench-1.3.6.dist-info/WHEEL +5 -0
@@ -0,0 +1,109 @@
1
+ # import random
2
+
3
+ # import bencher as bch
4
+
5
+
6
+ # class GaussianDist(bch.ParametrizedSweep):
7
+ # """A class to represent a gaussian distribution."""
8
+
9
+ # mean = bch.FloatSweep(
10
+ # default=0, bounds=[-1.0, 1.0], doc="mean of the gaussian distribution", samples=3
11
+ # )
12
+ # sigma = bch.FloatSweep(
13
+ # default=1, bounds=[0, 1.0], doc="standard deviation of gaussian distribution", samples=4
14
+ # )
15
+
16
+
17
+ # class Example2DGaussianResult(bch.ParametrizedSweep):
18
+ # """A class to represent the properties of a volume sample."""
19
+
20
+ # gauss_x = bch.ResultVar("m", doc="x value of the 2D gaussian")
21
+ # gauss_y = bch.ResultVar("m", doc="y value of the 2D gaussian")
22
+
23
+ # point2D = bch.ResultVec(2, "m", doc="2D vector of the point")
24
+
25
+
26
+ # def bench_fn(dist: GaussianDist) -> Example2DGaussianResult:
27
+ # """This function samples a point from a gaussian distribution.
28
+
29
+ # Args:
30
+ # dist (GaussianDist): Sample point
31
+
32
+ # Returns:
33
+ # Example2DGaussianResult: Value at that point
34
+ # """
35
+ # output = Example2DGaussianResult()
36
+
37
+ # output.gauss_x = random.gauss(dist.mean, dist.sigma)
38
+ # output.gauss_y = random.gauss(dist.mean, dist.sigma)
39
+ # output.point2D = [output.gauss_x, output.gauss_y]
40
+
41
+ # return output
42
+
43
+
44
+ # def example_floats2D_scatter(run_cfg: bch.BenchRunCfg) -> bch.Bench:
45
+ # """Example of how to perform a 3D floating point parameter sweep
46
+
47
+ # Args:
48
+ # run_cfg (BenchRunCfg): configuration of how to perform the param sweep
49
+
50
+ # Returns:
51
+ # Bench: results of the parameter sweep
52
+ # """
53
+ # bench = bch.Bench(
54
+ # "Bencher_Example_Floats_Scatter", bench_fn, GaussianDist, plot_lib=bch.PlotLibrary.default()
55
+ # )
56
+
57
+ # bench.plot_sweep(
58
+ # result_vars=[
59
+ # Example2DGaussianResult.param.point2D,
60
+ # Example2DGaussianResult.param.gauss_x,
61
+ # Example2DGaussianResult.param.gauss_y,
62
+ # ],
63
+ # title="Float 2D Scatter Example",
64
+ # run_cfg=run_cfg,
65
+ # )
66
+
67
+ # bench.plot_sweep(
68
+ # input_vars=[GaussianDist.param.mean],
69
+ # result_vars=[
70
+ # Example2DGaussianResult.param.point2D,
71
+ # Example2DGaussianResult.param.gauss_x,
72
+ # Example2DGaussianResult.param.gauss_y,
73
+ # ],
74
+ # title="Float 2D Scatter With Changing Mean",
75
+ # run_cfg=run_cfg,
76
+ # )
77
+
78
+ # bench.plot_sweep(
79
+ # input_vars=[GaussianDist.param.sigma],
80
+ # result_vars=[
81
+ # Example2DGaussianResult.param.point2D,
82
+ # Example2DGaussianResult.param.gauss_x,
83
+ # Example2DGaussianResult.param.gauss_y,
84
+ # ],
85
+ # title="Float 2D Scatter With Changing Sigma",
86
+ # run_cfg=run_cfg,
87
+ # )
88
+
89
+ # # future work
90
+ # # bench.plot_sweep(
91
+ # # input_vars=[GaussianDist.param.mean, GaussianDist.param.sigma],
92
+ # # result_vars=[
93
+ # # GaussianResult.param.point2D,
94
+ # # GaussianResult.param.gauss_x,
95
+ # # GaussianResult.param.gauss_y,
96
+ # # ],
97
+ # # title="Float 2D Scatter With Changing Sigma",
98
+ # # run_cfg=run_cfg,
99
+ # # )
100
+
101
+ # return bench
102
+
103
+
104
+ # if __name__ == "__main__":
105
+ # ex_run_cfg = bch.BenchRunCfg()
106
+ # ex_run_cfg.repeats = 50
107
+ # ex_run_cfg.over_time = True
108
+ # # ex_run_cfg.clear_history = True
109
+ # example_floats2D_scatter(ex_run_cfg).report.show()
@@ -0,0 +1,96 @@
1
+ # import numpy as np
2
+
3
+ # import bencher as bch
4
+
5
+
6
+ # class VolumeSample(bch.ParametrizedSweep):
7
+ # """A class to represent a 3D point in space."""
8
+
9
+ # pos_samples = 1
10
+
11
+ # x = bch.FloatSweep(
12
+ # default=0, bounds=[-1.0, 1.0], samples=pos_samples, doc="x coordinate of the sample volume"
13
+ # )
14
+ # y = bch.FloatSweep(
15
+ # default=0, bounds=[-1.0, 1.0], samples=pos_samples, doc="y coordinate of the sample volume"
16
+ # )
17
+ # z = bch.FloatSweep(
18
+ # default=0, bounds=[-1.0, 1.0], samples=pos_samples, doc="z coordinate of the sample volume"
19
+ # )
20
+
21
+ # vec_samples = 5
22
+
23
+ # vec_x = bch.FloatSweep(
24
+ # default=0, bounds=[-1.0, 1.0], samples=vec_samples, doc="x coordinate of the sample volume"
25
+ # )
26
+ # vec_y = bch.FloatSweep(
27
+ # default=0, bounds=[-1.0, 1.0], samples=vec_samples, doc="y coordinate of the sample volume"
28
+ # )
29
+ # vec_z = bch.FloatSweep(
30
+ # default=0, bounds=[-1.0, 1.0], samples=vec_samples, doc="z coordinate of the sample volume"
31
+ # )
32
+
33
+
34
+ # class VolumeResult(bch.ParametrizedSweep):
35
+ # """A class to represent the properties of a volume sample."""
36
+
37
+ # vec_dir = bch.ResultVec(3, "vec", doc="A vector field with an interesting shape")
38
+
39
+
40
+ # def bench_fn(point: VolumeSample, normalise=False) -> VolumeResult:
41
+ # """This function takes a 3D point as input and returns distance of that point to the origin.
42
+
43
+ # Args:
44
+ # point (VolumeSample): Sample point
45
+
46
+ # Returns:
47
+ # VolumeResult: Value at that point
48
+ # """
49
+ # output = VolumeResult()
50
+
51
+ # vec = np.array([point.vec_x, point.vec_y, point.vec_z])
52
+
53
+ # if normalise:
54
+ # norm = np.linalg.norm(vec)
55
+ # if norm > 0:
56
+ # vec /= norm
57
+
58
+ # output.vec_dir = list(vec)
59
+ # return output
60
+
61
+
62
+ # def example_cone(run_cfg: bch.BenchRunCfg) -> bch.Bench:
63
+ # """Example of how to perform a 3D floating point parameter sweep
64
+
65
+ # Args:
66
+ # run_cfg (BenchRunCfg): configuration of how to perform the param sweep
67
+
68
+ # Returns:
69
+ # Bench: results of the parameter sweep
70
+ # """
71
+ # bench = bch.Bench("Bencher_Example_Cone", bench_fn, VolumeSample)
72
+
73
+ # bench.plot_sweep(
74
+ # input_vars=[
75
+ # VolumeSample.param.x,
76
+ # VolumeSample.param.y,
77
+ # VolumeSample.param.z,
78
+ # VolumeSample.param.vec_x,
79
+ # VolumeSample.param.vec_y,
80
+ # VolumeSample.param.vec_z,
81
+ # ],
82
+ # result_vars=[
83
+ # VolumeResult.param.vec_dir,
84
+ # ],
85
+ # title="Float 3D cone Example",
86
+ # description="""This example shows how to sample 3 floating point variables and plot a vector field representation of the results. The benchmark function returns the distance to the origin""",
87
+ # run_cfg=run_cfg,
88
+ # )
89
+
90
+ # return bench
91
+
92
+
93
+ # if __name__ == "__main__":
94
+ # ex_run_cfg = bch.BenchRunCfg()
95
+ # ex_run_cfg.use_cache = True
96
+ # example_cone(ex_run_cfg).report.show()
@@ -0,0 +1,63 @@
1
+ # import math
2
+
3
+ # import bencher as bch
4
+
5
+
6
+ # def bench_function(
7
+ # theta: float = 0,
8
+ # offset: float = 0,
9
+ # scale: float = 1.0,
10
+ # trig_func: str = "sin",
11
+ # **kwargs, # pylint: disable=unused-argument
12
+ # ) -> dict:
13
+ # """All the other examples use classes and parameters to define the inputs and outputs to the function. However it makes the code less flexible when integrating with other systems, so this example shows a more basic interface that accepts and returns dictionaries. The classes still need to be defined however because that is how the sweep and plotting settings are calcuated"""
14
+ # output = {}
15
+
16
+ # if trig_func == "sin":
17
+ # output["voltage"] = offset + math.sin(theta) * scale
18
+ # elif trig_func == "cos":
19
+ # output["voltage"] = offset + math.cos(theta) * scale
20
+
21
+ # return output
22
+
23
+
24
+ # class InputCfg(bch.ParametrizedSweep):
25
+ # """This class is used to define the default values and bounds of the variables to benchmark."""
26
+
27
+ # theta = bch.FloatSweep(
28
+ # default=0.0,
29
+ # bounds=[0.0, 6.0],
30
+ # doc="Input angle to the trig function",
31
+ # units="rad",
32
+ # samples=10,
33
+ # )
34
+
35
+ # offset = bch.FloatSweep(
36
+ # default=0.0,
37
+ # bounds=[0.0, 3.0],
38
+ # doc="Add an offset voltage to the result of the trig function",
39
+ # units="v",
40
+ # samples=5,
41
+ # )
42
+
43
+ # trig_func = bch.StringSweep(["sin", "cos"], doc="Select what trigonometric function use")
44
+
45
+
46
+ # class OutputVoltage(bch.ParametrizedSweep):
47
+ # voltage = bch.ResultVar(units="v", doc="Output voltage")
48
+
49
+
50
+ # if __name__ == "__main__":
51
+ # # pass the objective function you have defined to bencher. The other examples pass the InputCfg type, but this benchmark function accepts a kwargs dictionary so you don't need to pass the inputCfg type.
52
+ # bench = bch.Bench("Bencher_Example_Categorical", bench_function)
53
+
54
+ # # Bencher needs to know the metadata of the variable in order to automatically sweep and plot it, so it is passed by using param's metadata syntax. InputCfg.param.* is how to access the metadata defined in the class description.
55
+ # bench.plot_sweep(
56
+ # input_vars=[InputCfg.param.theta, InputCfg.param.offset],
57
+ # result_vars=[OutputVoltage.param.voltage],
58
+ # title="Example with kwarg inputs and dict output",
59
+ # description=bench_function.__doc__,
60
+ # )
61
+
62
+ # # launch web server and view
63
+ # bench.report.show()
bencher/job.py ADDED
@@ -0,0 +1,184 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+ from sortedcontainers import SortedDict
4
+ import logging
5
+ from diskcache import Cache
6
+ from concurrent.futures import Future, ProcessPoolExecutor
7
+ from .utils import hash_sha1
8
+ from strenum import StrEnum
9
+ from enum import auto
10
+
11
+ try:
12
+ from scoop import futures as scoop_future_executor
13
+ except ImportError as e:
14
+ logging.warning(e.msg)
15
+ scoop_future_executor = None
16
+
17
+
18
+ class Job:
19
+ def __init__(
20
+ self, job_id: str, function: Callable, job_args: dict, job_key=None, tag=""
21
+ ) -> None:
22
+ self.job_id = job_id
23
+ self.function = function
24
+ self.job_args = job_args
25
+ if job_key is None:
26
+ self.job_key = hash_sha1(tuple(SortedDict(self.job_args).items()))
27
+ else:
28
+ self.job_key = job_key
29
+ self.tag = tag
30
+
31
+
32
+ # @dataclass
33
+ class JobFuture:
34
+ def __init__(self, job: Job, res: dict = None, future: Future = None, cache=None) -> None:
35
+ self.job = job
36
+ self.res = res
37
+ self.future = future
38
+ # either a result or a future needs to be passed
39
+ assert self.res is not None or self.future is not None
40
+ self.cache = cache
41
+
42
+ def result(self):
43
+ if self.future is not None:
44
+ self.res = self.future.result()
45
+ if self.cache is not None and self.res is not None:
46
+ self.cache.set(self.job.job_key, self.res, tag=self.job.tag)
47
+ return self.res
48
+
49
+
50
+ def run_job(job: Job) -> dict:
51
+ result = job.function(**job.job_args)
52
+ return result
53
+
54
+
55
+ class Executors(StrEnum):
56
+ SERIAL = auto() # slow but reliable
57
+ MULTIPROCESSING = auto() # breaks for large number of futures
58
+ SCOOP = auto() # requires running with python -m scoop your_file.py
59
+ # THREADS=auto() #not that useful as most bench code is cpu bound
60
+
61
+ @staticmethod
62
+ def factory(provider: Executors) -> Future():
63
+ providers = {
64
+ Executors.SERIAL: None,
65
+ Executors.MULTIPROCESSING: ProcessPoolExecutor(),
66
+ Executors.SCOOP: scoop_future_executor,
67
+ }
68
+ return providers[provider]
69
+
70
+
71
+ class FutureCache:
72
+ """The aim of this class is to provide a unified interface for running jobs. T"""
73
+
74
+ def __init__(
75
+ self,
76
+ executor=Executors.SERIAL,
77
+ overwrite: bool = True,
78
+ cache_name: str = "fcache",
79
+ tag_index: bool = True,
80
+ size_limit: int = int(20e9), # 20 GB
81
+ use_cache=True,
82
+ ):
83
+ self.executor = Executors.factory(executor)
84
+ if use_cache:
85
+ self.cache = Cache(f"cachedir/{cache_name}", tag_index=tag_index, size_limit=size_limit)
86
+ logging.info(f"cache dir: {self.cache.directory}")
87
+ else:
88
+ self.cache = None
89
+
90
+ self.overwrite = overwrite
91
+ self.call_count = 0
92
+ self.size_limit = size_limit
93
+
94
+ self.worker_wrapper_call_count = 0
95
+ self.worker_fn_call_count = 0
96
+ self.worker_cache_call_count = 0
97
+
98
+ def submit(self, job: Job) -> JobFuture:
99
+ self.worker_wrapper_call_count += 1
100
+
101
+ if self.cache is not None:
102
+ if not self.overwrite and job.job_key in self.cache:
103
+ logging.info(f"Found job: {job.job_id} in cache, loading...")
104
+ # logging.info(f"Found key: {job.job_key} in cache")
105
+ self.worker_cache_call_count += 1
106
+ return JobFuture(
107
+ job=job,
108
+ res=self.cache[job.job_key],
109
+ )
110
+
111
+ self.worker_fn_call_count += 1
112
+
113
+ if self.executor is not None:
114
+ self.overwrite_msg(job, " starting parallel job...")
115
+ return JobFuture(
116
+ job=job,
117
+ future=self.executor.submit(run_job, job),
118
+ cache=self.cache,
119
+ )
120
+ self.overwrite_msg(job, " starting serial job...")
121
+ return JobFuture(
122
+ job=job,
123
+ res=run_job(job),
124
+ cache=self.cache,
125
+ )
126
+
127
+ def overwrite_msg(self, job: Job, suffix: str) -> None:
128
+ msg = "OVERWRITING" if self.overwrite else "NOT in"
129
+ logging.info(f"{job.job_id} {msg} cache{suffix}")
130
+
131
+ def clear_call_counts(self) -> None:
132
+ """Clear the worker and cache call counts, to help debug and assert caching is happening properly"""
133
+ self.worker_wrapper_call_count = 0
134
+ self.worker_fn_call_count = 0
135
+ self.worker_cache_call_count = 0
136
+
137
+ def clear_cache(self) -> None:
138
+ if self.cache:
139
+ self.cache.clear()
140
+
141
+ def clear_tag(self, tag: str) -> None:
142
+ logging.info(f"clearing the sample cache for tag: {tag}")
143
+ removed_vals = self.cache.evict(tag)
144
+ logging.info(f"removed: {removed_vals} items from the cache")
145
+
146
+ def close(self) -> None:
147
+ if self.cache:
148
+ self.cache.close()
149
+ if self.executor:
150
+ self.executor.shutdown()
151
+
152
+ # def __del__(self):
153
+ # self.close()
154
+
155
+ def stats(self) -> str:
156
+ logging.info(f"job calls: {self.worker_wrapper_call_count}")
157
+ logging.info(f"cache calls: {self.worker_cache_call_count}")
158
+ logging.info(f"worker calls: {self.worker_fn_call_count}")
159
+ if self.cache:
160
+ return f"cache size :{int(self.cache.volume() / 1000000)}MB / {int(self.size_limit/1000000)}MB"
161
+ return ""
162
+
163
+
164
+ class JobFunctionCache(FutureCache):
165
+ def __init__(
166
+ self,
167
+ function: Callable,
168
+ overwrite=False,
169
+ executor: bool = False,
170
+ cache_name: str = "fcache",
171
+ tag_index: bool = True,
172
+ size_limit: int = int(100e8),
173
+ ):
174
+ super().__init__(
175
+ executor=executor,
176
+ cache_name=cache_name,
177
+ tag_index=tag_index,
178
+ size_limit=size_limit,
179
+ overwrite=overwrite,
180
+ )
181
+ self.function = function
182
+
183
+ def call(self, **kwargs) -> JobFuture:
184
+ return self.submit(Job(self.call_count, self.function, kwargs))
@@ -0,0 +1,168 @@
1
+ from typing import List
2
+
3
+ import optuna
4
+ import panel as pn
5
+ import param
6
+ from optuna.visualization import (
7
+ plot_param_importances,
8
+ plot_pareto_front,
9
+ plot_optimization_history,
10
+ )
11
+
12
+ from bencher.bench_cfg import BenchCfg
13
+
14
+
15
+ from bencher.variables.inputs import IntSweep, FloatSweep, StringSweep, EnumSweep, BoolSweep
16
+ from bencher.variables.time import TimeSnapshot, TimeEvent
17
+
18
+ from bencher.variables.parametrised_sweep import ParametrizedSweep
19
+
20
+
21
+ # BENCH_CFG
22
+ def optuna_grid_search(bench_cfg: BenchCfg) -> optuna.Study:
23
+ """use optuna to perform a grid search
24
+
25
+ Args:
26
+ bench_cfg (BenchCfg): setting for grid search
27
+
28
+ Returns:
29
+ optuna.Study: results of grid search
30
+ """
31
+ search_space = {}
32
+ for iv in bench_cfg.all_vars:
33
+ search_space[iv.name] = iv.values(bench_cfg.debug)
34
+ directions = []
35
+ for rv in bench_cfg.optuna_targets(True):
36
+ directions.append(rv.direction)
37
+
38
+ study = optuna.create_study(
39
+ sampler=optuna.samplers.GridSampler(search_space),
40
+ directions=directions,
41
+ study_name=bench_cfg.title,
42
+ )
43
+ return study
44
+
45
+
46
+ # BENCH_CFG
47
+ def param_importance(bench_cfg: BenchCfg, study: optuna.Study) -> pn.Row:
48
+ col_importance = pn.Column()
49
+ for tgt in bench_cfg.optuna_targets():
50
+ col_importance.append(
51
+ pn.Column(
52
+ pn.pane.Markdown(f"## Parameter importance for: {tgt}"),
53
+ plot_param_importances(study, target=lambda t: t.values[0], target_name=tgt),
54
+ )
55
+ )
56
+ return col_importance
57
+
58
+
59
+ # BENCH_CFG
60
+ def summarise_trial(trial: optuna.trial, bench_cfg: BenchCfg) -> List[str]:
61
+ """Given a trial produce a string summary of the best results
62
+
63
+ Args:
64
+ trial (optuna.trial): trial to summarise
65
+ bench_cfg (BenchCfg): info about the trial
66
+
67
+ Returns:
68
+ List[str]: Summary of trial
69
+ """
70
+ sep = " "
71
+ output = []
72
+ output.append(f"Trial id:{trial.number}:")
73
+ output.append(f"{sep}Inputs:")
74
+ for k, v in trial.params.items():
75
+ output.append(f"{sep}{sep}{k}:{v}")
76
+ output.append(f"{sep}Results:")
77
+ for it, rv in enumerate(bench_cfg.optuna_targets()):
78
+ output.append(f"{sep}{sep}{rv}:{trial.values[it]}")
79
+ return output
80
+
81
+
82
+ def sweep_var_to_optuna_dist(var: param.Parameter) -> optuna.distributions.BaseDistribution:
83
+ """Convert a sweep var to an optuna distribution
84
+
85
+ Args:
86
+ var (param.Parameter): A sweep var
87
+
88
+ Raises:
89
+ ValueError: Unsupported var type
90
+
91
+ Returns:
92
+ optuna.distributions.BaseDistribution: Optuna representation of a sweep var
93
+ """
94
+
95
+ iv_type = type(var)
96
+ if iv_type == IntSweep:
97
+ return optuna.distributions.IntDistribution(var.bounds[0], var.bounds[1])
98
+ if iv_type == FloatSweep:
99
+ return optuna.distributions.FloatDistribution(var.bounds[0], var.bounds[1])
100
+ if iv_type in (EnumSweep, StringSweep):
101
+ return optuna.distributions.CategoricalDistribution(var.objects)
102
+ if iv_type == BoolSweep:
103
+ return optuna.distributions.CategoricalDistribution([False, True])
104
+ if iv_type == TimeSnapshot:
105
+ # return optuna.distributions.IntDistribution(0, sys.maxsize)
106
+ return optuna.distributions.FloatDistribution(0, 1e20)
107
+ # return optuna.distributions.CategoricalDistribution([])
108
+ # elif iv_type == TimeEvent:
109
+ # pass
110
+ # return optuna.distributions.CategoricalDistribution(["now"])
111
+
112
+ raise ValueError(f"This input type {iv_type} is not supported")
113
+
114
+
115
+ def sweep_var_to_suggest(iv: ParametrizedSweep, trial: optuna.trial) -> object:
116
+ """Converts from a sweep var to an optuna
117
+
118
+ Args:
119
+ iv (ParametrizedSweep): A parametrized sweep input variable
120
+ trial (optuna.trial): Optuna trial used to define the sample
121
+
122
+ Raises:
123
+ ValueError: Unsupported var type
124
+
125
+ Returns:
126
+ Any: A sampled variable (can be any type)
127
+ """
128
+ iv_type = type(iv)
129
+
130
+ if iv_type == IntSweep:
131
+ return trial.suggest_int(iv.name, iv.bounds[0], iv.bounds[1])
132
+ if iv_type == FloatSweep:
133
+ return trial.suggest_float(iv.name, iv.bounds[0], iv.bounds[1])
134
+ if iv_type in (EnumSweep, StringSweep):
135
+ return trial.suggest_categorical(iv.name, iv.objects)
136
+ if iv_type in (TimeSnapshot, TimeEvent):
137
+ pass # optuna does not like time
138
+ if iv_type == BoolSweep:
139
+ return trial.suggest_categorical(iv.name, [True, False])
140
+ raise ValueError(f"This input type {iv_type} is not supported")
141
+
142
+
143
+ def cfg_from_optuna_trial(
144
+ trial: optuna.trial, bench_cfg: BenchCfg, cfg_type: ParametrizedSweep
145
+ ) -> ParametrizedSweep:
146
+ cfg = cfg_type()
147
+ for iv in bench_cfg.input_vars:
148
+ cfg.param.set_param(iv.name, sweep_var_to_suggest(iv, trial))
149
+ for mv in bench_cfg.meta_vars:
150
+ sweep_var_to_suggest(mv, trial)
151
+ return cfg
152
+
153
+
154
+ def summarise_optuna_study(study: optuna.study.Study) -> pn.pane.panel:
155
+ """Summarise an optuna study in a panel format"""
156
+ row = pn.Column(name="Optimisation Results")
157
+ row.append(plot_optimization_history(study))
158
+ row.append(plot_param_importances(study))
159
+ try:
160
+ row.append(plot_pareto_front(study))
161
+ except Exception:
162
+ pass
163
+
164
+ row.append(
165
+ pn.pane.Markdown(f"```\nBest value: {study.best_value}\nParams: {study.best_params}```")
166
+ )
167
+
168
+ return row
File without changes