gpjax 0.8.2__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +9 -6
- gpjax/citation.py +0 -47
- gpjax/dataset.py +19 -13
- gpjax/decision_making/decision_maker.py +16 -19
- gpjax/decision_making/posterior_handler.py +20 -23
- gpjax/decision_making/search_space.py +1 -1
- gpjax/decision_making/test_functions/continuous_functions.py +23 -16
- gpjax/decision_making/test_functions/non_conjugate_functions.py +8 -8
- gpjax/decision_making/utility_functions/__init__.py +8 -0
- gpjax/decision_making/utility_functions/base.py +9 -13
- gpjax/decision_making/utility_functions/expected_improvement.py +112 -0
- gpjax/decision_making/utility_functions/probability_of_improvement.py +125 -0
- gpjax/decision_making/utility_functions/thompson_sampling.py +4 -4
- gpjax/decision_making/utility_maximizer.py +7 -10
- gpjax/decision_making/utils.py +16 -2
- gpjax/distributions.py +54 -63
- gpjax/fit.py +150 -120
- gpjax/gps.py +205 -213
- gpjax/integrators.py +39 -41
- gpjax/kernels/__init__.py +3 -0
- gpjax/kernels/approximations/rff.py +52 -42
- gpjax/kernels/base.py +205 -74
- gpjax/kernels/computations/base.py +53 -33
- gpjax/kernels/computations/basis_functions.py +31 -47
- gpjax/kernels/computations/constant_diagonal.py +11 -48
- gpjax/kernels/computations/dense.py +4 -17
- gpjax/kernels/computations/diagonal.py +6 -31
- gpjax/kernels/computations/eigen.py +13 -22
- gpjax/kernels/non_euclidean/graph.py +60 -54
- gpjax/kernels/non_euclidean/utils.py +4 -5
- gpjax/kernels/nonstationary/arccosine.py +75 -37
- gpjax/kernels/nonstationary/linear.py +42 -22
- gpjax/kernels/nonstationary/polynomial.py +56 -31
- gpjax/kernels/stationary/__init__.py +2 -0
- gpjax/kernels/stationary/base.py +194 -0
- gpjax/kernels/stationary/matern12.py +14 -30
- gpjax/kernels/stationary/matern32.py +20 -38
- gpjax/kernels/stationary/matern52.py +17 -35
- gpjax/kernels/stationary/periodic.py +54 -27
- gpjax/kernels/stationary/powered_exponential.py +57 -32
- gpjax/kernels/stationary/rational_quadratic.py +56 -32
- gpjax/kernels/stationary/rbf.py +16 -34
- gpjax/kernels/stationary/utils.py +0 -1
- gpjax/kernels/stationary/white.py +30 -28
- gpjax/likelihoods.py +69 -50
- gpjax/lower_cholesky.py +23 -14
- gpjax/mean_functions.py +44 -47
- gpjax/objectives.py +326 -407
- gpjax/parameters.py +167 -0
- gpjax/scan.py +3 -7
- gpjax/typing.py +5 -3
- gpjax/variational_families.py +207 -172
- {gpjax-0.8.2.dist-info → gpjax-0.9.1.dist-info}/METADATA +31 -32
- gpjax-0.9.1.dist-info/RECORD +62 -0
- {gpjax-0.8.2.dist-info → gpjax-0.9.1.dist-info}/WHEEL +1 -1
- gpjax/base/__init__.py +0 -38
- gpjax/base/module.py +0 -416
- gpjax/base/param.py +0 -73
- gpjax/flax_base/bijectors.py +0 -8
- gpjax/flax_base/param.py +0 -16
- gpjax/flax_base/types.py +0 -15
- gpjax/progress_bar.py +0 -131
- gpjax-0.8.2.dist-info/LICENSE +0 -201
- gpjax-0.8.2.dist-info/RECORD +0 -66
- /LICENSE → /gpjax-0.9.1.dist-info/licenses/LICENSE +0 -0
gpjax/__init__.py
CHANGED
|
@@ -12,8 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
from warnings import filterwarnings
|
|
16
|
+
|
|
17
|
+
from beartype.roar import BeartypeDecorHintPep585DeprecationWarning
|
|
18
|
+
|
|
19
|
+
filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning)
|
|
20
|
+
|
|
15
21
|
from gpjax import (
|
|
16
|
-
base,
|
|
17
22
|
decision_making,
|
|
18
23
|
gps,
|
|
19
24
|
integrators,
|
|
@@ -21,12 +26,9 @@ from gpjax import (
|
|
|
21
26
|
likelihoods,
|
|
22
27
|
mean_functions,
|
|
23
28
|
objectives,
|
|
29
|
+
parameters,
|
|
24
30
|
variational_families,
|
|
25
31
|
)
|
|
26
|
-
from gpjax.base import (
|
|
27
|
-
Module,
|
|
28
|
-
param_field,
|
|
29
|
-
)
|
|
30
32
|
from gpjax.citation import cite
|
|
31
33
|
from gpjax.dataset import Dataset
|
|
32
34
|
from gpjax.fit import (
|
|
@@ -38,7 +40,7 @@ __license__ = "MIT"
|
|
|
38
40
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
39
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
40
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
41
|
-
__version__ = "0.
|
|
43
|
+
__version__ = "0.9.1"
|
|
42
44
|
|
|
43
45
|
__all__ = [
|
|
44
46
|
"base",
|
|
@@ -49,6 +51,7 @@ __all__ = [
|
|
|
49
51
|
"likelihoods",
|
|
50
52
|
"mean_functions",
|
|
51
53
|
"objectives",
|
|
54
|
+
"parameters",
|
|
52
55
|
"variational_families",
|
|
53
56
|
"Dataset",
|
|
54
57
|
"cite",
|
gpjax/citation.py
CHANGED
|
@@ -23,13 +23,6 @@ from gpjax.kernels import (
|
|
|
23
23
|
Matern32,
|
|
24
24
|
Matern52,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.objectives import (
|
|
27
|
-
ELBO,
|
|
28
|
-
CollapsedELBO,
|
|
29
|
-
ConjugateMLL,
|
|
30
|
-
LogPosteriorDensity,
|
|
31
|
-
NonConjugateMLL,
|
|
32
|
-
)
|
|
33
26
|
|
|
34
27
|
CitationType = Union[None, str, Dict[str, str]]
|
|
35
28
|
|
|
@@ -158,46 +151,6 @@ def _(tree) -> PaperCitation:
|
|
|
158
151
|
)
|
|
159
152
|
|
|
160
153
|
|
|
161
|
-
####################
|
|
162
|
-
# Objective citations
|
|
163
|
-
####################
|
|
164
|
-
@cite.register(ConjugateMLL)
|
|
165
|
-
@cite.register(NonConjugateMLL)
|
|
166
|
-
@cite.register(LogPosteriorDensity)
|
|
167
|
-
def _(tree) -> BookCitation:
|
|
168
|
-
return BookCitation(
|
|
169
|
-
citation_key="rasmussen2006gaussian",
|
|
170
|
-
title="Gaussian Processes for Machine Learning",
|
|
171
|
-
authors="Rasmussen, Carl Edward and Williams, Christopher K",
|
|
172
|
-
year="2006",
|
|
173
|
-
publisher="MIT press Cambridge, MA",
|
|
174
|
-
volume="2",
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
@cite.register(CollapsedELBO)
|
|
179
|
-
def _(tree) -> PaperCitation:
|
|
180
|
-
return PaperCitation(
|
|
181
|
-
citation_key="titsias2009variational",
|
|
182
|
-
title="Variational learning of inducing variables in sparse Gaussian processes",
|
|
183
|
-
authors="Titsias, Michalis",
|
|
184
|
-
year="2009",
|
|
185
|
-
booktitle="International Conference on Artificial Intelligence and Statistics",
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
@cite.register(ELBO)
|
|
190
|
-
def _(tree) -> PaperCitation:
|
|
191
|
-
return PaperCitation(
|
|
192
|
-
citation_key="hensman2013gaussian",
|
|
193
|
-
title="Gaussian Processes for Big Data",
|
|
194
|
-
authors="Hensman, James and Fusi, Nicolo and Lawrence, Neil D",
|
|
195
|
-
year="2013",
|
|
196
|
-
booktitle="Uncertainty in Artificial Intelligence",
|
|
197
|
-
citation_type="article",
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
|
|
201
154
|
####################
|
|
202
155
|
# Decision making citations
|
|
203
156
|
####################
|
gpjax/dataset.py
CHANGED
|
@@ -12,40 +12,39 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
15
|
from dataclasses import dataclass
|
|
17
16
|
import warnings
|
|
18
17
|
|
|
19
18
|
from beartype.typing import Optional
|
|
19
|
+
import jax
|
|
20
20
|
import jax.numpy as jnp
|
|
21
21
|
from jaxtyping import Num
|
|
22
|
-
from simple_pytree import Pytree
|
|
23
22
|
|
|
24
23
|
from gpjax.typing import Array
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
@dataclass
|
|
28
|
-
|
|
27
|
+
@jax.tree_util.register_pytree_node_class
|
|
28
|
+
class Dataset:
|
|
29
29
|
r"""Base class for datasets.
|
|
30
30
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
y (Optional[Num[Array, "N Q"]]): output data.
|
|
31
|
+
Args:
|
|
32
|
+
X: input data.
|
|
33
|
+
y: output data.
|
|
35
34
|
"""
|
|
36
35
|
|
|
37
36
|
X: Optional[Num[Array, "N D"]] = None
|
|
38
37
|
y: Optional[Num[Array, "N Q"]] = None
|
|
39
38
|
|
|
40
39
|
def __post_init__(self) -> None:
|
|
41
|
-
r"""Checks that the shapes of
|
|
42
|
-
and provides warnings regarding the precision of
|
|
40
|
+
r"""Checks that the shapes of $X$ and $y$ are compatible,
|
|
41
|
+
and provides warnings regarding the precision of $X$ and $y$."""
|
|
43
42
|
_check_shape(self.X, self.y)
|
|
44
43
|
_check_precision(self.X, self.y)
|
|
45
44
|
|
|
46
45
|
def __repr__(self) -> str:
|
|
47
46
|
r"""Returns a string representation of the dataset."""
|
|
48
|
-
repr = f"
|
|
47
|
+
repr = f"Dataset(Number of observations: {self.n:=} - Input dimension: {self.in_dim})"
|
|
49
48
|
return repr
|
|
50
49
|
|
|
51
50
|
def is_supervised(self) -> bool:
|
|
@@ -76,14 +75,21 @@ class Dataset(Pytree):
|
|
|
76
75
|
|
|
77
76
|
@property
|
|
78
77
|
def in_dim(self) -> int:
|
|
79
|
-
r"""Dimension of the inputs,
|
|
78
|
+
r"""Dimension of the inputs, $X$."""
|
|
80
79
|
return self.X.shape[1]
|
|
81
80
|
|
|
81
|
+
def tree_flatten(self):
|
|
82
|
+
return (self.X, self.y), None
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def tree_unflatten(cls, aux_data, children):
|
|
86
|
+
return cls(*children)
|
|
87
|
+
|
|
82
88
|
|
|
83
89
|
def _check_shape(
|
|
84
90
|
X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
|
|
85
91
|
) -> None:
|
|
86
|
-
r"""Checks that the shapes of
|
|
92
|
+
r"""Checks that the shapes of $X$ and $y$ are compatible."""
|
|
87
93
|
if X is not None and y is not None and X.shape[0] != y.shape[0]:
|
|
88
94
|
raise ValueError(
|
|
89
95
|
"Inputs, X, and outputs, y, must have the same number of rows."
|
|
@@ -104,7 +110,7 @@ def _check_shape(
|
|
|
104
110
|
def _check_precision(
|
|
105
111
|
X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
|
|
106
112
|
) -> None:
|
|
107
|
-
r"""Checks the precision of
|
|
113
|
+
r"""Checks the precision of $X$ and $y`."""
|
|
108
114
|
if X is not None and X.dtype != jnp.float64:
|
|
109
115
|
warnings.warn(
|
|
110
116
|
"X is not of type float64. "
|
|
@@ -58,26 +58,23 @@ class AbstractDecisionMaker(ABC):
|
|
|
58
58
|
the black-box function of interest at this point.
|
|
59
59
|
|
|
60
60
|
Attributes:
|
|
61
|
-
search_space
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
decision_making.utils.
|
|
70
|
-
datasets (Dict[str, Dataset]): Dictionary of datasets, which are augmented with
|
|
61
|
+
search_space: Search space over which we can evaluate the function(s) of interest.
|
|
62
|
+
posterior_handlers: dictionary of posterior handlers, which are used to update
|
|
63
|
+
posteriors throughout the decision making loop. Note that the word `posteriors`
|
|
64
|
+
is used for consistency with GPJax, but these objects are typically referred to
|
|
65
|
+
as `models` in the model-based decision making literature. Tags are used to
|
|
66
|
+
distinguish between posteriors. In a typical Bayesian optimisation setup one of
|
|
67
|
+
the tags will be `OBJECTIVE`, defined in `decision_making.utils`.
|
|
68
|
+
datasets: dictionary of datasets, which are augmented with
|
|
71
69
|
observations throughout the decision making loop. In a typical setup they are
|
|
72
70
|
also used to update the posteriors, using the `posterior_handlers`. Tags are used
|
|
73
71
|
to distinguish datasets, and correspond to tags in `posterior_handlers`.
|
|
74
|
-
key
|
|
75
|
-
batch_size
|
|
72
|
+
key: JAX random key, used to generate random numbers.
|
|
73
|
+
batch_size: Number of points to query at each step of the decision making
|
|
76
74
|
loop. Note that `SinglePointUtilityFunction`s are only capable of generating
|
|
77
75
|
one point to be queried at each iteration of the decision making loop.
|
|
78
|
-
post_ask
|
|
79
|
-
post_tell
|
|
80
|
-
step.
|
|
76
|
+
post_ask: List of functions to be executed after each ask step.
|
|
77
|
+
post_tell: List of functions to be executed after each tell step.
|
|
81
78
|
"""
|
|
82
79
|
|
|
83
80
|
search_space: AbstractSearchSpace
|
|
@@ -140,10 +137,10 @@ class AbstractDecisionMaker(ABC):
|
|
|
140
137
|
Add newly observed data to datasets and update the corresponding posteriors.
|
|
141
138
|
|
|
142
139
|
Args:
|
|
143
|
-
observation_datasets
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
key
|
|
140
|
+
observation_datasets: dictionary of datasets containing new observations.
|
|
141
|
+
Tags are used to distinguish datasets, and correspond to tags in
|
|
142
|
+
`posterior_handlers` and `self.datasets`.
|
|
143
|
+
key: JAX PRNG key for controlling random state.
|
|
147
144
|
"""
|
|
148
145
|
if observation_datasets.keys() != self.datasets.keys():
|
|
149
146
|
raise ValueError(
|
|
@@ -27,7 +27,7 @@ from gpjax.gps import (
|
|
|
27
27
|
AbstractPosterior,
|
|
28
28
|
AbstractPrior,
|
|
29
29
|
)
|
|
30
|
-
from gpjax.objectives import
|
|
30
|
+
from gpjax.objectives import Objective
|
|
31
31
|
from gpjax.typing import KeyArray
|
|
32
32
|
|
|
33
33
|
LikelihoodBuilder = Callable[[int], AbstractLikelihood]
|
|
@@ -42,21 +42,18 @@ class PosteriorHandler:
|
|
|
42
42
|
observed.
|
|
43
43
|
|
|
44
44
|
Attributes:
|
|
45
|
-
prior
|
|
46
|
-
likelihood_builder
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
posterior hyperparameters.
|
|
53
|
-
num_optimization_iterations (int): Number of iterations to optimize
|
|
54
|
-
the posterior hyperparameters for.
|
|
45
|
+
prior: prior to use when forming the posterior.
|
|
46
|
+
likelihood_builder: function which takes the number of datapoints as input and
|
|
47
|
+
returns a likelihood object initialised with the given number of datapoints.
|
|
48
|
+
optimization_objective: objective to use for optimizing the posterior hyperparameters.
|
|
49
|
+
optimizer: an optax optimizer to use for optimizing the posterior hyperparameters.
|
|
50
|
+
num_optimization_iterations: the number of iterations to optimize
|
|
51
|
+
the posterior hyperparameters for.
|
|
55
52
|
"""
|
|
56
53
|
|
|
57
54
|
prior: AbstractPrior
|
|
58
55
|
likelihood_builder: LikelihoodBuilder
|
|
59
|
-
optimization_objective:
|
|
56
|
+
optimization_objective: Objective
|
|
60
57
|
optimizer: ox.GradientTransformation
|
|
61
58
|
num_optimization_iters: int
|
|
62
59
|
|
|
@@ -71,10 +68,10 @@ class PosteriorHandler:
|
|
|
71
68
|
Initialise (and optionally optimize) a posterior using the given dataset.
|
|
72
69
|
|
|
73
70
|
Args:
|
|
74
|
-
dataset
|
|
75
|
-
optimize
|
|
76
|
-
key
|
|
77
|
-
|
|
71
|
+
dataset: dataset to get posterior for.
|
|
72
|
+
optimize: whether to optimize the posterior hyperparameters.
|
|
73
|
+
key: a JAX PRNG key which is used for optimizing the posterior
|
|
74
|
+
hyperparameters.
|
|
78
75
|
|
|
79
76
|
Returns:
|
|
80
77
|
Posterior for the given dataset.
|
|
@@ -108,14 +105,14 @@ class PosteriorHandler:
|
|
|
108
105
|
set as in the `likelihood_builder` function.
|
|
109
106
|
|
|
110
107
|
Args:
|
|
111
|
-
dataset:
|
|
112
|
-
previous_posterior:
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
optimize:
|
|
108
|
+
dataset: dataset to get posterior for.
|
|
109
|
+
previous_posterior: posterior being updated. This is supplied as one may
|
|
110
|
+
wish to simply increase the number of datapoints in the likelihood, without
|
|
111
|
+
optimizing the posterior hyperparameters, in which case the previous
|
|
112
|
+
posterior can be used to obtain the previously set prior hyperparameters.
|
|
113
|
+
optimize: whether to optimize the posterior hyperparameters.
|
|
117
114
|
key: A JAX PRNG key which is used for optimizing the posterior
|
|
118
|
-
|
|
115
|
+
hyperparameters.
|
|
119
116
|
"""
|
|
120
117
|
posterior = previous_posterior.prior * self.likelihood_builder(dataset.n)
|
|
121
118
|
|
|
@@ -56,7 +56,7 @@ class AbstractSearchSpace(ABC):
|
|
|
56
56
|
|
|
57
57
|
@dataclass
|
|
58
58
|
class ContinuousSearchSpace(AbstractSearchSpace):
|
|
59
|
-
"""The `ContinuousSearchSpace` class is used to bound the domain of continuous real functions of dimension
|
|
59
|
+
"""The `ContinuousSearchSpace` class is used to bound the domain of continuous real functions of dimension $D$."""
|
|
60
60
|
|
|
61
61
|
lower_bounds: Float[Array, " D"]
|
|
62
62
|
upper_bounds: Float[Array, " D"]
|
|
@@ -12,24 +12,24 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from abc import
|
|
16
|
-
ABC,
|
|
17
|
-
abstractmethod,
|
|
18
|
-
)
|
|
15
|
+
from abc import abstractmethod
|
|
19
16
|
from dataclasses import dataclass
|
|
20
17
|
|
|
21
18
|
import jax.numpy as jnp
|
|
22
19
|
from jaxtyping import (
|
|
23
20
|
Array,
|
|
24
21
|
Float,
|
|
22
|
+
Num,
|
|
25
23
|
)
|
|
24
|
+
import tensorflow_probability.substrates.jax as tfp
|
|
26
25
|
|
|
27
26
|
from gpjax.dataset import Dataset
|
|
28
27
|
from gpjax.decision_making.search_space import ContinuousSearchSpace
|
|
28
|
+
from gpjax.gps import AbstractMeanFunction
|
|
29
29
|
from gpjax.typing import KeyArray
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
class AbstractContinuousTestFunction(
|
|
32
|
+
class AbstractContinuousTestFunction(AbstractMeanFunction):
|
|
33
33
|
"""
|
|
34
34
|
Abstract base class for continuous test functions.
|
|
35
35
|
|
|
@@ -43,19 +43,28 @@ class AbstractContinuousTestFunction(ABC):
|
|
|
43
43
|
minimizer: Float[Array, "1 D"]
|
|
44
44
|
minimum: Float[Array, "1 1"]
|
|
45
45
|
|
|
46
|
-
def generate_dataset(
|
|
46
|
+
def generate_dataset(
|
|
47
|
+
self, num_points: int, key: KeyArray, obs_stddev: float = 0.0
|
|
48
|
+
) -> Dataset:
|
|
47
49
|
"""
|
|
48
50
|
Generate a toy dataset from the test function.
|
|
49
51
|
|
|
50
52
|
Args:
|
|
51
53
|
num_points (int): Number of points to sample.
|
|
52
54
|
key (KeyArray): JAX PRNG key.
|
|
55
|
+
obs_stddev (float): (Optional) standard deviation of Gaussian distributed
|
|
56
|
+
noise added to observations.
|
|
53
57
|
|
|
54
58
|
Returns:
|
|
55
59
|
Dataset: Dataset of points sampled from the test function.
|
|
56
60
|
"""
|
|
57
61
|
X = self.search_space.sample(num_points=num_points, key=key)
|
|
58
|
-
|
|
62
|
+
gaussian_noise = tfp.distributions.Normal(
|
|
63
|
+
jnp.zeros(num_points), obs_stddev * jnp.ones(num_points)
|
|
64
|
+
)
|
|
65
|
+
y = self.evaluate(X) + jnp.transpose(
|
|
66
|
+
gaussian_noise.sample(sample_shape=[1], seed=key)
|
|
67
|
+
)
|
|
59
68
|
return Dataset(X=X, y=y)
|
|
60
69
|
|
|
61
70
|
def generate_test_points(
|
|
@@ -73,6 +82,9 @@ class AbstractContinuousTestFunction(ABC):
|
|
|
73
82
|
"""
|
|
74
83
|
return self.search_space.sample(num_points=num_points, key=key)
|
|
75
84
|
|
|
85
|
+
def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
|
|
86
|
+
return self.evaluate(x)
|
|
87
|
+
|
|
76
88
|
@abstractmethod
|
|
77
89
|
def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
|
|
78
90
|
"""
|
|
@@ -92,7 +104,7 @@ class Forrester(AbstractContinuousTestFunction):
|
|
|
92
104
|
"""
|
|
93
105
|
Forrester function introduced in 'Engineering design via surrogate modelling: a
|
|
94
106
|
practical guide' (Forrester et al. 2008), rescaled to have zero mean and unit
|
|
95
|
-
variance over
|
|
107
|
+
variance over $[0, 1]$.
|
|
96
108
|
"""
|
|
97
109
|
|
|
98
110
|
search_space = ContinuousSearchSpace(
|
|
@@ -113,7 +125,7 @@ class LogarithmicGoldsteinPrice(AbstractContinuousTestFunction):
|
|
|
113
125
|
"""
|
|
114
126
|
Logarithmic Goldstein-Price function introduced in 'A benchmark of kriging-based
|
|
115
127
|
infill criteria for noisy optimization' (Picheny et al. 2013), which has zero mean
|
|
116
|
-
and unit variance over
|
|
128
|
+
and unit variance over $[0, 1]^2$.
|
|
117
129
|
"""
|
|
118
130
|
|
|
119
131
|
search_space = ContinuousSearchSpace(
|
|
@@ -127,12 +139,7 @@ class LogarithmicGoldsteinPrice(AbstractContinuousTestFunction):
|
|
|
127
139
|
x1 = 4.0 * x[:, 0] - 2.0
|
|
128
140
|
x2 = 4.0 * x[:, 1] - 2.0
|
|
129
141
|
a = 1.0 + (x1 + x2 + 1.0) ** 2 * (
|
|
130
|
-
19.0
|
|
131
|
-
- 14.0 * x1
|
|
132
|
-
+ 3.0 * (x1**2)
|
|
133
|
-
- 14.0 * x2
|
|
134
|
-
+ 6.0 * x1 * x2
|
|
135
|
-
+ 3.0 * (x2**2)
|
|
142
|
+
19.0 - 14.0 * x1 + 3.0 * (x1**2) - 14.0 * x2 + 6.0 * x1 * x2 + 3.0 * (x2**2)
|
|
136
143
|
)
|
|
137
144
|
b = 30.0 + (2.0 * x1 - 3.0 * x2) ** 2 * (
|
|
138
145
|
18.0
|
|
@@ -148,7 +155,7 @@ class LogarithmicGoldsteinPrice(AbstractContinuousTestFunction):
|
|
|
148
155
|
@dataclass
|
|
149
156
|
class Quadratic(AbstractContinuousTestFunction):
|
|
150
157
|
"""
|
|
151
|
-
Toy quadratic function defined over
|
|
158
|
+
Toy quadratic function defined over $[0, 1]$.
|
|
152
159
|
"""
|
|
153
160
|
|
|
154
161
|
search_space = ContinuousSearchSpace(
|
|
@@ -17,15 +17,15 @@ from dataclasses import dataclass
|
|
|
17
17
|
|
|
18
18
|
import jax.numpy as jnp
|
|
19
19
|
import jax.random as jr
|
|
20
|
-
from jaxtyping import (
|
|
21
|
-
Array,
|
|
22
|
-
Float,
|
|
23
|
-
Integer,
|
|
24
|
-
)
|
|
25
20
|
|
|
26
21
|
from gpjax.dataset import Dataset
|
|
27
22
|
from gpjax.decision_making.search_space import ContinuousSearchSpace
|
|
28
|
-
from gpjax.typing import
|
|
23
|
+
from gpjax.typing import (
|
|
24
|
+
Array,
|
|
25
|
+
Float,
|
|
26
|
+
Int,
|
|
27
|
+
KeyArray,
|
|
28
|
+
)
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
@dataclass
|
|
@@ -74,7 +74,7 @@ class PoissonTestFunction:
|
|
|
74
74
|
return self.search_space.sample(num_points=num_points, key=key)
|
|
75
75
|
|
|
76
76
|
@abstractmethod
|
|
77
|
-
def evaluate(self, x: Float[Array, "N 1"]) ->
|
|
77
|
+
def evaluate(self, x: Float[Array, "N 1"]) -> Int[Array, "N 1"]:
|
|
78
78
|
"""
|
|
79
79
|
Evaluate the test function at a set of points. Function taken from
|
|
80
80
|
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
|
|
@@ -83,7 +83,7 @@ class PoissonTestFunction:
|
|
|
83
83
|
x (Float[Array, 'N D']): Points to evaluate the test function at.
|
|
84
84
|
|
|
85
85
|
Returns:
|
|
86
|
-
|
|
86
|
+
Float[Array, 'N 1']: Values of the test function at the points.
|
|
87
87
|
"""
|
|
88
88
|
key = jr.key(42)
|
|
89
89
|
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x
|
|
@@ -18,12 +18,20 @@ from gpjax.decision_making.utility_functions.base import (
|
|
|
18
18
|
SinglePointUtilityFunction,
|
|
19
19
|
UtilityFunction,
|
|
20
20
|
)
|
|
21
|
+
from gpjax.decision_making.utility_functions.expected_improvement import (
|
|
22
|
+
ExpectedImprovement,
|
|
23
|
+
)
|
|
24
|
+
from gpjax.decision_making.utility_functions.probability_of_improvement import (
|
|
25
|
+
ProbabilityOfImprovement,
|
|
26
|
+
)
|
|
21
27
|
from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling
|
|
22
28
|
|
|
23
29
|
__all__ = [
|
|
24
30
|
"UtilityFunction",
|
|
25
31
|
"AbstractUtilityFunctionBuilder",
|
|
26
32
|
"AbstractSinglePointUtilityFunctionBuilder",
|
|
33
|
+
"ExpectedImprovement",
|
|
27
34
|
"SinglePointUtilityFunction",
|
|
28
35
|
"ThompsonSampling",
|
|
36
|
+
"ProbabilityOfImprovement",
|
|
29
37
|
]
|
|
@@ -35,8 +35,8 @@ from gpjax.typing import (
|
|
|
35
35
|
SinglePointUtilityFunction = Callable[[Float[Array, "N D"]], Float[Array, "N 1"]]
|
|
36
36
|
"""
|
|
37
37
|
Type alias for utility functions which don't support batching, and instead characterise
|
|
38
|
-
the utility of querying a single point, rather than a batch of points. They take an array of points of shape
|
|
39
|
-
and return the value of the utility function at each point in an array of shape
|
|
38
|
+
the utility of querying a single point, rather than a batch of points. They take an array of points of shape $[N, D]$
|
|
39
|
+
and return the value of the utility function at each point in an array of shape $[N, 1]$.
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
42
|
|
|
@@ -65,14 +65,12 @@ class AbstractSinglePointUtilityFunctionBuilder(ABC):
|
|
|
65
65
|
datasets.
|
|
66
66
|
|
|
67
67
|
Args:
|
|
68
|
-
posteriors
|
|
69
|
-
used to form the utility function.
|
|
70
|
-
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
|
|
71
|
-
to form the utility function.
|
|
68
|
+
posteriors: dictionary of posteriors to be used to form the utility function.
|
|
69
|
+
datasets: dictionary of datasets which may be used to form the utility function.
|
|
72
70
|
|
|
73
71
|
Raises:
|
|
74
72
|
ValueError: If the objective posterior or dataset are not present in the
|
|
75
|
-
|
|
73
|
+
posteriors or datasets.
|
|
76
74
|
"""
|
|
77
75
|
if OBJECTIVE not in posteriors.keys():
|
|
78
76
|
raise ValueError("Objective posterior not found in posteriors")
|
|
@@ -90,15 +88,13 @@ class AbstractSinglePointUtilityFunctionBuilder(ABC):
|
|
|
90
88
|
Build a `UtilityFunction` from a set of posteriors and datasets.
|
|
91
89
|
|
|
92
90
|
Args:
|
|
93
|
-
posteriors
|
|
94
|
-
used to form the utility function.
|
|
95
|
-
|
|
96
|
-
to form the utility function.
|
|
97
|
-
key (KeyArray): JAX PRNG key used for random number generation.
|
|
91
|
+
posteriors: dictionary of posteriors to be used to form the utility function.
|
|
92
|
+
datasets: dictionary of datasets which may be used to form the utility function.
|
|
93
|
+
key: JAX PRNG key used for random number generation.
|
|
98
94
|
|
|
99
95
|
Returns:
|
|
100
96
|
SinglePointUtilityFunction: Utility function to be *maximised* in order to
|
|
101
|
-
|
|
97
|
+
decide which point to query next.
|
|
102
98
|
"""
|
|
103
99
|
raise NotImplementedError
|
|
104
100
|
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from functools import partial
|
|
17
|
+
|
|
18
|
+
from beartype.typing import Mapping
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import tensorflow_probability.substrates.jax as tfp
|
|
21
|
+
|
|
22
|
+
from gpjax.dataset import Dataset
|
|
23
|
+
from gpjax.decision_making.utility_functions.base import (
|
|
24
|
+
AbstractSinglePointUtilityFunctionBuilder,
|
|
25
|
+
SinglePointUtilityFunction,
|
|
26
|
+
)
|
|
27
|
+
from gpjax.decision_making.utils import (
|
|
28
|
+
OBJECTIVE,
|
|
29
|
+
get_best_latent_observation_val,
|
|
30
|
+
)
|
|
31
|
+
from gpjax.gps import ConjugatePosterior
|
|
32
|
+
from gpjax.typing import (
|
|
33
|
+
Array,
|
|
34
|
+
Float,
|
|
35
|
+
KeyArray,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ExpectedImprovement(AbstractSinglePointUtilityFunctionBuilder):
|
|
41
|
+
"""
|
|
42
|
+
Expected Improvement acquisition function as introduced by [Močkus,
|
|
43
|
+
1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55). The "best"
|
|
44
|
+
incumbent value is defined as the lowest posterior mean value evaluated at the the
|
|
45
|
+
previously observed points. This enables the acquisition function to be utilised with noisy observations.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def build_utility_function(
|
|
49
|
+
self,
|
|
50
|
+
posteriors: Mapping[str, ConjugatePosterior],
|
|
51
|
+
datasets: Mapping[str, Dataset],
|
|
52
|
+
key: KeyArray,
|
|
53
|
+
) -> SinglePointUtilityFunction:
|
|
54
|
+
r"""
|
|
55
|
+
Build the Expected Improvement acquisition function. This computes the expected
|
|
56
|
+
improvement over the "best" of the previously observed points, utilising the
|
|
57
|
+
posterior distribution of the surrogate model. For posterior distribution
|
|
58
|
+
$`f(\cdot)`$, and best incumbent value $`\eta`$, this is defined
|
|
59
|
+
as:
|
|
60
|
+
```math
|
|
61
|
+
\alpha_{\text{EI}}(\mathbf{x}) = \mathbb{E}\left[\max(0, \eta - f(\mathbf{x}))\right]
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
|
|
66
|
+
used to form the utility function. One posteriors must correspond to the
|
|
67
|
+
`OBJECTIVE` key, as we utilise the objective posterior to form the utility
|
|
68
|
+
function.
|
|
69
|
+
datasets (Mapping[str, Dataset]): Dictionary of datasets used to form the
|
|
70
|
+
utility function. Keys in `datasets` should correspond to keys in
|
|
71
|
+
`posteriors`. One of the datasets must correspond to the `OBJECTIVE` key.
|
|
72
|
+
key (KeyArray): JAX PRNG key used for random number generation.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
SinglePointUtilityFunction: The Expected Improvement acquisition function to
|
|
76
|
+
to be *maximised* in order to decide which point to query next.
|
|
77
|
+
"""
|
|
78
|
+
self.check_objective_present(posteriors, datasets)
|
|
79
|
+
objective_posterior = posteriors[OBJECTIVE]
|
|
80
|
+
objective_dataset = datasets[OBJECTIVE]
|
|
81
|
+
|
|
82
|
+
if not isinstance(objective_posterior, ConjugatePosterior):
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Objective posterior must be a ConjugatePosterior to compute the Expected Improvement."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if (
|
|
88
|
+
objective_dataset.X is None
|
|
89
|
+
or objective_dataset.n == 0
|
|
90
|
+
or objective_dataset.y is None
|
|
91
|
+
):
|
|
92
|
+
raise ValueError("Objective dataset must contain at least one item")
|
|
93
|
+
|
|
94
|
+
eta = get_best_latent_observation_val(objective_posterior, objective_dataset)
|
|
95
|
+
return partial(
|
|
96
|
+
_expected_improvement, objective_posterior, objective_dataset, eta
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _expected_improvement(
|
|
101
|
+
objective_posterior: ConjugatePosterior,
|
|
102
|
+
objective_dataset: Dataset,
|
|
103
|
+
eta: Float[Array, ""],
|
|
104
|
+
x: Float[Array, "N D"],
|
|
105
|
+
) -> Float[Array, "N 1"]:
|
|
106
|
+
latent_dist = objective_posterior(x, objective_dataset)
|
|
107
|
+
mean = latent_dist.mean()
|
|
108
|
+
var = latent_dist.variance()
|
|
109
|
+
normal = tfp.distributions.Normal(mean, jnp.sqrt(var))
|
|
110
|
+
return jnp.expand_dims(
|
|
111
|
+
((eta - mean) * normal.cdf(eta) + var * normal.prob(eta)), -1
|
|
112
|
+
)
|