gpjax 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -3
- gpjax/citation.py +0 -43
- gpjax/distributions.py +3 -1
- {gpjax-0.9.4.dist-info → gpjax-0.9.5.dist-info}/METADATA +3 -4
- {gpjax-0.9.4.dist-info → gpjax-0.9.5.dist-info}/RECORD +7 -21
- gpjax-0.9.5.dist-info/licenses/LICENSE.txt +19 -0
- gpjax/decision_making/__init__.py +0 -63
- gpjax/decision_making/decision_maker.py +0 -302
- gpjax/decision_making/posterior_handler.py +0 -152
- gpjax/decision_making/search_space.py +0 -96
- gpjax/decision_making/test_functions/__init__.py +0 -31
- gpjax/decision_making/test_functions/continuous_functions.py +0 -169
- gpjax/decision_making/test_functions/non_conjugate_functions.py +0 -90
- gpjax/decision_making/utility_functions/__init__.py +0 -37
- gpjax/decision_making/utility_functions/base.py +0 -106
- gpjax/decision_making/utility_functions/expected_improvement.py +0 -112
- gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -125
- gpjax/decision_making/utility_functions/thompson_sampling.py +0 -101
- gpjax/decision_making/utility_maximizer.py +0 -157
- gpjax/decision_making/utils.py +0 -64
- gpjax-0.9.4.dist-info/licenses/LICENSE +0 -201
- {gpjax-0.9.4.dist-info → gpjax-0.9.5.dist-info}/WHEEL +0 -0
gpjax/__init__.py
CHANGED
|
@@ -19,7 +19,6 @@ from beartype.roar import BeartypeDecorHintPep585DeprecationWarning
|
|
|
19
19
|
filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning)
|
|
20
20
|
|
|
21
21
|
from gpjax import (
|
|
22
|
-
decision_making,
|
|
23
22
|
gps,
|
|
24
23
|
integrators,
|
|
25
24
|
kernels,
|
|
@@ -40,11 +39,10 @@ __license__ = "MIT"
|
|
|
40
39
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
41
40
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
41
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.9.
|
|
42
|
+
__version__ = "0.9.5"
|
|
44
43
|
|
|
45
44
|
__all__ = [
|
|
46
45
|
"base",
|
|
47
|
-
"decision_making",
|
|
48
46
|
"gps",
|
|
49
47
|
"integrators",
|
|
50
48
|
"kernels",
|
gpjax/citation.py
CHANGED
|
@@ -10,11 +10,6 @@ from beartype.typing import (
|
|
|
10
10
|
)
|
|
11
11
|
from jaxlib.xla_extension import PjitFunction
|
|
12
12
|
|
|
13
|
-
from gpjax.decision_making.test_functions import (
|
|
14
|
-
Forrester,
|
|
15
|
-
LogarithmicGoldsteinPrice,
|
|
16
|
-
)
|
|
17
|
-
from gpjax.decision_making.utility_functions import ThompsonSampling
|
|
18
13
|
from gpjax.kernels import (
|
|
19
14
|
RFF,
|
|
20
15
|
ArcCosine,
|
|
@@ -149,41 +144,3 @@ def _(tree) -> PaperCitation:
|
|
|
149
144
|
booktitle="Advances in neural information processing systems",
|
|
150
145
|
citation_type="article",
|
|
151
146
|
)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
####################
|
|
155
|
-
# Decision making citations
|
|
156
|
-
####################
|
|
157
|
-
@cite.register(ThompsonSampling)
|
|
158
|
-
def _(tree) -> PaperCitation:
|
|
159
|
-
return PaperCitation(
|
|
160
|
-
citation_key="wilson2020efficiently",
|
|
161
|
-
title="Efficiently sampling functions from Gaussian process posteriors",
|
|
162
|
-
authors="Wilson, James and Borovitskiy, Viacheslav and Terenin, Alexander and Mostowsky, Peter and Deisenroth, Marc",
|
|
163
|
-
year="2020",
|
|
164
|
-
booktitle="International Conference on Machine Learning",
|
|
165
|
-
citation_type="article",
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
@cite.register(Forrester)
|
|
170
|
-
def _(tree) -> BookCitation:
|
|
171
|
-
return BookCitation(
|
|
172
|
-
citation_key="forrester2008engineering",
|
|
173
|
-
authors="Forrester, Alexander and Sobester, Andras and Keane, Andy",
|
|
174
|
-
title="Engineering design via surrogate modelling: a practical guide",
|
|
175
|
-
year="2008",
|
|
176
|
-
publisher="John Wiley & Sons",
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
@cite.register(LogarithmicGoldsteinPrice)
|
|
181
|
-
def _(tree) -> PaperCitation:
|
|
182
|
-
return PaperCitation(
|
|
183
|
-
citation_key="picheny2013benchmark",
|
|
184
|
-
authors="Picheny, Victor and Wagner, Tobias and Ginsbourger, David",
|
|
185
|
-
title="A benchmark of kriging-based infill criteria for noisy optimization",
|
|
186
|
-
year="2013",
|
|
187
|
-
booktitle="Structural and multidisciplinary optimization",
|
|
188
|
-
citation_type="article",
|
|
189
|
-
)
|
gpjax/distributions.py
CHANGED
|
@@ -162,7 +162,9 @@ class GaussianDistribution(tfd.Distribution):
|
|
|
162
162
|
|
|
163
163
|
return vmap(affine_transformation)(Z)
|
|
164
164
|
|
|
165
|
-
def sample(
|
|
165
|
+
def sample(
|
|
166
|
+
self, seed: KeyArray, sample_shape: Tuple[int, ...]
|
|
167
|
+
): # pylint: disable=useless-super-delegation
|
|
166
168
|
r"""See `Distribution.sample`."""
|
|
167
169
|
return self._sample_n(
|
|
168
170
|
seed, sample_shape[0]
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.9.
|
|
3
|
+
Version: 0.9.5
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
7
7
|
Project-URL: Source, https://github.com/JaxGaussianProcesses/GPJax
|
|
8
8
|
Author-email: Thomas Pinder <tompinder@live.co.uk>
|
|
9
|
-
License
|
|
10
|
-
License-File: LICENSE
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE.txt
|
|
11
11
|
Keywords: gaussian-processes jax machine-learning bayesian
|
|
12
12
|
Classifier: Development Status :: 4 - Beta
|
|
13
13
|
Classifier: Programming Language :: Python
|
|
@@ -22,7 +22,6 @@ Requires-Dist: cola-ml==0.0.5
|
|
|
22
22
|
Requires-Dist: flax<0.10.0
|
|
23
23
|
Requires-Dist: jax<0.4.28
|
|
24
24
|
Requires-Dist: jaxlib<0.4.28
|
|
25
|
-
Requires-Dist: jaxopt==0.8.2
|
|
26
25
|
Requires-Dist: jaxtyping>0.2.10
|
|
27
26
|
Requires-Dist: numpy<2.0.0
|
|
28
27
|
Requires-Dist: optax>0.2.1
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
2
|
-
gpjax/citation.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=T-2EbsNxg5VcdTeSH_G-mWwNcMTJVqbdI55gl9HMvG8,1653
|
|
2
|
+
gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
|
-
gpjax/distributions.py,sha256=
|
|
4
|
+
gpjax/distributions.py,sha256=X48FJr3reop9maherdMVt7-XZOm2f26T8AJt_IKM_oE,9339
|
|
5
5
|
gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
|
|
6
6
|
gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
@@ -13,20 +13,6 @@ gpjax/parameters.py,sha256=Z4Wy3gEzPZG23-dtqC437_ZWnd_sPe9LcLCKn21ZBvA,4886
|
|
|
13
13
|
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
14
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
15
|
gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
|
|
16
|
-
gpjax/decision_making/__init__.py,sha256=SDuPQl80lJ7nhfRsiB_7c22wCMiQO5ehSNohxUGnB7w,2170
|
|
17
|
-
gpjax/decision_making/decision_maker.py,sha256=S4pOXrWcEHy0NDA0gfWzhk7pG0NJfaPpMXvq03yTy0g,13915
|
|
18
|
-
gpjax/decision_making/posterior_handler.py,sha256=UgXf1Gu7GMh2YDSmiSWJIzmWlFW06KTS44HYz3mazZQ,5905
|
|
19
|
-
gpjax/decision_making/search_space.py,sha256=bXwtMOhHZ2klnABpXm5Raxe7b0NTRDjo_cN3ecbk53Y,3545
|
|
20
|
-
gpjax/decision_making/utility_maximizer.py,sha256=VT2amwSJbB64IL_MiWNl9ZgjcqO757qK6NW2gUBKsqs,5965
|
|
21
|
-
gpjax/decision_making/utils.py,sha256=5j1GO5kcmG2laZR39NjhqgEjRekAWWzrnREv_5Zct_Y,2367
|
|
22
|
-
gpjax/decision_making/test_functions/__init__.py,sha256=GDCY9_kaAnxDWwzo1FkdxnDx-80MErAHchbGybT9xYs,1109
|
|
23
|
-
gpjax/decision_making/test_functions/continuous_functions.py,sha256=oL5ZQkvmbC3u9rEvSYI2DRAN3r7Ynf7wRZQlUWjKjt0,5612
|
|
24
|
-
gpjax/decision_making/test_functions/non_conjugate_functions.py,sha256=eJpCnTS9dRieLxpjH4L6OTsP-w9JM3XhjnzCfk2Xqn8,2957
|
|
25
|
-
gpjax/decision_making/utility_functions/__init__.py,sha256=xXI-4JKWAfTJ7XZ1vRDpqtb91MNzSPD0lP6xo0tOc7o,1445
|
|
26
|
-
gpjax/decision_making/utility_functions/base.py,sha256=FOqrsRDmtHiCVl6IHr12-AEYBLStzMT5EBs-F92e1Og,3882
|
|
27
|
-
gpjax/decision_making/utility_functions/expected_improvement.py,sha256=H6hjC-lj1oiHf2BomeQqroORQ7vtcOngiDAWxRwkNbg,4481
|
|
28
|
-
gpjax/decision_making/utility_functions/probability_of_improvement.py,sha256=O_rHH1yR34JJlpAueSDJ_yo95fPI2aAGkwphS8snBYk,5220
|
|
29
|
-
gpjax/decision_making/utility_functions/thompson_sampling.py,sha256=S-Yyn-9jsKkaXTvKFBP4sG_eCCKApGbHao5RR5tqXAo,4353
|
|
30
16
|
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
31
17
|
gpjax/kernels/base.py,sha256=abkj3zidsBs7YSkYEfjeJ5jTs1YyDCPoBM2ZzqaqrgI,11561
|
|
32
18
|
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
@@ -56,7 +42,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
|
|
|
56
42
|
gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
|
|
57
43
|
gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
|
|
58
44
|
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
59
|
-
gpjax-0.9.
|
|
60
|
-
gpjax-0.9.
|
|
61
|
-
gpjax-0.9.
|
|
62
|
-
gpjax-0.9.
|
|
45
|
+
gpjax-0.9.5.dist-info/METADATA,sha256=T-OvGAyBe1N_QW6F9RbV-sx8wBJSAYQjpildBdhotS0,9967
|
|
46
|
+
gpjax-0.9.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
47
|
+
gpjax-0.9.5.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
48
|
+
gpjax-0.9.5.dist-info/RECORD,,
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
(C) Copyright 2019 Hewlett Packard Enterprise Development LP
|
|
2
|
+
|
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a
|
|
4
|
+
copy of this software and associated documentation files (the "Software"),
|
|
5
|
+
to deal in the Software without restriction, including without limitation
|
|
6
|
+
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
7
|
+
and/or sell copies of the Software, and to permit persons to whom the
|
|
8
|
+
Software is furnished to do so, subject to the following conditions:
|
|
9
|
+
|
|
10
|
+
The above copyright notice and this permission notice shall be included
|
|
11
|
+
in all copies or substantial portions of the Software.
|
|
12
|
+
|
|
13
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
15
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
|
16
|
+
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
|
17
|
+
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
|
18
|
+
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
|
19
|
+
OTHER DEALINGS IN THE SOFTWARE.
|
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from gpjax.decision_making.decision_maker import (
|
|
16
|
-
AbstractDecisionMaker,
|
|
17
|
-
UtilityDrivenDecisionMaker,
|
|
18
|
-
)
|
|
19
|
-
from gpjax.decision_making.posterior_handler import PosteriorHandler
|
|
20
|
-
from gpjax.decision_making.search_space import (
|
|
21
|
-
AbstractSearchSpace,
|
|
22
|
-
ContinuousSearchSpace,
|
|
23
|
-
)
|
|
24
|
-
from gpjax.decision_making.test_functions import (
|
|
25
|
-
AbstractContinuousTestFunction,
|
|
26
|
-
Forrester,
|
|
27
|
-
LogarithmicGoldsteinPrice,
|
|
28
|
-
Quadratic,
|
|
29
|
-
)
|
|
30
|
-
from gpjax.decision_making.utility_functions import (
|
|
31
|
-
AbstractSinglePointUtilityFunctionBuilder,
|
|
32
|
-
AbstractUtilityFunctionBuilder,
|
|
33
|
-
SinglePointUtilityFunction,
|
|
34
|
-
ThompsonSampling,
|
|
35
|
-
UtilityFunction,
|
|
36
|
-
)
|
|
37
|
-
from gpjax.decision_making.utility_maximizer import (
|
|
38
|
-
AbstractSinglePointUtilityMaximizer,
|
|
39
|
-
AbstractUtilityMaximizer,
|
|
40
|
-
ContinuousSinglePointUtilityMaximizer,
|
|
41
|
-
)
|
|
42
|
-
from gpjax.decision_making.utils import build_function_evaluator
|
|
43
|
-
|
|
44
|
-
__all__ = [
|
|
45
|
-
"AbstractUtilityFunctionBuilder",
|
|
46
|
-
"AbstractUtilityMaximizer",
|
|
47
|
-
"AbstractDecisionMaker",
|
|
48
|
-
"AbstractSearchSpace",
|
|
49
|
-
"AbstractSinglePointUtilityFunctionBuilder",
|
|
50
|
-
"AbstractSinglePointUtilityMaximizer",
|
|
51
|
-
"UtilityFunction",
|
|
52
|
-
"build_function_evaluator",
|
|
53
|
-
"ContinuousSinglePointUtilityMaximizer",
|
|
54
|
-
"ContinuousSearchSpace",
|
|
55
|
-
"UtilityDrivenDecisionMaker",
|
|
56
|
-
"AbstractContinuousTestFunction",
|
|
57
|
-
"Forrester",
|
|
58
|
-
"LogarithmicGoldsteinPrice",
|
|
59
|
-
"PosteriorHandler",
|
|
60
|
-
"Quadratic",
|
|
61
|
-
"SinglePointUtilityFunction",
|
|
62
|
-
"ThompsonSampling",
|
|
63
|
-
]
|
|
@@ -1,302 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from abc import (
|
|
16
|
-
ABC,
|
|
17
|
-
abstractmethod,
|
|
18
|
-
)
|
|
19
|
-
import copy
|
|
20
|
-
from dataclasses import dataclass
|
|
21
|
-
|
|
22
|
-
from beartype.typing import (
|
|
23
|
-
Callable,
|
|
24
|
-
Dict,
|
|
25
|
-
List,
|
|
26
|
-
Mapping,
|
|
27
|
-
)
|
|
28
|
-
import jax.numpy as jnp
|
|
29
|
-
import jax.random as jr
|
|
30
|
-
|
|
31
|
-
from gpjax.dataset import Dataset
|
|
32
|
-
from gpjax.decision_making.posterior_handler import PosteriorHandler
|
|
33
|
-
from gpjax.decision_making.search_space import AbstractSearchSpace
|
|
34
|
-
from gpjax.decision_making.utility_functions import (
|
|
35
|
-
AbstractUtilityFunctionBuilder,
|
|
36
|
-
ThompsonSampling,
|
|
37
|
-
)
|
|
38
|
-
from gpjax.decision_making.utility_maximizer import AbstractUtilityMaximizer
|
|
39
|
-
from gpjax.decision_making.utils import FunctionEvaluator
|
|
40
|
-
from gpjax.gps import AbstractPosterior
|
|
41
|
-
from gpjax.typing import (
|
|
42
|
-
Array,
|
|
43
|
-
Float,
|
|
44
|
-
KeyArray,
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@dataclass
|
|
49
|
-
class AbstractDecisionMaker(ABC):
|
|
50
|
-
"""
|
|
51
|
-
AbstractDecisionMaker abstract base class which handles the core decision making
|
|
52
|
-
loop, where we sequentially decide on points to query our function of interest at.
|
|
53
|
-
The decision making loop is split into two key steps, `ask` and `tell`. The `ask`
|
|
54
|
-
step is typically used to decide which point to query next. The `tell` step is
|
|
55
|
-
typically used to update models and datasets with newly queried points. These steps
|
|
56
|
-
can be combined in a 'run' loop which alternates between asking which point to query
|
|
57
|
-
next and telling the decision maker about the newly queried point having evaluated
|
|
58
|
-
the black-box function of interest at this point.
|
|
59
|
-
|
|
60
|
-
Attributes:
|
|
61
|
-
search_space: Search space over which we can evaluate the function(s) of interest.
|
|
62
|
-
posterior_handlers: dictionary of posterior handlers, which are used to update
|
|
63
|
-
posteriors throughout the decision making loop. Note that the word `posteriors`
|
|
64
|
-
is used for consistency with GPJax, but these objects are typically referred to
|
|
65
|
-
as `models` in the model-based decision making literature. Tags are used to
|
|
66
|
-
distinguish between posteriors. In a typical Bayesian optimisation setup one of
|
|
67
|
-
the tags will be `OBJECTIVE`, defined in `decision_making.utils`.
|
|
68
|
-
datasets: dictionary of datasets, which are augmented with
|
|
69
|
-
observations throughout the decision making loop. In a typical setup they are
|
|
70
|
-
also used to update the posteriors, using the `posterior_handlers`. Tags are used
|
|
71
|
-
to distinguish datasets, and correspond to tags in `posterior_handlers`.
|
|
72
|
-
key: JAX random key, used to generate random numbers.
|
|
73
|
-
batch_size: Number of points to query at each step of the decision making
|
|
74
|
-
loop. Note that `SinglePointUtilityFunction`s are only capable of generating
|
|
75
|
-
one point to be queried at each iteration of the decision making loop.
|
|
76
|
-
post_ask: List of functions to be executed after each ask step.
|
|
77
|
-
post_tell: List of functions to be executed after each tell step.
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
|
-
search_space: AbstractSearchSpace
|
|
81
|
-
posterior_handlers: Dict[str, PosteriorHandler]
|
|
82
|
-
datasets: Dict[str, Dataset]
|
|
83
|
-
key: KeyArray
|
|
84
|
-
batch_size: int
|
|
85
|
-
post_ask: List[
|
|
86
|
-
Callable
|
|
87
|
-
] # Specific type is List[Callable[[AbstractDecisionMaker, Float[Array, ["B D"]]], None]] but causes Beartype issues
|
|
88
|
-
post_tell: List[
|
|
89
|
-
Callable
|
|
90
|
-
] # Specific type is List[Callable[[AbstractDecisionMaker], None]] but causes Beartype issues
|
|
91
|
-
|
|
92
|
-
def __post_init__(self):
|
|
93
|
-
"""
|
|
94
|
-
At initialisation we check that the posterior handlers and datasets are
|
|
95
|
-
consistent (i.e. have the same tags), and then initialise the posteriors, optimizing them using the
|
|
96
|
-
corresponding datasets.
|
|
97
|
-
"""
|
|
98
|
-
self.datasets = copy.copy(
|
|
99
|
-
self.datasets
|
|
100
|
-
) # Ensure initial datasets passed in to DecisionMaker are not mutated from within
|
|
101
|
-
|
|
102
|
-
if self.batch_size < 1:
|
|
103
|
-
raise ValueError(
|
|
104
|
-
f"Batch size must be greater than 0, got {self.batch_size}."
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
# Check that posterior handlers and datasets are consistent
|
|
108
|
-
if self.posterior_handlers.keys() != self.datasets.keys():
|
|
109
|
-
raise ValueError(
|
|
110
|
-
"Posterior handlers and datasets must have the same keys. "
|
|
111
|
-
f"Got posterior handlers keys {self.posterior_handlers.keys()} and "
|
|
112
|
-
f"datasets keys {self.datasets.keys()}."
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
# Initialize posteriors
|
|
116
|
-
self.posteriors: Dict[str, AbstractPosterior] = {}
|
|
117
|
-
for tag, posterior_handler in self.posterior_handlers.items():
|
|
118
|
-
self.posteriors[tag] = posterior_handler.get_posterior(
|
|
119
|
-
self.datasets[tag], optimize=True, key=self.key
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
@abstractmethod
|
|
123
|
-
def ask(self, key: KeyArray) -> Float[Array, "B D"]:
|
|
124
|
-
"""
|
|
125
|
-
Get the point(s) to be queried next.
|
|
126
|
-
|
|
127
|
-
Args:
|
|
128
|
-
key (KeyArray): JAX PRNG key for controlling random state.
|
|
129
|
-
|
|
130
|
-
Returns:
|
|
131
|
-
Float[Array, "1 D"]: Point to be queried next
|
|
132
|
-
"""
|
|
133
|
-
raise NotImplementedError
|
|
134
|
-
|
|
135
|
-
def tell(self, observation_datasets: Mapping[str, Dataset], key: KeyArray):
|
|
136
|
-
"""
|
|
137
|
-
Add newly observed data to datasets and update the corresponding posteriors.
|
|
138
|
-
|
|
139
|
-
Args:
|
|
140
|
-
observation_datasets: dictionary of datasets containing new observations.
|
|
141
|
-
Tags are used to distinguish datasets, and correspond to tags in
|
|
142
|
-
`posterior_handlers` and `self.datasets`.
|
|
143
|
-
key: JAX PRNG key for controlling random state.
|
|
144
|
-
"""
|
|
145
|
-
if observation_datasets.keys() != self.datasets.keys():
|
|
146
|
-
raise ValueError(
|
|
147
|
-
"Observation datasets and existing datasets must have the same keys. "
|
|
148
|
-
f"Got observation datasets keys {observation_datasets.keys()} and "
|
|
149
|
-
f"existing datasets keys {self.datasets.keys()}."
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
for tag, observation_dataset in observation_datasets.items():
|
|
153
|
-
self.datasets[tag] += observation_dataset
|
|
154
|
-
|
|
155
|
-
for tag, posterior_handler in self.posterior_handlers.items():
|
|
156
|
-
key, _ = jr.split(key)
|
|
157
|
-
self.posteriors[tag] = posterior_handler.update_posterior(
|
|
158
|
-
self.datasets[tag], self.posteriors[tag], optimize=True, key=key
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
def run(
|
|
162
|
-
self, n_steps: int, black_box_function_evaluator: FunctionEvaluator
|
|
163
|
-
) -> Mapping[str, Dataset]:
|
|
164
|
-
"""
|
|
165
|
-
Run the decision making loop continuously for for `n_steps`. This is broken down
|
|
166
|
-
into three main steps:
|
|
167
|
-
1. Call the `ask` method to get the point to be queried next.
|
|
168
|
-
2. Call the `black_box_function_evaluator` to evaluate the black box functions
|
|
169
|
-
of interest at the point chosen to be queried.
|
|
170
|
-
3. Call the `tell` method to update the datasets and posteriors with the newly
|
|
171
|
-
observed data.
|
|
172
|
-
|
|
173
|
-
In addition to this, after the `ask` step, the functions in the `post_ask` list
|
|
174
|
-
are executed, taking as arguments the decision maker and the point chosen to be
|
|
175
|
-
queried next. Similarly, after the `tell` step, the functions in the `post_tell`
|
|
176
|
-
list are executed, taking the decision maker as the sole argument.
|
|
177
|
-
|
|
178
|
-
Args:
|
|
179
|
-
n_steps (int): Number of steps to run the decision making loop for.
|
|
180
|
-
black_box_function_evaluator (FunctionEvaluator): Function evaluator which
|
|
181
|
-
evaluates the black box functions of interest at supplied points.
|
|
182
|
-
|
|
183
|
-
Returns:
|
|
184
|
-
Mapping[str, Dataset]: Dictionary of datasets containing the observations
|
|
185
|
-
made throughout the decision making loop, as well as the initial data
|
|
186
|
-
supplied when initialising the `DecisionMaker`.
|
|
187
|
-
"""
|
|
188
|
-
for _ in range(n_steps):
|
|
189
|
-
query_point = self.ask(self.key)
|
|
190
|
-
|
|
191
|
-
for post_ask_method in self.post_ask:
|
|
192
|
-
post_ask_method(self, query_point)
|
|
193
|
-
|
|
194
|
-
self.key, _ = jr.split(self.key)
|
|
195
|
-
observation_datasets = black_box_function_evaluator(query_point)
|
|
196
|
-
self.tell(observation_datasets, self.key)
|
|
197
|
-
|
|
198
|
-
for post_tell_method in self.post_tell:
|
|
199
|
-
post_tell_method(self)
|
|
200
|
-
|
|
201
|
-
return self.datasets
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
@dataclass
|
|
205
|
-
class UtilityDrivenDecisionMaker(AbstractDecisionMaker):
|
|
206
|
-
"""
|
|
207
|
-
UtilityDrivenDecisionMaker class which handles the core decision making loop in a
|
|
208
|
-
typical model-based decision making setup. In this setup we use surrogate model(s)
|
|
209
|
-
for the function(s) of interest, and define a utility function (often called the
|
|
210
|
-
'acquisition function' in the context of Bayesian optimisation) which characterises
|
|
211
|
-
how useful it would be to query a given point within the search space given the data
|
|
212
|
-
we have observed so far. This can then be used to decide which point(s) to query
|
|
213
|
-
next.
|
|
214
|
-
|
|
215
|
-
The decision making loop is split into two key steps, `ask` and `tell`. The `ask`
|
|
216
|
-
step forms a `UtilityFunction` from the current `posteriors` and `datasets` and
|
|
217
|
-
returns the point which maximises it. It also stores the formed utility function
|
|
218
|
-
under the attribute `self.current_utility_function` so that it can be called,
|
|
219
|
-
for instance for plotting, after the `ask` function has been called. The `tell` step
|
|
220
|
-
adds a newly queried point to the `datasets` and updates the `posteriors`.
|
|
221
|
-
|
|
222
|
-
This can be run as a typical ask-tell loop, or the `run` method can be used to run
|
|
223
|
-
the decision making loop for a fixed number of steps. Moreover, the `run` method executes
|
|
224
|
-
the functions in `post_ask` and `post_tell` after each ask and tell step
|
|
225
|
-
respectively. This enables the user to add custom functionality, such as the ability
|
|
226
|
-
to plot values of interest during the optimization process.
|
|
227
|
-
|
|
228
|
-
Attributes:
|
|
229
|
-
utility_function_builder (AbstractUtilityFunctionBuilder): Object which
|
|
230
|
-
builds utility functions from posteriors and datasets, to decide where
|
|
231
|
-
to query next. In a typical Bayesian optimisation setup the point chosen to
|
|
232
|
-
be queried next is the point which maximizes the utility function.
|
|
233
|
-
utility_maximizer (AbstractUtilityMaximizer): Object which maximizes
|
|
234
|
-
utility functions over the search space.
|
|
235
|
-
"""
|
|
236
|
-
|
|
237
|
-
utility_function_builder: AbstractUtilityFunctionBuilder
|
|
238
|
-
utility_maximizer: AbstractUtilityMaximizer
|
|
239
|
-
|
|
240
|
-
def __post_init__(self):
|
|
241
|
-
super().__post_init__()
|
|
242
|
-
if self.batch_size > 1 and not isinstance(
|
|
243
|
-
self.utility_function_builder, ThompsonSampling
|
|
244
|
-
):
|
|
245
|
-
raise NotImplementedError(
|
|
246
|
-
"Batch size > 1 currently only supported for Thompson sampling."
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
def ask(self, key: KeyArray) -> Float[Array, "B D"]:
|
|
250
|
-
"""
|
|
251
|
-
Get updated utility function(s) and return the point(s) which maximises it/them. This
|
|
252
|
-
method also stores the utility function(s) in
|
|
253
|
-
`self.current_utility_functions` so that they can be accessed after the ask
|
|
254
|
-
function has been called. This is useful for non-deterministic utility
|
|
255
|
-
functions, which may differ between calls to `ask` due to the splitting of
|
|
256
|
-
`self.key`.
|
|
257
|
-
|
|
258
|
-
Note that in general `SinglePointUtilityFunction`s are only capable of
|
|
259
|
-
generating one point to be queried at each iteration of the decision making loop
|
|
260
|
-
(i.e. `self.batch_size` must be 1). However, Thompson sampling can be used in a
|
|
261
|
-
batched setting by drawing a batch of different samples from the GP posterior.
|
|
262
|
-
This is done by calling `build_utility_function` with different keys
|
|
263
|
-
sequentilly, and optimising each of these individual samples in sequence in
|
|
264
|
-
order to obtain `self.batch_size` points to query next.
|
|
265
|
-
|
|
266
|
-
Args:
|
|
267
|
-
key (KeyArray): JAX PRNG key for controlling random state.
|
|
268
|
-
|
|
269
|
-
Returns:
|
|
270
|
-
Float[Array, "B D"]: Point(s) to be queried next.
|
|
271
|
-
"""
|
|
272
|
-
self.current_utility_functions = []
|
|
273
|
-
maximizers = []
|
|
274
|
-
# We currently only allow Thompson sampling to be run with batch size > 1. More
|
|
275
|
-
# batched utility functions may be added in the future.
|
|
276
|
-
if isinstance(self.utility_function_builder, ThompsonSampling) or (
|
|
277
|
-
(not isinstance(self.utility_function_builder, ThompsonSampling))
|
|
278
|
-
and (self.batch_size == 1)
|
|
279
|
-
):
|
|
280
|
-
# Draw 'self.batch_size' Thompson samples and optimize each of them in order to
|
|
281
|
-
# obtain 'self.batch_size' points to query next.
|
|
282
|
-
for _ in range(self.batch_size):
|
|
283
|
-
decision_function = (
|
|
284
|
-
self.utility_function_builder.build_utility_function(
|
|
285
|
-
self.posteriors, self.datasets, key
|
|
286
|
-
)
|
|
287
|
-
)
|
|
288
|
-
self.current_utility_functions.append(decision_function)
|
|
289
|
-
|
|
290
|
-
_, key = jr.split(key)
|
|
291
|
-
maximizer = self.utility_maximizer.maximize(
|
|
292
|
-
decision_function, self.search_space, key
|
|
293
|
-
)
|
|
294
|
-
maximizers.append(maximizer)
|
|
295
|
-
_, key = jr.split(key)
|
|
296
|
-
|
|
297
|
-
maximizers = jnp.concatenate(maximizers)
|
|
298
|
-
return maximizers
|
|
299
|
-
else:
|
|
300
|
-
raise NotImplementedError(
|
|
301
|
-
"Only Thompson sampling currently supports batch size > 1."
|
|
302
|
-
)
|