ennbo 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enn/enn/enn.py +71 -31
- enn/enn/enn_fit.py +26 -24
- enn/enn/enn_normal.py +3 -2
- enn/enn/enn_params.py +13 -0
- enn/enn/enn_util.py +40 -12
- enn/turbo/base_turbo_impl.py +53 -7
- enn/turbo/lhd_only_impl.py +7 -0
- enn/turbo/morbo_trust_region.py +189 -0
- enn/turbo/no_trust_region.py +65 -0
- enn/turbo/proposal.py +11 -2
- enn/turbo/turbo_config.py +48 -4
- enn/turbo/turbo_enn_impl.py +46 -21
- enn/turbo/turbo_gp.py +9 -1
- enn/turbo/turbo_mode_impl.py +11 -2
- enn/turbo/turbo_one_impl.py +163 -24
- enn/turbo/turbo_optimizer.py +246 -58
- enn/turbo/turbo_trust_region.py +8 -10
- enn/turbo/turbo_utils.py +116 -26
- enn/turbo/turbo_zero_impl.py +5 -0
- {ennbo-0.1.0.dist-info → ennbo-0.1.2.dist-info}/METADATA +5 -4
- ennbo-0.1.2.dist-info/RECORD +29 -0
- ennbo-0.1.0.dist-info/RECORD +0 -27
- {ennbo-0.1.0.dist-info → ennbo-0.1.2.dist-info}/WHEEL +0 -0
- {ennbo-0.1.0.dist-info → ennbo-0.1.2.dist-info}/licenses/LICENSE +0 -0
enn/enn/enn.py
CHANGED
|
@@ -12,12 +12,16 @@ if TYPE_CHECKING:
|
|
|
12
12
|
class EpistemicNearestNeighbors:
|
|
13
13
|
def __init__(
|
|
14
14
|
self,
|
|
15
|
-
train_x: np.ndarray
|
|
16
|
-
train_y: np.ndarray
|
|
17
|
-
train_yvar: np.ndarray |
|
|
15
|
+
train_x: np.ndarray,
|
|
16
|
+
train_y: np.ndarray,
|
|
17
|
+
train_yvar: np.ndarray | None = None,
|
|
18
|
+
*,
|
|
19
|
+
scale_x: bool = False,
|
|
18
20
|
) -> None:
|
|
19
21
|
import numpy as np
|
|
20
22
|
|
|
23
|
+
train_x = np.asarray(train_x, dtype=float)
|
|
24
|
+
train_y = np.asarray(train_y, dtype=float)
|
|
21
25
|
if train_x.ndim != 2:
|
|
22
26
|
raise ValueError(train_x.shape)
|
|
23
27
|
if train_y.ndim != 2:
|
|
@@ -25,22 +29,41 @@ class EpistemicNearestNeighbors:
|
|
|
25
29
|
if train_x.shape[0] != train_y.shape[0]:
|
|
26
30
|
raise ValueError((train_x.shape, train_y.shape))
|
|
27
31
|
if train_yvar is not None:
|
|
32
|
+
train_yvar = np.asarray(train_yvar, dtype=float)
|
|
28
33
|
if train_yvar.ndim != 2:
|
|
29
34
|
raise ValueError(train_yvar.shape)
|
|
30
35
|
if train_y.shape != train_yvar.shape:
|
|
31
36
|
raise ValueError((train_y.shape, train_yvar.shape))
|
|
32
|
-
|
|
33
|
-
self.
|
|
34
|
-
self.
|
|
35
|
-
|
|
36
|
-
)
|
|
37
|
+
|
|
38
|
+
self._train_x = train_x
|
|
39
|
+
self._train_y = train_y
|
|
40
|
+
self._train_yvar = train_yvar
|
|
37
41
|
self._num_obs, self._num_dim = self._train_x.shape
|
|
38
42
|
_, self._num_metrics = self._train_y.shape
|
|
39
43
|
self._eps_var = 1e-9
|
|
44
|
+
self._scale_x = bool(scale_x)
|
|
45
|
+
if self._scale_x:
|
|
46
|
+
if len(self._train_x) < 2:
|
|
47
|
+
x_scale = np.ones((1, self._num_dim), dtype=float)
|
|
48
|
+
else:
|
|
49
|
+
x_scale = np.std(self._train_x, axis=0, keepdims=True).astype(float)
|
|
50
|
+
x_scale = np.where(
|
|
51
|
+
np.isfinite(x_scale) & (x_scale > 1e-12),
|
|
52
|
+
x_scale,
|
|
53
|
+
1.0,
|
|
54
|
+
)
|
|
55
|
+
self._x_scale = x_scale
|
|
56
|
+
self._train_x_scaled = self._train_x / self._x_scale
|
|
57
|
+
else:
|
|
58
|
+
self._x_scale = np.ones((1, self._num_dim), dtype=float)
|
|
59
|
+
self._train_x_scaled = self._train_x
|
|
40
60
|
if len(self._train_y) < 2:
|
|
41
61
|
self._y_scale = np.ones(shape=(1, self._num_metrics), dtype=float)
|
|
42
62
|
else:
|
|
43
|
-
|
|
63
|
+
y_scale = np.std(self._train_y, axis=0, keepdims=True).astype(float)
|
|
64
|
+
self._y_scale = np.where(
|
|
65
|
+
np.isfinite(y_scale) & (y_scale > 0.0), y_scale, 1.0
|
|
66
|
+
)
|
|
44
67
|
|
|
45
68
|
self._index: Any | None = None
|
|
46
69
|
self._build_index()
|
|
@@ -70,14 +93,42 @@ class EpistemicNearestNeighbors:
|
|
|
70
93
|
|
|
71
94
|
if self._num_obs == 0:
|
|
72
95
|
return
|
|
73
|
-
x_f32 = self.
|
|
96
|
+
x_f32 = self._train_x_scaled.astype(np.float32, copy=False)
|
|
74
97
|
index = faiss.IndexFlatL2(self._num_dim)
|
|
75
98
|
index.add(x_f32)
|
|
76
99
|
self._index = index
|
|
77
100
|
|
|
101
|
+
def _search_index(
|
|
102
|
+
self,
|
|
103
|
+
x: np.ndarray,
|
|
104
|
+
*,
|
|
105
|
+
search_k: int,
|
|
106
|
+
exclude_nearest: bool,
|
|
107
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
108
|
+
import numpy as np
|
|
109
|
+
|
|
110
|
+
search_k = int(search_k)
|
|
111
|
+
if search_k <= 0:
|
|
112
|
+
raise ValueError(search_k)
|
|
113
|
+
x = np.asarray(x, dtype=float)
|
|
114
|
+
if x.ndim != 2 or x.shape[1] != self._num_dim:
|
|
115
|
+
raise ValueError(x.shape)
|
|
116
|
+
if self._index is None:
|
|
117
|
+
raise RuntimeError("index is not initialized")
|
|
118
|
+
|
|
119
|
+
x_scaled = x / self._x_scale if self._scale_x else x
|
|
120
|
+
x_f32 = x_scaled.astype(np.float32, copy=False)
|
|
121
|
+
dist2s_full, idx_full = self._index.search(x_f32, search_k)
|
|
122
|
+
dist2s_full = dist2s_full.astype(float)
|
|
123
|
+
idx_full = idx_full.astype(int)
|
|
124
|
+
if exclude_nearest:
|
|
125
|
+
dist2s_full = dist2s_full[:, 1:]
|
|
126
|
+
idx_full = idx_full[:, 1:]
|
|
127
|
+
return dist2s_full, idx_full
|
|
128
|
+
|
|
78
129
|
def posterior(
|
|
79
130
|
self,
|
|
80
|
-
x: np.ndarray
|
|
131
|
+
x: np.ndarray,
|
|
81
132
|
*,
|
|
82
133
|
params: ENNParams,
|
|
83
134
|
exclude_nearest: bool = False,
|
|
@@ -97,7 +148,7 @@ class EpistemicNearestNeighbors:
|
|
|
97
148
|
|
|
98
149
|
def batch_posterior(
|
|
99
150
|
self,
|
|
100
|
-
x: np.ndarray
|
|
151
|
+
x: np.ndarray,
|
|
101
152
|
paramss: list[ENNParams],
|
|
102
153
|
*,
|
|
103
154
|
exclude_nearest: bool = False,
|
|
@@ -107,6 +158,7 @@ class EpistemicNearestNeighbors:
|
|
|
107
158
|
|
|
108
159
|
from .enn_normal import ENNNormal
|
|
109
160
|
|
|
161
|
+
x = np.asarray(x, dtype=float)
|
|
110
162
|
if x.ndim != 2:
|
|
111
163
|
raise ValueError(x.shape)
|
|
112
164
|
if x.shape[1] != self._num_dim:
|
|
@@ -126,15 +178,9 @@ class EpistemicNearestNeighbors:
|
|
|
126
178
|
search_k = int(min(max_k + 1, len(self)))
|
|
127
179
|
else:
|
|
128
180
|
search_k = int(min(max_k, len(self)))
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
dist2s_full, idx_full = self._index.search(x_f32, search_k)
|
|
133
|
-
dist2s_full = dist2s_full.astype(float)
|
|
134
|
-
idx_full = idx_full.astype(int)
|
|
135
|
-
if exclude_nearest:
|
|
136
|
-
dist2s_full = dist2s_full[:, 1:]
|
|
137
|
-
idx_full = idx_full[:, 1:]
|
|
181
|
+
dist2s_full, idx_full = self._search_index(
|
|
182
|
+
x, search_k=search_k, exclude_nearest=exclude_nearest
|
|
183
|
+
)
|
|
138
184
|
mu_all = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
|
|
139
185
|
se_all = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
|
|
140
186
|
available_k = search_k - 1 if exclude_nearest else search_k
|
|
@@ -178,7 +224,7 @@ class EpistemicNearestNeighbors:
|
|
|
178
224
|
|
|
179
225
|
def neighbors(
|
|
180
226
|
self,
|
|
181
|
-
x: np.ndarray
|
|
227
|
+
x: np.ndarray,
|
|
182
228
|
k: int,
|
|
183
229
|
*,
|
|
184
230
|
exclude_nearest: bool = False,
|
|
@@ -210,15 +256,9 @@ class EpistemicNearestNeighbors:
|
|
|
210
256
|
search_k = int(min(k, len(self)))
|
|
211
257
|
if search_k == 0:
|
|
212
258
|
return []
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
dist2s_full, idx_full = self._index.search(x_f32, search_k)
|
|
217
|
-
dist2s_full = dist2s_full.astype(float)
|
|
218
|
-
idx_full = idx_full.astype(int)
|
|
219
|
-
if exclude_nearest:
|
|
220
|
-
dist2s_full = dist2s_full[:, 1:]
|
|
221
|
-
idx_full = idx_full[:, 1:]
|
|
259
|
+
dist2s_full, idx_full = self._search_index(
|
|
260
|
+
x, search_k=search_k, exclude_nearest=exclude_nearest
|
|
261
|
+
)
|
|
222
262
|
actual_k = min(k, len(idx_full[0]))
|
|
223
263
|
idx = idx_full[0, :actual_k]
|
|
224
264
|
result = []
|
enn/enn/enn_fit.py
CHANGED
|
@@ -9,8 +9,6 @@ if TYPE_CHECKING:
|
|
|
9
9
|
from .enn import EpistemicNearestNeighbors
|
|
10
10
|
from .enn_params import ENNParams
|
|
11
11
|
|
|
12
|
-
from .enn_util import standardize_y
|
|
13
|
-
|
|
14
12
|
|
|
15
13
|
def subsample_loglik(
|
|
16
14
|
model: EpistemicNearestNeighbors | Any,
|
|
@@ -23,17 +21,21 @@ def subsample_loglik(
|
|
|
23
21
|
) -> list[float]:
|
|
24
22
|
import numpy as np
|
|
25
23
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
if
|
|
31
|
-
|
|
24
|
+
x_array = np.asarray(x, dtype=float)
|
|
25
|
+
if x_array.ndim != 2:
|
|
26
|
+
raise ValueError(x_array.shape)
|
|
27
|
+
y_array = np.asarray(y, dtype=float)
|
|
28
|
+
if y_array.ndim == 1:
|
|
29
|
+
y_array = y_array.reshape(-1, 1)
|
|
30
|
+
if y_array.ndim != 2:
|
|
31
|
+
raise ValueError(y_array.shape)
|
|
32
|
+
if x_array.shape[0] != y_array.shape[0]:
|
|
33
|
+
raise ValueError((x_array.shape, y_array.shape))
|
|
32
34
|
if P <= 0:
|
|
33
35
|
raise ValueError(P)
|
|
34
36
|
if len(paramss) == 0:
|
|
35
37
|
raise ValueError("paramss must be non-empty")
|
|
36
|
-
n =
|
|
38
|
+
n = x_array.shape[0]
|
|
37
39
|
if n == 0:
|
|
38
40
|
return [0.0] * len(paramss)
|
|
39
41
|
if len(model) <= 1:
|
|
@@ -43,8 +45,8 @@ def subsample_loglik(
|
|
|
43
45
|
indices = np.arange(n, dtype=int)
|
|
44
46
|
else:
|
|
45
47
|
indices = rng.permutation(n)[:P_actual]
|
|
46
|
-
x_selected =
|
|
47
|
-
y_selected =
|
|
48
|
+
x_selected = x_array[indices]
|
|
49
|
+
y_selected = y_array[indices]
|
|
48
50
|
if not np.isfinite(y_selected).all():
|
|
49
51
|
return [0.0] * len(paramss)
|
|
50
52
|
post_batch = model.batch_posterior(
|
|
@@ -52,16 +54,22 @@ def subsample_loglik(
|
|
|
52
54
|
)
|
|
53
55
|
mu_batch = post_batch.mu
|
|
54
56
|
se_batch = post_batch.se
|
|
55
|
-
if mu_batch.shape[2] == 1:
|
|
56
|
-
mu_batch = mu_batch[:, :, 0]
|
|
57
|
-
se_batch = se_batch[:, :, 0]
|
|
58
57
|
num_params = len(paramss)
|
|
59
|
-
|
|
58
|
+
num_outputs = y_selected.shape[1]
|
|
59
|
+
if mu_batch.shape != (num_params, P_actual, num_outputs) or se_batch.shape != (
|
|
60
60
|
num_params,
|
|
61
61
|
P_actual,
|
|
62
|
+
num_outputs,
|
|
62
63
|
):
|
|
63
|
-
raise ValueError(
|
|
64
|
-
|
|
64
|
+
raise ValueError(
|
|
65
|
+
(
|
|
66
|
+
mu_batch.shape,
|
|
67
|
+
se_batch.shape,
|
|
68
|
+
(num_params, P_actual, num_outputs),
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
y_std = np.std(y_array, axis=0, keepdims=True).astype(float)
|
|
72
|
+
y_std = np.where(np.isfinite(y_std) & (y_std > 0.0), y_std, 1.0)
|
|
65
73
|
y_scaled = y_selected / y_std
|
|
66
74
|
mu_scaled = mu_batch / y_std
|
|
67
75
|
se_scaled = se_batch / y_std
|
|
@@ -100,12 +108,6 @@ def enn_fit(
|
|
|
100
108
|
|
|
101
109
|
train_x = model.train_x
|
|
102
110
|
train_y = model.train_y
|
|
103
|
-
train_yvar = model.train_yvar
|
|
104
|
-
if train_y.shape[1] != 1:
|
|
105
|
-
raise ValueError(train_y.shape)
|
|
106
|
-
if train_yvar is not None and train_yvar.shape[1] != 1:
|
|
107
|
-
raise ValueError(train_yvar.shape)
|
|
108
|
-
y = train_y[:, 0]
|
|
109
111
|
log_min = -3.0
|
|
110
112
|
log_max = 3.0
|
|
111
113
|
epi_var_scale_log_values = rng.uniform(log_min, log_max, size=num_fit_candidates)
|
|
@@ -135,7 +137,7 @@ def enn_fit(
|
|
|
135
137
|
import numpy as np
|
|
136
138
|
|
|
137
139
|
logliks = subsample_loglik(
|
|
138
|
-
model, train_x,
|
|
140
|
+
model, train_x, train_y, paramss=paramss, P=num_fit_samples, rng=rng
|
|
139
141
|
)
|
|
140
142
|
if len(logliks) == 0:
|
|
141
143
|
return paramss[0]
|
enn/enn/enn_normal.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
|
|
5
5
|
|
|
6
6
|
if TYPE_CHECKING:
|
|
7
7
|
import numpy as np
|
|
8
|
+
from numpy.random import Generator
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@dataclass
|
|
@@ -15,8 +16,8 @@ class ENNNormal:
|
|
|
15
16
|
def sample(
|
|
16
17
|
self,
|
|
17
18
|
num_samples: int,
|
|
18
|
-
rng,
|
|
19
|
-
clip=None,
|
|
19
|
+
rng: Generator,
|
|
20
|
+
clip: float | None = None,
|
|
20
21
|
) -> np.ndarray:
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
enn/enn/enn_params.py
CHANGED
|
@@ -8,3 +8,16 @@ class ENNParams:
|
|
|
8
8
|
k: int
|
|
9
9
|
epi_var_scale: float
|
|
10
10
|
ale_homoscedastic_scale: float
|
|
11
|
+
|
|
12
|
+
def __post_init__(self) -> None:
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
k = int(self.k)
|
|
16
|
+
if k <= 0:
|
|
17
|
+
raise ValueError(f"k must be > 0, got {k}")
|
|
18
|
+
epi_var_scale = float(self.epi_var_scale)
|
|
19
|
+
if not np.isfinite(epi_var_scale) or epi_var_scale < 0.0:
|
|
20
|
+
raise ValueError(f"epi_var_scale must be >= 0, got {epi_var_scale}")
|
|
21
|
+
ale_scale = float(self.ale_homoscedastic_scale)
|
|
22
|
+
if not np.isfinite(ale_scale) or ale_scale < 0.0:
|
|
23
|
+
raise ValueError(f"ale_homoscedastic_scale must be >= 0, got {ale_scale}")
|
enn/enn/enn_util.py
CHANGED
|
@@ -67,7 +67,6 @@ def arms_from_pareto_fronts(
|
|
|
67
67
|
rng: Generator | Any,
|
|
68
68
|
) -> np.ndarray:
|
|
69
69
|
import numpy as np
|
|
70
|
-
from nds import ndomsort
|
|
71
70
|
|
|
72
71
|
if x_cand.ndim != 2:
|
|
73
72
|
raise ValueError(x_cand.shape)
|
|
@@ -75,22 +74,51 @@ def arms_from_pareto_fronts(
|
|
|
75
74
|
raise ValueError((mu.shape, se.shape))
|
|
76
75
|
if mu.size != x_cand.shape[0]:
|
|
77
76
|
raise ValueError((mu.size, x_cand.shape[0]))
|
|
77
|
+
num_arms = int(num_arms)
|
|
78
|
+
if num_arms <= 0:
|
|
79
|
+
raise ValueError(num_arms)
|
|
80
|
+
if not np.all(np.isfinite(mu)) or not np.all(np.isfinite(se)):
|
|
81
|
+
raise ValueError("mu and se must be finite")
|
|
78
82
|
|
|
79
|
-
|
|
80
|
-
|
|
83
|
+
def _pareto_front_2d_maximize(
|
|
84
|
+
mu_: np.ndarray, se_: np.ndarray, idx: np.ndarray
|
|
85
|
+
) -> np.ndarray:
|
|
86
|
+
order = np.lexsort((-se_[idx], -mu_[idx]))
|
|
87
|
+
sorted_idx = idx[order]
|
|
88
|
+
keep: list[int] = []
|
|
89
|
+
best_se = -float("inf")
|
|
90
|
+
last_mu = float("nan")
|
|
91
|
+
last_se = float("nan")
|
|
92
|
+
for i in sorted_idx.tolist():
|
|
93
|
+
s = float(se_[i])
|
|
94
|
+
m = float(mu_[i])
|
|
95
|
+
if s > best_se:
|
|
96
|
+
keep.append(i)
|
|
97
|
+
best_se = s
|
|
98
|
+
last_mu = m
|
|
99
|
+
last_se = s
|
|
100
|
+
elif s == best_se and m == last_mu and s == last_se:
|
|
101
|
+
keep.append(i)
|
|
102
|
+
return np.asarray(keep, dtype=int)
|
|
81
103
|
|
|
82
104
|
i_keep: list[int] = []
|
|
83
|
-
|
|
84
|
-
|
|
105
|
+
remaining = np.arange(mu.size, dtype=int)
|
|
106
|
+
while remaining.size > 0 and len(i_keep) < num_arms:
|
|
107
|
+
front_indices = _pareto_front_2d_maximize(mu, se, remaining)
|
|
108
|
+
if front_indices.size == 0:
|
|
109
|
+
raise RuntimeError("pareto front extraction failed")
|
|
85
110
|
front_indices = front_indices[np.argsort(-mu[front_indices])]
|
|
86
|
-
if len(i_keep) +
|
|
111
|
+
if len(i_keep) + int(front_indices.size) <= num_arms:
|
|
87
112
|
i_keep.extend(front_indices.tolist())
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
113
|
+
is_front = np.zeros(mu.size, dtype=bool)
|
|
114
|
+
is_front[front_indices] = True
|
|
115
|
+
remaining = remaining[~is_front[remaining]]
|
|
116
|
+
continue
|
|
117
|
+
remaining_arms = num_arms - len(i_keep)
|
|
118
|
+
i_keep.extend(
|
|
119
|
+
rng.choice(front_indices, size=remaining_arms, replace=False).tolist()
|
|
120
|
+
)
|
|
121
|
+
break
|
|
94
122
|
|
|
95
123
|
i_keep = np.array(i_keep)
|
|
96
124
|
return x_cand[i_keep[np.argsort(-mu[i_keep])]]
|
enn/turbo/base_turbo_impl.py
CHANGED
|
@@ -18,6 +18,7 @@ class BaseTurboImpl:
|
|
|
18
18
|
x_obs_list: list,
|
|
19
19
|
y_obs_list: list,
|
|
20
20
|
rng: Generator,
|
|
21
|
+
tr_state: Any = None,
|
|
21
22
|
) -> np.ndarray | None:
|
|
22
23
|
import numpy as np
|
|
23
24
|
|
|
@@ -26,24 +27,58 @@ class BaseTurboImpl:
|
|
|
26
27
|
y_array = np.asarray(y_obs_list, dtype=float)
|
|
27
28
|
if y_array.size == 0:
|
|
28
29
|
return None
|
|
29
|
-
idx = argmax_random_tie(y_array, rng=rng)
|
|
30
30
|
x_array = np.asarray(x_obs_list, dtype=float)
|
|
31
|
+
|
|
32
|
+
# For morbo: scalarize raw y observations
|
|
33
|
+
if self._config.tr_type == "morbo" and tr_state is not None:
|
|
34
|
+
if y_array.ndim == 1:
|
|
35
|
+
y_array = y_array.reshape(-1, tr_state.num_metrics)
|
|
36
|
+
scalarized = tr_state.scalarize(y_array, clip=True)
|
|
37
|
+
idx = argmax_random_tie(scalarized, rng=rng)
|
|
38
|
+
else:
|
|
39
|
+
idx = argmax_random_tie(y_array, rng=rng)
|
|
40
|
+
|
|
31
41
|
return x_array[idx]
|
|
32
42
|
|
|
33
43
|
def needs_tr_list(self) -> bool:
|
|
34
44
|
return False
|
|
35
45
|
|
|
36
|
-
def create_trust_region(
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
46
|
+
def create_trust_region(
|
|
47
|
+
self,
|
|
48
|
+
num_dim: int,
|
|
49
|
+
num_arms: int,
|
|
50
|
+
rng: Generator,
|
|
51
|
+
num_metrics: int | None = None,
|
|
52
|
+
) -> Any:
|
|
53
|
+
if self._config.tr_type == "none":
|
|
54
|
+
from .no_trust_region import NoTrustRegion
|
|
55
|
+
|
|
56
|
+
return NoTrustRegion(num_dim=num_dim, num_arms=num_arms)
|
|
57
|
+
elif self._config.tr_type == "turbo":
|
|
58
|
+
from .turbo_trust_region import TurboTrustRegion
|
|
59
|
+
|
|
60
|
+
return TurboTrustRegion(num_dim=num_dim, num_arms=num_arms)
|
|
61
|
+
elif self._config.tr_type == "morbo":
|
|
62
|
+
from .morbo_trust_region import MorboTrustRegion
|
|
63
|
+
|
|
64
|
+
effective_num_metrics = num_metrics or self._config.num_metrics
|
|
65
|
+
if effective_num_metrics is None:
|
|
66
|
+
raise ValueError("num_metrics required for tr_type='morbo'")
|
|
67
|
+
return MorboTrustRegion(
|
|
68
|
+
num_dim=num_dim,
|
|
69
|
+
num_arms=num_arms,
|
|
70
|
+
num_metrics=effective_num_metrics,
|
|
71
|
+
rng=rng,
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
raise ValueError(f"Unknown tr_type: {self._config.tr_type!r}")
|
|
40
75
|
|
|
41
76
|
def try_early_ask(
|
|
42
77
|
self,
|
|
43
78
|
num_arms: int,
|
|
44
79
|
x_obs_list: list,
|
|
45
80
|
draw_initial_fn: Callable[[int], np.ndarray],
|
|
46
|
-
get_init_lhd_points_fn: Callable[[int], np.ndarray
|
|
81
|
+
get_init_lhd_points_fn: Callable[[int], np.ndarray],
|
|
47
82
|
) -> np.ndarray | None:
|
|
48
83
|
return None
|
|
49
84
|
|
|
@@ -55,6 +90,11 @@ class BaseTurboImpl:
|
|
|
55
90
|
init_idx: int,
|
|
56
91
|
num_init: int,
|
|
57
92
|
) -> tuple[bool, int]:
|
|
93
|
+
if self._config.tr_type == "morbo":
|
|
94
|
+
x_obs_list.clear()
|
|
95
|
+
y_obs_list.clear()
|
|
96
|
+
yvar_obs_list.clear()
|
|
97
|
+
return True, 0
|
|
58
98
|
return False, init_idx
|
|
59
99
|
|
|
60
100
|
def prepare_ask(
|
|
@@ -76,20 +116,26 @@ class BaseTurboImpl:
|
|
|
76
116
|
rng: Generator,
|
|
77
117
|
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
78
118
|
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
119
|
+
tr_state: Any = None,
|
|
79
120
|
) -> np.ndarray:
|
|
80
121
|
raise NotImplementedError("Subclasses must implement select_candidates")
|
|
81
122
|
|
|
82
123
|
def update_trust_region(
|
|
83
124
|
self,
|
|
84
125
|
tr_state: Any,
|
|
126
|
+
x_obs_list: list,
|
|
85
127
|
y_obs_list: list,
|
|
86
128
|
x_center: np.ndarray | None = None,
|
|
87
129
|
k: int | None = None,
|
|
88
130
|
) -> None:
|
|
89
131
|
import numpy as np
|
|
90
132
|
|
|
133
|
+
x_obs_array = np.asarray(x_obs_list, dtype=float)
|
|
91
134
|
y_obs_array = np.asarray(y_obs_list, dtype=float)
|
|
92
|
-
tr_state
|
|
135
|
+
if hasattr(tr_state, "update_xy"):
|
|
136
|
+
tr_state.update_xy(x_obs_array, y_obs_array, k=k)
|
|
137
|
+
else:
|
|
138
|
+
tr_state.update(y_obs_array)
|
|
93
139
|
|
|
94
140
|
def estimate_y(self, x_unit: np.ndarray, y_observed: np.ndarray) -> np.ndarray:
|
|
95
141
|
return y_observed
|
enn/turbo/lhd_only_impl.py
CHANGED
|
@@ -7,14 +7,19 @@ if TYPE_CHECKING:
|
|
|
7
7
|
from numpy.random import Generator
|
|
8
8
|
|
|
9
9
|
from .base_turbo_impl import BaseTurboImpl
|
|
10
|
+
from .turbo_config import LHDOnlyConfig
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class LHDOnlyImpl(BaseTurboImpl):
|
|
14
|
+
def __init__(self, config: LHDOnlyConfig) -> None:
|
|
15
|
+
super().__init__(config)
|
|
16
|
+
|
|
13
17
|
def get_x_center(
|
|
14
18
|
self,
|
|
15
19
|
x_obs_list: list,
|
|
16
20
|
y_obs_list: list,
|
|
17
21
|
rng: Generator,
|
|
22
|
+
tr_state: Any = None,
|
|
18
23
|
) -> np.ndarray | None:
|
|
19
24
|
return None
|
|
20
25
|
|
|
@@ -26,6 +31,7 @@ class LHDOnlyImpl(BaseTurboImpl):
|
|
|
26
31
|
rng: Generator,
|
|
27
32
|
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
28
33
|
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
34
|
+
tr_state: Any = None, # noqa: ARG002
|
|
29
35
|
) -> np.ndarray:
|
|
30
36
|
from .turbo_utils import latin_hypercube
|
|
31
37
|
|
|
@@ -35,6 +41,7 @@ class LHDOnlyImpl(BaseTurboImpl):
|
|
|
35
41
|
def update_trust_region(
|
|
36
42
|
self,
|
|
37
43
|
tr_state: Any,
|
|
44
|
+
x_obs_list: list,
|
|
38
45
|
y_obs_list: list,
|
|
39
46
|
x_center: np.ndarray | None = None,
|
|
40
47
|
k: int | None = None,
|