ennbo 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enn/enn/enn.py +71 -31
- enn/enn/enn_fit.py +26 -24
- enn/enn/enn_normal.py +3 -2
- enn/enn/enn_params.py +13 -0
- enn/enn/enn_util.py +40 -12
- enn/turbo/base_turbo_impl.py +53 -7
- enn/turbo/lhd_only_impl.py +7 -0
- enn/turbo/morbo_trust_region.py +189 -0
- enn/turbo/no_trust_region.py +65 -0
- enn/turbo/proposal.py +11 -2
- enn/turbo/turbo_config.py +48 -4
- enn/turbo/turbo_enn_impl.py +46 -21
- enn/turbo/turbo_gp.py +9 -1
- enn/turbo/turbo_mode_impl.py +11 -2
- enn/turbo/turbo_one_impl.py +163 -24
- enn/turbo/turbo_optimizer.py +246 -58
- enn/turbo/turbo_trust_region.py +8 -10
- enn/turbo/turbo_utils.py +116 -26
- enn/turbo/turbo_zero_impl.py +5 -0
- {ennbo-0.1.0.dist-info → ennbo-0.1.2.dist-info}/METADATA +5 -4
- ennbo-0.1.2.dist-info/RECORD +29 -0
- ennbo-0.1.0.dist-info/RECORD +0 -27
- {ennbo-0.1.0.dist-info → ennbo-0.1.2.dist-info}/WHEEL +0 -0
- {ennbo-0.1.0.dist-info → ennbo-0.1.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.random import Generator
|
|
8
|
+
from scipy.stats._qmc import QMCEngine
|
|
9
|
+
|
|
10
|
+
from .turbo_trust_region import TurboTrustRegion
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MorboTrustRegion:
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
num_dim: int,
|
|
17
|
+
num_arms: int,
|
|
18
|
+
num_metrics: int,
|
|
19
|
+
*,
|
|
20
|
+
rng: Generator,
|
|
21
|
+
) -> None:
|
|
22
|
+
import numpy as np
|
|
23
|
+
|
|
24
|
+
self._tr = TurboTrustRegion(num_dim=num_dim, num_arms=num_arms)
|
|
25
|
+
self._num_dim = int(num_dim)
|
|
26
|
+
self._num_arms = int(num_arms)
|
|
27
|
+
self._num_metrics = int(num_metrics)
|
|
28
|
+
if self._num_metrics <= 0:
|
|
29
|
+
raise ValueError(self._num_metrics)
|
|
30
|
+
|
|
31
|
+
alpha = np.ones(self._num_metrics, dtype=float)
|
|
32
|
+
self._weights = np.asarray(rng.dirichlet(alpha), dtype=float)
|
|
33
|
+
self._alpha = 0.05
|
|
34
|
+
|
|
35
|
+
self._y_min: np.ndarray | Any | None = None
|
|
36
|
+
self._y_max: np.ndarray | Any | None = None
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def num_dim(self) -> int:
|
|
40
|
+
return self._num_dim
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def num_arms(self) -> int:
|
|
44
|
+
return self._num_arms
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def num_metrics(self) -> int:
|
|
48
|
+
return self._num_metrics
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def weights(self) -> np.ndarray:
|
|
52
|
+
return self._weights
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def length(self) -> float:
|
|
56
|
+
return float(self._tr.length)
|
|
57
|
+
|
|
58
|
+
def update(self, values: np.ndarray | Any) -> None:
|
|
59
|
+
raise NotImplementedError(
|
|
60
|
+
"Use update_xy(x_obs, y_obs) with multi-objective observations."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def update_xy(
|
|
64
|
+
self, x_obs: np.ndarray | Any, y_obs: np.ndarray | Any, *, k: Any = None
|
|
65
|
+
) -> None: # noqa: ARG002
|
|
66
|
+
import numpy as np
|
|
67
|
+
|
|
68
|
+
x_obs = np.asarray(x_obs, dtype=float)
|
|
69
|
+
y_obs = np.asarray(y_obs, dtype=float)
|
|
70
|
+
|
|
71
|
+
if x_obs.ndim != 2 or x_obs.shape[1] != self._num_dim:
|
|
72
|
+
raise ValueError(x_obs.shape)
|
|
73
|
+
if y_obs.ndim != 2 or y_obs.shape[0] != x_obs.shape[0]:
|
|
74
|
+
raise ValueError((x_obs.shape, y_obs.shape))
|
|
75
|
+
if y_obs.shape[1] != self._num_metrics:
|
|
76
|
+
raise ValueError((y_obs.shape, self._num_metrics))
|
|
77
|
+
|
|
78
|
+
n = int(x_obs.shape[0])
|
|
79
|
+
if n == 0:
|
|
80
|
+
self._y_min = None
|
|
81
|
+
self._y_max = None
|
|
82
|
+
self._tr.restart()
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
prev_n = int(self._tr.prev_num_obs)
|
|
86
|
+
if n < prev_n:
|
|
87
|
+
raise ValueError((n, prev_n))
|
|
88
|
+
|
|
89
|
+
y_min_all = y_obs.min(axis=0)
|
|
90
|
+
y_max_all = y_obs.max(axis=0)
|
|
91
|
+
y_min_prev = y_obs[:prev_n].min(axis=0) if prev_n > 0 else y_min_all
|
|
92
|
+
y_max_prev = y_obs[:prev_n].max(axis=0) if prev_n > 0 else y_max_all
|
|
93
|
+
|
|
94
|
+
self._y_min = y_min_all
|
|
95
|
+
self._y_max = y_max_all
|
|
96
|
+
|
|
97
|
+
if prev_n == 0:
|
|
98
|
+
values = np.asarray(self.scalarize(y_obs, clip=True), dtype=float)
|
|
99
|
+
if values.shape != (n,):
|
|
100
|
+
raise RuntimeError((values.shape, n))
|
|
101
|
+
self._tr.update(values)
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
if not np.isfinite(self._tr.best_value):
|
|
105
|
+
raise RuntimeError(self._tr.best_value)
|
|
106
|
+
|
|
107
|
+
values_old = self._scalarize_with_ranges(
|
|
108
|
+
y_obs, y_min=y_min_prev, y_max=y_max_prev, clip=True
|
|
109
|
+
)
|
|
110
|
+
values_old = np.asarray(values_old, dtype=float)
|
|
111
|
+
if values_old.shape != (n,):
|
|
112
|
+
raise RuntimeError((values_old.shape, n))
|
|
113
|
+
|
|
114
|
+
incumbent_old = float(np.max(values_old[:prev_n]))
|
|
115
|
+
self._tr.best_value = incumbent_old
|
|
116
|
+
if prev_n == n:
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
self._tr.update(values_old)
|
|
120
|
+
|
|
121
|
+
def scalarize(self, y: np.ndarray | Any, *, clip: bool) -> np.ndarray:
|
|
122
|
+
import numpy as np
|
|
123
|
+
|
|
124
|
+
y = np.asarray(y, dtype=float)
|
|
125
|
+
if y.ndim != 2 or y.shape[1] != self._num_metrics:
|
|
126
|
+
raise ValueError(y.shape)
|
|
127
|
+
if self._y_min is None or self._y_max is None:
|
|
128
|
+
raise RuntimeError("scalarize called before any observations")
|
|
129
|
+
|
|
130
|
+
return self._scalarize_with_ranges(
|
|
131
|
+
y, y_min=self._y_min, y_max=self._y_max, clip=clip
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def _scalarize_with_ranges(
|
|
135
|
+
self,
|
|
136
|
+
y: np.ndarray | Any,
|
|
137
|
+
*,
|
|
138
|
+
y_min: np.ndarray,
|
|
139
|
+
y_max: np.ndarray,
|
|
140
|
+
clip: bool,
|
|
141
|
+
) -> np.ndarray:
|
|
142
|
+
import numpy as np
|
|
143
|
+
|
|
144
|
+
y = np.asarray(y, dtype=float)
|
|
145
|
+
if y.ndim != 2 or y.shape[1] != self._num_metrics:
|
|
146
|
+
raise ValueError(y.shape)
|
|
147
|
+
y_min = np.asarray(y_min, dtype=float).reshape(-1)
|
|
148
|
+
y_max = np.asarray(y_max, dtype=float).reshape(-1)
|
|
149
|
+
if y_min.shape != (self._num_metrics,) or y_max.shape != (self._num_metrics,):
|
|
150
|
+
raise ValueError((y_min.shape, y_max.shape, self._num_metrics))
|
|
151
|
+
|
|
152
|
+
denom = y_max - y_min
|
|
153
|
+
is_deg = denom <= 0.0
|
|
154
|
+
denom_safe = np.where(is_deg, 1.0, denom)
|
|
155
|
+
z = (y - y_min.reshape(1, -1)) / denom_safe.reshape(1, -1)
|
|
156
|
+
z = np.where(is_deg, 0.5, z)
|
|
157
|
+
if clip:
|
|
158
|
+
z = np.clip(z, 0.0, 1.0)
|
|
159
|
+
t = z * self._weights.reshape(1, -1)
|
|
160
|
+
scores = np.min(t, axis=1) + self._alpha * np.sum(t, axis=1)
|
|
161
|
+
return scores
|
|
162
|
+
|
|
163
|
+
def needs_restart(self) -> bool:
|
|
164
|
+
return self._tr.needs_restart()
|
|
165
|
+
|
|
166
|
+
def restart(self) -> None:
|
|
167
|
+
self._y_min = None
|
|
168
|
+
self._y_max = None
|
|
169
|
+
self._tr.restart()
|
|
170
|
+
|
|
171
|
+
def validate_request(self, num_arms: int, *, is_fallback: bool = False) -> None:
|
|
172
|
+
return self._tr.validate_request(num_arms, is_fallback=is_fallback)
|
|
173
|
+
|
|
174
|
+
def compute_bounds_1d(
|
|
175
|
+
self, x_center: np.ndarray | Any, lengthscales: np.ndarray | None = None
|
|
176
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
177
|
+
return self._tr.compute_bounds_1d(x_center, lengthscales)
|
|
178
|
+
|
|
179
|
+
def generate_candidates(
|
|
180
|
+
self,
|
|
181
|
+
x_center: np.ndarray,
|
|
182
|
+
lengthscales: np.ndarray | None,
|
|
183
|
+
num_candidates: int,
|
|
184
|
+
rng: Generator,
|
|
185
|
+
sobol_engine: QMCEngine,
|
|
186
|
+
) -> np.ndarray:
|
|
187
|
+
return self._tr.generate_candidates(
|
|
188
|
+
x_center, lengthscales, num_candidates, rng, sobol_engine
|
|
189
|
+
)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.random import Generator
|
|
9
|
+
from scipy.stats._qmc import QMCEngine
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class NoTrustRegion:
|
|
14
|
+
num_dim: int
|
|
15
|
+
num_arms: int
|
|
16
|
+
length: float = 1.0
|
|
17
|
+
|
|
18
|
+
def update(self, values: np.ndarray | Any) -> None:
|
|
19
|
+
return
|
|
20
|
+
|
|
21
|
+
def needs_restart(self) -> bool:
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
def restart(self) -> None:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
def validate_request(self, num_arms: int, *, is_fallback: bool = False) -> None:
|
|
28
|
+
if is_fallback:
|
|
29
|
+
if num_arms > self.num_arms:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"num_arms {num_arms} > configured num_arms {self.num_arms}"
|
|
32
|
+
)
|
|
33
|
+
else:
|
|
34
|
+
if num_arms != self.num_arms:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"num_arms {num_arms} != configured num_arms {self.num_arms}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def compute_bounds_1d(
|
|
40
|
+
self, x_center: np.ndarray | Any, lengthscales: np.ndarray | None = None
|
|
41
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
42
|
+
import numpy as np
|
|
43
|
+
|
|
44
|
+
lb = np.zeros_like(x_center, dtype=float)
|
|
45
|
+
ub = np.ones_like(x_center, dtype=float)
|
|
46
|
+
return lb, ub
|
|
47
|
+
|
|
48
|
+
def generate_candidates(
|
|
49
|
+
self,
|
|
50
|
+
x_center: np.ndarray,
|
|
51
|
+
lengthscales: np.ndarray | None,
|
|
52
|
+
num_candidates: int,
|
|
53
|
+
rng: Generator,
|
|
54
|
+
sobol_engine: QMCEngine,
|
|
55
|
+
) -> np.ndarray:
|
|
56
|
+
from .turbo_utils import generate_trust_region_candidates
|
|
57
|
+
|
|
58
|
+
return generate_trust_region_candidates(
|
|
59
|
+
x_center,
|
|
60
|
+
lengthscales,
|
|
61
|
+
num_candidates,
|
|
62
|
+
compute_bounds_1d=self.compute_bounds_1d,
|
|
63
|
+
rng=rng,
|
|
64
|
+
sobol_engine=sobol_engine,
|
|
65
|
+
)
|
enn/turbo/proposal.py
CHANGED
|
@@ -22,6 +22,7 @@ def mk_enn(
|
|
|
22
22
|
k: int,
|
|
23
23
|
num_fit_samples: int | None = None,
|
|
24
24
|
num_fit_candidates: int | None = None,
|
|
25
|
+
scale_x: bool = False,
|
|
25
26
|
rng: Generator | Any | None = None,
|
|
26
27
|
params_warm_start: ENNParams | Any | None = None,
|
|
27
28
|
) -> tuple[EpistemicNearestNeighbors | None, ENNParams | None]:
|
|
@@ -36,10 +37,17 @@ def mk_enn(
|
|
|
36
37
|
if y_obs_array.size == 0:
|
|
37
38
|
return None, None
|
|
38
39
|
|
|
39
|
-
|
|
40
|
+
# Preserve multi-metric shape if present, otherwise reshape to (n, 1)
|
|
41
|
+
if y_obs_array.ndim == 1:
|
|
42
|
+
y = y_obs_array.reshape(-1, 1)
|
|
43
|
+
else:
|
|
44
|
+
y = y_obs_array
|
|
40
45
|
if yvar_obs_list is not None and len(yvar_obs_list) > 0:
|
|
41
46
|
yvar_array = np.asarray(yvar_obs_list, dtype=float)
|
|
42
|
-
|
|
47
|
+
if yvar_array.ndim == 1:
|
|
48
|
+
yvar = yvar_array.reshape(-1, 1)
|
|
49
|
+
else:
|
|
50
|
+
yvar = yvar_array
|
|
43
51
|
else:
|
|
44
52
|
yvar = None
|
|
45
53
|
x_obs_array = np.asarray(x_obs_list, dtype=float)
|
|
@@ -47,6 +55,7 @@ def mk_enn(
|
|
|
47
55
|
x_obs_array,
|
|
48
56
|
y,
|
|
49
57
|
yvar,
|
|
58
|
+
scale_x=scale_x,
|
|
50
59
|
)
|
|
51
60
|
if len(enn_model) == 0:
|
|
52
61
|
return None, None
|
enn/turbo/turbo_config.py
CHANGED
|
@@ -9,20 +9,64 @@ class TurboConfig:
|
|
|
9
9
|
k: int | None = None
|
|
10
10
|
num_candidates: int | None = None
|
|
11
11
|
num_init: int | None = None
|
|
12
|
-
var_scale: float = 1.0
|
|
13
12
|
|
|
14
13
|
# Experimental
|
|
15
14
|
trailing_obs: int | None = None
|
|
15
|
+
tr_type: Literal["turbo", "morbo", "none"] = "turbo"
|
|
16
|
+
num_metrics: int | None = None
|
|
17
|
+
|
|
18
|
+
def __post_init__(self) -> None:
|
|
19
|
+
if self.tr_type not in ["turbo", "morbo", "none"]:
|
|
20
|
+
raise ValueError(
|
|
21
|
+
f"tr_type must be 'turbo', 'morbo', or 'none', got {self.tr_type!r}"
|
|
22
|
+
)
|
|
23
|
+
if self.num_metrics is not None and self.num_metrics < 1:
|
|
24
|
+
raise ValueError(f"num_metrics must be >= 1, got {self.num_metrics}")
|
|
25
|
+
if self.tr_type == "turbo":
|
|
26
|
+
if self.num_metrics is not None and self.num_metrics != 1:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"num_metrics must be 1 for tr_type='turbo', got {self.num_metrics}"
|
|
29
|
+
)
|
|
30
|
+
if self.tr_type == "none":
|
|
31
|
+
if self.num_metrics is not None and self.num_metrics != 1:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"num_metrics must be 1 for tr_type='none', got {self.num_metrics}"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class TurboOneConfig(TurboConfig):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class TurboZeroConfig(TurboConfig):
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class LHDOnlyConfig(TurboConfig):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class TurboENNConfig(TurboConfig):
|
|
54
|
+
acq_type: Literal["thompson", "pareto", "ucb"] = "pareto"
|
|
16
55
|
num_fit_samples: int | None = None
|
|
17
56
|
num_fit_candidates: int | None = None
|
|
18
|
-
|
|
19
|
-
local_only: bool = False
|
|
57
|
+
scale_x: bool = False
|
|
20
58
|
|
|
21
59
|
def __post_init__(self) -> None:
|
|
60
|
+
super().__post_init__()
|
|
22
61
|
if self.acq_type not in ["thompson", "pareto", "ucb"]:
|
|
23
62
|
raise ValueError(
|
|
24
63
|
f"acq_type must be 'thompson', 'pareto', or 'ucb', got {self.acq_type!r}"
|
|
25
64
|
)
|
|
26
|
-
# Pareto acquisition is the only type that works well without hyperparameter fitting
|
|
27
65
|
if self.num_fit_samples is None and self.acq_type != "pareto":
|
|
28
66
|
raise ValueError(f"num_fit_samples required for acq_type={self.acq_type!r}")
|
|
67
|
+
if self.num_fit_samples is not None and int(self.num_fit_samples) <= 0:
|
|
68
|
+
raise ValueError(f"num_fit_samples must be > 0, got {self.num_fit_samples}")
|
|
69
|
+
if self.num_fit_candidates is not None and int(self.num_fit_candidates) <= 0:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"num_fit_candidates must be > 0, got {self.num_fit_candidates}"
|
|
72
|
+
)
|
enn/turbo/turbo_enn_impl.py
CHANGED
|
@@ -7,11 +7,11 @@ if TYPE_CHECKING:
|
|
|
7
7
|
from numpy.random import Generator
|
|
8
8
|
|
|
9
9
|
from .base_turbo_impl import BaseTurboImpl
|
|
10
|
-
from .turbo_config import
|
|
10
|
+
from .turbo_config import TurboENNConfig
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TurboENNImpl(BaseTurboImpl):
|
|
14
|
-
def __init__(self, config:
|
|
14
|
+
def __init__(self, config: TurboENNConfig) -> None:
|
|
15
15
|
super().__init__(config)
|
|
16
16
|
self._enn: Any | None = None
|
|
17
17
|
self._fitted_params: Any | None = None
|
|
@@ -22,6 +22,7 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
22
22
|
x_obs_list: list,
|
|
23
23
|
y_obs_list: list,
|
|
24
24
|
rng: Generator,
|
|
25
|
+
tr_state: Any = None,
|
|
25
26
|
) -> np.ndarray | None:
|
|
26
27
|
import numpy as np
|
|
27
28
|
|
|
@@ -30,7 +31,7 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
30
31
|
if len(y_obs_list) == 0:
|
|
31
32
|
return None
|
|
32
33
|
if self._enn is None or self._fitted_params is None:
|
|
33
|
-
return super().get_x_center(x_obs_list, y_obs_list, rng)
|
|
34
|
+
return super().get_x_center(x_obs_list, y_obs_list, rng, tr_state)
|
|
34
35
|
if self._fitted_n_obs != len(x_obs_list):
|
|
35
36
|
raise RuntimeError(
|
|
36
37
|
f"ENN fitted on {self._fitted_n_obs} obs but get_x_center called with {len(x_obs_list)}"
|
|
@@ -38,17 +39,40 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
38
39
|
|
|
39
40
|
y_array = np.asarray(y_obs_list, dtype=float)
|
|
40
41
|
x_array = np.asarray(x_obs_list, dtype=float)
|
|
41
|
-
|
|
42
42
|
k = self._config.k if self._config.k is not None else 10
|
|
43
|
-
num_top = min(k, len(y_array))
|
|
44
|
-
top_indices = np.argpartition(-y_array, num_top - 1)[:num_top]
|
|
45
43
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
44
|
+
# For morbo: top-k per metric → union → scalarize mu
|
|
45
|
+
if self._config.tr_type == "morbo" and tr_state is not None:
|
|
46
|
+
if y_array.ndim == 1:
|
|
47
|
+
y_array = y_array.reshape(-1, tr_state.num_metrics)
|
|
48
|
+
num_metrics = y_array.shape[1]
|
|
49
|
+
|
|
50
|
+
# Find top-k indices for each metric and take union
|
|
51
|
+
union_indices = set()
|
|
52
|
+
for m in range(num_metrics):
|
|
53
|
+
num_top = min(k, len(y_array))
|
|
54
|
+
top_m = np.argpartition(-y_array[:, m], num_top - 1)[:num_top]
|
|
55
|
+
union_indices.update(top_m.tolist())
|
|
56
|
+
union_indices = np.array(sorted(union_indices), dtype=int)
|
|
57
|
+
|
|
58
|
+
x_union = x_array[union_indices]
|
|
59
|
+
posterior = self._enn.posterior(x_union, params=self._fitted_params)
|
|
60
|
+
mu = posterior.mu # (len(union), num_metrics)
|
|
61
|
+
|
|
62
|
+
scalarized = tr_state.scalarize(mu, clip=False)
|
|
63
|
+
best_idx_in_union = argmax_random_tie(scalarized, rng=rng)
|
|
64
|
+
return x_union[best_idx_in_union]
|
|
65
|
+
else:
|
|
66
|
+
# Single-objective: original logic
|
|
67
|
+
num_top = min(k, len(y_array))
|
|
68
|
+
top_indices = np.argpartition(-y_array, num_top - 1)[:num_top]
|
|
69
|
+
|
|
70
|
+
x_top = x_array[top_indices]
|
|
71
|
+
posterior = self._enn.posterior(x_top, params=self._fitted_params)
|
|
72
|
+
mu = posterior.mu[:, 0]
|
|
49
73
|
|
|
50
|
-
|
|
51
|
-
|
|
74
|
+
best_idx_in_top = argmax_random_tie(mu, rng=rng)
|
|
75
|
+
return x_top[best_idx_in_top]
|
|
52
76
|
|
|
53
77
|
def needs_tr_list(self) -> bool:
|
|
54
78
|
return True
|
|
@@ -85,6 +109,7 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
85
109
|
k=k,
|
|
86
110
|
num_fit_samples=self._config.num_fit_samples,
|
|
87
111
|
num_fit_candidates=self._config.num_fit_candidates,
|
|
112
|
+
scale_x=self._config.scale_x,
|
|
88
113
|
rng=rng,
|
|
89
114
|
params_warm_start=self._fitted_params,
|
|
90
115
|
)
|
|
@@ -99,6 +124,7 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
99
124
|
rng: Generator,
|
|
100
125
|
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
101
126
|
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
127
|
+
tr_state: Any = None, # noqa: ARG002
|
|
102
128
|
) -> np.ndarray:
|
|
103
129
|
import numpy as np
|
|
104
130
|
|
|
@@ -106,7 +132,6 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
106
132
|
|
|
107
133
|
acq_type = self._config.acq_type
|
|
108
134
|
k = self._config.k
|
|
109
|
-
var_scale = self._config.var_scale
|
|
110
135
|
|
|
111
136
|
if self._enn is None:
|
|
112
137
|
return fallback_fn(x_cand, num_arms)
|
|
@@ -115,9 +140,7 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
115
140
|
params = self._fitted_params
|
|
116
141
|
else:
|
|
117
142
|
k_val = k if k is not None else 10
|
|
118
|
-
params = ENNParams(
|
|
119
|
-
k=k_val, epi_var_scale=var_scale, ale_homoscedastic_scale=0.0
|
|
120
|
-
)
|
|
143
|
+
params = ENNParams(k=k_val, epi_var_scale=1.0, ale_homoscedastic_scale=0.0)
|
|
121
144
|
|
|
122
145
|
posterior = self._enn.posterior(x_cand, params=params)
|
|
123
146
|
mu = posterior.mu[:, 0]
|
|
@@ -155,6 +178,9 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
155
178
|
if self._enn is None or self._fitted_params is None:
|
|
156
179
|
return y_observed
|
|
157
180
|
posterior = self._enn.posterior(x_unit, params=self._fitted_params)
|
|
181
|
+
# For multi-metric (morbo), return full mu; for single-metric, return 1D
|
|
182
|
+
if posterior.mu.shape[1] > 1:
|
|
183
|
+
return posterior.mu
|
|
158
184
|
return posterior.mu[:, 0]
|
|
159
185
|
|
|
160
186
|
def get_mu_sigma(self, x_unit: np.ndarray) -> tuple[np.ndarray, np.ndarray] | None:
|
|
@@ -166,11 +192,10 @@ class TurboENNImpl(BaseTurboImpl):
|
|
|
166
192
|
params = (
|
|
167
193
|
self._fitted_params
|
|
168
194
|
if self._fitted_params is not None
|
|
169
|
-
else ENNParams(
|
|
170
|
-
k=k, epi_var_scale=self._config.var_scale, ale_homoscedastic_scale=0.0
|
|
171
|
-
)
|
|
195
|
+
else ENNParams(k=k, epi_var_scale=1.0, ale_homoscedastic_scale=0.0)
|
|
172
196
|
)
|
|
173
197
|
posterior = self._enn.posterior(x_unit, params=params, observation_noise=False)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
198
|
+
# For multi-metric (morbo), return full mu/sigma; for single-metric, return 1D
|
|
199
|
+
if posterior.mu.shape[1] > 1:
|
|
200
|
+
return posterior.mu, posterior.se
|
|
201
|
+
return posterior.mu[:, 0], posterior.se[:, 0]
|
enn/turbo/turbo_gp.py
CHANGED
|
@@ -13,17 +13,25 @@ class TurboGP(TurboGPBase):
|
|
|
13
13
|
outputscale_constraint,
|
|
14
14
|
ard_dims: int,
|
|
15
15
|
) -> None:
|
|
16
|
+
import torch
|
|
16
17
|
from gpytorch.kernels import MaternKernel, ScaleKernel
|
|
17
18
|
from gpytorch.means import ConstantMean
|
|
18
19
|
|
|
19
20
|
super().__init__(train_x, train_y, likelihood)
|
|
20
|
-
|
|
21
|
+
batch_shape = (
|
|
22
|
+
torch.Size(train_y.shape[:-1])
|
|
23
|
+
if getattr(train_y, "ndim", 0) > 1
|
|
24
|
+
else torch.Size()
|
|
25
|
+
)
|
|
26
|
+
self.mean_module = ConstantMean(batch_shape=batch_shape)
|
|
21
27
|
base_kernel = MaternKernel(
|
|
22
28
|
nu=2.5,
|
|
23
29
|
ard_num_dims=ard_dims,
|
|
30
|
+
batch_shape=batch_shape,
|
|
24
31
|
lengthscale_constraint=lengthscale_constraint,
|
|
25
32
|
)
|
|
26
33
|
self.covar_module = ScaleKernel(
|
|
27
34
|
base_kernel,
|
|
35
|
+
batch_shape=batch_shape,
|
|
28
36
|
outputscale_constraint=outputscale_constraint,
|
|
29
37
|
)
|
enn/turbo/turbo_mode_impl.py
CHANGED
|
@@ -13,18 +13,25 @@ class TurboModeImpl(Protocol):
|
|
|
13
13
|
x_obs_list: list,
|
|
14
14
|
y_obs_list: list,
|
|
15
15
|
rng: Generator,
|
|
16
|
+
tr_state: Any = None,
|
|
16
17
|
) -> np.ndarray | None: ...
|
|
17
18
|
|
|
18
19
|
def needs_tr_list(self) -> bool: ...
|
|
19
20
|
|
|
20
|
-
def create_trust_region(
|
|
21
|
+
def create_trust_region(
|
|
22
|
+
self,
|
|
23
|
+
num_dim: int,
|
|
24
|
+
num_arms: int,
|
|
25
|
+
rng: Generator,
|
|
26
|
+
num_metrics: int | None = None,
|
|
27
|
+
) -> Any: ...
|
|
21
28
|
|
|
22
29
|
def try_early_ask(
|
|
23
30
|
self,
|
|
24
31
|
num_arms: int,
|
|
25
32
|
x_obs_list: list,
|
|
26
33
|
draw_initial_fn: Callable[[int], np.ndarray],
|
|
27
|
-
get_init_lhd_points_fn: Callable[[int], np.ndarray
|
|
34
|
+
get_init_lhd_points_fn: Callable[[int], np.ndarray],
|
|
28
35
|
) -> np.ndarray | None: ...
|
|
29
36
|
|
|
30
37
|
def handle_restart(
|
|
@@ -54,11 +61,13 @@ class TurboModeImpl(Protocol):
|
|
|
54
61
|
rng: Generator,
|
|
55
62
|
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
56
63
|
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
64
|
+
tr_state: Any = None,
|
|
57
65
|
) -> np.ndarray: ...
|
|
58
66
|
|
|
59
67
|
def update_trust_region(
|
|
60
68
|
self,
|
|
61
69
|
tr_state: Any,
|
|
70
|
+
x_obs_list: list,
|
|
62
71
|
y_obs_list: list,
|
|
63
72
|
x_center: np.ndarray | None = None,
|
|
64
73
|
k: int | None = None,
|