ennbo 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
enn/turbo/proposal.py ADDED
@@ -0,0 +1,133 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Callable
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+ from numpy.random import Generator
8
+
9
+ from enn.enn import EpistemicNearestNeighbors
10
+ from enn.enn.enn_params import ENNParams
11
+
12
+ from .turbo_gp import TurboGP
13
+
14
+ from .turbo_utils import gp_thompson_sample
15
+
16
+
17
+ def mk_enn(
18
+ x_obs_list: list[float] | list[list[float]],
19
+ y_obs_list: list[float] | list[list[float]],
20
+ *,
21
+ yvar_obs_list: list[float] | None = None,
22
+ k: int,
23
+ num_fit_samples: int | None = None,
24
+ num_fit_candidates: int | None = None,
25
+ rng: Generator | Any | None = None,
26
+ params_warm_start: ENNParams | Any | None = None,
27
+ ) -> tuple[EpistemicNearestNeighbors | None, ENNParams | None]:
28
+ import numpy as np
29
+
30
+ from enn.enn import EpistemicNearestNeighbors
31
+ from enn.enn.enn_params import ENNParams
32
+
33
+ if len(x_obs_list) == 0:
34
+ return None, None
35
+ y_obs_array = np.asarray(y_obs_list, dtype=float)
36
+ if y_obs_array.size == 0:
37
+ return None, None
38
+
39
+ y = y_obs_array.reshape(-1, 1)
40
+ if yvar_obs_list is not None and len(yvar_obs_list) > 0:
41
+ yvar_array = np.asarray(yvar_obs_list, dtype=float)
42
+ yvar = yvar_array.reshape(-1, 1)
43
+ else:
44
+ yvar = None
45
+ x_obs_array = np.asarray(x_obs_list, dtype=float)
46
+ enn_model = EpistemicNearestNeighbors(
47
+ x_obs_array,
48
+ y,
49
+ yvar,
50
+ )
51
+ if len(enn_model) == 0:
52
+ return None, None
53
+
54
+ fitted_params: ENNParams | None = None
55
+ if num_fit_samples is not None and rng is not None:
56
+ from enn.enn.enn_fit import enn_fit
57
+
58
+ fitted_params = enn_fit(
59
+ enn_model,
60
+ k=k,
61
+ num_fit_candidates=num_fit_candidates
62
+ if num_fit_candidates is not None
63
+ else 30,
64
+ num_fit_samples=num_fit_samples,
65
+ rng=rng,
66
+ params_warm_start=params_warm_start,
67
+ )
68
+ else:
69
+ fitted_params = ENNParams(k=k, epi_var_scale=1.0, ale_homoscedastic_scale=0.0)
70
+
71
+ return enn_model, fitted_params
72
+
73
+
74
+ def select_uniform(
75
+ x_cand: np.ndarray,
76
+ num_arms: int,
77
+ num_dim: int,
78
+ rng: Generator | Any,
79
+ from_unit_fn: Callable[[np.ndarray], np.ndarray],
80
+ ) -> np.ndarray:
81
+ if x_cand.ndim != 2 or x_cand.shape[1] != num_dim:
82
+ raise ValueError(x_cand.shape)
83
+ if x_cand.shape[0] < num_arms:
84
+ raise ValueError((x_cand.shape[0], num_arms))
85
+ idx = rng.choice(x_cand.shape[0], size=num_arms, replace=False)
86
+ return from_unit_fn(x_cand[idx])
87
+
88
+
89
+ def select_gp_thompson(
90
+ x_cand: np.ndarray,
91
+ num_arms: int,
92
+ x_obs_list: list[float] | list[list[float]],
93
+ y_obs_list: list[float] | list[list[float]],
94
+ num_dim: int,
95
+ gp_num_steps: int,
96
+ rng: Generator | Any,
97
+ gp_y_mean: float,
98
+ gp_y_std: float,
99
+ select_sobol_fn: Callable[[np.ndarray, int], np.ndarray],
100
+ from_unit_fn: Callable[[np.ndarray], np.ndarray],
101
+ *,
102
+ model: TurboGP | None = None,
103
+ new_gp_y_mean: float | None = None,
104
+ new_gp_y_std: float | None = None,
105
+ ) -> tuple[np.ndarray, float, float, TurboGP | None]:
106
+ from .turbo_utils import fit_gp
107
+
108
+ if len(x_obs_list) == 0:
109
+ return select_sobol_fn(x_cand, num_arms), gp_y_mean, gp_y_std, None
110
+ if model is None:
111
+ model, _likelihood, new_gp_y_mean, new_gp_y_std = fit_gp(
112
+ x_obs_list,
113
+ y_obs_list,
114
+ num_dim,
115
+ num_steps=gp_num_steps,
116
+ )
117
+ if model is None:
118
+ return select_sobol_fn(x_cand, num_arms), gp_y_mean, gp_y_std, None
119
+ if new_gp_y_mean is None:
120
+ new_gp_y_mean = gp_y_mean
121
+ if new_gp_y_std is None:
122
+ new_gp_y_std = gp_y_std
123
+ if x_cand.shape[0] < num_arms:
124
+ raise ValueError((x_cand.shape[0], num_arms))
125
+ idx = gp_thompson_sample(
126
+ model,
127
+ x_cand,
128
+ num_arms,
129
+ rng,
130
+ new_gp_y_mean,
131
+ new_gp_y_std,
132
+ )
133
+ return from_unit_fn(x_cand[idx]), new_gp_y_mean, new_gp_y_std, model
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class TurboConfig:
9
+ k: int | None = None
10
+ num_candidates: int | None = None
11
+ num_init: int | None = None
12
+ var_scale: float = 1.0
13
+
14
+ # Experimental
15
+ trailing_obs: int | None = None
16
+ num_fit_samples: int | None = None
17
+ num_fit_candidates: int | None = None
18
+ acq_type: Literal["thompson", "pareto", "ucb"] = "pareto"
19
+ local_only: bool = False
20
+
21
+ def __post_init__(self) -> None:
22
+ if self.acq_type not in ["thompson", "pareto", "ucb"]:
23
+ raise ValueError(
24
+ f"acq_type must be 'thompson', 'pareto', or 'ucb', got {self.acq_type!r}"
25
+ )
26
+ # Pareto acquisition is the only type that works well without hyperparameter fitting
27
+ if self.num_fit_samples is None and self.acq_type != "pareto":
28
+ raise ValueError(f"num_fit_samples required for acq_type={self.acq_type!r}")
@@ -0,0 +1,176 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Callable
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+ from numpy.random import Generator
8
+
9
+ from .base_turbo_impl import BaseTurboImpl
10
+ from .turbo_config import TurboConfig
11
+
12
+
13
+ class TurboENNImpl(BaseTurboImpl):
14
+ def __init__(self, config: TurboConfig) -> None:
15
+ super().__init__(config)
16
+ self._enn: Any | None = None
17
+ self._fitted_params: Any | None = None
18
+ self._fitted_n_obs: int = 0
19
+
20
+ def get_x_center(
21
+ self,
22
+ x_obs_list: list,
23
+ y_obs_list: list,
24
+ rng: Generator,
25
+ ) -> np.ndarray | None:
26
+ import numpy as np
27
+
28
+ from .turbo_utils import argmax_random_tie
29
+
30
+ if len(y_obs_list) == 0:
31
+ return None
32
+ if self._enn is None or self._fitted_params is None:
33
+ return super().get_x_center(x_obs_list, y_obs_list, rng)
34
+ if self._fitted_n_obs != len(x_obs_list):
35
+ raise RuntimeError(
36
+ f"ENN fitted on {self._fitted_n_obs} obs but get_x_center called with {len(x_obs_list)}"
37
+ )
38
+
39
+ y_array = np.asarray(y_obs_list, dtype=float)
40
+ x_array = np.asarray(x_obs_list, dtype=float)
41
+
42
+ k = self._config.k if self._config.k is not None else 10
43
+ num_top = min(k, len(y_array))
44
+ top_indices = np.argpartition(-y_array, num_top - 1)[:num_top]
45
+
46
+ x_top = x_array[top_indices]
47
+ posterior = self._enn.posterior(x_top, params=self._fitted_params)
48
+ mu = posterior.mu[:, 0]
49
+
50
+ best_idx_in_top = argmax_random_tie(mu, rng=rng)
51
+ return x_top[best_idx_in_top]
52
+
53
+ def needs_tr_list(self) -> bool:
54
+ return True
55
+
56
+ def handle_restart(
57
+ self,
58
+ x_obs_list: list,
59
+ y_obs_list: list,
60
+ yvar_obs_list: list,
61
+ init_idx: int,
62
+ num_init: int,
63
+ ) -> tuple[bool, int]:
64
+ x_obs_list.clear()
65
+ y_obs_list.clear()
66
+ yvar_obs_list.clear()
67
+ return True, 0
68
+
69
+ def prepare_ask(
70
+ self,
71
+ x_obs_list: list,
72
+ y_obs_list: list,
73
+ yvar_obs_list: list,
74
+ num_dim: int,
75
+ gp_num_steps: int,
76
+ rng: Any | None = None,
77
+ ) -> tuple[Any, float | None, float | None, np.ndarray | None]:
78
+ from .proposal import mk_enn
79
+
80
+ k = self._config.k if self._config.k is not None else 10
81
+ self._enn, self._fitted_params = mk_enn(
82
+ x_obs_list,
83
+ y_obs_list,
84
+ yvar_obs_list=yvar_obs_list,
85
+ k=k,
86
+ num_fit_samples=self._config.num_fit_samples,
87
+ num_fit_candidates=self._config.num_fit_candidates,
88
+ rng=rng,
89
+ params_warm_start=self._fitted_params,
90
+ )
91
+ self._fitted_n_obs = len(x_obs_list)
92
+ return None, None, None, None
93
+
94
+ def select_candidates(
95
+ self,
96
+ x_cand: np.ndarray,
97
+ num_arms: int,
98
+ num_dim: int,
99
+ rng: Generator,
100
+ fallback_fn: Callable[[np.ndarray, int], np.ndarray],
101
+ from_unit_fn: Callable[[np.ndarray], np.ndarray],
102
+ ) -> np.ndarray:
103
+ import numpy as np
104
+
105
+ from enn.enn.enn_params import ENNParams
106
+
107
+ acq_type = self._config.acq_type
108
+ k = self._config.k
109
+ var_scale = self._config.var_scale
110
+
111
+ if self._enn is None:
112
+ return fallback_fn(x_cand, num_arms)
113
+
114
+ if self._fitted_params is not None:
115
+ params = self._fitted_params
116
+ else:
117
+ k_val = k if k is not None else 10
118
+ params = ENNParams(
119
+ k=k_val, epi_var_scale=var_scale, ale_homoscedastic_scale=0.0
120
+ )
121
+
122
+ posterior = self._enn.posterior(x_cand, params=params)
123
+ mu = posterior.mu[:, 0]
124
+ se = posterior.se[:, 0]
125
+
126
+ if acq_type == "pareto":
127
+ from enn.enn.enn_util import arms_from_pareto_fronts
128
+
129
+ x_arms = arms_from_pareto_fronts(x_cand, mu, se, num_arms, rng)
130
+ elif acq_type == "ucb":
131
+ scores = mu + se
132
+ shuffled_indices = rng.permutation(len(scores))
133
+ shuffled_scores = scores[shuffled_indices]
134
+ top_k_in_shuffled = np.argpartition(-shuffled_scores, num_arms - 1)[
135
+ :num_arms
136
+ ]
137
+ idx = shuffled_indices[top_k_in_shuffled]
138
+ x_arms = x_cand[idx]
139
+ elif acq_type == "thompson":
140
+ samples = posterior.sample(num_samples=1, rng=rng)
141
+ scores = samples[:, 0, 0]
142
+ shuffled_indices = rng.permutation(len(scores))
143
+ shuffled_scores = scores[shuffled_indices]
144
+ top_k_in_shuffled = np.argpartition(-shuffled_scores, num_arms - 1)[
145
+ :num_arms
146
+ ]
147
+ idx = shuffled_indices[top_k_in_shuffled]
148
+ x_arms = x_cand[idx]
149
+ else:
150
+ raise ValueError(f"Unknown acq_type: {acq_type}")
151
+
152
+ return from_unit_fn(x_arms)
153
+
154
+ def estimate_y(self, x_unit: np.ndarray, y_observed: np.ndarray) -> np.ndarray:
155
+ if self._enn is None or self._fitted_params is None:
156
+ return y_observed
157
+ posterior = self._enn.posterior(x_unit, params=self._fitted_params)
158
+ return posterior.mu[:, 0]
159
+
160
+ def get_mu_sigma(self, x_unit: np.ndarray) -> tuple[np.ndarray, np.ndarray] | None:
161
+ if self._enn is None:
162
+ return None
163
+ k = self._config.k if self._config.k is not None else 10
164
+ from enn.enn.enn_params import ENNParams
165
+
166
+ params = (
167
+ self._fitted_params
168
+ if self._fitted_params is not None
169
+ else ENNParams(
170
+ k=k, epi_var_scale=self._config.var_scale, ale_homoscedastic_scale=0.0
171
+ )
172
+ )
173
+ posterior = self._enn.posterior(x_unit, params=params, observation_noise=False)
174
+ mu = posterior.mu[:, 0]
175
+ sigma = posterior.se[:, 0]
176
+ return mu, sigma
enn/turbo/turbo_gp.py ADDED
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from .turbo_gp_base import TurboGPBase
4
+
5
+
6
+ class TurboGP(TurboGPBase):
7
+ def __init__(
8
+ self,
9
+ train_x,
10
+ train_y,
11
+ likelihood,
12
+ lengthscale_constraint,
13
+ outputscale_constraint,
14
+ ard_dims: int,
15
+ ) -> None:
16
+ from gpytorch.kernels import MaternKernel, ScaleKernel
17
+ from gpytorch.means import ConstantMean
18
+
19
+ super().__init__(train_x, train_y, likelihood)
20
+ self.mean_module = ConstantMean()
21
+ base_kernel = MaternKernel(
22
+ nu=2.5,
23
+ ard_num_dims=ard_dims,
24
+ lengthscale_constraint=lengthscale_constraint,
25
+ )
26
+ self.covar_module = ScaleKernel(
27
+ base_kernel,
28
+ outputscale_constraint=outputscale_constraint,
29
+ )
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ if TYPE_CHECKING:
6
+ from gpytorch.distributions import MultivariateNormal
7
+
8
+
9
+ def _get_exact_gp_base():
10
+ from gpytorch.models import ExactGP
11
+
12
+ return ExactGP
13
+
14
+
15
+ class TurboGPBase(_get_exact_gp_base()):
16
+ mean_module: Any
17
+ covar_module: Any
18
+
19
+ def forward(self, x) -> MultivariateNormal:
20
+ from gpytorch.distributions import MultivariateNormal
21
+
22
+ mean_x = self.mean_module(x)
23
+ covar_x = self.covar_module(x)
24
+ return MultivariateNormal(mean_x, covar_x)
25
+
26
+ def posterior(self, x) -> MultivariateNormal:
27
+ return self(x)
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from .turbo_gp_base import TurboGPBase
4
+
5
+
6
+ class TurboGPNoisy(TurboGPBase):
7
+ def __init__(
8
+ self,
9
+ train_x,
10
+ train_y,
11
+ train_y_var,
12
+ lengthscale_constraint,
13
+ outputscale_constraint,
14
+ ard_dims: int,
15
+ *,
16
+ learn_additional_noise: bool = True,
17
+ ) -> None:
18
+ from gpytorch.kernels import MaternKernel, ScaleKernel
19
+ from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
20
+ from gpytorch.means import ConstantMean
21
+
22
+ likelihood = FixedNoiseGaussianLikelihood(
23
+ noise=train_y_var,
24
+ learn_additional_noise=learn_additional_noise,
25
+ )
26
+ super().__init__(train_x, train_y, likelihood)
27
+ self.mean_module = ConstantMean()
28
+ base_kernel = MaternKernel(
29
+ nu=2.5,
30
+ ard_num_dims=ard_dims,
31
+ lengthscale_constraint=lengthscale_constraint,
32
+ )
33
+ self.covar_module = ScaleKernel(
34
+ base_kernel,
35
+ outputscale_constraint=outputscale_constraint,
36
+ )
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum, auto
4
+
5
+
6
+ class TurboMode(Enum):
7
+ TURBO_ONE = auto()
8
+ TURBO_ZERO = auto()
9
+ TURBO_ENN = auto()
10
+ LHD_ONLY = auto()
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Callable, Protocol
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+ from numpy.random import Generator
8
+
9
+
10
+ class TurboModeImpl(Protocol):
11
+ def get_x_center(
12
+ self,
13
+ x_obs_list: list,
14
+ y_obs_list: list,
15
+ rng: Generator,
16
+ ) -> np.ndarray | None: ...
17
+
18
+ def needs_tr_list(self) -> bool: ...
19
+
20
+ def create_trust_region(self, num_dim: int, num_arms: int) -> Any: ...
21
+
22
+ def try_early_ask(
23
+ self,
24
+ num_arms: int,
25
+ x_obs_list: list,
26
+ draw_initial_fn: Callable[[int], np.ndarray],
27
+ get_init_lhd_points_fn: Callable[[int], np.ndarray | None],
28
+ ) -> np.ndarray | None: ...
29
+
30
+ def handle_restart(
31
+ self,
32
+ x_obs_list: list,
33
+ y_obs_list: list,
34
+ yvar_obs_list: list,
35
+ init_idx: int,
36
+ num_init: int,
37
+ ) -> tuple[bool, int]: ...
38
+
39
+ def prepare_ask(
40
+ self,
41
+ x_obs_list: list,
42
+ y_obs_list: list,
43
+ yvar_obs_list: list,
44
+ num_dim: int,
45
+ gp_num_steps: int,
46
+ rng: Generator | Any | None = None,
47
+ ) -> tuple[Any, float | None, float | None, np.ndarray | None]: ...
48
+
49
+ def select_candidates(
50
+ self,
51
+ x_cand: np.ndarray,
52
+ num_arms: int,
53
+ num_dim: int,
54
+ rng: Generator,
55
+ fallback_fn: Callable[[np.ndarray, int], np.ndarray],
56
+ from_unit_fn: Callable[[np.ndarray], np.ndarray],
57
+ ) -> np.ndarray: ...
58
+
59
+ def update_trust_region(
60
+ self,
61
+ tr_state: Any,
62
+ y_obs_list: list,
63
+ x_center: np.ndarray | None = None,
64
+ k: int | None = None,
65
+ ) -> None: ...
66
+
67
+ def estimate_y(self, x_unit: np.ndarray, y_observed: np.ndarray) -> np.ndarray: ...
@@ -0,0 +1,163 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Callable
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+ from numpy.random import Generator
8
+
9
+ from .base_turbo_impl import BaseTurboImpl
10
+ from .turbo_config import TurboConfig
11
+ from .turbo_utils import gp_thompson_sample
12
+
13
+
14
+ class TurboOneImpl(BaseTurboImpl):
15
+ def __init__(self, config: TurboConfig) -> None:
16
+ super().__init__(config)
17
+ self._gp_model: Any | None = None
18
+ self._gp_y_mean: float = 0.0
19
+ self._gp_y_std: float = 1.0
20
+ self._fitted_n_obs: int = 0
21
+
22
+ def get_x_center(
23
+ self,
24
+ x_obs_list: list,
25
+ y_obs_list: list,
26
+ rng: Generator,
27
+ ) -> np.ndarray | None:
28
+ import numpy as np
29
+ import torch
30
+
31
+ from .turbo_utils import argmax_random_tie
32
+
33
+ if len(y_obs_list) == 0:
34
+ return None
35
+ if self._gp_model is None:
36
+ return super().get_x_center(x_obs_list, y_obs_list, rng)
37
+ if self._fitted_n_obs != len(x_obs_list):
38
+ raise RuntimeError(
39
+ f"GP fitted on {self._fitted_n_obs} obs but get_x_center called with {len(x_obs_list)}"
40
+ )
41
+
42
+ x_array = np.asarray(x_obs_list, dtype=float)
43
+ x_torch = torch.as_tensor(x_array, dtype=torch.float64)
44
+ with torch.no_grad():
45
+ posterior = self._gp_model.posterior(x_torch)
46
+ mu = posterior.mean.cpu().numpy().ravel()
47
+
48
+ best_idx = argmax_random_tie(mu, rng=rng)
49
+ return x_array[best_idx]
50
+
51
+ def needs_tr_list(self) -> bool:
52
+ return True
53
+
54
+ def try_early_ask(
55
+ self,
56
+ num_arms: int,
57
+ x_obs_list: list,
58
+ draw_initial_fn: Callable[[int], np.ndarray],
59
+ get_init_lhd_points_fn: Callable[[int], np.ndarray | None],
60
+ ) -> np.ndarray | None:
61
+ if len(x_obs_list) == 0:
62
+ return get_init_lhd_points_fn(num_arms)
63
+ return None
64
+
65
+ def handle_restart(
66
+ self,
67
+ x_obs_list: list,
68
+ y_obs_list: list,
69
+ yvar_obs_list: list,
70
+ init_idx: int,
71
+ num_init: int,
72
+ ) -> tuple[bool, int]:
73
+ x_obs_list.clear()
74
+ y_obs_list.clear()
75
+ yvar_obs_list.clear()
76
+ return True, 0
77
+
78
+ def prepare_ask(
79
+ self,
80
+ x_obs_list: list,
81
+ y_obs_list: list,
82
+ yvar_obs_list: list,
83
+ num_dim: int,
84
+ gp_num_steps: int,
85
+ rng: Any | None = None,
86
+ ) -> tuple[Any, float | None, float | None, np.ndarray | None]:
87
+ import numpy as np
88
+
89
+ from .turbo_utils import fit_gp
90
+
91
+ if len(x_obs_list) == 0:
92
+ return None, None, None, None
93
+ self._gp_model, _likelihood, gp_y_mean_fitted, gp_y_std_fitted = fit_gp(
94
+ x_obs_list,
95
+ y_obs_list,
96
+ num_dim,
97
+ yvar_obs_list=yvar_obs_list if yvar_obs_list else None,
98
+ num_steps=gp_num_steps,
99
+ )
100
+ self._fitted_n_obs = len(x_obs_list)
101
+ if gp_y_mean_fitted is not None:
102
+ self._gp_y_mean = gp_y_mean_fitted
103
+ if gp_y_std_fitted is not None:
104
+ self._gp_y_std = gp_y_std_fitted
105
+ weights = None
106
+ if self._gp_model is not None:
107
+ weights = (
108
+ self._gp_model.covar_module.base_kernel.lengthscale.cpu()
109
+ .detach()
110
+ .numpy()
111
+ .ravel()
112
+ )
113
+ # First line helps stabilize second line.
114
+ weights = weights / weights.mean()
115
+ weights = weights / np.prod(np.power(weights, 1.0 / len(weights)))
116
+ return self._gp_model, gp_y_mean_fitted, gp_y_std_fitted, weights
117
+
118
+ def select_candidates(
119
+ self,
120
+ x_cand: np.ndarray,
121
+ num_arms: int,
122
+ num_dim: int,
123
+ rng: Generator,
124
+ fallback_fn: Callable[[np.ndarray, int], np.ndarray],
125
+ from_unit_fn: Callable[[np.ndarray], np.ndarray],
126
+ ) -> np.ndarray:
127
+ if self._gp_model is None:
128
+ return fallback_fn(x_cand, num_arms)
129
+
130
+ idx = gp_thompson_sample(
131
+ self._gp_model,
132
+ x_cand,
133
+ num_arms,
134
+ rng,
135
+ self._gp_y_mean,
136
+ self._gp_y_std,
137
+ )
138
+ return from_unit_fn(x_cand[idx])
139
+
140
+ def estimate_y(self, x_unit: np.ndarray, y_observed: np.ndarray) -> np.ndarray:
141
+ import torch
142
+
143
+ if self._gp_model is None:
144
+ return y_observed
145
+ x_torch = torch.as_tensor(x_unit, dtype=torch.float64)
146
+ with torch.no_grad():
147
+ posterior = self._gp_model.posterior(x_torch)
148
+ mu = posterior.mean.cpu().numpy().ravel()
149
+ return self._gp_y_mean + self._gp_y_std * mu
150
+
151
+ def get_mu_sigma(self, x_unit: np.ndarray) -> tuple[np.ndarray, np.ndarray] | None:
152
+ import torch
153
+
154
+ if self._gp_model is None:
155
+ return None
156
+ x_torch = torch.as_tensor(x_unit, dtype=torch.float64)
157
+ with torch.no_grad():
158
+ posterior = self._gp_model.posterior(x_torch)
159
+ mu_std = posterior.mean.cpu().numpy().ravel()
160
+ sigma_std = posterior.variance.cpu().numpy().ravel() ** 0.5
161
+ mu = self._gp_y_mean + self._gp_y_std * mu_std
162
+ sigma = self._gp_y_std * sigma_std
163
+ return mu, sigma