ennbo 0.1.0__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enn/__init__.py +25 -13
- enn/benchmarks/__init__.py +3 -0
- enn/benchmarks/ackley.py +5 -0
- enn/benchmarks/ackley_class.py +17 -0
- enn/benchmarks/ackley_core.py +12 -0
- enn/benchmarks/double_ackley.py +24 -0
- enn/enn/candidates.py +14 -0
- enn/enn/conditional_posterior_draw_internals.py +15 -0
- enn/enn/draw_internals.py +15 -0
- enn/enn/enn.py +16 -229
- enn/enn/enn_class.py +423 -0
- enn/enn/enn_conditional.py +325 -0
- enn/enn/enn_fit.py +77 -76
- enn/enn/enn_hash.py +79 -0
- enn/enn/enn_index.py +92 -0
- enn/enn/enn_like_protocol.py +35 -0
- enn/enn/enn_normal.py +3 -3
- enn/enn/enn_params.py +3 -9
- enn/enn/enn_params_class.py +24 -0
- enn/enn/enn_util.py +79 -37
- enn/enn/neighbor_data.py +14 -0
- enn/enn/neighbors.py +14 -0
- enn/enn/posterior_flags.py +8 -0
- enn/enn/weighted_stats.py +14 -0
- enn/turbo/components/__init__.py +41 -0
- enn/turbo/components/acquisition.py +13 -0
- enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
- enn/turbo/components/builder.py +22 -0
- enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
- enn/turbo/components/enn_surrogate.py +115 -0
- enn/turbo/components/gp_surrogate.py +144 -0
- enn/turbo/components/hnr_acq_optimizer.py +83 -0
- enn/turbo/components/incumbent_selector.py +11 -0
- enn/turbo/components/incumbent_selector_protocol.py +16 -0
- enn/turbo/components/no_incumbent_selector.py +21 -0
- enn/turbo/components/no_surrogate.py +49 -0
- enn/turbo/components/pareto_acq_optimizer.py +49 -0
- enn/turbo/components/posterior_result.py +12 -0
- enn/turbo/components/protocols.py +13 -0
- enn/turbo/components/random_acq_optimizer.py +21 -0
- enn/turbo/components/scalar_incumbent_selector.py +39 -0
- enn/turbo/components/surrogate_protocol.py +32 -0
- enn/turbo/components/surrogate_result.py +12 -0
- enn/turbo/components/surrogates.py +5 -0
- enn/turbo/components/thompson_acq_optimizer.py +49 -0
- enn/turbo/components/trust_region_protocol.py +24 -0
- enn/turbo/components/ucb_acq_optimizer.py +49 -0
- enn/turbo/config/__init__.py +87 -0
- enn/turbo/config/acq_type.py +8 -0
- enn/turbo/config/acquisition.py +26 -0
- enn/turbo/config/base.py +4 -0
- enn/turbo/config/candidate_gen_config.py +49 -0
- enn/turbo/config/candidate_rv.py +7 -0
- enn/turbo/config/draw_acquisition_config.py +14 -0
- enn/turbo/config/enn_index_driver.py +6 -0
- enn/turbo/config/enn_surrogate_config.py +42 -0
- enn/turbo/config/enums.py +7 -0
- enn/turbo/config/factory.py +118 -0
- enn/turbo/config/gp_surrogate_config.py +14 -0
- enn/turbo/config/hnr_optimizer_config.py +7 -0
- enn/turbo/config/init_config.py +17 -0
- enn/turbo/config/init_strategies/__init__.py +9 -0
- enn/turbo/config/init_strategies/hybrid_init.py +23 -0
- enn/turbo/config/init_strategies/init_strategy.py +19 -0
- enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
- enn/turbo/config/morbo_tr_config.py +82 -0
- enn/turbo/config/nds_optimizer_config.py +7 -0
- enn/turbo/config/no_surrogate_config.py +14 -0
- enn/turbo/config/no_tr_config.py +31 -0
- enn/turbo/config/optimizer_config.py +72 -0
- enn/turbo/config/pareto_acquisition_config.py +14 -0
- enn/turbo/config/raasp_driver.py +6 -0
- enn/turbo/config/raasp_optimizer_config.py +7 -0
- enn/turbo/config/random_acquisition_config.py +14 -0
- enn/turbo/config/rescalarize.py +7 -0
- enn/turbo/config/surrogate.py +12 -0
- enn/turbo/config/trust_region.py +34 -0
- enn/turbo/config/turbo_tr_config.py +71 -0
- enn/turbo/config/ucb_acquisition_config.py +14 -0
- enn/turbo/config/validation.py +45 -0
- enn/turbo/hypervolume.py +30 -0
- enn/turbo/impl_helpers.py +68 -0
- enn/turbo/morbo_trust_region.py +250 -0
- enn/turbo/no_trust_region.py +58 -0
- enn/turbo/optimizer.py +300 -0
- enn/turbo/optimizer_config.py +8 -0
- enn/turbo/proposal.py +46 -39
- enn/turbo/sampling.py +21 -0
- enn/turbo/strategies/__init__.py +9 -0
- enn/turbo/strategies/lhd_only_strategy.py +36 -0
- enn/turbo/strategies/optimization_strategy.py +19 -0
- enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
- enn/turbo/tr_helpers.py +202 -0
- enn/turbo/turbo_gp.py +9 -2
- enn/turbo/turbo_gp_base.py +0 -1
- enn/turbo/turbo_gp_fit.py +187 -0
- enn/turbo/turbo_gp_noisy.py +0 -1
- enn/turbo/turbo_optimizer_utils.py +98 -0
- enn/turbo/turbo_trust_region.py +129 -63
- enn/turbo/turbo_utils.py +144 -117
- enn/turbo/types/__init__.py +7 -0
- enn/turbo/types/appendable_array.py +85 -0
- enn/turbo/types/gp_data_prep.py +13 -0
- enn/turbo/types/gp_fit_result.py +11 -0
- enn/turbo/types/obs_lists.py +10 -0
- enn/turbo/types/prepare_ask_result.py +14 -0
- enn/turbo/types/tell_inputs.py +14 -0
- {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/METADATA +22 -14
- ennbo-0.1.7.dist-info/RECORD +111 -0
- enn/enn/__init__.py +0 -4
- enn/turbo/__init__.py +0 -11
- enn/turbo/base_turbo_impl.py +0 -98
- enn/turbo/lhd_only_impl.py +0 -42
- enn/turbo/turbo_config.py +0 -28
- enn/turbo/turbo_enn_impl.py +0 -176
- enn/turbo/turbo_mode.py +0 -10
- enn/turbo/turbo_mode_impl.py +0 -67
- enn/turbo/turbo_one_impl.py +0 -163
- enn/turbo/turbo_optimizer.py +0 -337
- enn/turbo/turbo_zero_impl.py +0 -24
- ennbo-0.1.0.dist-info/RECORD +0 -27
- {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
- {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
enn/turbo/turbo_optimizer.py
DELETED
|
@@ -1,337 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Callable
|
|
5
|
-
|
|
6
|
-
from .proposal import select_uniform
|
|
7
|
-
from .turbo_config import TurboConfig
|
|
8
|
-
from .turbo_utils import from_unit, latin_hypercube, to_unit
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@dataclass(frozen=True)
|
|
12
|
-
class Telemetry:
|
|
13
|
-
dt_fit: float
|
|
14
|
-
dt_sel: float
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
if TYPE_CHECKING:
|
|
18
|
-
import numpy as np
|
|
19
|
-
from numpy.random import Generator
|
|
20
|
-
|
|
21
|
-
from .turbo_mode import TurboMode
|
|
22
|
-
from .turbo_mode_impl import TurboModeImpl
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class TurboOptimizer:
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
bounds: np.ndarray,
|
|
29
|
-
mode: TurboMode,
|
|
30
|
-
*,
|
|
31
|
-
rng: Generator,
|
|
32
|
-
config: TurboConfig | None = None,
|
|
33
|
-
) -> None:
|
|
34
|
-
import numpy as np
|
|
35
|
-
from scipy.stats import qmc
|
|
36
|
-
|
|
37
|
-
from .turbo_mode import TurboMode
|
|
38
|
-
|
|
39
|
-
if config is None:
|
|
40
|
-
config = TurboConfig()
|
|
41
|
-
self._config = config
|
|
42
|
-
|
|
43
|
-
if bounds.ndim != 2 or bounds.shape[1] != 2:
|
|
44
|
-
raise ValueError(bounds.shape)
|
|
45
|
-
self._bounds = np.asarray(bounds, dtype=float)
|
|
46
|
-
self._num_dim = self._bounds.shape[0]
|
|
47
|
-
self._mode = mode
|
|
48
|
-
num_candidates = config.num_candidates
|
|
49
|
-
if num_candidates is None:
|
|
50
|
-
num_candidates = min(5000, 100 * self._num_dim)
|
|
51
|
-
|
|
52
|
-
self._num_candidates = int(num_candidates)
|
|
53
|
-
if self._num_candidates <= 0:
|
|
54
|
-
raise ValueError(self._num_candidates)
|
|
55
|
-
self._rng = rng
|
|
56
|
-
sobol_seed = int(self._rng.integers(1_000_000))
|
|
57
|
-
self._sobol_engine = qmc.Sobol(d=self._num_dim, scramble=True, seed=sobol_seed)
|
|
58
|
-
self._x_obs_list: list = []
|
|
59
|
-
self._y_obs_list: list = []
|
|
60
|
-
self._yvar_obs_list: list = []
|
|
61
|
-
match mode:
|
|
62
|
-
case TurboMode.TURBO_ONE:
|
|
63
|
-
from .turbo_one_impl import TurboOneImpl
|
|
64
|
-
|
|
65
|
-
self._mode_impl: TurboModeImpl = TurboOneImpl(config)
|
|
66
|
-
case TurboMode.TURBO_ZERO:
|
|
67
|
-
from .turbo_zero_impl import TurboZeroImpl
|
|
68
|
-
|
|
69
|
-
self._mode_impl = TurboZeroImpl(config)
|
|
70
|
-
case TurboMode.TURBO_ENN:
|
|
71
|
-
from .turbo_enn_impl import TurboENNImpl
|
|
72
|
-
|
|
73
|
-
self._mode_impl = TurboENNImpl(config)
|
|
74
|
-
case TurboMode.LHD_ONLY:
|
|
75
|
-
from .lhd_only_impl import LHDOnlyImpl
|
|
76
|
-
|
|
77
|
-
self._mode_impl = LHDOnlyImpl(config)
|
|
78
|
-
case _:
|
|
79
|
-
raise ValueError(f"Unknown mode: {mode}")
|
|
80
|
-
self._tr_state: Any | None = None
|
|
81
|
-
self._gp_num_steps: int = 50
|
|
82
|
-
if config.k is not None:
|
|
83
|
-
k_val = int(config.k)
|
|
84
|
-
if k_val < 3:
|
|
85
|
-
raise ValueError(f"k must be >= 3, got {k_val}")
|
|
86
|
-
self._k = k_val
|
|
87
|
-
else:
|
|
88
|
-
self._k = None
|
|
89
|
-
if config.trailing_obs is not None:
|
|
90
|
-
trailing_obs_val = int(config.trailing_obs)
|
|
91
|
-
if trailing_obs_val <= 0:
|
|
92
|
-
raise ValueError(f"trailing_obs must be > 0, got {trailing_obs_val}")
|
|
93
|
-
self._trailing_obs = trailing_obs_val
|
|
94
|
-
else:
|
|
95
|
-
self._trailing_obs = None
|
|
96
|
-
num_init = config.num_init
|
|
97
|
-
if num_init is None:
|
|
98
|
-
num_init = 2 * self._num_dim
|
|
99
|
-
num_init_val = int(num_init)
|
|
100
|
-
if num_init_val <= 0:
|
|
101
|
-
raise ValueError(f"num_init must be > 0, got {num_init_val}")
|
|
102
|
-
self._num_init = num_init_val
|
|
103
|
-
if config.local_only:
|
|
104
|
-
center = 0.5 * (self._bounds[:, 0] + self._bounds[:, 1])
|
|
105
|
-
self._init_lhd = center.reshape(1, -1)
|
|
106
|
-
self._num_init = 1
|
|
107
|
-
else:
|
|
108
|
-
self._init_lhd = from_unit(
|
|
109
|
-
latin_hypercube(self._num_init, self._num_dim, rng=self._rng),
|
|
110
|
-
self._bounds,
|
|
111
|
-
)
|
|
112
|
-
self._init_idx = 0
|
|
113
|
-
self._dt_fit: float = 0.0
|
|
114
|
-
self._dt_sel: float = 0.0
|
|
115
|
-
self._local_only = config.local_only
|
|
116
|
-
|
|
117
|
-
@property
|
|
118
|
-
def tr_obs_count(self) -> int:
|
|
119
|
-
return len(self._y_obs_list)
|
|
120
|
-
|
|
121
|
-
@property
|
|
122
|
-
def best_tr_value(self) -> float | None:
|
|
123
|
-
import numpy as np
|
|
124
|
-
|
|
125
|
-
if len(self._y_obs_list) == 0:
|
|
126
|
-
return None
|
|
127
|
-
return float(np.max(self._y_obs_list))
|
|
128
|
-
|
|
129
|
-
@property
|
|
130
|
-
def tr_length(self) -> float | None:
|
|
131
|
-
if self._tr_state is None:
|
|
132
|
-
return None
|
|
133
|
-
return float(self._tr_state.length)
|
|
134
|
-
|
|
135
|
-
def telemetry(self) -> Telemetry:
|
|
136
|
-
return Telemetry(dt_fit=self._dt_fit, dt_sel=self._dt_sel)
|
|
137
|
-
|
|
138
|
-
def ask(self, num_arms: int) -> np.ndarray:
|
|
139
|
-
num_arms = int(num_arms)
|
|
140
|
-
if num_arms <= 0:
|
|
141
|
-
raise ValueError(num_arms)
|
|
142
|
-
if self._tr_state is None:
|
|
143
|
-
self._tr_state = self._mode_impl.create_trust_region(
|
|
144
|
-
self._num_dim, num_arms
|
|
145
|
-
)
|
|
146
|
-
if self._local_only:
|
|
147
|
-
self._tr_state.length_max = 0.1
|
|
148
|
-
self._tr_state.length = min(self._tr_state.length, 0.1)
|
|
149
|
-
self._tr_state.length_init = min(self._tr_state.length_init, 0.1)
|
|
150
|
-
early_result = self._mode_impl.try_early_ask(
|
|
151
|
-
num_arms,
|
|
152
|
-
self._x_obs_list,
|
|
153
|
-
self._draw_initial,
|
|
154
|
-
self._get_init_lhd_points,
|
|
155
|
-
)
|
|
156
|
-
if early_result is not None:
|
|
157
|
-
self._dt_fit = 0.0
|
|
158
|
-
self._dt_sel = 0.0
|
|
159
|
-
return early_result
|
|
160
|
-
if self._init_idx < self._num_init:
|
|
161
|
-
if len(self._x_obs_list) == 0:
|
|
162
|
-
fallback_fn = None
|
|
163
|
-
else:
|
|
164
|
-
|
|
165
|
-
def fallback_fn(n: int) -> np.ndarray:
|
|
166
|
-
return self._ask_normal(n, is_fallback=True)
|
|
167
|
-
|
|
168
|
-
self._dt_fit = 0.0
|
|
169
|
-
self._dt_sel = 0.0
|
|
170
|
-
return self._get_init_lhd_points(num_arms, fallback_fn=fallback_fn)
|
|
171
|
-
if len(self._x_obs_list) == 0:
|
|
172
|
-
self._dt_fit = 0.0
|
|
173
|
-
self._dt_sel = 0.0
|
|
174
|
-
return self._draw_initial(num_arms)
|
|
175
|
-
return self._ask_normal(num_arms)
|
|
176
|
-
|
|
177
|
-
def _ask_normal(self, num_arms: int, *, is_fallback: bool = False) -> np.ndarray:
|
|
178
|
-
import numpy as np
|
|
179
|
-
|
|
180
|
-
if self._tr_state.needs_restart():
|
|
181
|
-
self._tr_state.restart()
|
|
182
|
-
should_reset_init, new_init_idx = self._mode_impl.handle_restart(
|
|
183
|
-
self._x_obs_list,
|
|
184
|
-
self._y_obs_list,
|
|
185
|
-
self._yvar_obs_list,
|
|
186
|
-
self._init_idx,
|
|
187
|
-
self._num_init,
|
|
188
|
-
)
|
|
189
|
-
if should_reset_init:
|
|
190
|
-
self._init_idx = new_init_idx
|
|
191
|
-
self._init_lhd = from_unit(
|
|
192
|
-
latin_hypercube(self._num_init, self._num_dim, rng=self._rng),
|
|
193
|
-
self._bounds,
|
|
194
|
-
)
|
|
195
|
-
return self._get_init_lhd_points(num_arms)
|
|
196
|
-
|
|
197
|
-
def from_unit_fn(x):
|
|
198
|
-
return from_unit(x, self._bounds)
|
|
199
|
-
|
|
200
|
-
if self._mode_impl.needs_tr_list() and len(self._x_obs_list) == 0:
|
|
201
|
-
return self._get_init_lhd_points(num_arms)
|
|
202
|
-
|
|
203
|
-
import time
|
|
204
|
-
|
|
205
|
-
t0_fit = time.perf_counter()
|
|
206
|
-
_gp_model, _gp_y_mean_fitted, _gp_y_std_fitted, weights = (
|
|
207
|
-
self._mode_impl.prepare_ask(
|
|
208
|
-
self._x_obs_list,
|
|
209
|
-
self._y_obs_list,
|
|
210
|
-
self._yvar_obs_list,
|
|
211
|
-
self._num_dim,
|
|
212
|
-
self._gp_num_steps,
|
|
213
|
-
rng=self._rng,
|
|
214
|
-
)
|
|
215
|
-
)
|
|
216
|
-
self._dt_fit = time.perf_counter() - t0_fit
|
|
217
|
-
|
|
218
|
-
x_center = self._mode_impl.get_x_center(
|
|
219
|
-
self._x_obs_list, self._y_obs_list, self._rng
|
|
220
|
-
)
|
|
221
|
-
if x_center is None:
|
|
222
|
-
if len(self._y_obs_list) == 0:
|
|
223
|
-
raise RuntimeError("no observations")
|
|
224
|
-
x_center = np.full(self._num_dim, 0.5)
|
|
225
|
-
|
|
226
|
-
x_cand = self._tr_state.generate_candidates(
|
|
227
|
-
x_center,
|
|
228
|
-
weights,
|
|
229
|
-
self._num_candidates,
|
|
230
|
-
self._rng,
|
|
231
|
-
self._sobol_engine,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
def fallback_fn(x, n):
|
|
235
|
-
return select_uniform(x, n, self._num_dim, self._rng, from_unit_fn)
|
|
236
|
-
|
|
237
|
-
self._tr_state.validate_request(num_arms, is_fallback=is_fallback)
|
|
238
|
-
|
|
239
|
-
t0_sel = time.perf_counter()
|
|
240
|
-
selected = self._mode_impl.select_candidates(
|
|
241
|
-
x_cand,
|
|
242
|
-
num_arms,
|
|
243
|
-
self._num_dim,
|
|
244
|
-
self._rng,
|
|
245
|
-
fallback_fn,
|
|
246
|
-
from_unit_fn,
|
|
247
|
-
)
|
|
248
|
-
self._dt_sel = time.perf_counter() - t0_sel
|
|
249
|
-
|
|
250
|
-
self._mode_impl.update_trust_region(
|
|
251
|
-
self._tr_state, self._y_obs_list, x_center=x_center, k=self._k
|
|
252
|
-
)
|
|
253
|
-
return selected
|
|
254
|
-
|
|
255
|
-
def _trim_trailing_obs(self) -> None:
|
|
256
|
-
import numpy as np
|
|
257
|
-
|
|
258
|
-
from .turbo_utils import argmax_random_tie
|
|
259
|
-
|
|
260
|
-
if len(self._x_obs_list) <= self._trailing_obs:
|
|
261
|
-
return
|
|
262
|
-
y_array = np.asarray(self._y_obs_list, dtype=float)
|
|
263
|
-
incumbent_idx = argmax_random_tie(y_array, rng=self._rng)
|
|
264
|
-
num_total = len(self._x_obs_list)
|
|
265
|
-
start_idx = max(0, num_total - self._trailing_obs)
|
|
266
|
-
if incumbent_idx < start_idx:
|
|
267
|
-
indices = np.array(
|
|
268
|
-
[incumbent_idx]
|
|
269
|
-
+ list(range(num_total - (self._trailing_obs - 1), num_total)),
|
|
270
|
-
dtype=int,
|
|
271
|
-
)
|
|
272
|
-
else:
|
|
273
|
-
indices = np.arange(start_idx, num_total, dtype=int)
|
|
274
|
-
if incumbent_idx not in indices:
|
|
275
|
-
raise RuntimeError("Incumbent must be included in trimmed list")
|
|
276
|
-
x_array = np.asarray(self._x_obs_list, dtype=float)
|
|
277
|
-
incumbent_value = y_array[incumbent_idx]
|
|
278
|
-
self._x_obs_list = x_array[indices].tolist()
|
|
279
|
-
self._y_obs_list = y_array[indices].tolist()
|
|
280
|
-
if len(self._yvar_obs_list) == len(y_array):
|
|
281
|
-
yvar_array = np.asarray(self._yvar_obs_list, dtype=float)
|
|
282
|
-
self._yvar_obs_list = yvar_array[indices].tolist()
|
|
283
|
-
y_trimmed = np.asarray(self._y_obs_list, dtype=float)
|
|
284
|
-
if not np.any(np.abs(y_trimmed - incumbent_value) < 1e-10):
|
|
285
|
-
raise RuntimeError("Incumbent value must be preserved in trimmed list")
|
|
286
|
-
|
|
287
|
-
def tell(
|
|
288
|
-
self,
|
|
289
|
-
x: np.ndarray | Any,
|
|
290
|
-
y: np.ndarray | Any,
|
|
291
|
-
y_var: np.ndarray | Any | None = None,
|
|
292
|
-
) -> np.ndarray:
|
|
293
|
-
import numpy as np
|
|
294
|
-
|
|
295
|
-
x = np.asarray(x, dtype=float)
|
|
296
|
-
y = np.asarray(y, dtype=float)
|
|
297
|
-
if x.ndim != 2 or x.shape[1] != self._num_dim:
|
|
298
|
-
raise ValueError(x.shape)
|
|
299
|
-
if y.ndim != 1 or y.shape[0] != x.shape[0]:
|
|
300
|
-
raise ValueError((x.shape, y.shape))
|
|
301
|
-
if y_var is not None:
|
|
302
|
-
y_var = np.asarray(y_var, dtype=float)
|
|
303
|
-
if y_var.shape != y.shape:
|
|
304
|
-
raise ValueError((y.shape, y_var.shape))
|
|
305
|
-
if x.shape[0] == 0:
|
|
306
|
-
return np.array([], dtype=float)
|
|
307
|
-
x_unit = to_unit(x, self._bounds)
|
|
308
|
-
y_estimate = self._mode_impl.estimate_y(x_unit, y)
|
|
309
|
-
self._x_obs_list.extend(x_unit.tolist())
|
|
310
|
-
self._y_obs_list.extend(y.tolist())
|
|
311
|
-
if y_var is not None:
|
|
312
|
-
self._yvar_obs_list.extend(y_var.tolist())
|
|
313
|
-
if self._trailing_obs is not None:
|
|
314
|
-
self._trim_trailing_obs()
|
|
315
|
-
self._mode_impl.update_trust_region(self._tr_state, self._y_obs_list)
|
|
316
|
-
return y_estimate
|
|
317
|
-
|
|
318
|
-
def _draw_initial(self, num_arms: int) -> np.ndarray:
|
|
319
|
-
unit = latin_hypercube(num_arms, self._num_dim, rng=self._rng)
|
|
320
|
-
return from_unit(unit, self._bounds)
|
|
321
|
-
|
|
322
|
-
def _get_init_lhd_points(
|
|
323
|
-
self, num_arms: int, fallback_fn: Callable[[int], np.ndarray] | None = None
|
|
324
|
-
) -> np.ndarray:
|
|
325
|
-
import numpy as np
|
|
326
|
-
|
|
327
|
-
remaining_init = self._num_init - self._init_idx
|
|
328
|
-
num_to_return = min(num_arms, remaining_init)
|
|
329
|
-
result = self._init_lhd[self._init_idx : self._init_idx + num_to_return]
|
|
330
|
-
self._init_idx += num_to_return
|
|
331
|
-
if num_to_return < num_arms:
|
|
332
|
-
num_remaining = num_arms - num_to_return
|
|
333
|
-
if fallback_fn is not None:
|
|
334
|
-
result = np.vstack([result, fallback_fn(num_remaining)])
|
|
335
|
-
else:
|
|
336
|
-
result = np.vstack([result, self._draw_initial(num_remaining)])
|
|
337
|
-
return result
|
enn/turbo/turbo_zero_impl.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import TYPE_CHECKING, Callable
|
|
4
|
-
|
|
5
|
-
if TYPE_CHECKING:
|
|
6
|
-
import numpy as np
|
|
7
|
-
from numpy.random import Generator
|
|
8
|
-
|
|
9
|
-
from .base_turbo_impl import BaseTurboImpl
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class TurboZeroImpl(BaseTurboImpl):
|
|
13
|
-
def select_candidates(
|
|
14
|
-
self,
|
|
15
|
-
x_cand: np.ndarray,
|
|
16
|
-
num_arms: int,
|
|
17
|
-
num_dim: int,
|
|
18
|
-
rng: Generator,
|
|
19
|
-
fallback_fn: Callable[[np.ndarray, int], np.ndarray],
|
|
20
|
-
from_unit_fn: Callable[[np.ndarray], np.ndarray],
|
|
21
|
-
) -> np.ndarray:
|
|
22
|
-
from .proposal import select_uniform
|
|
23
|
-
|
|
24
|
-
return select_uniform(x_cand, num_arms, num_dim, rng, from_unit_fn)
|
ennbo-0.1.0.dist-info/RECORD
DELETED
|
@@ -1,27 +0,0 @@
|
|
|
1
|
-
enn/__init__.py,sha256=VYIuOTCjhUFIJm78IoJv0WXtvA_IuZhY1sSMJJM3dx8,507
|
|
2
|
-
enn/enn/__init__.py,sha256=K3rntg_ZkITStmXMTBcEhxeS1kel1bb7wB_C7-2WE5Y,135
|
|
3
|
-
enn/enn/enn.py,sha256=ZdDPivZj4SL9e87FolU1oscdPdcwUeIByIrvBLsoCfE,8060
|
|
4
|
-
enn/enn/enn_fit.py,sha256=uv1BHO-nbxVXkR_tM1Ggoh6YNuR-VrjVECFxLquC7u8,4328
|
|
5
|
-
enn/enn/enn_normal.py,sha256=3kOymSx2kzcBMavScXLflPm_gDDLGF9fYLBJ816I3xg,596
|
|
6
|
-
enn/enn/enn_params.py,sha256=fwLZTA8ciRp4XUF5L_VAVsC3EvFuOzR85OYLVtv6TSw,184
|
|
7
|
-
enn/enn/enn_util.py,sha256=ZELPVeyUl0wiHOxjHYKjxeDz88ExmKMeX3P-bQ6tCoE,3075
|
|
8
|
-
enn/turbo/__init__.py,sha256=utnD3CLZgjCvw-46AAu5Tv2M2Vbg5YXK-_TycGk5BU4,197
|
|
9
|
-
enn/turbo/base_turbo_impl.py,sha256=wThjwXGboRrVTamsnvzmM0WNIOZ91GNJ-BmGzjgqdhg,2699
|
|
10
|
-
enn/turbo/lhd_only_impl.py,sha256=yWsOw7Oq0xfEnyXg5AXJSzZFjM7162pqNY37fHQtJQ4,1023
|
|
11
|
-
enn/turbo/proposal.py,sha256=w1izo3ooiiravNRoFWK5ZK7BH-f_HWgqYP8heVtLmYs,3977
|
|
12
|
-
enn/turbo/turbo_config.py,sha256=J0ww_qKDDMpbFVXdntuSbJtUTbdnXrFJyGD1svzG3RM,980
|
|
13
|
-
enn/turbo/turbo_enn_impl.py,sha256=YMAS4krpPXPNtlh46RRG3VLMuGyYLFw5UkPRBU29mzA,5837
|
|
14
|
-
enn/turbo/turbo_gp.py,sha256=i1bxVHima0Nv4MCLlADtlRzt1cENcnVLYk3S9vCoF4c,797
|
|
15
|
-
enn/turbo/turbo_gp_base.py,sha256=tnE5uX_eAt1Db-gemyy83ZvKpdNbMg_tsWkh6sG7zaM,638
|
|
16
|
-
enn/turbo/turbo_gp_noisy.py,sha256=itTL9jUCjE566jwDODT0P36fozsfU_bXACyuKqxYMXs,1080
|
|
17
|
-
enn/turbo/turbo_mode.py,sha256=JMP1jkFCRwPtOzU95MWWd04Sgze7eKF0xNkiPqtQ8SI,181
|
|
18
|
-
enn/turbo/turbo_mode_impl.py,sha256=3HKBjOS96Wn-R_znctQm9Ivrm3FhgZFTuBp7McNDQ88,1749
|
|
19
|
-
enn/turbo/turbo_one_impl.py,sha256=nS02RdRMcEsi3II07jzcrQbsFsfWYTeahUcqoyhig4Q,5207
|
|
20
|
-
enn/turbo/turbo_optimizer.py,sha256=IlofW9_ogCeQMVXa7n8xWEg5fbJBUkvAkeLKe3MoXlA,11902
|
|
21
|
-
enn/turbo/turbo_trust_region.py,sha256=VHNYKWtKLt3iKHI0enL9qMMu1Bwi1nupo20L0Sv-vYY,3759
|
|
22
|
-
enn/turbo/turbo_utils.py,sha256=XU9-YtW1u5-HKk3bA_M-hVNFPAuNcIYozAmej7ulVsY,7532
|
|
23
|
-
enn/turbo/turbo_zero_impl.py,sha256=S4TEHYkVDowtyWSVxWO0ncd1OUIFpeV3IR-eanGr1vg,643
|
|
24
|
-
ennbo-0.1.0.dist-info/METADATA,sha256=slkhtsGXaO31u8w35LNKXN2noxUJYTqHQF7bv1DZMmA,5930
|
|
25
|
-
ennbo-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
26
|
-
ennbo-0.1.0.dist-info/licenses/LICENSE,sha256=KTA0NjGalsl_JGrjT_x6SSq9ZYVO3gQ-hLVMEaekc5w,1070
|
|
27
|
-
ennbo-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|