gpjax 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -3
- gpjax/citation.py +0 -43
- gpjax/distributions.py +3 -1
- {gpjax-0.9.4.dist-info → gpjax-0.9.5.dist-info}/METADATA +3 -4
- {gpjax-0.9.4.dist-info → gpjax-0.9.5.dist-info}/RECORD +7 -21
- gpjax-0.9.5.dist-info/licenses/LICENSE.txt +19 -0
- gpjax/decision_making/__init__.py +0 -63
- gpjax/decision_making/decision_maker.py +0 -302
- gpjax/decision_making/posterior_handler.py +0 -152
- gpjax/decision_making/search_space.py +0 -96
- gpjax/decision_making/test_functions/__init__.py +0 -31
- gpjax/decision_making/test_functions/continuous_functions.py +0 -169
- gpjax/decision_making/test_functions/non_conjugate_functions.py +0 -90
- gpjax/decision_making/utility_functions/__init__.py +0 -37
- gpjax/decision_making/utility_functions/base.py +0 -106
- gpjax/decision_making/utility_functions/expected_improvement.py +0 -112
- gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -125
- gpjax/decision_making/utility_functions/thompson_sampling.py +0 -101
- gpjax/decision_making/utility_maximizer.py +0 -157
- gpjax/decision_making/utils.py +0 -64
- gpjax-0.9.4.dist-info/licenses/LICENSE +0 -201
- {gpjax-0.9.4.dist-info → gpjax-0.9.5.dist-info}/WHEEL +0 -0
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from dataclasses import dataclass
|
|
16
|
-
from functools import partial
|
|
17
|
-
|
|
18
|
-
from beartype.typing import Mapping
|
|
19
|
-
import jax.numpy as jnp
|
|
20
|
-
import tensorflow_probability.substrates.jax as tfp
|
|
21
|
-
|
|
22
|
-
from gpjax.dataset import Dataset
|
|
23
|
-
from gpjax.decision_making.utility_functions.base import (
|
|
24
|
-
AbstractSinglePointUtilityFunctionBuilder,
|
|
25
|
-
SinglePointUtilityFunction,
|
|
26
|
-
)
|
|
27
|
-
from gpjax.decision_making.utils import (
|
|
28
|
-
OBJECTIVE,
|
|
29
|
-
get_best_latent_observation_val,
|
|
30
|
-
)
|
|
31
|
-
from gpjax.gps import ConjugatePosterior
|
|
32
|
-
from gpjax.typing import (
|
|
33
|
-
Array,
|
|
34
|
-
Float,
|
|
35
|
-
KeyArray,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class ExpectedImprovement(AbstractSinglePointUtilityFunctionBuilder):
|
|
41
|
-
"""
|
|
42
|
-
Expected Improvement acquisition function as introduced by [Močkus,
|
|
43
|
-
1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55). The "best"
|
|
44
|
-
incumbent value is defined as the lowest posterior mean value evaluated at the the
|
|
45
|
-
previously observed points. This enables the acquisition function to be utilised with noisy observations.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
def build_utility_function(
|
|
49
|
-
self,
|
|
50
|
-
posteriors: Mapping[str, ConjugatePosterior],
|
|
51
|
-
datasets: Mapping[str, Dataset],
|
|
52
|
-
key: KeyArray,
|
|
53
|
-
) -> SinglePointUtilityFunction:
|
|
54
|
-
r"""
|
|
55
|
-
Build the Expected Improvement acquisition function. This computes the expected
|
|
56
|
-
improvement over the "best" of the previously observed points, utilising the
|
|
57
|
-
posterior distribution of the surrogate model. For posterior distribution
|
|
58
|
-
$`f(\cdot)`$, and best incumbent value $`\eta`$, this is defined
|
|
59
|
-
as:
|
|
60
|
-
```math
|
|
61
|
-
\alpha_{\text{EI}}(\mathbf{x}) = \mathbb{E}\left[\max(0, \eta - f(\mathbf{x}))\right]
|
|
62
|
-
```
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
|
|
66
|
-
used to form the utility function. One posteriors must correspond to the
|
|
67
|
-
`OBJECTIVE` key, as we utilise the objective posterior to form the utility
|
|
68
|
-
function.
|
|
69
|
-
datasets (Mapping[str, Dataset]): Dictionary of datasets used to form the
|
|
70
|
-
utility function. Keys in `datasets` should correspond to keys in
|
|
71
|
-
`posteriors`. One of the datasets must correspond to the `OBJECTIVE` key.
|
|
72
|
-
key (KeyArray): JAX PRNG key used for random number generation.
|
|
73
|
-
|
|
74
|
-
Returns:
|
|
75
|
-
SinglePointUtilityFunction: The Expected Improvement acquisition function to
|
|
76
|
-
to be *maximised* in order to decide which point to query next.
|
|
77
|
-
"""
|
|
78
|
-
self.check_objective_present(posteriors, datasets)
|
|
79
|
-
objective_posterior = posteriors[OBJECTIVE]
|
|
80
|
-
objective_dataset = datasets[OBJECTIVE]
|
|
81
|
-
|
|
82
|
-
if not isinstance(objective_posterior, ConjugatePosterior):
|
|
83
|
-
raise ValueError(
|
|
84
|
-
"Objective posterior must be a ConjugatePosterior to compute the Expected Improvement."
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
if (
|
|
88
|
-
objective_dataset.X is None
|
|
89
|
-
or objective_dataset.n == 0
|
|
90
|
-
or objective_dataset.y is None
|
|
91
|
-
):
|
|
92
|
-
raise ValueError("Objective dataset must contain at least one item")
|
|
93
|
-
|
|
94
|
-
eta = get_best_latent_observation_val(objective_posterior, objective_dataset)
|
|
95
|
-
return partial(
|
|
96
|
-
_expected_improvement, objective_posterior, objective_dataset, eta
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def _expected_improvement(
|
|
101
|
-
objective_posterior: ConjugatePosterior,
|
|
102
|
-
objective_dataset: Dataset,
|
|
103
|
-
eta: Float[Array, ""],
|
|
104
|
-
x: Float[Array, "N D"],
|
|
105
|
-
) -> Float[Array, "N 1"]:
|
|
106
|
-
latent_dist = objective_posterior(x, objective_dataset)
|
|
107
|
-
mean = latent_dist.mean()
|
|
108
|
-
var = latent_dist.variance()
|
|
109
|
-
normal = tfp.distributions.Normal(mean, jnp.sqrt(var))
|
|
110
|
-
return jnp.expand_dims(
|
|
111
|
-
((eta - mean) * normal.cdf(eta) + var * normal.prob(eta)), -1
|
|
112
|
-
)
|
|
@@ -1,125 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from dataclasses import dataclass
|
|
16
|
-
|
|
17
|
-
from beartype.typing import Mapping
|
|
18
|
-
from jaxtyping import Num
|
|
19
|
-
import tensorflow_probability.substrates.jax as tfp
|
|
20
|
-
|
|
21
|
-
from gpjax.dataset import Dataset
|
|
22
|
-
from gpjax.decision_making.utility_functions.base import (
|
|
23
|
-
AbstractSinglePointUtilityFunctionBuilder,
|
|
24
|
-
SinglePointUtilityFunction,
|
|
25
|
-
)
|
|
26
|
-
from gpjax.decision_making.utils import (
|
|
27
|
-
OBJECTIVE,
|
|
28
|
-
get_best_latent_observation_val,
|
|
29
|
-
)
|
|
30
|
-
from gpjax.gps import ConjugatePosterior
|
|
31
|
-
from gpjax.typing import (
|
|
32
|
-
Array,
|
|
33
|
-
KeyArray,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@dataclass
|
|
38
|
-
class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder):
|
|
39
|
-
r"""
|
|
40
|
-
An acquisition function which returns the probability of improvement
|
|
41
|
-
of the objective function over the best observed value.
|
|
42
|
-
|
|
43
|
-
More precisely, given a predictive posterior distribution of the objective
|
|
44
|
-
function $`f`$, the probability of improvement at a test point $`x`$ is defined as:
|
|
45
|
-
$$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$
|
|
46
|
-
where $`x_{\text{best}}`$ is the minimiser of the posterior mean
|
|
47
|
-
at previously observed values (to handle noisy observations).
|
|
48
|
-
|
|
49
|
-
The probability of improvement can be easily computed using the
|
|
50
|
-
cumulative distribution function of the standard normal distribution $`\Phi`$:
|
|
51
|
-
$$`\text{PI}(x) = \Phi\left(\frac{f(x_{\text{best}}) - \mu}{\sigma}\right)`$$
|
|
52
|
-
where $`\mu`$ and $`\sigma`$ are the mean and standard deviation of the
|
|
53
|
-
predictive distribution of the objective function at $`x`$.
|
|
54
|
-
|
|
55
|
-
References
|
|
56
|
-
----------
|
|
57
|
-
[1] Kushner, H. J. (1964).
|
|
58
|
-
A new method of locating the maximum point of an arbitrary multipeak curve in the presence of noise.
|
|
59
|
-
Journal of Basic Engineering, 86(1), 97-106.
|
|
60
|
-
|
|
61
|
-
[2] Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., & de Freitas, N. (2016).
|
|
62
|
-
Taking the human out of the loop: A review of Bayesian optimization.
|
|
63
|
-
Proceedings of the IEEE, 104(1), 148-175. doi: 10.1109/JPROC.2015.2494218
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
def build_utility_function(
|
|
67
|
-
self,
|
|
68
|
-
posteriors: Mapping[str, ConjugatePosterior],
|
|
69
|
-
datasets: Mapping[str, Dataset],
|
|
70
|
-
key: KeyArray,
|
|
71
|
-
) -> SinglePointUtilityFunction:
|
|
72
|
-
"""
|
|
73
|
-
Constructs the probability of improvement utility function
|
|
74
|
-
using the predictive posterior of the objective function.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
|
|
78
|
-
used to form the utility function. One of the posteriors must correspond
|
|
79
|
-
to the `OBJECTIVE` key, as we sample from the objective posterior to form
|
|
80
|
-
the utility function.
|
|
81
|
-
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
|
|
82
|
-
to form the utility function. Keys in `datasets` should correspond to
|
|
83
|
-
keys in `posteriors`. One of the datasets must correspond
|
|
84
|
-
to the `OBJECTIVE` key.
|
|
85
|
-
key (KeyArray): JAX PRNG key used for random number generation. Since
|
|
86
|
-
the probability of improvement is computed deterministically
|
|
87
|
-
from the predictive posterior, the key is not used.
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
SinglePointUtilityFunction: the probability of improvement utility function.
|
|
91
|
-
"""
|
|
92
|
-
self.check_objective_present(posteriors, datasets)
|
|
93
|
-
|
|
94
|
-
objective_posterior = posteriors[OBJECTIVE]
|
|
95
|
-
if not isinstance(objective_posterior, ConjugatePosterior):
|
|
96
|
-
raise ValueError(
|
|
97
|
-
"Objective posterior must be a ConjugatePosterior to compute the Probability of Improvement using a Gaussian CDF."
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
objective_dataset = datasets[OBJECTIVE]
|
|
101
|
-
if (
|
|
102
|
-
objective_dataset.X is None
|
|
103
|
-
or objective_dataset.n == 0
|
|
104
|
-
or objective_dataset.y is None
|
|
105
|
-
):
|
|
106
|
-
raise ValueError(
|
|
107
|
-
"Objective dataset must be non-empty to compute the "
|
|
108
|
-
"Probability of Improvement (since we need a "
|
|
109
|
-
"`best_y` value)."
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
def probability_of_improvement(x_test: Num[Array, "N D"]):
|
|
113
|
-
best_y = get_best_latent_observation_val(
|
|
114
|
-
objective_posterior, objective_dataset
|
|
115
|
-
)
|
|
116
|
-
predictive_dist = objective_posterior.predict(x_test, objective_dataset)
|
|
117
|
-
|
|
118
|
-
normal_dist = tfp.distributions.Normal(
|
|
119
|
-
loc=predictive_dist.mean(),
|
|
120
|
-
scale=predictive_dist.stddev(),
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
return normal_dist.cdf(best_y).reshape(-1, 1)
|
|
124
|
-
|
|
125
|
-
return probability_of_improvement
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from dataclasses import dataclass
|
|
16
|
-
|
|
17
|
-
from beartype.typing import Mapping
|
|
18
|
-
|
|
19
|
-
from gpjax.dataset import Dataset
|
|
20
|
-
from gpjax.decision_making.utility_functions.base import (
|
|
21
|
-
AbstractSinglePointUtilityFunctionBuilder,
|
|
22
|
-
SinglePointUtilityFunction,
|
|
23
|
-
)
|
|
24
|
-
from gpjax.decision_making.utils import OBJECTIVE
|
|
25
|
-
from gpjax.gps import ConjugatePosterior
|
|
26
|
-
from gpjax.typing import KeyArray
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@dataclass
|
|
30
|
-
class ThompsonSampling(AbstractSinglePointUtilityFunctionBuilder):
|
|
31
|
-
"""
|
|
32
|
-
Form a utility function by drawing an approximate sample from the posterior,
|
|
33
|
-
using decoupled sampling as introduced in [Wilson et. al.
|
|
34
|
-
(2020)](https://arxiv.org/abs/2002.09309). Note that we return the *negative* of the
|
|
35
|
-
sample as the utility function, as utility functions are *maximised*.
|
|
36
|
-
|
|
37
|
-
Note that this is a single batch utility function, as it doesn't support classical
|
|
38
|
-
batching. However, Thompson sampling can be used in a batched setting by drawing a
|
|
39
|
-
batch of different samples from the GP posterior. This can be done by calling
|
|
40
|
-
`build_utility_function` with different keys, an example of which can be seen in the
|
|
41
|
-
`ask` method of the `UtilityDrivenDecisionMaker` class. The samples can then be
|
|
42
|
-
optimised sequentially.
|
|
43
|
-
|
|
44
|
-
Attributes:
|
|
45
|
-
num_features (int): The number of random Fourier features to use when drawing
|
|
46
|
-
the approximate sample from the posterior. Defaults to 100.
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
|
-
num_features: int = 100
|
|
50
|
-
|
|
51
|
-
def __post_init__(self):
|
|
52
|
-
if self.num_features <= 0:
|
|
53
|
-
raise ValueError(
|
|
54
|
-
"The number of random Fourier features must be a positive integer."
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
def build_utility_function(
|
|
58
|
-
self,
|
|
59
|
-
posteriors: Mapping[str, ConjugatePosterior],
|
|
60
|
-
datasets: Mapping[str, Dataset],
|
|
61
|
-
key: KeyArray,
|
|
62
|
-
) -> SinglePointUtilityFunction:
|
|
63
|
-
"""
|
|
64
|
-
Draw an approximate sample from the posterior of the objective model and return
|
|
65
|
-
the *negative* of this sample as a utility function, as utility functions
|
|
66
|
-
are *maximised*.
|
|
67
|
-
|
|
68
|
-
Args:
|
|
69
|
-
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
|
|
70
|
-
be used to form the utility function. One of the posteriors must correspond
|
|
71
|
-
to the `OBJECTIVE` key, as we sample from the objective posterior to form
|
|
72
|
-
the utility function.
|
|
73
|
-
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
|
|
74
|
-
to form the utility function. Keys in `datasets` should correspond to
|
|
75
|
-
keys in `posteriors`. One of the datasets must correspond
|
|
76
|
-
to the `OBJECTIVE` key.
|
|
77
|
-
key (KeyArray): JAX PRNG key used for random number generation. This can be
|
|
78
|
-
changed to draw different samples.
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
SinglePointUtilityFunction: An appproximate sample from the objective model
|
|
82
|
-
posterior to to be *maximised* in order to decide which point to query
|
|
83
|
-
next.
|
|
84
|
-
"""
|
|
85
|
-
self.check_objective_present(posteriors, datasets)
|
|
86
|
-
|
|
87
|
-
objective_posterior = posteriors[OBJECTIVE]
|
|
88
|
-
if not isinstance(objective_posterior, ConjugatePosterior):
|
|
89
|
-
raise ValueError(
|
|
90
|
-
"Objective posterior must be a ConjugatePosterior to draw an approximate sample."
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
objective_dataset = datasets[OBJECTIVE]
|
|
94
|
-
thompson_sample = objective_posterior.sample_approx(
|
|
95
|
-
num_samples=1,
|
|
96
|
-
train_data=objective_dataset,
|
|
97
|
-
key=key,
|
|
98
|
-
num_features=self.num_features,
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
return lambda x: -1.0 * thompson_sample(x) # Utility functions are *maximised*
|
|
@@ -1,157 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from abc import (
|
|
16
|
-
ABC,
|
|
17
|
-
abstractmethod,
|
|
18
|
-
)
|
|
19
|
-
from dataclasses import dataclass
|
|
20
|
-
|
|
21
|
-
import jax.numpy as jnp
|
|
22
|
-
import jax.random as jr
|
|
23
|
-
from jaxopt import ScipyBoundedMinimize
|
|
24
|
-
|
|
25
|
-
from gpjax.decision_making.search_space import (
|
|
26
|
-
AbstractSearchSpace,
|
|
27
|
-
ContinuousSearchSpace,
|
|
28
|
-
)
|
|
29
|
-
from gpjax.decision_making.utility_functions import SinglePointUtilityFunction
|
|
30
|
-
from gpjax.typing import (
|
|
31
|
-
Array,
|
|
32
|
-
Float,
|
|
33
|
-
KeyArray,
|
|
34
|
-
ScalarFloat,
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def _get_discrete_maximizer(
|
|
39
|
-
query_points: Float[Array, "N D"], utility_function: SinglePointUtilityFunction
|
|
40
|
-
) -> Float[Array, "1 D"]:
|
|
41
|
-
"""Get the point which maximises the utility function evaluated at a given set of points.
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
query_points: set of points at which to evaluate the utility function, as an array
|
|
45
|
-
of shape `[n_points, n_dims]`.
|
|
46
|
-
utility_function: the single point utility function to be evaluated at `query_points`.
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
Array of shape `[1, n_dims]` representing the point which maximises the utility function.
|
|
50
|
-
"""
|
|
51
|
-
utility_function_values = utility_function(query_points)
|
|
52
|
-
max_utility_function_value_idx = jnp.argmax(
|
|
53
|
-
utility_function_values, axis=0, keepdims=True
|
|
54
|
-
)
|
|
55
|
-
best_sample_point = jnp.take_along_axis(
|
|
56
|
-
query_points, max_utility_function_value_idx, axis=0
|
|
57
|
-
)
|
|
58
|
-
return best_sample_point
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
@dataclass
|
|
62
|
-
class AbstractSinglePointUtilityMaximizer(ABC):
|
|
63
|
-
"""Abstract base class for single point utility function maximizers."""
|
|
64
|
-
|
|
65
|
-
@abstractmethod
|
|
66
|
-
def maximize(
|
|
67
|
-
self,
|
|
68
|
-
utility_function: SinglePointUtilityFunction,
|
|
69
|
-
search_space: AbstractSearchSpace,
|
|
70
|
-
key: KeyArray,
|
|
71
|
-
) -> Float[Array, "1 D"]:
|
|
72
|
-
"""Maximize the given utility function over the search space provided.
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
utility_function: utility function to be maximized.
|
|
76
|
-
search_space: search space over which to maximize the utility function.
|
|
77
|
-
key: JAX PRNG key.
|
|
78
|
-
|
|
79
|
-
Returns:
|
|
80
|
-
Float[Array, "1 D"]: Point at which the utility function is maximized.
|
|
81
|
-
"""
|
|
82
|
-
raise NotImplementedError
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
@dataclass
|
|
86
|
-
class ContinuousSinglePointUtilityMaximizer(AbstractSinglePointUtilityMaximizer):
|
|
87
|
-
"""The `ContinuousUtilityMaximizer` class is used to maximize utility
|
|
88
|
-
functions over the continuous domain with L-BFGS-B. First we sample the utility
|
|
89
|
-
function at `num_initial_samples` points from the search space, and then we run
|
|
90
|
-
L-BFGS-B from the best of these initial points. We run this process `num_restarts`
|
|
91
|
-
number of times, each time sampling a different random set of
|
|
92
|
-
`num_initial_samples`initial points.
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
num_initial_samples: int
|
|
96
|
-
num_restarts: int
|
|
97
|
-
|
|
98
|
-
def __post_init__(self):
|
|
99
|
-
if self.num_initial_samples < 1:
|
|
100
|
-
raise ValueError(
|
|
101
|
-
f"num_initial_samples must be greater than 0, got {self.num_initial_samples}."
|
|
102
|
-
)
|
|
103
|
-
elif self.num_restarts < 1:
|
|
104
|
-
raise ValueError(
|
|
105
|
-
f"num_restarts must be greater than 0, got {self.num_restarts}."
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
def maximize(
|
|
109
|
-
self,
|
|
110
|
-
utility_function: SinglePointUtilityFunction,
|
|
111
|
-
search_space: ContinuousSearchSpace,
|
|
112
|
-
key: KeyArray,
|
|
113
|
-
) -> Float[Array, "1 D"]:
|
|
114
|
-
max_observed_utility_function_value = None
|
|
115
|
-
maximizer = None
|
|
116
|
-
|
|
117
|
-
for _ in range(self.num_restarts):
|
|
118
|
-
key, _ = jr.split(key)
|
|
119
|
-
initial_sample_points = search_space.sample(
|
|
120
|
-
self.num_initial_samples, key=key
|
|
121
|
-
)
|
|
122
|
-
best_initial_sample_point = _get_discrete_maximizer(
|
|
123
|
-
initial_sample_points, utility_function
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
def _scalar_utility_function(x: Float[Array, "1 D"]) -> ScalarFloat:
|
|
127
|
-
"""
|
|
128
|
-
The Jaxopt minimizer requires a function which returns a scalar. It calls the
|
|
129
|
-
utility function with one point at a time, so the utility function
|
|
130
|
-
returns an array of shape [1, 1], so we index to return a scalar. Note that
|
|
131
|
-
we also return the negative of the utility function - this is because
|
|
132
|
-
utility functions should be *maximimized* but the Jaxopt minimizer
|
|
133
|
-
minimizes functions.
|
|
134
|
-
"""
|
|
135
|
-
return -utility_function(x)[0][0]
|
|
136
|
-
|
|
137
|
-
lbfgsb = ScipyBoundedMinimize(
|
|
138
|
-
fun=_scalar_utility_function, method="l-bfgs-b"
|
|
139
|
-
)
|
|
140
|
-
bounds = (search_space.lower_bounds, search_space.upper_bounds)
|
|
141
|
-
optimized_point = lbfgsb.run(
|
|
142
|
-
best_initial_sample_point, bounds=bounds
|
|
143
|
-
).params
|
|
144
|
-
optimized_utility_function_value = _scalar_utility_function(optimized_point)
|
|
145
|
-
if (max_observed_utility_function_value is None) or (
|
|
146
|
-
optimized_utility_function_value > max_observed_utility_function_value
|
|
147
|
-
):
|
|
148
|
-
max_observed_utility_function_value = optimized_utility_function_value
|
|
149
|
-
maximizer = optimized_point
|
|
150
|
-
return maximizer
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
AbstractUtilityMaximizer = AbstractSinglePointUtilityMaximizer
|
|
154
|
-
"""
|
|
155
|
-
Type alias for a utility maximizer. Currently we only support single point utility
|
|
156
|
-
functions, but in future may support batched utility functions.
|
|
157
|
-
"""
|
gpjax/decision_making/utils.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from beartype.typing import (
|
|
16
|
-
Callable,
|
|
17
|
-
Dict,
|
|
18
|
-
Final,
|
|
19
|
-
)
|
|
20
|
-
import jax.numpy as jnp
|
|
21
|
-
|
|
22
|
-
from gpjax.dataset import Dataset
|
|
23
|
-
from gpjax.gps import AbstractPosterior
|
|
24
|
-
from gpjax.typing import (
|
|
25
|
-
Array,
|
|
26
|
-
Float,
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
OBJECTIVE: Final[str] = "OBJECTIVE"
|
|
30
|
-
"""
|
|
31
|
-
Tag for the objective dataset/function in standard utility functions.
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
FunctionEvaluator = Callable[[Float[Array, "N D"]], Dict[str, Dataset]]
|
|
36
|
-
"""
|
|
37
|
-
Type alias for function evaluators, which take an array of points of shape $[N, D]$
|
|
38
|
-
and evaluate a set of functions at each point, returning a mapping from function tags
|
|
39
|
-
to datasets of the evaluated points. This is the same as the `Observer` in Trieste:
|
|
40
|
-
https://github.com/secondmind-labs/trieste/blob/develop/trieste/observer.py
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def build_function_evaluator(
|
|
45
|
-
functions: Dict[str, Callable[[Float[Array, "N D"]], Float[Array, "N 1"]]],
|
|
46
|
-
) -> FunctionEvaluator:
|
|
47
|
-
"""
|
|
48
|
-
Takes a dictionary of functions and returns a `FunctionEvaluator` which can be
|
|
49
|
-
used to evaluate each of the functions at a supplied set of points and return a
|
|
50
|
-
dictionary of datasets storing the evaluated points.
|
|
51
|
-
"""
|
|
52
|
-
return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()}
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def get_best_latent_observation_val(
|
|
56
|
-
posterior: AbstractPosterior, dataset: Dataset
|
|
57
|
-
) -> Float[Array, ""]:
|
|
58
|
-
"""
|
|
59
|
-
Takes a posterior and dataset and returns the best (latent) function value in the
|
|
60
|
-
dataset, corresponding to the minimum of the posterior mean value evaluated at
|
|
61
|
-
locations in the dataset. In the noiseless case, this corresponds to the minimum
|
|
62
|
-
value in the dataset.
|
|
63
|
-
"""
|
|
64
|
-
return jnp.min(posterior(dataset.X, dataset).mean())
|