gpjax 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,169 +0,0 @@
1
- # Copyright 2023 The GPJax Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from abc import abstractmethod
16
- from dataclasses import dataclass
17
-
18
- import jax.numpy as jnp
19
- from jaxtyping import (
20
- Array,
21
- Float,
22
- Num,
23
- )
24
- import tensorflow_probability.substrates.jax as tfp
25
-
26
- from gpjax.dataset import Dataset
27
- from gpjax.decision_making.search_space import ContinuousSearchSpace
28
- from gpjax.gps import AbstractMeanFunction
29
- from gpjax.typing import KeyArray
30
-
31
-
32
- class AbstractContinuousTestFunction(AbstractMeanFunction):
33
- """
34
- Abstract base class for continuous test functions.
35
-
36
- Attributes:
37
- search_space (ContinuousSearchSpace): Search space for the function.
38
- minimizer (Float[Array, '1 D']): Minimizer of the function (to 5 decimal places)
39
- minimum (Float[Array, '1 1']): Minimum of the function (to 5 decimal places).
40
- """
41
-
42
- search_space: ContinuousSearchSpace
43
- minimizer: Float[Array, "1 D"]
44
- minimum: Float[Array, "1 1"]
45
-
46
- def generate_dataset(
47
- self, num_points: int, key: KeyArray, obs_stddev: float = 0.0
48
- ) -> Dataset:
49
- """
50
- Generate a toy dataset from the test function.
51
-
52
- Args:
53
- num_points (int): Number of points to sample.
54
- key (KeyArray): JAX PRNG key.
55
- obs_stddev (float): (Optional) standard deviation of Gaussian distributed
56
- noise added to observations.
57
-
58
- Returns:
59
- Dataset: Dataset of points sampled from the test function.
60
- """
61
- X = self.search_space.sample(num_points=num_points, key=key)
62
- gaussian_noise = tfp.distributions.Normal(
63
- jnp.zeros(num_points), obs_stddev * jnp.ones(num_points)
64
- )
65
- y = self.evaluate(X) + jnp.transpose(
66
- gaussian_noise.sample(sample_shape=[1], seed=key)
67
- )
68
- return Dataset(X=X, y=y)
69
-
70
- def generate_test_points(
71
- self, num_points: int, key: KeyArray
72
- ) -> Float[Array, "N D"]:
73
- """
74
- Generate test points from the search space of the test function.
75
-
76
- Args:
77
- num_points (int): Number of points to sample.
78
- key (KeyArray): JAX PRNG key.
79
-
80
- Returns:
81
- Float[Array, 'N D']: Test points sampled from the search space.
82
- """
83
- return self.search_space.sample(num_points=num_points, key=key)
84
-
85
- def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
86
- return self.evaluate(x)
87
-
88
- @abstractmethod
89
- def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
90
- """
91
- Evaluate the test function at a set of points.
92
-
93
- Args:
94
- x (Float[Array, 'N D']): Points to evaluate the test function at.
95
-
96
- Returns:
97
- Float[Array, 'N 1']: Values of the test function at the points.
98
- """
99
- raise NotImplementedError
100
-
101
-
102
- @dataclass
103
- class Forrester(AbstractContinuousTestFunction):
104
- """
105
- Forrester function introduced in 'Engineering design via surrogate modelling: a
106
- practical guide' (Forrester et al. 2008), rescaled to have zero mean and unit
107
- variance over $[0, 1]$.
108
- """
109
-
110
- search_space = ContinuousSearchSpace(
111
- lower_bounds=jnp.array([0.0]),
112
- upper_bounds=jnp.array([1.0]),
113
- )
114
- minimizer = jnp.array([[0.75725]])
115
- minimum = jnp.array([[-1.45280]])
116
-
117
- def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
118
- mean = 0.45321
119
- std = jnp.sqrt(19.8577)
120
- return (((6 * x - 2) ** 2) * jnp.sin(12 * x - 4) - mean) / std
121
-
122
-
123
- @dataclass
124
- class LogarithmicGoldsteinPrice(AbstractContinuousTestFunction):
125
- """
126
- Logarithmic Goldstein-Price function introduced in 'A benchmark of kriging-based
127
- infill criteria for noisy optimization' (Picheny et al. 2013), which has zero mean
128
- and unit variance over $[0, 1]^2$.
129
- """
130
-
131
- search_space = ContinuousSearchSpace(
132
- lower_bounds=jnp.array([0.0, 0.0]),
133
- upper_bounds=jnp.array([1.0, 1.0]),
134
- )
135
- minimizer = jnp.array([[0.5, 0.25]])
136
- minimum = jnp.array([[-3.12913]])
137
-
138
- def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
139
- x1 = 4.0 * x[:, 0] - 2.0
140
- x2 = 4.0 * x[:, 1] - 2.0
141
- a = 1.0 + (x1 + x2 + 1.0) ** 2 * (
142
- 19.0 - 14.0 * x1 + 3.0 * (x1**2) - 14.0 * x2 + 6.0 * x1 * x2 + 3.0 * (x2**2)
143
- )
144
- b = 30.0 + (2.0 * x1 - 3.0 * x2) ** 2 * (
145
- 18.0
146
- - 32.0 * x1
147
- + 12.0 * (x1**2)
148
- + 48.0 * x2
149
- - 36.0 * x1 * x2
150
- + 27.0 * (x2**2)
151
- )
152
- return ((jnp.log((a * b)) - 8.693) / 2.427).reshape(-1, 1)
153
-
154
-
155
- @dataclass
156
- class Quadratic(AbstractContinuousTestFunction):
157
- """
158
- Toy quadratic function defined over $[0, 1]$.
159
- """
160
-
161
- search_space = ContinuousSearchSpace(
162
- lower_bounds=jnp.array([0.0]),
163
- upper_bounds=jnp.array([1.0]),
164
- )
165
- minimizer = jnp.array([[0.5]])
166
- minimum = jnp.array([[0.0]])
167
-
168
- def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
169
- return (x - 0.5) ** 2
@@ -1,90 +0,0 @@
1
- # Copyright 2023 The GPJax Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from abc import abstractmethod
16
- from dataclasses import dataclass
17
-
18
- import jax.numpy as jnp
19
- import jax.random as jr
20
-
21
- from gpjax.dataset import Dataset
22
- from gpjax.decision_making.search_space import ContinuousSearchSpace
23
- from gpjax.typing import (
24
- Array,
25
- Float,
26
- Int,
27
- KeyArray,
28
- )
29
-
30
-
31
- @dataclass
32
- class PoissonTestFunction:
33
- """
34
- Test function for GPs utilising the Poisson likelihood. Function taken from
35
- https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
36
-
37
- Attributes:
38
- search_space (ContinuousSearchSpace): Search space for the function.
39
- """
40
-
41
- search_space = ContinuousSearchSpace(
42
- lower_bounds=jnp.array([-2.0]),
43
- upper_bounds=jnp.array([2.0]),
44
- )
45
-
46
- def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset:
47
- """
48
- Generate a toy dataset from the test function.
49
-
50
- Args:
51
- num_points (int): Number of points to sample.
52
- key (KeyArray): JAX PRNG key.
53
-
54
- Returns:
55
- Dataset: Dataset of points sampled from the test function.
56
- """
57
- X = self.search_space.sample(num_points=num_points, key=key)
58
- y = self.evaluate(X)
59
- return Dataset(X=X, y=y)
60
-
61
- def generate_test_points(
62
- self, num_points: int, key: KeyArray
63
- ) -> Float[Array, "N D"]:
64
- """
65
- Generate test points from the search space of the test function.
66
-
67
- Args:
68
- num_points (int): Number of points to sample.
69
- key (KeyArray): JAX PRNG key.
70
-
71
- Returns:
72
- Float[Array, 'N D']: Test points sampled from the search space.
73
- """
74
- return self.search_space.sample(num_points=num_points, key=key)
75
-
76
- @abstractmethod
77
- def evaluate(self, x: Float[Array, "N 1"]) -> Int[Array, "N 1"]:
78
- """
79
- Evaluate the test function at a set of points. Function taken from
80
- https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
81
-
82
- Args:
83
- x (Float[Array, 'N D']): Points to evaluate the test function at.
84
-
85
- Returns:
86
- Float[Array, 'N 1']: Values of the test function at the points.
87
- """
88
- key = jr.key(42)
89
- f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x
90
- return jr.poisson(key, jnp.exp(f(x)))
@@ -1,37 +0,0 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from gpjax.decision_making.utility_functions.base import (
16
- AbstractSinglePointUtilityFunctionBuilder,
17
- AbstractUtilityFunctionBuilder,
18
- SinglePointUtilityFunction,
19
- UtilityFunction,
20
- )
21
- from gpjax.decision_making.utility_functions.expected_improvement import (
22
- ExpectedImprovement,
23
- )
24
- from gpjax.decision_making.utility_functions.probability_of_improvement import (
25
- ProbabilityOfImprovement,
26
- )
27
- from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling
28
-
29
- __all__ = [
30
- "UtilityFunction",
31
- "AbstractUtilityFunctionBuilder",
32
- "AbstractSinglePointUtilityFunctionBuilder",
33
- "ExpectedImprovement",
34
- "SinglePointUtilityFunction",
35
- "ThompsonSampling",
36
- "ProbabilityOfImprovement",
37
- ]
@@ -1,106 +0,0 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from abc import (
16
- ABC,
17
- abstractmethod,
18
- )
19
- from dataclasses import dataclass
20
-
21
- from beartype.typing import (
22
- Callable,
23
- Mapping,
24
- )
25
-
26
- from gpjax.dataset import Dataset
27
- from gpjax.decision_making.utils import OBJECTIVE
28
- from gpjax.gps import AbstractPosterior
29
- from gpjax.typing import (
30
- Array,
31
- Float,
32
- KeyArray,
33
- )
34
-
35
- SinglePointUtilityFunction = Callable[[Float[Array, "N D"]], Float[Array, "N 1"]]
36
- """
37
- Type alias for utility functions which don't support batching, and instead characterise
38
- the utility of querying a single point, rather than a batch of points. They take an array of points of shape $[N, D]$
39
- and return the value of the utility function at each point in an array of shape $[N, 1]$.
40
- """
41
-
42
-
43
- UtilityFunction = SinglePointUtilityFunction
44
- """
45
- Type alias for all utility functions. Currently we only support
46
- `SinglePointUtilityFunction`s, but in future may support batched utility functions too.
47
- Note that `UtilityFunction`s are *maximised* in order to decide which point, or batch of points, to query next.
48
- """
49
-
50
-
51
- @dataclass
52
- class AbstractSinglePointUtilityFunctionBuilder(ABC):
53
- """
54
- Abstract class for building utility functions which don't support batches. As such,
55
- they characterise the utility of querying a single point next.
56
- """
57
-
58
- def check_objective_present(
59
- self,
60
- posteriors: Mapping[str, AbstractPosterior],
61
- datasets: Mapping[str, Dataset],
62
- ) -> None:
63
- """
64
- Check that the objective posterior and dataset are present in the posteriors and
65
- datasets.
66
-
67
- Args:
68
- posteriors: dictionary of posteriors to be used to form the utility function.
69
- datasets: dictionary of datasets which may be used to form the utility function.
70
-
71
- Raises:
72
- ValueError: If the objective posterior or dataset are not present in the
73
- posteriors or datasets.
74
- """
75
- if OBJECTIVE not in posteriors.keys():
76
- raise ValueError("Objective posterior not found in posteriors")
77
- elif OBJECTIVE not in datasets.keys():
78
- raise ValueError("Objective dataset not found in datasets")
79
-
80
- @abstractmethod
81
- def build_utility_function(
82
- self,
83
- posteriors: Mapping[str, AbstractPosterior],
84
- datasets: Mapping[str, Dataset],
85
- key: KeyArray,
86
- ) -> SinglePointUtilityFunction:
87
- """
88
- Build a `UtilityFunction` from a set of posteriors and datasets.
89
-
90
- Args:
91
- posteriors: dictionary of posteriors to be used to form the utility function.
92
- datasets: dictionary of datasets which may be used to form the utility function.
93
- key: JAX PRNG key used for random number generation.
94
-
95
- Returns:
96
- SinglePointUtilityFunction: Utility function to be *maximised* in order to
97
- decide which point to query next.
98
- """
99
- raise NotImplementedError
100
-
101
-
102
- AbstractUtilityFunctionBuilder = AbstractSinglePointUtilityFunctionBuilder
103
- """
104
- Type alias for utility function builders. For now this only include single point utility
105
- function builders, but in the future we may support batched utility function builders.
106
- """
@@ -1,112 +0,0 @@
1
- # Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from dataclasses import dataclass
16
- from functools import partial
17
-
18
- from beartype.typing import Mapping
19
- import jax.numpy as jnp
20
- import tensorflow_probability.substrates.jax as tfp
21
-
22
- from gpjax.dataset import Dataset
23
- from gpjax.decision_making.utility_functions.base import (
24
- AbstractSinglePointUtilityFunctionBuilder,
25
- SinglePointUtilityFunction,
26
- )
27
- from gpjax.decision_making.utils import (
28
- OBJECTIVE,
29
- get_best_latent_observation_val,
30
- )
31
- from gpjax.gps import ConjugatePosterior
32
- from gpjax.typing import (
33
- Array,
34
- Float,
35
- KeyArray,
36
- )
37
-
38
-
39
- @dataclass
40
- class ExpectedImprovement(AbstractSinglePointUtilityFunctionBuilder):
41
- """
42
- Expected Improvement acquisition function as introduced by [Močkus,
43
- 1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55). The "best"
44
- incumbent value is defined as the lowest posterior mean value evaluated at the the
45
- previously observed points. This enables the acquisition function to be utilised with noisy observations.
46
- """
47
-
48
- def build_utility_function(
49
- self,
50
- posteriors: Mapping[str, ConjugatePosterior],
51
- datasets: Mapping[str, Dataset],
52
- key: KeyArray,
53
- ) -> SinglePointUtilityFunction:
54
- r"""
55
- Build the Expected Improvement acquisition function. This computes the expected
56
- improvement over the "best" of the previously observed points, utilising the
57
- posterior distribution of the surrogate model. For posterior distribution
58
- $`f(\cdot)`$, and best incumbent value $`\eta`$, this is defined
59
- as:
60
- ```math
61
- \alpha_{\text{EI}}(\mathbf{x}) = \mathbb{E}\left[\max(0, \eta - f(\mathbf{x}))\right]
62
- ```
63
-
64
- Args:
65
- posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
66
- used to form the utility function. One posteriors must correspond to the
67
- `OBJECTIVE` key, as we utilise the objective posterior to form the utility
68
- function.
69
- datasets (Mapping[str, Dataset]): Dictionary of datasets used to form the
70
- utility function. Keys in `datasets` should correspond to keys in
71
- `posteriors`. One of the datasets must correspond to the `OBJECTIVE` key.
72
- key (KeyArray): JAX PRNG key used for random number generation.
73
-
74
- Returns:
75
- SinglePointUtilityFunction: The Expected Improvement acquisition function to
76
- to be *maximised* in order to decide which point to query next.
77
- """
78
- self.check_objective_present(posteriors, datasets)
79
- objective_posterior = posteriors[OBJECTIVE]
80
- objective_dataset = datasets[OBJECTIVE]
81
-
82
- if not isinstance(objective_posterior, ConjugatePosterior):
83
- raise ValueError(
84
- "Objective posterior must be a ConjugatePosterior to compute the Expected Improvement."
85
- )
86
-
87
- if (
88
- objective_dataset.X is None
89
- or objective_dataset.n == 0
90
- or objective_dataset.y is None
91
- ):
92
- raise ValueError("Objective dataset must contain at least one item")
93
-
94
- eta = get_best_latent_observation_val(objective_posterior, objective_dataset)
95
- return partial(
96
- _expected_improvement, objective_posterior, objective_dataset, eta
97
- )
98
-
99
-
100
- def _expected_improvement(
101
- objective_posterior: ConjugatePosterior,
102
- objective_dataset: Dataset,
103
- eta: Float[Array, ""],
104
- x: Float[Array, "N D"],
105
- ) -> Float[Array, "N 1"]:
106
- latent_dist = objective_posterior(x, objective_dataset)
107
- mean = latent_dist.mean()
108
- var = latent_dist.variance()
109
- normal = tfp.distributions.Normal(mean, jnp.sqrt(var))
110
- return jnp.expand_dims(
111
- ((eta - mean) * normal.cdf(eta) + var * normal.prob(eta)), -1
112
- )
@@ -1,125 +0,0 @@
1
- # Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from dataclasses import dataclass
16
-
17
- from beartype.typing import Mapping
18
- from jaxtyping import Num
19
- import tensorflow_probability.substrates.jax as tfp
20
-
21
- from gpjax.dataset import Dataset
22
- from gpjax.decision_making.utility_functions.base import (
23
- AbstractSinglePointUtilityFunctionBuilder,
24
- SinglePointUtilityFunction,
25
- )
26
- from gpjax.decision_making.utils import (
27
- OBJECTIVE,
28
- get_best_latent_observation_val,
29
- )
30
- from gpjax.gps import ConjugatePosterior
31
- from gpjax.typing import (
32
- Array,
33
- KeyArray,
34
- )
35
-
36
-
37
- @dataclass
38
- class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder):
39
- r"""
40
- An acquisition function which returns the probability of improvement
41
- of the objective function over the best observed value.
42
-
43
- More precisely, given a predictive posterior distribution of the objective
44
- function $`f`$, the probability of improvement at a test point $`x`$ is defined as:
45
- $$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$
46
- where $`x_{\text{best}}`$ is the minimiser of the posterior mean
47
- at previously observed values (to handle noisy observations).
48
-
49
- The probability of improvement can be easily computed using the
50
- cumulative distribution function of the standard normal distribution $`\Phi`$:
51
- $$`\text{PI}(x) = \Phi\left(\frac{f(x_{\text{best}}) - \mu}{\sigma}\right)`$$
52
- where $`\mu`$ and $`\sigma`$ are the mean and standard deviation of the
53
- predictive distribution of the objective function at $`x`$.
54
-
55
- References
56
- ----------
57
- [1] Kushner, H. J. (1964).
58
- A new method of locating the maximum point of an arbitrary multipeak curve in the presence of noise.
59
- Journal of Basic Engineering, 86(1), 97-106.
60
-
61
- [2] Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., & de Freitas, N. (2016).
62
- Taking the human out of the loop: A review of Bayesian optimization.
63
- Proceedings of the IEEE, 104(1), 148-175. doi: 10.1109/JPROC.2015.2494218
64
- """
65
-
66
- def build_utility_function(
67
- self,
68
- posteriors: Mapping[str, ConjugatePosterior],
69
- datasets: Mapping[str, Dataset],
70
- key: KeyArray,
71
- ) -> SinglePointUtilityFunction:
72
- """
73
- Constructs the probability of improvement utility function
74
- using the predictive posterior of the objective function.
75
-
76
- Args:
77
- posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
78
- used to form the utility function. One of the posteriors must correspond
79
- to the `OBJECTIVE` key, as we sample from the objective posterior to form
80
- the utility function.
81
- datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
82
- to form the utility function. Keys in `datasets` should correspond to
83
- keys in `posteriors`. One of the datasets must correspond
84
- to the `OBJECTIVE` key.
85
- key (KeyArray): JAX PRNG key used for random number generation. Since
86
- the probability of improvement is computed deterministically
87
- from the predictive posterior, the key is not used.
88
-
89
- Returns:
90
- SinglePointUtilityFunction: the probability of improvement utility function.
91
- """
92
- self.check_objective_present(posteriors, datasets)
93
-
94
- objective_posterior = posteriors[OBJECTIVE]
95
- if not isinstance(objective_posterior, ConjugatePosterior):
96
- raise ValueError(
97
- "Objective posterior must be a ConjugatePosterior to compute the Probability of Improvement using a Gaussian CDF."
98
- )
99
-
100
- objective_dataset = datasets[OBJECTIVE]
101
- if (
102
- objective_dataset.X is None
103
- or objective_dataset.n == 0
104
- or objective_dataset.y is None
105
- ):
106
- raise ValueError(
107
- "Objective dataset must be non-empty to compute the "
108
- "Probability of Improvement (since we need a "
109
- "`best_y` value)."
110
- )
111
-
112
- def probability_of_improvement(x_test: Num[Array, "N D"]):
113
- best_y = get_best_latent_observation_val(
114
- objective_posterior, objective_dataset
115
- )
116
- predictive_dist = objective_posterior.predict(x_test, objective_dataset)
117
-
118
- normal_dist = tfp.distributions.Normal(
119
- loc=predictive_dist.mean(),
120
- scale=predictive_dist.stddev(),
121
- )
122
-
123
- return normal_dist.cdf(best_y).reshape(-1, 1)
124
-
125
- return probability_of_improvement