gpjax 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,112 +0,0 @@
1
- # Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from dataclasses import dataclass
16
- from functools import partial
17
-
18
- from beartype.typing import Mapping
19
- import jax.numpy as jnp
20
- import tensorflow_probability.substrates.jax as tfp
21
-
22
- from gpjax.dataset import Dataset
23
- from gpjax.decision_making.utility_functions.base import (
24
- AbstractSinglePointUtilityFunctionBuilder,
25
- SinglePointUtilityFunction,
26
- )
27
- from gpjax.decision_making.utils import (
28
- OBJECTIVE,
29
- get_best_latent_observation_val,
30
- )
31
- from gpjax.gps import ConjugatePosterior
32
- from gpjax.typing import (
33
- Array,
34
- Float,
35
- KeyArray,
36
- )
37
-
38
-
39
- @dataclass
40
- class ExpectedImprovement(AbstractSinglePointUtilityFunctionBuilder):
41
- """
42
- Expected Improvement acquisition function as introduced by [Močkus,
43
- 1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55). The "best"
44
- incumbent value is defined as the lowest posterior mean value evaluated at the the
45
- previously observed points. This enables the acquisition function to be utilised with noisy observations.
46
- """
47
-
48
- def build_utility_function(
49
- self,
50
- posteriors: Mapping[str, ConjugatePosterior],
51
- datasets: Mapping[str, Dataset],
52
- key: KeyArray,
53
- ) -> SinglePointUtilityFunction:
54
- r"""
55
- Build the Expected Improvement acquisition function. This computes the expected
56
- improvement over the "best" of the previously observed points, utilising the
57
- posterior distribution of the surrogate model. For posterior distribution
58
- $`f(\cdot)`$, and best incumbent value $`\eta`$, this is defined
59
- as:
60
- ```math
61
- \alpha_{\text{EI}}(\mathbf{x}) = \mathbb{E}\left[\max(0, \eta - f(\mathbf{x}))\right]
62
- ```
63
-
64
- Args:
65
- posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
66
- used to form the utility function. One posteriors must correspond to the
67
- `OBJECTIVE` key, as we utilise the objective posterior to form the utility
68
- function.
69
- datasets (Mapping[str, Dataset]): Dictionary of datasets used to form the
70
- utility function. Keys in `datasets` should correspond to keys in
71
- `posteriors`. One of the datasets must correspond to the `OBJECTIVE` key.
72
- key (KeyArray): JAX PRNG key used for random number generation.
73
-
74
- Returns:
75
- SinglePointUtilityFunction: The Expected Improvement acquisition function to
76
- to be *maximised* in order to decide which point to query next.
77
- """
78
- self.check_objective_present(posteriors, datasets)
79
- objective_posterior = posteriors[OBJECTIVE]
80
- objective_dataset = datasets[OBJECTIVE]
81
-
82
- if not isinstance(objective_posterior, ConjugatePosterior):
83
- raise ValueError(
84
- "Objective posterior must be a ConjugatePosterior to compute the Expected Improvement."
85
- )
86
-
87
- if (
88
- objective_dataset.X is None
89
- or objective_dataset.n == 0
90
- or objective_dataset.y is None
91
- ):
92
- raise ValueError("Objective dataset must contain at least one item")
93
-
94
- eta = get_best_latent_observation_val(objective_posterior, objective_dataset)
95
- return partial(
96
- _expected_improvement, objective_posterior, objective_dataset, eta
97
- )
98
-
99
-
100
- def _expected_improvement(
101
- objective_posterior: ConjugatePosterior,
102
- objective_dataset: Dataset,
103
- eta: Float[Array, ""],
104
- x: Float[Array, "N D"],
105
- ) -> Float[Array, "N 1"]:
106
- latent_dist = objective_posterior(x, objective_dataset)
107
- mean = latent_dist.mean()
108
- var = latent_dist.variance()
109
- normal = tfp.distributions.Normal(mean, jnp.sqrt(var))
110
- return jnp.expand_dims(
111
- ((eta - mean) * normal.cdf(eta) + var * normal.prob(eta)), -1
112
- )
@@ -1,125 +0,0 @@
1
- # Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from dataclasses import dataclass
16
-
17
- from beartype.typing import Mapping
18
- from jaxtyping import Num
19
- import tensorflow_probability.substrates.jax as tfp
20
-
21
- from gpjax.dataset import Dataset
22
- from gpjax.decision_making.utility_functions.base import (
23
- AbstractSinglePointUtilityFunctionBuilder,
24
- SinglePointUtilityFunction,
25
- )
26
- from gpjax.decision_making.utils import (
27
- OBJECTIVE,
28
- get_best_latent_observation_val,
29
- )
30
- from gpjax.gps import ConjugatePosterior
31
- from gpjax.typing import (
32
- Array,
33
- KeyArray,
34
- )
35
-
36
-
37
- @dataclass
38
- class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder):
39
- r"""
40
- An acquisition function which returns the probability of improvement
41
- of the objective function over the best observed value.
42
-
43
- More precisely, given a predictive posterior distribution of the objective
44
- function $`f`$, the probability of improvement at a test point $`x`$ is defined as:
45
- $$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$
46
- where $`x_{\text{best}}`$ is the minimiser of the posterior mean
47
- at previously observed values (to handle noisy observations).
48
-
49
- The probability of improvement can be easily computed using the
50
- cumulative distribution function of the standard normal distribution $`\Phi`$:
51
- $$`\text{PI}(x) = \Phi\left(\frac{f(x_{\text{best}}) - \mu}{\sigma}\right)`$$
52
- where $`\mu`$ and $`\sigma`$ are the mean and standard deviation of the
53
- predictive distribution of the objective function at $`x`$.
54
-
55
- References
56
- ----------
57
- [1] Kushner, H. J. (1964).
58
- A new method of locating the maximum point of an arbitrary multipeak curve in the presence of noise.
59
- Journal of Basic Engineering, 86(1), 97-106.
60
-
61
- [2] Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., & de Freitas, N. (2016).
62
- Taking the human out of the loop: A review of Bayesian optimization.
63
- Proceedings of the IEEE, 104(1), 148-175. doi: 10.1109/JPROC.2015.2494218
64
- """
65
-
66
- def build_utility_function(
67
- self,
68
- posteriors: Mapping[str, ConjugatePosterior],
69
- datasets: Mapping[str, Dataset],
70
- key: KeyArray,
71
- ) -> SinglePointUtilityFunction:
72
- """
73
- Constructs the probability of improvement utility function
74
- using the predictive posterior of the objective function.
75
-
76
- Args:
77
- posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
78
- used to form the utility function. One of the posteriors must correspond
79
- to the `OBJECTIVE` key, as we sample from the objective posterior to form
80
- the utility function.
81
- datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
82
- to form the utility function. Keys in `datasets` should correspond to
83
- keys in `posteriors`. One of the datasets must correspond
84
- to the `OBJECTIVE` key.
85
- key (KeyArray): JAX PRNG key used for random number generation. Since
86
- the probability of improvement is computed deterministically
87
- from the predictive posterior, the key is not used.
88
-
89
- Returns:
90
- SinglePointUtilityFunction: the probability of improvement utility function.
91
- """
92
- self.check_objective_present(posteriors, datasets)
93
-
94
- objective_posterior = posteriors[OBJECTIVE]
95
- if not isinstance(objective_posterior, ConjugatePosterior):
96
- raise ValueError(
97
- "Objective posterior must be a ConjugatePosterior to compute the Probability of Improvement using a Gaussian CDF."
98
- )
99
-
100
- objective_dataset = datasets[OBJECTIVE]
101
- if (
102
- objective_dataset.X is None
103
- or objective_dataset.n == 0
104
- or objective_dataset.y is None
105
- ):
106
- raise ValueError(
107
- "Objective dataset must be non-empty to compute the "
108
- "Probability of Improvement (since we need a "
109
- "`best_y` value)."
110
- )
111
-
112
- def probability_of_improvement(x_test: Num[Array, "N D"]):
113
- best_y = get_best_latent_observation_val(
114
- objective_posterior, objective_dataset
115
- )
116
- predictive_dist = objective_posterior.predict(x_test, objective_dataset)
117
-
118
- normal_dist = tfp.distributions.Normal(
119
- loc=predictive_dist.mean(),
120
- scale=predictive_dist.stddev(),
121
- )
122
-
123
- return normal_dist.cdf(best_y).reshape(-1, 1)
124
-
125
- return probability_of_improvement
@@ -1,101 +0,0 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from dataclasses import dataclass
16
-
17
- from beartype.typing import Mapping
18
-
19
- from gpjax.dataset import Dataset
20
- from gpjax.decision_making.utility_functions.base import (
21
- AbstractSinglePointUtilityFunctionBuilder,
22
- SinglePointUtilityFunction,
23
- )
24
- from gpjax.decision_making.utils import OBJECTIVE
25
- from gpjax.gps import ConjugatePosterior
26
- from gpjax.typing import KeyArray
27
-
28
-
29
- @dataclass
30
- class ThompsonSampling(AbstractSinglePointUtilityFunctionBuilder):
31
- """
32
- Form a utility function by drawing an approximate sample from the posterior,
33
- using decoupled sampling as introduced in [Wilson et. al.
34
- (2020)](https://arxiv.org/abs/2002.09309). Note that we return the *negative* of the
35
- sample as the utility function, as utility functions are *maximised*.
36
-
37
- Note that this is a single batch utility function, as it doesn't support classical
38
- batching. However, Thompson sampling can be used in a batched setting by drawing a
39
- batch of different samples from the GP posterior. This can be done by calling
40
- `build_utility_function` with different keys, an example of which can be seen in the
41
- `ask` method of the `UtilityDrivenDecisionMaker` class. The samples can then be
42
- optimised sequentially.
43
-
44
- Attributes:
45
- num_features (int): The number of random Fourier features to use when drawing
46
- the approximate sample from the posterior. Defaults to 100.
47
- """
48
-
49
- num_features: int = 100
50
-
51
- def __post_init__(self):
52
- if self.num_features <= 0:
53
- raise ValueError(
54
- "The number of random Fourier features must be a positive integer."
55
- )
56
-
57
- def build_utility_function(
58
- self,
59
- posteriors: Mapping[str, ConjugatePosterior],
60
- datasets: Mapping[str, Dataset],
61
- key: KeyArray,
62
- ) -> SinglePointUtilityFunction:
63
- """
64
- Draw an approximate sample from the posterior of the objective model and return
65
- the *negative* of this sample as a utility function, as utility functions
66
- are *maximised*.
67
-
68
- Args:
69
- posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
70
- be used to form the utility function. One of the posteriors must correspond
71
- to the `OBJECTIVE` key, as we sample from the objective posterior to form
72
- the utility function.
73
- datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
74
- to form the utility function. Keys in `datasets` should correspond to
75
- keys in `posteriors`. One of the datasets must correspond
76
- to the `OBJECTIVE` key.
77
- key (KeyArray): JAX PRNG key used for random number generation. This can be
78
- changed to draw different samples.
79
-
80
- Returns:
81
- SinglePointUtilityFunction: An appproximate sample from the objective model
82
- posterior to to be *maximised* in order to decide which point to query
83
- next.
84
- """
85
- self.check_objective_present(posteriors, datasets)
86
-
87
- objective_posterior = posteriors[OBJECTIVE]
88
- if not isinstance(objective_posterior, ConjugatePosterior):
89
- raise ValueError(
90
- "Objective posterior must be a ConjugatePosterior to draw an approximate sample."
91
- )
92
-
93
- objective_dataset = datasets[OBJECTIVE]
94
- thompson_sample = objective_posterior.sample_approx(
95
- num_samples=1,
96
- train_data=objective_dataset,
97
- key=key,
98
- num_features=self.num_features,
99
- )
100
-
101
- return lambda x: -1.0 * thompson_sample(x) # Utility functions are *maximised*
@@ -1,157 +0,0 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from abc import (
16
- ABC,
17
- abstractmethod,
18
- )
19
- from dataclasses import dataclass
20
-
21
- import jax.numpy as jnp
22
- import jax.random as jr
23
- from jaxopt import ScipyBoundedMinimize
24
-
25
- from gpjax.decision_making.search_space import (
26
- AbstractSearchSpace,
27
- ContinuousSearchSpace,
28
- )
29
- from gpjax.decision_making.utility_functions import SinglePointUtilityFunction
30
- from gpjax.typing import (
31
- Array,
32
- Float,
33
- KeyArray,
34
- ScalarFloat,
35
- )
36
-
37
-
38
- def _get_discrete_maximizer(
39
- query_points: Float[Array, "N D"], utility_function: SinglePointUtilityFunction
40
- ) -> Float[Array, "1 D"]:
41
- """Get the point which maximises the utility function evaluated at a given set of points.
42
-
43
- Args:
44
- query_points: set of points at which to evaluate the utility function, as an array
45
- of shape `[n_points, n_dims]`.
46
- utility_function: the single point utility function to be evaluated at `query_points`.
47
-
48
- Returns:
49
- Array of shape `[1, n_dims]` representing the point which maximises the utility function.
50
- """
51
- utility_function_values = utility_function(query_points)
52
- max_utility_function_value_idx = jnp.argmax(
53
- utility_function_values, axis=0, keepdims=True
54
- )
55
- best_sample_point = jnp.take_along_axis(
56
- query_points, max_utility_function_value_idx, axis=0
57
- )
58
- return best_sample_point
59
-
60
-
61
- @dataclass
62
- class AbstractSinglePointUtilityMaximizer(ABC):
63
- """Abstract base class for single point utility function maximizers."""
64
-
65
- @abstractmethod
66
- def maximize(
67
- self,
68
- utility_function: SinglePointUtilityFunction,
69
- search_space: AbstractSearchSpace,
70
- key: KeyArray,
71
- ) -> Float[Array, "1 D"]:
72
- """Maximize the given utility function over the search space provided.
73
-
74
- Args:
75
- utility_function: utility function to be maximized.
76
- search_space: search space over which to maximize the utility function.
77
- key: JAX PRNG key.
78
-
79
- Returns:
80
- Float[Array, "1 D"]: Point at which the utility function is maximized.
81
- """
82
- raise NotImplementedError
83
-
84
-
85
- @dataclass
86
- class ContinuousSinglePointUtilityMaximizer(AbstractSinglePointUtilityMaximizer):
87
- """The `ContinuousUtilityMaximizer` class is used to maximize utility
88
- functions over the continuous domain with L-BFGS-B. First we sample the utility
89
- function at `num_initial_samples` points from the search space, and then we run
90
- L-BFGS-B from the best of these initial points. We run this process `num_restarts`
91
- number of times, each time sampling a different random set of
92
- `num_initial_samples`initial points.
93
- """
94
-
95
- num_initial_samples: int
96
- num_restarts: int
97
-
98
- def __post_init__(self):
99
- if self.num_initial_samples < 1:
100
- raise ValueError(
101
- f"num_initial_samples must be greater than 0, got {self.num_initial_samples}."
102
- )
103
- elif self.num_restarts < 1:
104
- raise ValueError(
105
- f"num_restarts must be greater than 0, got {self.num_restarts}."
106
- )
107
-
108
- def maximize(
109
- self,
110
- utility_function: SinglePointUtilityFunction,
111
- search_space: ContinuousSearchSpace,
112
- key: KeyArray,
113
- ) -> Float[Array, "1 D"]:
114
- max_observed_utility_function_value = None
115
- maximizer = None
116
-
117
- for _ in range(self.num_restarts):
118
- key, _ = jr.split(key)
119
- initial_sample_points = search_space.sample(
120
- self.num_initial_samples, key=key
121
- )
122
- best_initial_sample_point = _get_discrete_maximizer(
123
- initial_sample_points, utility_function
124
- )
125
-
126
- def _scalar_utility_function(x: Float[Array, "1 D"]) -> ScalarFloat:
127
- """
128
- The Jaxopt minimizer requires a function which returns a scalar. It calls the
129
- utility function with one point at a time, so the utility function
130
- returns an array of shape [1, 1], so we index to return a scalar. Note that
131
- we also return the negative of the utility function - this is because
132
- utility functions should be *maximimized* but the Jaxopt minimizer
133
- minimizes functions.
134
- """
135
- return -utility_function(x)[0][0]
136
-
137
- lbfgsb = ScipyBoundedMinimize(
138
- fun=_scalar_utility_function, method="l-bfgs-b"
139
- )
140
- bounds = (search_space.lower_bounds, search_space.upper_bounds)
141
- optimized_point = lbfgsb.run(
142
- best_initial_sample_point, bounds=bounds
143
- ).params
144
- optimized_utility_function_value = _scalar_utility_function(optimized_point)
145
- if (max_observed_utility_function_value is None) or (
146
- optimized_utility_function_value > max_observed_utility_function_value
147
- ):
148
- max_observed_utility_function_value = optimized_utility_function_value
149
- maximizer = optimized_point
150
- return maximizer
151
-
152
-
153
- AbstractUtilityMaximizer = AbstractSinglePointUtilityMaximizer
154
- """
155
- Type alias for a utility maximizer. Currently we only support single point utility
156
- functions, but in future may support batched utility functions.
157
- """
@@ -1,64 +0,0 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from beartype.typing import (
16
- Callable,
17
- Dict,
18
- Final,
19
- )
20
- import jax.numpy as jnp
21
-
22
- from gpjax.dataset import Dataset
23
- from gpjax.gps import AbstractPosterior
24
- from gpjax.typing import (
25
- Array,
26
- Float,
27
- )
28
-
29
- OBJECTIVE: Final[str] = "OBJECTIVE"
30
- """
31
- Tag for the objective dataset/function in standard utility functions.
32
- """
33
-
34
-
35
- FunctionEvaluator = Callable[[Float[Array, "N D"]], Dict[str, Dataset]]
36
- """
37
- Type alias for function evaluators, which take an array of points of shape $[N, D]$
38
- and evaluate a set of functions at each point, returning a mapping from function tags
39
- to datasets of the evaluated points. This is the same as the `Observer` in Trieste:
40
- https://github.com/secondmind-labs/trieste/blob/develop/trieste/observer.py
41
- """
42
-
43
-
44
- def build_function_evaluator(
45
- functions: Dict[str, Callable[[Float[Array, "N D"]], Float[Array, "N 1"]]],
46
- ) -> FunctionEvaluator:
47
- """
48
- Takes a dictionary of functions and returns a `FunctionEvaluator` which can be
49
- used to evaluate each of the functions at a supplied set of points and return a
50
- dictionary of datasets storing the evaluated points.
51
- """
52
- return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()}
53
-
54
-
55
- def get_best_latent_observation_val(
56
- posterior: AbstractPosterior, dataset: Dataset
57
- ) -> Float[Array, ""]:
58
- """
59
- Takes a posterior and dataset and returns the best (latent) function value in the
60
- dataset, corresponding to the minimum of the posterior mean value evaluated at
61
- locations in the dataset. In the noiseless case, this corresponds to the minimum
62
- value in the dataset.
63
- """
64
- return jnp.min(posterior(dataset.X, dataset).mean())