gpjax 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -3
- gpjax/citation.py +0 -43
- gpjax/distributions.py +3 -1
- gpjax/gps.py +2 -1
- gpjax/variational_families.py +24 -19
- {gpjax-0.9.3.dist-info → gpjax-0.9.5.dist-info}/METADATA +20 -21
- {gpjax-0.9.3.dist-info → gpjax-0.9.5.dist-info}/RECORD +9 -23
- {gpjax-0.9.3.dist-info → gpjax-0.9.5.dist-info}/WHEEL +1 -1
- gpjax-0.9.5.dist-info/licenses/LICENSE.txt +19 -0
- gpjax/decision_making/__init__.py +0 -63
- gpjax/decision_making/decision_maker.py +0 -302
- gpjax/decision_making/posterior_handler.py +0 -152
- gpjax/decision_making/search_space.py +0 -96
- gpjax/decision_making/test_functions/__init__.py +0 -31
- gpjax/decision_making/test_functions/continuous_functions.py +0 -169
- gpjax/decision_making/test_functions/non_conjugate_functions.py +0 -90
- gpjax/decision_making/utility_functions/__init__.py +0 -37
- gpjax/decision_making/utility_functions/base.py +0 -106
- gpjax/decision_making/utility_functions/expected_improvement.py +0 -112
- gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -125
- gpjax/decision_making/utility_functions/thompson_sampling.py +0 -101
- gpjax/decision_making/utility_maximizer.py +0 -157
- gpjax/decision_making/utils.py +0 -64
- gpjax-0.9.3.dist-info/licenses/LICENSE +0 -201
|
@@ -1,169 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from abc import abstractmethod
|
|
16
|
-
from dataclasses import dataclass
|
|
17
|
-
|
|
18
|
-
import jax.numpy as jnp
|
|
19
|
-
from jaxtyping import (
|
|
20
|
-
Array,
|
|
21
|
-
Float,
|
|
22
|
-
Num,
|
|
23
|
-
)
|
|
24
|
-
import tensorflow_probability.substrates.jax as tfp
|
|
25
|
-
|
|
26
|
-
from gpjax.dataset import Dataset
|
|
27
|
-
from gpjax.decision_making.search_space import ContinuousSearchSpace
|
|
28
|
-
from gpjax.gps import AbstractMeanFunction
|
|
29
|
-
from gpjax.typing import KeyArray
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class AbstractContinuousTestFunction(AbstractMeanFunction):
|
|
33
|
-
"""
|
|
34
|
-
Abstract base class for continuous test functions.
|
|
35
|
-
|
|
36
|
-
Attributes:
|
|
37
|
-
search_space (ContinuousSearchSpace): Search space for the function.
|
|
38
|
-
minimizer (Float[Array, '1 D']): Minimizer of the function (to 5 decimal places)
|
|
39
|
-
minimum (Float[Array, '1 1']): Minimum of the function (to 5 decimal places).
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
search_space: ContinuousSearchSpace
|
|
43
|
-
minimizer: Float[Array, "1 D"]
|
|
44
|
-
minimum: Float[Array, "1 1"]
|
|
45
|
-
|
|
46
|
-
def generate_dataset(
|
|
47
|
-
self, num_points: int, key: KeyArray, obs_stddev: float = 0.0
|
|
48
|
-
) -> Dataset:
|
|
49
|
-
"""
|
|
50
|
-
Generate a toy dataset from the test function.
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
num_points (int): Number of points to sample.
|
|
54
|
-
key (KeyArray): JAX PRNG key.
|
|
55
|
-
obs_stddev (float): (Optional) standard deviation of Gaussian distributed
|
|
56
|
-
noise added to observations.
|
|
57
|
-
|
|
58
|
-
Returns:
|
|
59
|
-
Dataset: Dataset of points sampled from the test function.
|
|
60
|
-
"""
|
|
61
|
-
X = self.search_space.sample(num_points=num_points, key=key)
|
|
62
|
-
gaussian_noise = tfp.distributions.Normal(
|
|
63
|
-
jnp.zeros(num_points), obs_stddev * jnp.ones(num_points)
|
|
64
|
-
)
|
|
65
|
-
y = self.evaluate(X) + jnp.transpose(
|
|
66
|
-
gaussian_noise.sample(sample_shape=[1], seed=key)
|
|
67
|
-
)
|
|
68
|
-
return Dataset(X=X, y=y)
|
|
69
|
-
|
|
70
|
-
def generate_test_points(
|
|
71
|
-
self, num_points: int, key: KeyArray
|
|
72
|
-
) -> Float[Array, "N D"]:
|
|
73
|
-
"""
|
|
74
|
-
Generate test points from the search space of the test function.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
num_points (int): Number of points to sample.
|
|
78
|
-
key (KeyArray): JAX PRNG key.
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
Float[Array, 'N D']: Test points sampled from the search space.
|
|
82
|
-
"""
|
|
83
|
-
return self.search_space.sample(num_points=num_points, key=key)
|
|
84
|
-
|
|
85
|
-
def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
|
|
86
|
-
return self.evaluate(x)
|
|
87
|
-
|
|
88
|
-
@abstractmethod
|
|
89
|
-
def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
|
|
90
|
-
"""
|
|
91
|
-
Evaluate the test function at a set of points.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
x (Float[Array, 'N D']): Points to evaluate the test function at.
|
|
95
|
-
|
|
96
|
-
Returns:
|
|
97
|
-
Float[Array, 'N 1']: Values of the test function at the points.
|
|
98
|
-
"""
|
|
99
|
-
raise NotImplementedError
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
@dataclass
|
|
103
|
-
class Forrester(AbstractContinuousTestFunction):
|
|
104
|
-
"""
|
|
105
|
-
Forrester function introduced in 'Engineering design via surrogate modelling: a
|
|
106
|
-
practical guide' (Forrester et al. 2008), rescaled to have zero mean and unit
|
|
107
|
-
variance over $[0, 1]$.
|
|
108
|
-
"""
|
|
109
|
-
|
|
110
|
-
search_space = ContinuousSearchSpace(
|
|
111
|
-
lower_bounds=jnp.array([0.0]),
|
|
112
|
-
upper_bounds=jnp.array([1.0]),
|
|
113
|
-
)
|
|
114
|
-
minimizer = jnp.array([[0.75725]])
|
|
115
|
-
minimum = jnp.array([[-1.45280]])
|
|
116
|
-
|
|
117
|
-
def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
|
|
118
|
-
mean = 0.45321
|
|
119
|
-
std = jnp.sqrt(19.8577)
|
|
120
|
-
return (((6 * x - 2) ** 2) * jnp.sin(12 * x - 4) - mean) / std
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
@dataclass
|
|
124
|
-
class LogarithmicGoldsteinPrice(AbstractContinuousTestFunction):
|
|
125
|
-
"""
|
|
126
|
-
Logarithmic Goldstein-Price function introduced in 'A benchmark of kriging-based
|
|
127
|
-
infill criteria for noisy optimization' (Picheny et al. 2013), which has zero mean
|
|
128
|
-
and unit variance over $[0, 1]^2$.
|
|
129
|
-
"""
|
|
130
|
-
|
|
131
|
-
search_space = ContinuousSearchSpace(
|
|
132
|
-
lower_bounds=jnp.array([0.0, 0.0]),
|
|
133
|
-
upper_bounds=jnp.array([1.0, 1.0]),
|
|
134
|
-
)
|
|
135
|
-
minimizer = jnp.array([[0.5, 0.25]])
|
|
136
|
-
minimum = jnp.array([[-3.12913]])
|
|
137
|
-
|
|
138
|
-
def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
|
|
139
|
-
x1 = 4.0 * x[:, 0] - 2.0
|
|
140
|
-
x2 = 4.0 * x[:, 1] - 2.0
|
|
141
|
-
a = 1.0 + (x1 + x2 + 1.0) ** 2 * (
|
|
142
|
-
19.0 - 14.0 * x1 + 3.0 * (x1**2) - 14.0 * x2 + 6.0 * x1 * x2 + 3.0 * (x2**2)
|
|
143
|
-
)
|
|
144
|
-
b = 30.0 + (2.0 * x1 - 3.0 * x2) ** 2 * (
|
|
145
|
-
18.0
|
|
146
|
-
- 32.0 * x1
|
|
147
|
-
+ 12.0 * (x1**2)
|
|
148
|
-
+ 48.0 * x2
|
|
149
|
-
- 36.0 * x1 * x2
|
|
150
|
-
+ 27.0 * (x2**2)
|
|
151
|
-
)
|
|
152
|
-
return ((jnp.log((a * b)) - 8.693) / 2.427).reshape(-1, 1)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
@dataclass
|
|
156
|
-
class Quadratic(AbstractContinuousTestFunction):
|
|
157
|
-
"""
|
|
158
|
-
Toy quadratic function defined over $[0, 1]$.
|
|
159
|
-
"""
|
|
160
|
-
|
|
161
|
-
search_space = ContinuousSearchSpace(
|
|
162
|
-
lower_bounds=jnp.array([0.0]),
|
|
163
|
-
upper_bounds=jnp.array([1.0]),
|
|
164
|
-
)
|
|
165
|
-
minimizer = jnp.array([[0.5]])
|
|
166
|
-
minimum = jnp.array([[0.0]])
|
|
167
|
-
|
|
168
|
-
def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
|
|
169
|
-
return (x - 0.5) ** 2
|
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from abc import abstractmethod
|
|
16
|
-
from dataclasses import dataclass
|
|
17
|
-
|
|
18
|
-
import jax.numpy as jnp
|
|
19
|
-
import jax.random as jr
|
|
20
|
-
|
|
21
|
-
from gpjax.dataset import Dataset
|
|
22
|
-
from gpjax.decision_making.search_space import ContinuousSearchSpace
|
|
23
|
-
from gpjax.typing import (
|
|
24
|
-
Array,
|
|
25
|
-
Float,
|
|
26
|
-
Int,
|
|
27
|
-
KeyArray,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
@dataclass
|
|
32
|
-
class PoissonTestFunction:
|
|
33
|
-
"""
|
|
34
|
-
Test function for GPs utilising the Poisson likelihood. Function taken from
|
|
35
|
-
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
|
|
36
|
-
|
|
37
|
-
Attributes:
|
|
38
|
-
search_space (ContinuousSearchSpace): Search space for the function.
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
search_space = ContinuousSearchSpace(
|
|
42
|
-
lower_bounds=jnp.array([-2.0]),
|
|
43
|
-
upper_bounds=jnp.array([2.0]),
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset:
|
|
47
|
-
"""
|
|
48
|
-
Generate a toy dataset from the test function.
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
num_points (int): Number of points to sample.
|
|
52
|
-
key (KeyArray): JAX PRNG key.
|
|
53
|
-
|
|
54
|
-
Returns:
|
|
55
|
-
Dataset: Dataset of points sampled from the test function.
|
|
56
|
-
"""
|
|
57
|
-
X = self.search_space.sample(num_points=num_points, key=key)
|
|
58
|
-
y = self.evaluate(X)
|
|
59
|
-
return Dataset(X=X, y=y)
|
|
60
|
-
|
|
61
|
-
def generate_test_points(
|
|
62
|
-
self, num_points: int, key: KeyArray
|
|
63
|
-
) -> Float[Array, "N D"]:
|
|
64
|
-
"""
|
|
65
|
-
Generate test points from the search space of the test function.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
num_points (int): Number of points to sample.
|
|
69
|
-
key (KeyArray): JAX PRNG key.
|
|
70
|
-
|
|
71
|
-
Returns:
|
|
72
|
-
Float[Array, 'N D']: Test points sampled from the search space.
|
|
73
|
-
"""
|
|
74
|
-
return self.search_space.sample(num_points=num_points, key=key)
|
|
75
|
-
|
|
76
|
-
@abstractmethod
|
|
77
|
-
def evaluate(self, x: Float[Array, "N 1"]) -> Int[Array, "N 1"]:
|
|
78
|
-
"""
|
|
79
|
-
Evaluate the test function at a set of points. Function taken from
|
|
80
|
-
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
x (Float[Array, 'N D']): Points to evaluate the test function at.
|
|
84
|
-
|
|
85
|
-
Returns:
|
|
86
|
-
Float[Array, 'N 1']: Values of the test function at the points.
|
|
87
|
-
"""
|
|
88
|
-
key = jr.key(42)
|
|
89
|
-
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x
|
|
90
|
-
return jr.poisson(key, jnp.exp(f(x)))
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from gpjax.decision_making.utility_functions.base import (
|
|
16
|
-
AbstractSinglePointUtilityFunctionBuilder,
|
|
17
|
-
AbstractUtilityFunctionBuilder,
|
|
18
|
-
SinglePointUtilityFunction,
|
|
19
|
-
UtilityFunction,
|
|
20
|
-
)
|
|
21
|
-
from gpjax.decision_making.utility_functions.expected_improvement import (
|
|
22
|
-
ExpectedImprovement,
|
|
23
|
-
)
|
|
24
|
-
from gpjax.decision_making.utility_functions.probability_of_improvement import (
|
|
25
|
-
ProbabilityOfImprovement,
|
|
26
|
-
)
|
|
27
|
-
from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling
|
|
28
|
-
|
|
29
|
-
__all__ = [
|
|
30
|
-
"UtilityFunction",
|
|
31
|
-
"AbstractUtilityFunctionBuilder",
|
|
32
|
-
"AbstractSinglePointUtilityFunctionBuilder",
|
|
33
|
-
"ExpectedImprovement",
|
|
34
|
-
"SinglePointUtilityFunction",
|
|
35
|
-
"ThompsonSampling",
|
|
36
|
-
"ProbabilityOfImprovement",
|
|
37
|
-
]
|
|
@@ -1,106 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from abc import (
|
|
16
|
-
ABC,
|
|
17
|
-
abstractmethod,
|
|
18
|
-
)
|
|
19
|
-
from dataclasses import dataclass
|
|
20
|
-
|
|
21
|
-
from beartype.typing import (
|
|
22
|
-
Callable,
|
|
23
|
-
Mapping,
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
from gpjax.dataset import Dataset
|
|
27
|
-
from gpjax.decision_making.utils import OBJECTIVE
|
|
28
|
-
from gpjax.gps import AbstractPosterior
|
|
29
|
-
from gpjax.typing import (
|
|
30
|
-
Array,
|
|
31
|
-
Float,
|
|
32
|
-
KeyArray,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
SinglePointUtilityFunction = Callable[[Float[Array, "N D"]], Float[Array, "N 1"]]
|
|
36
|
-
"""
|
|
37
|
-
Type alias for utility functions which don't support batching, and instead characterise
|
|
38
|
-
the utility of querying a single point, rather than a batch of points. They take an array of points of shape $[N, D]$
|
|
39
|
-
and return the value of the utility function at each point in an array of shape $[N, 1]$.
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
UtilityFunction = SinglePointUtilityFunction
|
|
44
|
-
"""
|
|
45
|
-
Type alias for all utility functions. Currently we only support
|
|
46
|
-
`SinglePointUtilityFunction`s, but in future may support batched utility functions too.
|
|
47
|
-
Note that `UtilityFunction`s are *maximised* in order to decide which point, or batch of points, to query next.
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@dataclass
|
|
52
|
-
class AbstractSinglePointUtilityFunctionBuilder(ABC):
|
|
53
|
-
"""
|
|
54
|
-
Abstract class for building utility functions which don't support batches. As such,
|
|
55
|
-
they characterise the utility of querying a single point next.
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
def check_objective_present(
|
|
59
|
-
self,
|
|
60
|
-
posteriors: Mapping[str, AbstractPosterior],
|
|
61
|
-
datasets: Mapping[str, Dataset],
|
|
62
|
-
) -> None:
|
|
63
|
-
"""
|
|
64
|
-
Check that the objective posterior and dataset are present in the posteriors and
|
|
65
|
-
datasets.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
posteriors: dictionary of posteriors to be used to form the utility function.
|
|
69
|
-
datasets: dictionary of datasets which may be used to form the utility function.
|
|
70
|
-
|
|
71
|
-
Raises:
|
|
72
|
-
ValueError: If the objective posterior or dataset are not present in the
|
|
73
|
-
posteriors or datasets.
|
|
74
|
-
"""
|
|
75
|
-
if OBJECTIVE not in posteriors.keys():
|
|
76
|
-
raise ValueError("Objective posterior not found in posteriors")
|
|
77
|
-
elif OBJECTIVE not in datasets.keys():
|
|
78
|
-
raise ValueError("Objective dataset not found in datasets")
|
|
79
|
-
|
|
80
|
-
@abstractmethod
|
|
81
|
-
def build_utility_function(
|
|
82
|
-
self,
|
|
83
|
-
posteriors: Mapping[str, AbstractPosterior],
|
|
84
|
-
datasets: Mapping[str, Dataset],
|
|
85
|
-
key: KeyArray,
|
|
86
|
-
) -> SinglePointUtilityFunction:
|
|
87
|
-
"""
|
|
88
|
-
Build a `UtilityFunction` from a set of posteriors and datasets.
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
posteriors: dictionary of posteriors to be used to form the utility function.
|
|
92
|
-
datasets: dictionary of datasets which may be used to form the utility function.
|
|
93
|
-
key: JAX PRNG key used for random number generation.
|
|
94
|
-
|
|
95
|
-
Returns:
|
|
96
|
-
SinglePointUtilityFunction: Utility function to be *maximised* in order to
|
|
97
|
-
decide which point to query next.
|
|
98
|
-
"""
|
|
99
|
-
raise NotImplementedError
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
AbstractUtilityFunctionBuilder = AbstractSinglePointUtilityFunctionBuilder
|
|
103
|
-
"""
|
|
104
|
-
Type alias for utility function builders. For now this only include single point utility
|
|
105
|
-
function builders, but in the future we may support batched utility function builders.
|
|
106
|
-
"""
|
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from dataclasses import dataclass
|
|
16
|
-
from functools import partial
|
|
17
|
-
|
|
18
|
-
from beartype.typing import Mapping
|
|
19
|
-
import jax.numpy as jnp
|
|
20
|
-
import tensorflow_probability.substrates.jax as tfp
|
|
21
|
-
|
|
22
|
-
from gpjax.dataset import Dataset
|
|
23
|
-
from gpjax.decision_making.utility_functions.base import (
|
|
24
|
-
AbstractSinglePointUtilityFunctionBuilder,
|
|
25
|
-
SinglePointUtilityFunction,
|
|
26
|
-
)
|
|
27
|
-
from gpjax.decision_making.utils import (
|
|
28
|
-
OBJECTIVE,
|
|
29
|
-
get_best_latent_observation_val,
|
|
30
|
-
)
|
|
31
|
-
from gpjax.gps import ConjugatePosterior
|
|
32
|
-
from gpjax.typing import (
|
|
33
|
-
Array,
|
|
34
|
-
Float,
|
|
35
|
-
KeyArray,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class ExpectedImprovement(AbstractSinglePointUtilityFunctionBuilder):
|
|
41
|
-
"""
|
|
42
|
-
Expected Improvement acquisition function as introduced by [Močkus,
|
|
43
|
-
1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55). The "best"
|
|
44
|
-
incumbent value is defined as the lowest posterior mean value evaluated at the the
|
|
45
|
-
previously observed points. This enables the acquisition function to be utilised with noisy observations.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
def build_utility_function(
|
|
49
|
-
self,
|
|
50
|
-
posteriors: Mapping[str, ConjugatePosterior],
|
|
51
|
-
datasets: Mapping[str, Dataset],
|
|
52
|
-
key: KeyArray,
|
|
53
|
-
) -> SinglePointUtilityFunction:
|
|
54
|
-
r"""
|
|
55
|
-
Build the Expected Improvement acquisition function. This computes the expected
|
|
56
|
-
improvement over the "best" of the previously observed points, utilising the
|
|
57
|
-
posterior distribution of the surrogate model. For posterior distribution
|
|
58
|
-
$`f(\cdot)`$, and best incumbent value $`\eta`$, this is defined
|
|
59
|
-
as:
|
|
60
|
-
```math
|
|
61
|
-
\alpha_{\text{EI}}(\mathbf{x}) = \mathbb{E}\left[\max(0, \eta - f(\mathbf{x}))\right]
|
|
62
|
-
```
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
|
|
66
|
-
used to form the utility function. One posteriors must correspond to the
|
|
67
|
-
`OBJECTIVE` key, as we utilise the objective posterior to form the utility
|
|
68
|
-
function.
|
|
69
|
-
datasets (Mapping[str, Dataset]): Dictionary of datasets used to form the
|
|
70
|
-
utility function. Keys in `datasets` should correspond to keys in
|
|
71
|
-
`posteriors`. One of the datasets must correspond to the `OBJECTIVE` key.
|
|
72
|
-
key (KeyArray): JAX PRNG key used for random number generation.
|
|
73
|
-
|
|
74
|
-
Returns:
|
|
75
|
-
SinglePointUtilityFunction: The Expected Improvement acquisition function to
|
|
76
|
-
to be *maximised* in order to decide which point to query next.
|
|
77
|
-
"""
|
|
78
|
-
self.check_objective_present(posteriors, datasets)
|
|
79
|
-
objective_posterior = posteriors[OBJECTIVE]
|
|
80
|
-
objective_dataset = datasets[OBJECTIVE]
|
|
81
|
-
|
|
82
|
-
if not isinstance(objective_posterior, ConjugatePosterior):
|
|
83
|
-
raise ValueError(
|
|
84
|
-
"Objective posterior must be a ConjugatePosterior to compute the Expected Improvement."
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
if (
|
|
88
|
-
objective_dataset.X is None
|
|
89
|
-
or objective_dataset.n == 0
|
|
90
|
-
or objective_dataset.y is None
|
|
91
|
-
):
|
|
92
|
-
raise ValueError("Objective dataset must contain at least one item")
|
|
93
|
-
|
|
94
|
-
eta = get_best_latent_observation_val(objective_posterior, objective_dataset)
|
|
95
|
-
return partial(
|
|
96
|
-
_expected_improvement, objective_posterior, objective_dataset, eta
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def _expected_improvement(
|
|
101
|
-
objective_posterior: ConjugatePosterior,
|
|
102
|
-
objective_dataset: Dataset,
|
|
103
|
-
eta: Float[Array, ""],
|
|
104
|
-
x: Float[Array, "N D"],
|
|
105
|
-
) -> Float[Array, "N 1"]:
|
|
106
|
-
latent_dist = objective_posterior(x, objective_dataset)
|
|
107
|
-
mean = latent_dist.mean()
|
|
108
|
-
var = latent_dist.variance()
|
|
109
|
-
normal = tfp.distributions.Normal(mean, jnp.sqrt(var))
|
|
110
|
-
return jnp.expand_dims(
|
|
111
|
-
((eta - mean) * normal.cdf(eta) + var * normal.prob(eta)), -1
|
|
112
|
-
)
|
|
@@ -1,125 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from dataclasses import dataclass
|
|
16
|
-
|
|
17
|
-
from beartype.typing import Mapping
|
|
18
|
-
from jaxtyping import Num
|
|
19
|
-
import tensorflow_probability.substrates.jax as tfp
|
|
20
|
-
|
|
21
|
-
from gpjax.dataset import Dataset
|
|
22
|
-
from gpjax.decision_making.utility_functions.base import (
|
|
23
|
-
AbstractSinglePointUtilityFunctionBuilder,
|
|
24
|
-
SinglePointUtilityFunction,
|
|
25
|
-
)
|
|
26
|
-
from gpjax.decision_making.utils import (
|
|
27
|
-
OBJECTIVE,
|
|
28
|
-
get_best_latent_observation_val,
|
|
29
|
-
)
|
|
30
|
-
from gpjax.gps import ConjugatePosterior
|
|
31
|
-
from gpjax.typing import (
|
|
32
|
-
Array,
|
|
33
|
-
KeyArray,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@dataclass
|
|
38
|
-
class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder):
|
|
39
|
-
r"""
|
|
40
|
-
An acquisition function which returns the probability of improvement
|
|
41
|
-
of the objective function over the best observed value.
|
|
42
|
-
|
|
43
|
-
More precisely, given a predictive posterior distribution of the objective
|
|
44
|
-
function $`f`$, the probability of improvement at a test point $`x`$ is defined as:
|
|
45
|
-
$$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$
|
|
46
|
-
where $`x_{\text{best}}`$ is the minimiser of the posterior mean
|
|
47
|
-
at previously observed values (to handle noisy observations).
|
|
48
|
-
|
|
49
|
-
The probability of improvement can be easily computed using the
|
|
50
|
-
cumulative distribution function of the standard normal distribution $`\Phi`$:
|
|
51
|
-
$$`\text{PI}(x) = \Phi\left(\frac{f(x_{\text{best}}) - \mu}{\sigma}\right)`$$
|
|
52
|
-
where $`\mu`$ and $`\sigma`$ are the mean and standard deviation of the
|
|
53
|
-
predictive distribution of the objective function at $`x`$.
|
|
54
|
-
|
|
55
|
-
References
|
|
56
|
-
----------
|
|
57
|
-
[1] Kushner, H. J. (1964).
|
|
58
|
-
A new method of locating the maximum point of an arbitrary multipeak curve in the presence of noise.
|
|
59
|
-
Journal of Basic Engineering, 86(1), 97-106.
|
|
60
|
-
|
|
61
|
-
[2] Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., & de Freitas, N. (2016).
|
|
62
|
-
Taking the human out of the loop: A review of Bayesian optimization.
|
|
63
|
-
Proceedings of the IEEE, 104(1), 148-175. doi: 10.1109/JPROC.2015.2494218
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
def build_utility_function(
|
|
67
|
-
self,
|
|
68
|
-
posteriors: Mapping[str, ConjugatePosterior],
|
|
69
|
-
datasets: Mapping[str, Dataset],
|
|
70
|
-
key: KeyArray,
|
|
71
|
-
) -> SinglePointUtilityFunction:
|
|
72
|
-
"""
|
|
73
|
-
Constructs the probability of improvement utility function
|
|
74
|
-
using the predictive posterior of the objective function.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
|
|
78
|
-
used to form the utility function. One of the posteriors must correspond
|
|
79
|
-
to the `OBJECTIVE` key, as we sample from the objective posterior to form
|
|
80
|
-
the utility function.
|
|
81
|
-
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
|
|
82
|
-
to form the utility function. Keys in `datasets` should correspond to
|
|
83
|
-
keys in `posteriors`. One of the datasets must correspond
|
|
84
|
-
to the `OBJECTIVE` key.
|
|
85
|
-
key (KeyArray): JAX PRNG key used for random number generation. Since
|
|
86
|
-
the probability of improvement is computed deterministically
|
|
87
|
-
from the predictive posterior, the key is not used.
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
SinglePointUtilityFunction: the probability of improvement utility function.
|
|
91
|
-
"""
|
|
92
|
-
self.check_objective_present(posteriors, datasets)
|
|
93
|
-
|
|
94
|
-
objective_posterior = posteriors[OBJECTIVE]
|
|
95
|
-
if not isinstance(objective_posterior, ConjugatePosterior):
|
|
96
|
-
raise ValueError(
|
|
97
|
-
"Objective posterior must be a ConjugatePosterior to compute the Probability of Improvement using a Gaussian CDF."
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
objective_dataset = datasets[OBJECTIVE]
|
|
101
|
-
if (
|
|
102
|
-
objective_dataset.X is None
|
|
103
|
-
or objective_dataset.n == 0
|
|
104
|
-
or objective_dataset.y is None
|
|
105
|
-
):
|
|
106
|
-
raise ValueError(
|
|
107
|
-
"Objective dataset must be non-empty to compute the "
|
|
108
|
-
"Probability of Improvement (since we need a "
|
|
109
|
-
"`best_y` value)."
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
def probability_of_improvement(x_test: Num[Array, "N D"]):
|
|
113
|
-
best_y = get_best_latent_observation_val(
|
|
114
|
-
objective_posterior, objective_dataset
|
|
115
|
-
)
|
|
116
|
-
predictive_dist = objective_posterior.predict(x_test, objective_dataset)
|
|
117
|
-
|
|
118
|
-
normal_dist = tfp.distributions.Normal(
|
|
119
|
-
loc=predictive_dist.mean(),
|
|
120
|
-
scale=predictive_dist.stddev(),
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
return normal_dist.cdf(best_y).reshape(-1, 1)
|
|
124
|
-
|
|
125
|
-
return probability_of_improvement
|