gpjax 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -19,7 +19,6 @@ from beartype.roar import BeartypeDecorHintPep585DeprecationWarning
19
19
  filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning)
20
20
 
21
21
  from gpjax import (
22
- decision_making,
23
22
  gps,
24
23
  integrators,
25
24
  kernels,
@@ -40,11 +39,10 @@ __license__ = "MIT"
40
39
  __description__ = "Didactic Gaussian processes in JAX"
41
40
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
41
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.9.4"
42
+ __version__ = "0.10.0"
44
43
 
45
44
  __all__ = [
46
45
  "base",
47
- "decision_making",
48
46
  "gps",
49
47
  "integrators",
50
48
  "kernels",
gpjax/citation.py CHANGED
@@ -10,11 +10,6 @@ from beartype.typing import (
10
10
  )
11
11
  from jaxlib.xla_extension import PjitFunction
12
12
 
13
- from gpjax.decision_making.test_functions import (
14
- Forrester,
15
- LogarithmicGoldsteinPrice,
16
- )
17
- from gpjax.decision_making.utility_functions import ThompsonSampling
18
13
  from gpjax.kernels import (
19
14
  RFF,
20
15
  ArcCosine,
@@ -149,41 +144,3 @@ def _(tree) -> PaperCitation:
149
144
  booktitle="Advances in neural information processing systems",
150
145
  citation_type="article",
151
146
  )
152
-
153
-
154
- ####################
155
- # Decision making citations
156
- ####################
157
- @cite.register(ThompsonSampling)
158
- def _(tree) -> PaperCitation:
159
- return PaperCitation(
160
- citation_key="wilson2020efficiently",
161
- title="Efficiently sampling functions from Gaussian process posteriors",
162
- authors="Wilson, James and Borovitskiy, Viacheslav and Terenin, Alexander and Mostowsky, Peter and Deisenroth, Marc",
163
- year="2020",
164
- booktitle="International Conference on Machine Learning",
165
- citation_type="article",
166
- )
167
-
168
-
169
- @cite.register(Forrester)
170
- def _(tree) -> BookCitation:
171
- return BookCitation(
172
- citation_key="forrester2008engineering",
173
- authors="Forrester, Alexander and Sobester, Andras and Keane, Andy",
174
- title="Engineering design via surrogate modelling: a practical guide",
175
- year="2008",
176
- publisher="John Wiley & Sons",
177
- )
178
-
179
-
180
- @cite.register(LogarithmicGoldsteinPrice)
181
- def _(tree) -> PaperCitation:
182
- return PaperCitation(
183
- citation_key="picheny2013benchmark",
184
- authors="Picheny, Victor and Wagner, Tobias and Ginsbourger, David",
185
- title="A benchmark of kriging-based infill criteria for noisy optimization",
186
- year="2013",
187
- booktitle="Structural and multidisciplinary optimization",
188
- citation_type="article",
189
- )
gpjax/distributions.py CHANGED
@@ -162,7 +162,9 @@ class GaussianDistribution(tfd.Distribution):
162
162
 
163
163
  return vmap(affine_transformation)(Z)
164
164
 
165
- def sample(self, seed: KeyArray, sample_shape: Tuple[int, ...]): # pylint: disable=useless-super-delegation
165
+ def sample(
166
+ self, seed: KeyArray, sample_shape: Tuple[int, ...]
167
+ ): # pylint: disable=useless-super-delegation
166
168
  r"""See `Distribution.sample`."""
167
169
  return self._sample_n(
168
170
  seed, sample_shape[0]
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.9.4
3
+ Version: 0.10.0
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
7
7
  Project-URL: Source, https://github.com/JaxGaussianProcesses/GPJax
8
8
  Author-email: Thomas Pinder <tompinder@live.co.uk>
9
- License-Expression: Apache-2.0
10
- License-File: LICENSE
9
+ License: MIT
10
+ License-File: LICENSE.txt
11
11
  Keywords: gaussian-processes jax machine-learning bayesian
12
12
  Classifier: Development Status :: 4 - Beta
13
13
  Classifier: Programming Language :: Python
@@ -18,13 +18,12 @@ Classifier: Programming Language :: Python :: Implementation :: CPython
18
18
  Classifier: Programming Language :: Python :: Implementation :: PyPy
19
19
  Requires-Python: <3.13,>=3.10
20
20
  Requires-Dist: beartype>0.16.1
21
- Requires-Dist: cola-ml==0.0.5
22
- Requires-Dist: flax<0.10.0
23
- Requires-Dist: jax<0.4.28
24
- Requires-Dist: jaxlib<0.4.28
25
- Requires-Dist: jaxopt==0.8.2
21
+ Requires-Dist: cola-ml>=0.0.7
22
+ Requires-Dist: flax>=0.10.0
23
+ Requires-Dist: jax>=0.5.0
24
+ Requires-Dist: jaxlib>=0.5.0
26
25
  Requires-Dist: jaxtyping>0.2.10
27
- Requires-Dist: numpy<2.0.0
26
+ Requires-Dist: numpy>=2.0.0
28
27
  Requires-Dist: optax>0.2.1
29
28
  Requires-Dist: tensorflow-probability>=0.24.0
30
29
  Requires-Dist: tqdm>4.66.2
@@ -1,7 +1,7 @@
1
- gpjax/__init__.py,sha256=f1Sl-8Oz6YuEueKxvzIAL0iH_9b9xGzQv07tddS5wto,1697
2
- gpjax/citation.py,sha256=R4Pmvjt0ndA0avEDSvIbxDxKapkRRYXWX7RRWBvZCRQ,5306
1
+ gpjax/__init__.py,sha256=LeAdMRx9XYvLf6csLhCIv6IHnDbAFB9rP--TYXECgz0,1654
2
+ gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
- gpjax/distributions.py,sha256=zxkSEZIlTg0PHvvgj0BQuIFEg-ugx6_NkEwSsbqWUM0,9325
4
+ gpjax/distributions.py,sha256=X48FJr3reop9maherdMVt7-XZOm2f26T8AJt_IKM_oE,9339
5
5
  gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
6
6
  gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
@@ -13,20 +13,6 @@ gpjax/parameters.py,sha256=Z4Wy3gEzPZG23-dtqC437_ZWnd_sPe9LcLCKn21ZBvA,4886
13
13
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
15
  gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
16
- gpjax/decision_making/__init__.py,sha256=SDuPQl80lJ7nhfRsiB_7c22wCMiQO5ehSNohxUGnB7w,2170
17
- gpjax/decision_making/decision_maker.py,sha256=S4pOXrWcEHy0NDA0gfWzhk7pG0NJfaPpMXvq03yTy0g,13915
18
- gpjax/decision_making/posterior_handler.py,sha256=UgXf1Gu7GMh2YDSmiSWJIzmWlFW06KTS44HYz3mazZQ,5905
19
- gpjax/decision_making/search_space.py,sha256=bXwtMOhHZ2klnABpXm5Raxe7b0NTRDjo_cN3ecbk53Y,3545
20
- gpjax/decision_making/utility_maximizer.py,sha256=VT2amwSJbB64IL_MiWNl9ZgjcqO757qK6NW2gUBKsqs,5965
21
- gpjax/decision_making/utils.py,sha256=5j1GO5kcmG2laZR39NjhqgEjRekAWWzrnREv_5Zct_Y,2367
22
- gpjax/decision_making/test_functions/__init__.py,sha256=GDCY9_kaAnxDWwzo1FkdxnDx-80MErAHchbGybT9xYs,1109
23
- gpjax/decision_making/test_functions/continuous_functions.py,sha256=oL5ZQkvmbC3u9rEvSYI2DRAN3r7Ynf7wRZQlUWjKjt0,5612
24
- gpjax/decision_making/test_functions/non_conjugate_functions.py,sha256=eJpCnTS9dRieLxpjH4L6OTsP-w9JM3XhjnzCfk2Xqn8,2957
25
- gpjax/decision_making/utility_functions/__init__.py,sha256=xXI-4JKWAfTJ7XZ1vRDpqtb91MNzSPD0lP6xo0tOc7o,1445
26
- gpjax/decision_making/utility_functions/base.py,sha256=FOqrsRDmtHiCVl6IHr12-AEYBLStzMT5EBs-F92e1Og,3882
27
- gpjax/decision_making/utility_functions/expected_improvement.py,sha256=H6hjC-lj1oiHf2BomeQqroORQ7vtcOngiDAWxRwkNbg,4481
28
- gpjax/decision_making/utility_functions/probability_of_improvement.py,sha256=O_rHH1yR34JJlpAueSDJ_yo95fPI2aAGkwphS8snBYk,5220
29
- gpjax/decision_making/utility_functions/thompson_sampling.py,sha256=S-Yyn-9jsKkaXTvKFBP4sG_eCCKApGbHao5RR5tqXAo,4353
30
16
  gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
31
17
  gpjax/kernels/base.py,sha256=abkj3zidsBs7YSkYEfjeJ5jTs1YyDCPoBM2ZzqaqrgI,11561
32
18
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
@@ -56,7 +42,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
56
42
  gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
57
43
  gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
58
44
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
59
- gpjax-0.9.4.dist-info/METADATA,sha256=Qx_Qv91sE7_Y-c9CGuF40QBFk2FjLW0Fo2SHqFAgQFQ,10010
60
- gpjax-0.9.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
61
- gpjax-0.9.4.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
62
- gpjax-0.9.4.dist-info/RECORD,,
45
+ gpjax-0.10.0.dist-info/METADATA,sha256=wZyZSD1p2t_K5m25TrvGJr6lTlfaUxoB12F-0f1d9Co,9970
46
+ gpjax-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
47
+ gpjax-0.10.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
48
+ gpjax-0.10.0.dist-info/RECORD,,
@@ -0,0 +1,19 @@
1
+ (C) Copyright 2019 Hewlett Packard Enterprise Development LP
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a
4
+ copy of this software and associated documentation files (the "Software"),
5
+ to deal in the Software without restriction, including without limitation
6
+ the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ and/or sell copies of the Software, and to permit persons to whom the
8
+ Software is furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included
11
+ in all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
17
+ OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
18
+ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
19
+ OTHER DEALINGS IN THE SOFTWARE.
@@ -1,63 +0,0 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from gpjax.decision_making.decision_maker import (
16
- AbstractDecisionMaker,
17
- UtilityDrivenDecisionMaker,
18
- )
19
- from gpjax.decision_making.posterior_handler import PosteriorHandler
20
- from gpjax.decision_making.search_space import (
21
- AbstractSearchSpace,
22
- ContinuousSearchSpace,
23
- )
24
- from gpjax.decision_making.test_functions import (
25
- AbstractContinuousTestFunction,
26
- Forrester,
27
- LogarithmicGoldsteinPrice,
28
- Quadratic,
29
- )
30
- from gpjax.decision_making.utility_functions import (
31
- AbstractSinglePointUtilityFunctionBuilder,
32
- AbstractUtilityFunctionBuilder,
33
- SinglePointUtilityFunction,
34
- ThompsonSampling,
35
- UtilityFunction,
36
- )
37
- from gpjax.decision_making.utility_maximizer import (
38
- AbstractSinglePointUtilityMaximizer,
39
- AbstractUtilityMaximizer,
40
- ContinuousSinglePointUtilityMaximizer,
41
- )
42
- from gpjax.decision_making.utils import build_function_evaluator
43
-
44
- __all__ = [
45
- "AbstractUtilityFunctionBuilder",
46
- "AbstractUtilityMaximizer",
47
- "AbstractDecisionMaker",
48
- "AbstractSearchSpace",
49
- "AbstractSinglePointUtilityFunctionBuilder",
50
- "AbstractSinglePointUtilityMaximizer",
51
- "UtilityFunction",
52
- "build_function_evaluator",
53
- "ContinuousSinglePointUtilityMaximizer",
54
- "ContinuousSearchSpace",
55
- "UtilityDrivenDecisionMaker",
56
- "AbstractContinuousTestFunction",
57
- "Forrester",
58
- "LogarithmicGoldsteinPrice",
59
- "PosteriorHandler",
60
- "Quadratic",
61
- "SinglePointUtilityFunction",
62
- "ThompsonSampling",
63
- ]
@@ -1,302 +0,0 @@
1
- # Copyright 2023 The GPJax Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from abc import (
16
- ABC,
17
- abstractmethod,
18
- )
19
- import copy
20
- from dataclasses import dataclass
21
-
22
- from beartype.typing import (
23
- Callable,
24
- Dict,
25
- List,
26
- Mapping,
27
- )
28
- import jax.numpy as jnp
29
- import jax.random as jr
30
-
31
- from gpjax.dataset import Dataset
32
- from gpjax.decision_making.posterior_handler import PosteriorHandler
33
- from gpjax.decision_making.search_space import AbstractSearchSpace
34
- from gpjax.decision_making.utility_functions import (
35
- AbstractUtilityFunctionBuilder,
36
- ThompsonSampling,
37
- )
38
- from gpjax.decision_making.utility_maximizer import AbstractUtilityMaximizer
39
- from gpjax.decision_making.utils import FunctionEvaluator
40
- from gpjax.gps import AbstractPosterior
41
- from gpjax.typing import (
42
- Array,
43
- Float,
44
- KeyArray,
45
- )
46
-
47
-
48
- @dataclass
49
- class AbstractDecisionMaker(ABC):
50
- """
51
- AbstractDecisionMaker abstract base class which handles the core decision making
52
- loop, where we sequentially decide on points to query our function of interest at.
53
- The decision making loop is split into two key steps, `ask` and `tell`. The `ask`
54
- step is typically used to decide which point to query next. The `tell` step is
55
- typically used to update models and datasets with newly queried points. These steps
56
- can be combined in a 'run' loop which alternates between asking which point to query
57
- next and telling the decision maker about the newly queried point having evaluated
58
- the black-box function of interest at this point.
59
-
60
- Attributes:
61
- search_space: Search space over which we can evaluate the function(s) of interest.
62
- posterior_handlers: dictionary of posterior handlers, which are used to update
63
- posteriors throughout the decision making loop. Note that the word `posteriors`
64
- is used for consistency with GPJax, but these objects are typically referred to
65
- as `models` in the model-based decision making literature. Tags are used to
66
- distinguish between posteriors. In a typical Bayesian optimisation setup one of
67
- the tags will be `OBJECTIVE`, defined in `decision_making.utils`.
68
- datasets: dictionary of datasets, which are augmented with
69
- observations throughout the decision making loop. In a typical setup they are
70
- also used to update the posteriors, using the `posterior_handlers`. Tags are used
71
- to distinguish datasets, and correspond to tags in `posterior_handlers`.
72
- key: JAX random key, used to generate random numbers.
73
- batch_size: Number of points to query at each step of the decision making
74
- loop. Note that `SinglePointUtilityFunction`s are only capable of generating
75
- one point to be queried at each iteration of the decision making loop.
76
- post_ask: List of functions to be executed after each ask step.
77
- post_tell: List of functions to be executed after each tell step.
78
- """
79
-
80
- search_space: AbstractSearchSpace
81
- posterior_handlers: Dict[str, PosteriorHandler]
82
- datasets: Dict[str, Dataset]
83
- key: KeyArray
84
- batch_size: int
85
- post_ask: List[
86
- Callable
87
- ] # Specific type is List[Callable[[AbstractDecisionMaker, Float[Array, ["B D"]]], None]] but causes Beartype issues
88
- post_tell: List[
89
- Callable
90
- ] # Specific type is List[Callable[[AbstractDecisionMaker], None]] but causes Beartype issues
91
-
92
- def __post_init__(self):
93
- """
94
- At initialisation we check that the posterior handlers and datasets are
95
- consistent (i.e. have the same tags), and then initialise the posteriors, optimizing them using the
96
- corresponding datasets.
97
- """
98
- self.datasets = copy.copy(
99
- self.datasets
100
- ) # Ensure initial datasets passed in to DecisionMaker are not mutated from within
101
-
102
- if self.batch_size < 1:
103
- raise ValueError(
104
- f"Batch size must be greater than 0, got {self.batch_size}."
105
- )
106
-
107
- # Check that posterior handlers and datasets are consistent
108
- if self.posterior_handlers.keys() != self.datasets.keys():
109
- raise ValueError(
110
- "Posterior handlers and datasets must have the same keys. "
111
- f"Got posterior handlers keys {self.posterior_handlers.keys()} and "
112
- f"datasets keys {self.datasets.keys()}."
113
- )
114
-
115
- # Initialize posteriors
116
- self.posteriors: Dict[str, AbstractPosterior] = {}
117
- for tag, posterior_handler in self.posterior_handlers.items():
118
- self.posteriors[tag] = posterior_handler.get_posterior(
119
- self.datasets[tag], optimize=True, key=self.key
120
- )
121
-
122
- @abstractmethod
123
- def ask(self, key: KeyArray) -> Float[Array, "B D"]:
124
- """
125
- Get the point(s) to be queried next.
126
-
127
- Args:
128
- key (KeyArray): JAX PRNG key for controlling random state.
129
-
130
- Returns:
131
- Float[Array, "1 D"]: Point to be queried next
132
- """
133
- raise NotImplementedError
134
-
135
- def tell(self, observation_datasets: Mapping[str, Dataset], key: KeyArray):
136
- """
137
- Add newly observed data to datasets and update the corresponding posteriors.
138
-
139
- Args:
140
- observation_datasets: dictionary of datasets containing new observations.
141
- Tags are used to distinguish datasets, and correspond to tags in
142
- `posterior_handlers` and `self.datasets`.
143
- key: JAX PRNG key for controlling random state.
144
- """
145
- if observation_datasets.keys() != self.datasets.keys():
146
- raise ValueError(
147
- "Observation datasets and existing datasets must have the same keys. "
148
- f"Got observation datasets keys {observation_datasets.keys()} and "
149
- f"existing datasets keys {self.datasets.keys()}."
150
- )
151
-
152
- for tag, observation_dataset in observation_datasets.items():
153
- self.datasets[tag] += observation_dataset
154
-
155
- for tag, posterior_handler in self.posterior_handlers.items():
156
- key, _ = jr.split(key)
157
- self.posteriors[tag] = posterior_handler.update_posterior(
158
- self.datasets[tag], self.posteriors[tag], optimize=True, key=key
159
- )
160
-
161
- def run(
162
- self, n_steps: int, black_box_function_evaluator: FunctionEvaluator
163
- ) -> Mapping[str, Dataset]:
164
- """
165
- Run the decision making loop continuously for for `n_steps`. This is broken down
166
- into three main steps:
167
- 1. Call the `ask` method to get the point to be queried next.
168
- 2. Call the `black_box_function_evaluator` to evaluate the black box functions
169
- of interest at the point chosen to be queried.
170
- 3. Call the `tell` method to update the datasets and posteriors with the newly
171
- observed data.
172
-
173
- In addition to this, after the `ask` step, the functions in the `post_ask` list
174
- are executed, taking as arguments the decision maker and the point chosen to be
175
- queried next. Similarly, after the `tell` step, the functions in the `post_tell`
176
- list are executed, taking the decision maker as the sole argument.
177
-
178
- Args:
179
- n_steps (int): Number of steps to run the decision making loop for.
180
- black_box_function_evaluator (FunctionEvaluator): Function evaluator which
181
- evaluates the black box functions of interest at supplied points.
182
-
183
- Returns:
184
- Mapping[str, Dataset]: Dictionary of datasets containing the observations
185
- made throughout the decision making loop, as well as the initial data
186
- supplied when initialising the `DecisionMaker`.
187
- """
188
- for _ in range(n_steps):
189
- query_point = self.ask(self.key)
190
-
191
- for post_ask_method in self.post_ask:
192
- post_ask_method(self, query_point)
193
-
194
- self.key, _ = jr.split(self.key)
195
- observation_datasets = black_box_function_evaluator(query_point)
196
- self.tell(observation_datasets, self.key)
197
-
198
- for post_tell_method in self.post_tell:
199
- post_tell_method(self)
200
-
201
- return self.datasets
202
-
203
-
204
- @dataclass
205
- class UtilityDrivenDecisionMaker(AbstractDecisionMaker):
206
- """
207
- UtilityDrivenDecisionMaker class which handles the core decision making loop in a
208
- typical model-based decision making setup. In this setup we use surrogate model(s)
209
- for the function(s) of interest, and define a utility function (often called the
210
- 'acquisition function' in the context of Bayesian optimisation) which characterises
211
- how useful it would be to query a given point within the search space given the data
212
- we have observed so far. This can then be used to decide which point(s) to query
213
- next.
214
-
215
- The decision making loop is split into two key steps, `ask` and `tell`. The `ask`
216
- step forms a `UtilityFunction` from the current `posteriors` and `datasets` and
217
- returns the point which maximises it. It also stores the formed utility function
218
- under the attribute `self.current_utility_function` so that it can be called,
219
- for instance for plotting, after the `ask` function has been called. The `tell` step
220
- adds a newly queried point to the `datasets` and updates the `posteriors`.
221
-
222
- This can be run as a typical ask-tell loop, or the `run` method can be used to run
223
- the decision making loop for a fixed number of steps. Moreover, the `run` method executes
224
- the functions in `post_ask` and `post_tell` after each ask and tell step
225
- respectively. This enables the user to add custom functionality, such as the ability
226
- to plot values of interest during the optimization process.
227
-
228
- Attributes:
229
- utility_function_builder (AbstractUtilityFunctionBuilder): Object which
230
- builds utility functions from posteriors and datasets, to decide where
231
- to query next. In a typical Bayesian optimisation setup the point chosen to
232
- be queried next is the point which maximizes the utility function.
233
- utility_maximizer (AbstractUtilityMaximizer): Object which maximizes
234
- utility functions over the search space.
235
- """
236
-
237
- utility_function_builder: AbstractUtilityFunctionBuilder
238
- utility_maximizer: AbstractUtilityMaximizer
239
-
240
- def __post_init__(self):
241
- super().__post_init__()
242
- if self.batch_size > 1 and not isinstance(
243
- self.utility_function_builder, ThompsonSampling
244
- ):
245
- raise NotImplementedError(
246
- "Batch size > 1 currently only supported for Thompson sampling."
247
- )
248
-
249
- def ask(self, key: KeyArray) -> Float[Array, "B D"]:
250
- """
251
- Get updated utility function(s) and return the point(s) which maximises it/them. This
252
- method also stores the utility function(s) in
253
- `self.current_utility_functions` so that they can be accessed after the ask
254
- function has been called. This is useful for non-deterministic utility
255
- functions, which may differ between calls to `ask` due to the splitting of
256
- `self.key`.
257
-
258
- Note that in general `SinglePointUtilityFunction`s are only capable of
259
- generating one point to be queried at each iteration of the decision making loop
260
- (i.e. `self.batch_size` must be 1). However, Thompson sampling can be used in a
261
- batched setting by drawing a batch of different samples from the GP posterior.
262
- This is done by calling `build_utility_function` with different keys
263
- sequentilly, and optimising each of these individual samples in sequence in
264
- order to obtain `self.batch_size` points to query next.
265
-
266
- Args:
267
- key (KeyArray): JAX PRNG key for controlling random state.
268
-
269
- Returns:
270
- Float[Array, "B D"]: Point(s) to be queried next.
271
- """
272
- self.current_utility_functions = []
273
- maximizers = []
274
- # We currently only allow Thompson sampling to be run with batch size > 1. More
275
- # batched utility functions may be added in the future.
276
- if isinstance(self.utility_function_builder, ThompsonSampling) or (
277
- (not isinstance(self.utility_function_builder, ThompsonSampling))
278
- and (self.batch_size == 1)
279
- ):
280
- # Draw 'self.batch_size' Thompson samples and optimize each of them in order to
281
- # obtain 'self.batch_size' points to query next.
282
- for _ in range(self.batch_size):
283
- decision_function = (
284
- self.utility_function_builder.build_utility_function(
285
- self.posteriors, self.datasets, key
286
- )
287
- )
288
- self.current_utility_functions.append(decision_function)
289
-
290
- _, key = jr.split(key)
291
- maximizer = self.utility_maximizer.maximize(
292
- decision_function, self.search_space, key
293
- )
294
- maximizers.append(maximizer)
295
- _, key = jr.split(key)
296
-
297
- maximizers = jnp.concatenate(maximizers)
298
- return maximizers
299
- else:
300
- raise NotImplementedError(
301
- "Only Thompson sampling currently supports batch size > 1."
302
- )