synbo 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
synbo/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ """Modern Reaction Optimization Framework.
2
+
3
+ A sophisticated framework for optimizing chemical reactions using
4
+ Bayesian Optimization with modern Python practices and rich output.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ __version__ = "0.1.0"
10
+ __author__ = "Zhenzhi Tan"
11
+ __email__ = "zhenzhi-tan@outlook.com"
12
+
13
+ # Core classes
14
+ from .synbo import ReactionOptimizer
15
+ from .initialize import Initializer
16
+ from .optimize import Optimizer
17
+
18
+ __all__ = [
19
+ "ReactionOptimizer",
20
+ "Initializer",
21
+ "Optimizer",
22
+ ]
@@ -0,0 +1,337 @@
1
+ import numpy as np
2
+ import gpytorch
3
+ import torch
4
+ from torch import Tensor
5
+ from botorch.models import ModelListGP
6
+ from botorch.acquisition.multi_objective.logei import qLogNoisyExpectedHypervolumeImprovement, qLogExpectedHypervolumeImprovement
7
+ from botorch.acquisition.multi_objective.monte_carlo import qExpectedHypervolumeImprovement
8
+ from botorch.utils.multi_objective.box_decompositions import NondominatedPartitioning
9
+ from botorch.sampling.normal import SobolQMCNormalSampler
10
+
11
+ from botorch.acquisition.monte_carlo import qUpperConfidenceBound, qExpectedImprovement
12
+ from botorch.acquisition.logei import qLogNoisyExpectedImprovement
13
+ from botorch.acquisition.objective import GenericMCObjective
14
+ from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization
15
+ from botorch.optim import optimize_acqf_discrete
16
+
17
+ from synbo.utils.logger import console
18
+
19
+
20
+ class BaseAcquisitionFunction:
21
+ """Base class for acquisition functions"""
22
+
23
+ def __init__(self, model: ModelListGP, sampler: SobolQMCNormalSampler):
24
+ self.model = model
25
+ self.sampler = sampler
26
+ self.acquisition_function = None
27
+ self.console = console
28
+
29
+ @property
30
+ def acq_func(self):
31
+ return self.acquisition_function
32
+
33
+ def _split_batch_eval_acqf(self, acq_func, X: Tensor, max_batch_size: int) -> Tensor:
34
+ """Helper to evaluate acquisition function in batches to avoid OOM."""
35
+ acq_values_list = []
36
+ with torch.no_grad():
37
+ for X_batches in X.split(max_batch_size):
38
+ acq_values = acq_func(X_batches)
39
+ acq_values_list.append(acq_values)
40
+ return torch.cat(acq_values_list, dim=0)
41
+
42
+ def optimize_discrete(
43
+ self,
44
+ acq_func,
45
+ q: int,
46
+ choices: Tensor,
47
+ unique: bool = True,
48
+ max_batch_size: int = 1024,
49
+ progress: object = None,
50
+ task: object = None,
51
+ min_distance: float = 1e-6,
52
+ exclude_points: Tensor = None,
53
+ ) -> tuple[Tensor, Tensor]:
54
+ """
55
+ Joint optimization for batch selection (EDBO-style).
56
+ This method selects q candidates simultaneously using BoTorch's optimize_acqf_discrete,
57
+ exactly as implemented in EDBO+.
58
+
59
+ Unlike the greedy sequence optimization, this approach considers the joint expected improvement
60
+ of all q candidates at once. This is the implementation used by EDBO+.
61
+
62
+ Args:
63
+ acq_func: Acquisition function to optimize (must support q-batch evaluation)
64
+ q: Number of candidates to select
65
+ choices: Candidate points tensor (n, D)
66
+ unique: Whether to ensure unique selections (default: True)
67
+ max_batch_size: Maximum batch size for evaluation (default: 128)
68
+ progress: Not used in EDBO implementation (kept for API compatibility)
69
+ task: Not used in EDBO implementation (kept for API compatibility)
70
+ min_distance: Not used in EDBO implementation (kept for API compatibility)
71
+ exclude_points: Not used in EDBO implementation (kept for API compatibility)
72
+
73
+ Returns:
74
+ tuple[Tensor, Tensor]:
75
+ - Selected candidates (q, D)
76
+ - Acquisition values for selected candidates (q,)
77
+ """
78
+
79
+ acq_result = self._new_optimize_acqf_discrete(
80
+ acq_function=acq_func,
81
+ choices=choices,
82
+ q=q,
83
+ unique=unique,
84
+ max_batch_size=max_batch_size,
85
+ )
86
+
87
+ selected_candidates = acq_result[0] # (q, D)
88
+ acquisition_values = acq_result[1] # (q,)
89
+
90
+ return selected_candidates, acquisition_values
91
+
92
+ def _new_optimize_acqf_discrete(self, acq_function, q, choices, max_batch_size, unique):
93
+
94
+ def _split_batch_eval_acqf(acq_function, X, max_batch_size):
95
+ return torch.cat([acq_function(X_) for X_ in X.split(max_batch_size)])
96
+
97
+ choices_batched = choices.unsqueeze(-2)
98
+ if q > 1:
99
+ candidate_list, acq_value_list = [], []
100
+ base_X_pending = acq_function.X_pending
101
+ for q_i in range(q):
102
+ with torch.no_grad():
103
+ acq_values = _split_batch_eval_acqf(
104
+ acq_function=acq_function,
105
+ X=choices_batched,
106
+ max_batch_size=max_batch_size,
107
+ )
108
+
109
+ best_idx = torch.argmax(acq_values)
110
+ candidate_list.append(choices_batched[best_idx])
111
+ acq_value_list.append(acq_values[best_idx])
112
+ # set pending points
113
+ candidates = torch.cat(candidate_list, dim=-2)
114
+ acq_function.set_X_pending(torch.cat([base_X_pending, candidates], dim=-2) if base_X_pending is not None else candidates)
115
+ # need to remove choice from choice set if enforcing uniqueness
116
+ if unique:
117
+ choices_batched = torch.cat([choices_batched[:best_idx], choices_batched[best_idx + 1 :]])
118
+
119
+ # Reset acq_func to previous X_pending state
120
+ acq_function.set_X_pending(base_X_pending)
121
+ return candidates, torch.stack(acq_value_list)
122
+
123
+ with torch.no_grad():
124
+ acq_values = _split_batch_eval_acqf(acq_function=acq_function, X=choices_batched, max_batch_size=max_batch_size)
125
+
126
+ best_idx = torch.argmax(acq_values)
127
+ return choices_batched[best_idx], acq_values[best_idx]
128
+
129
+
130
+ class EHVIAcquisitionFunction(BaseAcquisitionFunction):
131
+ """Enhanced Expected Hypervolume Improvement acquisition function"""
132
+
133
+ def __init__(
134
+ self,
135
+ model: ModelListGP,
136
+ sampler: SobolQMCNormalSampler,
137
+ ref_point: torch.Tensor,
138
+ partitioning,
139
+ train_x,
140
+ ):
141
+ super().__init__(model, sampler)
142
+ self.ref_point = ref_point
143
+ self.partitioning = partitioning
144
+ self.acquisition_function = qLogNoisyExpectedHypervolumeImprovement(
145
+ model=model,
146
+ sampler=sampler,
147
+ ref_point=ref_point,
148
+ # partitioning=partitioning,
149
+ alpha=0.0,
150
+ incremental_nehvi=True,
151
+ X_baseline=train_x,
152
+ prune_baseline=True,
153
+ )
154
+
155
+ print(self.ref_point)
156
+
157
+ def optimize_acqf_discrete(
158
+ self,
159
+ q: int,
160
+ choices: Tensor,
161
+ max_batch_size: int = 1024,
162
+ unique: bool = True,
163
+ progress: object = None,
164
+ task: object = None,
165
+ min_distance: float = 1e-6,
166
+ exclude_points: Tensor = None,
167
+ ) -> tuple[Tensor, Tensor]:
168
+
169
+ return self.optimize_discrete(
170
+ acq_func=self.acquisition_function,
171
+ q=q,
172
+ choices=choices,
173
+ max_batch_size=max_batch_size,
174
+ unique=unique,
175
+ progress=progress,
176
+ task=task,
177
+ min_distance=min_distance,
178
+ exclude_points=exclude_points,
179
+ )
180
+
181
+
182
+ class UCBAcquisitionFunction(BaseAcquisitionFunction):
183
+ """
184
+ Upper Confidence Bound acquisition function.
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ model: ModelListGP,
190
+ sampler: SobolQMCNormalSampler,
191
+ beta: float = 2.0,
192
+ weights: Tensor = None,
193
+ ):
194
+ super().__init__(model, sampler)
195
+ self.beta = beta
196
+ objective = None
197
+ if weights is not None:
198
+ objective = GenericMCObjective(lambda Z, X: Z @ weights)
199
+
200
+ self.acquisition_function = qUpperConfidenceBound(
201
+ model=model,
202
+ beta=beta,
203
+ sampler=sampler,
204
+ objective=objective,
205
+ )
206
+
207
+ def optimize_acqf_discrete(
208
+ self,
209
+ q: int,
210
+ choices: Tensor,
211
+ max_batch_size: int = 1024,
212
+ unique: bool = True,
213
+ maximum_metrics: bool = True,
214
+ progress: object = None,
215
+ task: object = None,
216
+ min_distance: float = 1e-6,
217
+ exclude_points: Tensor = None,
218
+ ) -> tuple[Tensor, Tensor]:
219
+ # 直接调用基类的通用逻辑
220
+ return self.optimize_discrete(
221
+ acq_func=self.acquisition_function,
222
+ q=q,
223
+ choices=choices,
224
+ max_batch_size=max_batch_size,
225
+ unique=unique,
226
+ progress=progress,
227
+ task=task,
228
+ min_distance=min_distance,
229
+ exclude_points=exclude_points,
230
+ )
231
+
232
+
233
+ class ParEGOAcquisitionFunction(BaseAcquisitionFunction):
234
+ """
235
+ Convert multiple objectives into single objectives using random Chebyshev scaling, and then apply EI.
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ model: ModelListGP,
241
+ sampler: SobolQMCNormalSampler,
242
+ X_baseline: Tensor,
243
+ num_objectives: int,
244
+ ):
245
+ super().__init__(model, sampler)
246
+
247
+ weights = torch.randn(num_objectives).abs()
248
+ weights /= weights.sum()
249
+
250
+ with torch.no_grad():
251
+ posterior = model.posterior(X_baseline)
252
+ Y_baseline = posterior.mean
253
+
254
+ weights = weights.to(Y_baseline.device)
255
+
256
+ objective = self._get_chebyshev_objective(weights=weights, Y=Y_baseline)
257
+
258
+ scalarized_Y = objective(Y_baseline)
259
+ best_f = scalarized_Y.max()
260
+
261
+ self.acquisition_function = qExpectedImprovement(
262
+ model=model,
263
+ best_f=best_f,
264
+ sampler=sampler,
265
+ objective=objective,
266
+ )
267
+
268
+ def optimize_acqf_discrete(
269
+ self,
270
+ q: int,
271
+ choices: Tensor,
272
+ max_batch_size: int = 1024,
273
+ unique: bool = True,
274
+ maximum_metrics: bool = True,
275
+ progress: object = None,
276
+ task: object = None,
277
+ min_distance: float = 1e-6,
278
+ exclude_points: Tensor = None,
279
+ ) -> tuple[Tensor, Tensor]:
280
+ return self.optimize_discrete(
281
+ acq_func=self.acquisition_function,
282
+ q=q,
283
+ choices=choices,
284
+ max_batch_size=max_batch_size,
285
+ unique=unique,
286
+ progress=progress,
287
+ task=task,
288
+ min_distance=min_distance,
289
+ exclude_points=exclude_points,
290
+ )
291
+
292
+ @staticmethod
293
+ def _get_chebyshev_objective(weights: Tensor, Y: Tensor) -> GenericMCObjective:
294
+ """ """
295
+ scalarization_fn = get_chebyshev_scalarization(weights=weights, Y=Y)
296
+ return GenericMCObjective(scalarization_fn)
297
+
298
+
299
+ class ParetoFrontCalculator:
300
+ """Class for calculating Pareto fronts"""
301
+
302
+ @staticmethod
303
+ def calculate_target_function(points: np.ndarray, progress: object, task: object) -> np.ndarray:
304
+ """
305
+ Calculate Pareto front for points in arbitrary dimensions
306
+
307
+ Args:
308
+ points: numpy array of shape (n_points, n_dimensions)
309
+
310
+ Returns:
311
+ numpy array of Pareto optimal points
312
+ """
313
+ if len(points) == 0:
314
+ return np.array([])
315
+ pareto_front = [points[0]] # Initialize list of Pareto optimal points
316
+ for point in points[1:]:
317
+ progress.update(task, advance=1)
318
+ is_pareto = True
319
+ to_remove = []
320
+ # Compare with all points in current Pareto front
321
+ for i, pf_point in enumerate(pareto_front):
322
+ # Check if the current point dominates any existing Pareto point
323
+ if np.all(point >= pf_point) and np.any(point > pf_point):
324
+ to_remove.append(i)
325
+ # Check if any existing Pareto point dominates the current point
326
+ elif np.all(point <= pf_point) and np.any(point < pf_point):
327
+ is_pareto = False
328
+ break
329
+
330
+ # Remove dominated points from Pareto front
331
+ for i in reversed(to_remove):
332
+ pareto_front.pop(i)
333
+
334
+ # Add current point if it's Pareto optimal
335
+ if is_pareto:
336
+ pareto_front.append(point)
337
+ return torch.tensor(np.array(pareto_front))
@@ -0,0 +1,257 @@
1
+ from typing import List, Tuple
2
+ import numpy as np
3
+ import torch
4
+ from rich.progress import Progress, TextColumn, BarColumn, MofNCompleteColumn, TimeRemainingColumn
5
+
6
+ from botorch.models import ModelListGP
7
+ from botorch.utils.multi_objective.box_decompositions import NondominatedPartitioning
8
+ from botorch.sampling.normal import SobolQMCNormalSampler
9
+
10
+ from synbo.utils.util_func import compute_hvi, generate_constraint_mask
11
+ from synbo.utils.logger import console
12
+ from synbo.algorithm.sg_model import (
13
+ BNNEnsembleSurrogateModel,
14
+ BayesianLinearSurrogateModel,
15
+ GPSurrogateModel,
16
+ RFSurrogateModel,
17
+ SklearnModelWrapper,
18
+ )
19
+ from synbo.algorithm.acq_function import (
20
+ EHVIAcquisitionFunction,
21
+ ParEGOAcquisitionFunction,
22
+ UCBAcquisitionFunction,
23
+ ParetoFrontCalculator,
24
+ )
25
+
26
+ import warnings
27
+ from linear_operator.utils.cholesky import NumericalWarning
28
+
29
+ warnings.filterwarnings("ignore", category=NumericalWarning)
30
+
31
+
32
+ class DefaultBO:
33
+ def __init__(
34
+ self,
35
+ random_seed: int = 42,
36
+ surrogate_model: str = "GP",
37
+ acq_func: str = "EHVI",
38
+ device: torch.device = torch.device("cpu"),
39
+ accuracy: str = "medium",
40
+ ):
41
+ self.random_seed = random_seed
42
+ self.console = console
43
+
44
+ if accuracy == "medium":
45
+ self.mc_num_samples, self.max_batch_size = 256, 1024
46
+ self.device = device
47
+
48
+ if surrogate_model == "GP":
49
+ self.surrogate_model_class = GPSurrogateModel
50
+ elif surrogate_model == "RF":
51
+ self.surrogate_model_class = RFSurrogateModel
52
+ elif surrogate_model == "BNN":
53
+ self.surrogate_model_class = BNNEnsembleSurrogateModel
54
+ elif surrogate_model == "BayesianLinear":
55
+ self.surrogate_model_class = BayesianLinearSurrogateModel
56
+ else:
57
+ raise ValueError(f"Unknown surrogate model: {surrogate_model}")
58
+
59
+ if acq_func == "EHVI":
60
+ self.acquisition_function_class = EHVIAcquisitionFunction
61
+ elif acq_func == "UCB":
62
+ self.acquisition_function_class = UCBAcquisitionFunction
63
+ elif acq_func == "ParEGO":
64
+ self.acquisition_function_class = ParEGOAcquisitionFunction
65
+ else:
66
+ raise ValueError(f"Unknown acquisition function: {acq_func}")
67
+
68
+ self.target_evaluator = ParetoFrontCalculator()
69
+
70
+ def optimize(
71
+ self,
72
+ training_X: np.ndarray,
73
+ training_y: np.ndarray,
74
+ candidate_X: np.ndarray,
75
+ opt_metric_settings: List[dict],
76
+ batch_size: int,
77
+ total_name_arr: np.ndarray = None,
78
+ condition_types: List[str] = None,
79
+ total_desc_arr: np.ndarray = None,
80
+ ) -> Tuple[np.ndarray, List[str], np.ndarray, np.ndarray]:
81
+
82
+ training_X_t = torch.tensor(training_X).double().to(device=self.device)
83
+ training_y_t = torch.tensor(training_y).double()
84
+ training_y_t = self._weight_y(training_y_t, opt_metric_settings).to(device=self.device)
85
+ candidate_X_t = torch.tensor(candidate_X).double().to(device=self.device)
86
+
87
+ with Progress(
88
+ TextColumn("[bold cyan]{task.description}"),
89
+ BarColumn(bar_width=None),
90
+ MofNCompleteColumn(),
91
+ TimeRemainingColumn(),
92
+ console=self.console,
93
+ ) as progress:
94
+ num_models = training_y_t.shape[1]
95
+ task_train = progress.add_task(description="Training surrogate models", total=num_models, start=True)
96
+
97
+ models = []
98
+ for i in range(num_models):
99
+ # key = list(training_y_dict.keys())[i]
100
+ progress.log(f"Fitting the {i}th model...", style="yellow")
101
+
102
+ # Instantiate model with random_seed for reproducibility
103
+ if self.surrogate_model_class in [RFSurrogateModel, BNNEnsembleSurrogateModel]:
104
+ model_i = self.surrogate_model_class(device=self.device, num_dims=training_X_t.shape[1], random_seed=self.random_seed)
105
+ else:
106
+ model_i = self.surrogate_model_class(device=self.device, num_dims=training_X_t.shape[1])
107
+
108
+ if isinstance(model_i, GPSurrogateModel):
109
+ model_i.fit(training_X_t, training_y_t[:, i].unsqueeze(-1))
110
+ models.append(model_i.model)
111
+ else:
112
+ wrapper = SklearnModelWrapper(model_i)
113
+ wrapper.fit_surrogate(training_X_t, training_y_t[:, i].unsqueeze(-1))
114
+ models.append(wrapper)
115
+
116
+ progress.update(task_train, advance=1)
117
+
118
+ self.global_model = ModelListGP(*models)
119
+
120
+ task_pareto = progress.add_task(description="Calculating Pareto frontiers", total=len(training_y) - 1)
121
+ training_y_np = training_y_t.cpu().numpy()
122
+ self.pareto_y = self.target_evaluator.calculate_target_function(training_y_np, progress, task_pareto).to(device=self.device)
123
+
124
+ y_min = training_y_t.min(dim=0).values
125
+ y_max = training_y_t.max(dim=0).values
126
+ y_range = y_max - y_min
127
+
128
+ ref_point_values = []
129
+ for i, omi in enumerate(opt_metric_settings):
130
+
131
+ ref_val = y_min[i] if y_range[i] > 0 else y_min[i] - 0.1
132
+ ref_point_values.append(ref_val)
133
+
134
+ self.ref_point = torch.tensor(ref_point_values, dtype=torch.double, device=self.device)
135
+
136
+ sampler = SobolQMCNormalSampler(sample_shape=torch.Size([self.mc_num_samples]), seed=self.random_seed)
137
+ if self.acquisition_function_class == EHVIAcquisitionFunction:
138
+ partitioning = NondominatedPartitioning(ref_point=self.ref_point, Y=self.pareto_y)
139
+ acq_func = self.acquisition_function_class(
140
+ model=self.global_model,
141
+ sampler=sampler,
142
+ ref_point=self.ref_point,
143
+ partitioning=partitioning,
144
+ train_x=torch.Tensor(training_X).to(self.device),
145
+ )
146
+ elif self.acquisition_function_class == UCBAcquisitionFunction:
147
+ weights = torch.tensor([d["metric_weight"] for d in opt_metric_settings], dtype=torch.double, device=self.device)
148
+ acq_func = self.acquisition_function_class(
149
+ model=self.global_model,
150
+ sampler=sampler,
151
+ beta=2.0,
152
+ weights=weights,
153
+ )
154
+ elif self.acquisition_function_class == ParEGOAcquisitionFunction:
155
+
156
+ acq_func = self.acquisition_function_class(
157
+ model=self.global_model, sampler=sampler, X_baseline=training_X_t, num_objectives=len(opt_metric_settings)
158
+ )
159
+ else:
160
+ raise ValueError(f"Unknown acquisition function class: {self.acquisition_function_class}")
161
+
162
+ task_acq_opt = progress.add_task(description="Optimizing acquisition function", total=batch_size)
163
+ self.acq_result, self.acq_value = acq_func.optimize_acqf_discrete(
164
+ q=batch_size,
165
+ choices=candidate_X_t,
166
+ max_batch_size=self.max_batch_size,
167
+ unique=True,
168
+ exclude_points=training_X_t,
169
+ min_distance=1e-6,
170
+ progress=progress,
171
+ task=task_acq_opt,
172
+ )
173
+
174
+ if self.device.type == "cuda":
175
+ best_samples = [res.cpu().numpy() for res in self.acq_result]
176
+ else:
177
+ best_samples = [res.numpy() for res in self.acq_result]
178
+
179
+ recommend_type = self._get_exploit_or_explore()
180
+ pred_mean, pred_std = self._get_predictions()
181
+
182
+ for i, d in enumerate(opt_metric_settings):
183
+ if d["opt_direct"] == "min":
184
+ pred_mean[:, i] = -pred_mean[:, i]
185
+
186
+ pred_mean = self._unweight_y(torch.tensor(pred_mean), opt_metric_settings).numpy()
187
+ pred_std = self._unweight_y(torch.tensor(pred_std), opt_metric_settings).numpy()
188
+ return best_samples, recommend_type, pred_mean, pred_std
189
+
190
+ def _get_exploit_or_explore(self) -> List[str]:
191
+ with torch.no_grad():
192
+ posterior = self.global_model.posterior(self.acq_result)
193
+ pred_mean = posterior.mean
194
+
195
+ # Handle different shapes from posterior.mean
196
+ # The posterior.mean from ModelListGP with multiple outputs has shape (batch_size, q)
197
+ # where each model contributes one dimension
198
+ # We need to reshape to (batch_size, n_outputs) where batch_size = n_points
199
+ if pred_mean.dim() == 2:
200
+ # Already in correct shape (n_points, n_outputs)
201
+ pass
202
+ elif pred_mean.dim() == 3:
203
+ # Shape: (1, n_points, n_outputs) -> (n_points, n_outputs)
204
+ pred_mean = pred_mean.squeeze(0)
205
+
206
+ # Ensure pred_mean is 2D
207
+ if pred_mean.dim() != 2:
208
+ raise ValueError(f"Unexpected pred_mean shape: {pred_mean.shape}")
209
+
210
+ hvi_values = torch.tensor([compute_hvi(pred_mean[i], self.pareto_y, self.ref_point) for i in range(pred_mean.shape[0])])
211
+ ehvi_values = self.acq_value.to(device="cpu")
212
+ exploit_scores = hvi_values / (ehvi_values + 1e-6)
213
+ explore_scores = 1 - exploit_scores
214
+
215
+ for i in range(self.acq_result.shape[0]):
216
+ self.console.log(
217
+ f"Point {i}: "
218
+ f"EHVI = {ehvi_values[i]:.3f}, "
219
+ f"HVI = {hvi_values[i]:.3f}, "
220
+ f"Exploit Score = {exploit_scores[i]:.3f}, "
221
+ f"Explore Score = {explore_scores[i]:.3f}"
222
+ )
223
+ return ["exploit" if exploit_scores[i] > explore_scores[i] else "explore" for i in range(self.acq_result.shape[0])]
224
+
225
+ def _get_predictions(self) -> Tuple[np.ndarray, np.ndarray]:
226
+ with torch.no_grad():
227
+ posterior = self.global_model.posterior(self.acq_result)
228
+ pred_mean = posterior.mean
229
+ pred_var = posterior.variance
230
+
231
+ # Handle different shapes from posterior
232
+ if pred_mean.dim() == 2:
233
+ # Already in correct shape (n_points, n_outputs)
234
+ pass
235
+ elif pred_mean.dim() == 3:
236
+ # Shape: (1, n_points, n_outputs) -> (n_points, n_outputs)
237
+ pred_mean = pred_mean.squeeze(0)
238
+ pred_var = pred_var.squeeze(0)
239
+
240
+ # Ensure correct shapes
241
+ if pred_mean.dim() != 2:
242
+ raise ValueError(f"Unexpected pred_mean shape: {pred_mean.shape}")
243
+
244
+ pred_mean = pred_mean.cpu().numpy()
245
+ pred_var = pred_var.cpu().numpy()
246
+ pred_std = np.sqrt(pred_var)
247
+ return pred_mean, pred_std
248
+
249
+ def _weight_y(self, training_y: torch.Tensor, opt_metric_settings: List[dict]) -> torch.Tensor:
250
+ weights = torch.tensor([d["metric_weight"] for d in opt_metric_settings])
251
+ training_y = training_y * weights
252
+ return training_y
253
+
254
+ def _unweight_y(self, training_y: torch.Tensor, opt_metric_settings: List[dict]) -> torch.Tensor:
255
+ weights = torch.tensor([d["metric_weight"] for d in opt_metric_settings])
256
+ training_y = training_y / weights
257
+ return training_y