ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -269
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +69 -70
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +0 -1
  18. enn/enn/enn_params.py +3 -22
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +60 -46
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +131 -70
  84. enn/turbo/no_trust_region.py +32 -39
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +36 -38
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +0 -1
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +126 -58
  100. enn/turbo/turbo_utils.py +98 -161
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -144
  113. enn/turbo/lhd_only_impl.py +0 -49
  114. enn/turbo/turbo_config.py +0 -72
  115. enn/turbo/turbo_enn_impl.py +0 -201
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -76
  118. enn/turbo/turbo_one_impl.py +0 -302
  119. enn/turbo/turbo_optimizer.py +0 -525
  120. enn/turbo/turbo_zero_impl.py +0 -29
  121. ennbo-0.1.2.dist-info/RECORD +0 -29
  122. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,525 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, Any, Callable
5
-
6
- from .proposal import select_uniform
7
- from .turbo_config import (
8
- LHDOnlyConfig,
9
- TurboConfig,
10
- TurboENNConfig,
11
- TurboOneConfig,
12
- TurboZeroConfig,
13
- )
14
- from .turbo_utils import from_unit, latin_hypercube, to_unit
15
-
16
-
17
- @dataclass(frozen=True)
18
- class Telemetry:
19
- dt_fit: float
20
- dt_sel: float
21
-
22
-
23
- if TYPE_CHECKING:
24
- import numpy as np
25
- from numpy.random import Generator
26
-
27
- from .turbo_mode import TurboMode
28
- from .turbo_mode_impl import TurboModeImpl
29
-
30
-
31
- class TurboOptimizer:
32
- def __init__(
33
- self,
34
- bounds: np.ndarray,
35
- mode: TurboMode,
36
- *,
37
- rng: Generator,
38
- config: TurboConfig | None = None,
39
- ) -> None:
40
- import numpy as np
41
-
42
- from .turbo_mode import TurboMode
43
-
44
- if config is None:
45
- match mode:
46
- case TurboMode.TURBO_ONE:
47
- config = TurboOneConfig()
48
- case TurboMode.TURBO_ZERO:
49
- config = TurboZeroConfig()
50
- case TurboMode.TURBO_ENN:
51
- config = TurboENNConfig()
52
- case TurboMode.LHD_ONLY:
53
- config = LHDOnlyConfig()
54
- case _:
55
- raise ValueError(f"Unknown mode: {mode}")
56
- else:
57
- match mode:
58
- case TurboMode.TURBO_ONE:
59
- if not isinstance(config, TurboOneConfig):
60
- raise ValueError(
61
- f"mode={mode} requires TurboOneConfig, got {type(config).__name__}"
62
- )
63
- case TurboMode.TURBO_ZERO:
64
- if not isinstance(config, TurboZeroConfig):
65
- raise ValueError(
66
- f"mode={mode} requires TurboZeroConfig, got {type(config).__name__}"
67
- )
68
- case TurboMode.TURBO_ENN:
69
- if not isinstance(config, TurboENNConfig):
70
- raise ValueError(
71
- f"mode={mode} requires TurboENNConfig, got {type(config).__name__}"
72
- )
73
- case TurboMode.LHD_ONLY:
74
- if not isinstance(config, LHDOnlyConfig):
75
- raise ValueError(
76
- f"mode={mode} requires LHDOnlyConfig, got {type(config).__name__}"
77
- )
78
- case _:
79
- raise ValueError(f"Unknown mode: {mode}")
80
- self._config = config
81
-
82
- bounds = np.asarray(bounds, dtype=float)
83
- if bounds.ndim != 2 or bounds.shape[1] != 2:
84
- raise ValueError(bounds.shape)
85
- self._bounds = bounds
86
- self._num_dim = self._bounds.shape[0]
87
- self._mode = mode
88
- num_candidates = config.num_candidates
89
- if num_candidates is None:
90
- num_candidates = min(5000, 100 * self._num_dim)
91
-
92
- self._num_candidates = int(num_candidates)
93
- if self._num_candidates <= 0:
94
- raise ValueError(self._num_candidates)
95
- self._rng = rng
96
- self._sobol_seed_base = int(self._rng.integers(2**31 - 1))
97
- self._x_obs_list: list[list[float]] = []
98
- self._y_obs_list: list[float] | list[list[float]] = []
99
- self._y_tr_list: list[float] = []
100
- self._yvar_obs_list: list[float] | list[list[float]] = []
101
- self._expects_yvar: bool | None = None
102
- match mode:
103
- case TurboMode.TURBO_ONE:
104
- from .turbo_one_impl import TurboOneImpl
105
-
106
- self._mode_impl: TurboModeImpl = TurboOneImpl(config)
107
- case TurboMode.TURBO_ZERO:
108
- from .turbo_zero_impl import TurboZeroImpl
109
-
110
- self._mode_impl = TurboZeroImpl(config)
111
- case TurboMode.TURBO_ENN:
112
- from .turbo_enn_impl import TurboENNImpl
113
-
114
- self._mode_impl = TurboENNImpl(config)
115
- case TurboMode.LHD_ONLY:
116
- from .lhd_only_impl import LHDOnlyImpl
117
-
118
- self._mode_impl = LHDOnlyImpl(config)
119
- case _:
120
- raise ValueError(f"Unknown mode: {mode}")
121
- self._tr_state: Any | None = None
122
- self._gp_num_steps: int = 50
123
- if config.k is not None:
124
- k_val = int(config.k)
125
- if k_val < 3:
126
- raise ValueError(f"k must be >= 3, got {k_val}")
127
- self._k = k_val
128
- else:
129
- self._k = None
130
- if config.trailing_obs is not None:
131
- trailing_obs_val = int(config.trailing_obs)
132
- if trailing_obs_val <= 0:
133
- raise ValueError(f"trailing_obs must be > 0, got {trailing_obs_val}")
134
- self._trailing_obs = trailing_obs_val
135
- else:
136
- self._trailing_obs = None
137
- num_init = config.num_init
138
- if num_init is None:
139
- num_init = 2 * self._num_dim
140
- num_init_val = int(num_init)
141
- if num_init_val <= 0:
142
- raise ValueError(f"num_init must be > 0, got {num_init_val}")
143
- self._num_init = num_init_val
144
- self._init_lhd = from_unit(
145
- latin_hypercube(self._num_init, self._num_dim, rng=self._rng),
146
- self._bounds,
147
- )
148
- self._init_idx = 0
149
- self._dt_fit: float = 0.0
150
- self._dt_sel: float = 0.0
151
-
152
- def _sobol_seed_for_state(self, *, n_obs: int, num_arms: int) -> int:
153
- mask64 = (1 << 64) - 1
154
-
155
- x = int(self._sobol_seed_base) & mask64
156
- x ^= (int(n_obs) + 1) * 0x9E3779B97F4A7C15 & mask64
157
- x ^= (int(num_arms) + 1) * 0xBF58476D1CE4E5B9 & mask64
158
- x = (x + 0x9E3779B97F4A7C15) & mask64
159
- z = x
160
- z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9 & mask64
161
- z = (z ^ (z >> 27)) * 0x94D049BB133111EB & mask64
162
- z = z ^ (z >> 31)
163
- return int(z & 0xFFFFFFFF)
164
-
165
- @property
166
- def tr_obs_count(self) -> int:
167
- return len(self._y_obs_list)
168
-
169
- @property
170
- def tr_length(self) -> float | None:
171
- if self._tr_state is None:
172
- return None
173
- if not hasattr(self._tr_state, "length"):
174
- return None
175
- return float(self._tr_state.length)
176
-
177
- def telemetry(self) -> Telemetry:
178
- return Telemetry(dt_fit=self._dt_fit, dt_sel=self._dt_sel)
179
-
180
- def ask(self, num_arms: int) -> np.ndarray:
181
- num_arms = int(num_arms)
182
- if num_arms <= 0:
183
- raise ValueError(num_arms)
184
- # For morbo, defer TR creation until tell() when we can infer num_metrics
185
- is_morbo = self._config.tr_type == "morbo"
186
- if self._tr_state is None and not is_morbo:
187
- self._tr_state = self._mode_impl.create_trust_region(
188
- self._num_dim, num_arms, self._rng
189
- )
190
- if self._tr_state is not None:
191
- self._tr_state.validate_request(num_arms)
192
- early_result = self._mode_impl.try_early_ask(
193
- num_arms,
194
- self._x_obs_list,
195
- self._draw_initial,
196
- self._get_init_lhd_points,
197
- )
198
- if early_result is not None:
199
- self._dt_fit = 0.0
200
- self._dt_sel = 0.0
201
- return early_result
202
- if self._init_idx < self._num_init:
203
- if len(self._x_obs_list) == 0:
204
- fallback_fn = None
205
- else:
206
-
207
- def fallback_fn(n: int) -> np.ndarray:
208
- return self._ask_normal(n, is_fallback=True)
209
-
210
- self._dt_fit = 0.0
211
- self._dt_sel = 0.0
212
- return self._get_init_lhd_points(num_arms, fallback_fn=fallback_fn)
213
- if len(self._x_obs_list) == 0:
214
- self._dt_fit = 0.0
215
- self._dt_sel = 0.0
216
- return self._draw_initial(num_arms)
217
- return self._ask_normal(num_arms)
218
-
219
- def _ask_normal(self, num_arms: int, *, is_fallback: bool = False) -> np.ndarray:
220
- import numpy as np
221
- from scipy.stats import qmc
222
-
223
- # For morbo, TR is created in tell() - if still None, return LHD
224
- if self._tr_state is None:
225
- return self._draw_initial(num_arms)
226
-
227
- if self._tr_state.needs_restart():
228
- self._tr_state.restart()
229
- should_reset_init, new_init_idx = self._mode_impl.handle_restart(
230
- self._x_obs_list,
231
- self._y_obs_list,
232
- self._yvar_obs_list,
233
- self._init_idx,
234
- self._num_init,
235
- )
236
- if should_reset_init:
237
- self._y_tr_list = []
238
- self._init_idx = new_init_idx
239
- self._init_lhd = from_unit(
240
- latin_hypercube(self._num_init, self._num_dim, rng=self._rng),
241
- self._bounds,
242
- )
243
- return self._get_init_lhd_points(num_arms)
244
-
245
- def from_unit_fn(x):
246
- return from_unit(x, self._bounds)
247
-
248
- if self._mode_impl.needs_tr_list() and len(self._x_obs_list) == 0:
249
- return self._get_init_lhd_points(num_arms)
250
-
251
- import time
252
-
253
- t0_fit = time.perf_counter()
254
- _gp_model, _gp_y_mean_fitted, _gp_y_std_fitted, lengthscales = (
255
- self._mode_impl.prepare_ask(
256
- self._x_obs_list,
257
- self._y_obs_list,
258
- self._yvar_obs_list,
259
- self._num_dim,
260
- self._gp_num_steps,
261
- rng=self._rng,
262
- )
263
- )
264
- self._dt_fit = time.perf_counter() - t0_fit
265
-
266
- x_center = self._mode_impl.get_x_center(
267
- self._x_obs_list,
268
- self._y_obs_list,
269
- self._rng,
270
- self._tr_state,
271
- )
272
- if x_center is None:
273
- if len(self._y_obs_list) == 0:
274
- raise RuntimeError("no observations")
275
- x_center = np.full(self._num_dim, 0.5)
276
-
277
- sobol_seed = self._sobol_seed_for_state(
278
- n_obs=len(self._x_obs_list),
279
- num_arms=num_arms,
280
- )
281
- sobol_engine = qmc.Sobol(d=self._num_dim, scramble=True, seed=sobol_seed)
282
- x_cand = self._tr_state.generate_candidates(
283
- x_center,
284
- lengthscales,
285
- self._num_candidates,
286
- self._rng,
287
- sobol_engine,
288
- )
289
-
290
- def fallback_fn(x, n):
291
- return select_uniform(x, n, self._num_dim, self._rng, from_unit_fn)
292
-
293
- self._tr_state.validate_request(num_arms, is_fallback=is_fallback)
294
-
295
- t0_sel = time.perf_counter()
296
- selected = self._mode_impl.select_candidates(
297
- x_cand,
298
- num_arms,
299
- self._num_dim,
300
- self._rng,
301
- fallback_fn,
302
- from_unit_fn,
303
- tr_state=self._tr_state,
304
- )
305
- self._dt_sel = time.perf_counter() - t0_sel
306
-
307
- # For morbo, TR is updated in tell() with raw multi-objective y
308
- if self._config.tr_type != "morbo":
309
- self._mode_impl.update_trust_region(
310
- self._tr_state,
311
- self._x_obs_list,
312
- self._y_tr_list,
313
- x_center=x_center,
314
- k=self._k,
315
- )
316
- return selected
317
-
318
- def _trim_trailing_obs(self) -> None:
319
- import numpy as np
320
-
321
- from .turbo_utils import argmax_random_tie
322
-
323
- if len(self._x_obs_list) <= self._trailing_obs:
324
- return
325
- y_tr_array = np.asarray(self._y_tr_list, dtype=float)
326
- incumbent_idx = argmax_random_tie(y_tr_array, rng=self._rng)
327
- num_total = len(self._x_obs_list)
328
- start_idx = max(0, num_total - self._trailing_obs)
329
- if incumbent_idx < start_idx:
330
- indices = np.array(
331
- [incumbent_idx]
332
- + list(range(num_total - (self._trailing_obs - 1), num_total)),
333
- dtype=int,
334
- )
335
- else:
336
- indices = np.arange(start_idx, num_total, dtype=int)
337
- if incumbent_idx not in indices:
338
- raise RuntimeError("Incumbent must be included in trimmed list")
339
- x_array = np.asarray(self._x_obs_list, dtype=float)
340
- incumbent_value = y_tr_array[incumbent_idx]
341
- self._x_obs_list = x_array[indices].tolist()
342
- y_obs_array = np.asarray(self._y_obs_list, dtype=float)
343
- self._y_obs_list = y_obs_array[indices].tolist()
344
- self._y_tr_list = y_tr_array[indices].tolist()
345
- if len(self._yvar_obs_list) == len(y_obs_array):
346
- yvar_array = np.asarray(self._yvar_obs_list, dtype=float)
347
- self._yvar_obs_list = yvar_array[indices].tolist()
348
- y_trimmed = np.asarray(self._y_tr_list, dtype=float)
349
- if not np.any(np.abs(y_trimmed - incumbent_value) < 1e-10):
350
- raise RuntimeError("Incumbent value must be preserved in trimmed list")
351
-
352
- def tell(
353
- self,
354
- x: np.ndarray,
355
- y: np.ndarray,
356
- y_var: np.ndarray | None = None,
357
- ) -> np.ndarray:
358
- import numpy as np
359
-
360
- x = np.asarray(x, dtype=float)
361
- y = np.asarray(y, dtype=float)
362
- if x.ndim != 2 or x.shape[1] != self._num_dim:
363
- raise ValueError(x.shape)
364
-
365
- # morbo accepts 2D y with shape (n, num_metrics)
366
- is_morbo = self._config.tr_type == "morbo"
367
- if is_morbo:
368
- if y.ndim == 1:
369
- y = y.reshape(-1, 1)
370
- if y.ndim != 2 or y.shape[0] != x.shape[0]:
371
- raise ValueError((x.shape, y.shape))
372
- num_metrics = y.shape[1]
373
- # Create TR lazily for morbo, inferring num_metrics from y
374
- if self._tr_state is None:
375
- self._tr_state = self._mode_impl.create_trust_region(
376
- self._num_dim, x.shape[0], self._rng, num_metrics=num_metrics
377
- )
378
- cfg_num_metrics = self._config.num_metrics
379
- if cfg_num_metrics is not None and num_metrics != cfg_num_metrics:
380
- raise ValueError(
381
- f"y has {num_metrics} metrics but expected {cfg_num_metrics}"
382
- )
383
- else:
384
- if self._tr_state is None:
385
- raise ValueError("tell() called before ask()")
386
- if y.ndim != 1 or y.shape[0] != x.shape[0]:
387
- raise ValueError((x.shape, y.shape))
388
-
389
- if self._expects_yvar is None:
390
- self._expects_yvar = y_var is not None
391
- if (y_var is not None) != bool(self._expects_yvar):
392
- raise ValueError(
393
- f"y_var must be {'provided' if self._expects_yvar else 'omitted'} on every tell() call"
394
- )
395
- if y_var is not None:
396
- y_var = np.asarray(y_var, dtype=float)
397
- if y_var.shape != y.shape:
398
- raise ValueError((y.shape, y_var.shape))
399
- if x.shape[0] == 0:
400
- return np.array([], dtype=float)
401
- x_unit = to_unit(x, self._bounds)
402
- self._x_obs_list.extend(x_unit.tolist())
403
-
404
- if is_morbo:
405
- y_estimate = y
406
- self._y_obs_list.extend(y.tolist())
407
- if y_var is not None:
408
- self._yvar_obs_list.extend(y_var.tolist())
409
- y_all = np.asarray(self._y_obs_list, dtype=float)
410
- if y_all.ndim == 1:
411
- y_all = y_all.reshape(-1, num_metrics)
412
- x_all = np.asarray(self._x_obs_list, dtype=float)
413
- self._tr_state.update_xy(x_all, y_all, k=self._k)
414
- else:
415
- from .turbo_mode import TurboMode
416
-
417
- self._y_obs_list.extend(y.tolist())
418
- if y_var is not None:
419
- self._yvar_obs_list.extend(y_var.tolist())
420
-
421
- if self._mode in (TurboMode.TURBO_ONE, TurboMode.TURBO_ENN):
422
- self._mode_impl.prepare_ask(
423
- self._x_obs_list,
424
- self._y_obs_list,
425
- self._yvar_obs_list,
426
- self._num_dim,
427
- 0,
428
- rng=self._rng,
429
- )
430
- x_all = np.asarray(self._x_obs_list, dtype=float)
431
- y_all = np.asarray(self._y_obs_list, dtype=float)
432
- if self._mode == TurboMode.TURBO_ONE:
433
- # We intentionally evaluate the GP posterior at the training inputs
434
- # (the observed points) right after conditioning the model. GPyTorch
435
- # warns about this in debug mode, but it's expected for our TR logic.
436
- import warnings
437
-
438
- try:
439
- from gpytorch.utils.warnings import GPInputWarning
440
- except Exception: # pragma: no cover
441
- GPInputWarning = None
442
-
443
- if GPInputWarning is None:
444
- mu_all = np.asarray(
445
- self._mode_impl.estimate_y(x_all, y_all), dtype=float
446
- ).reshape(-1)
447
- else:
448
- with warnings.catch_warnings():
449
- warnings.filterwarnings(
450
- "ignore",
451
- message=r"The input matches the stored training data\..*",
452
- category=GPInputWarning,
453
- )
454
- mu_all = np.asarray(
455
- self._mode_impl.estimate_y(x_all, y_all), dtype=float
456
- ).reshape(-1)
457
- else:
458
- mu_all = np.asarray(
459
- self._mode_impl.estimate_y(x_all, y_all), dtype=float
460
- ).reshape(-1)
461
- self._y_tr_list = mu_all.tolist()
462
- if self._mode == TurboMode.TURBO_ONE:
463
- import warnings
464
-
465
- try:
466
- from gpytorch.utils.warnings import GPInputWarning
467
- except Exception: # pragma: no cover
468
- GPInputWarning = None
469
-
470
- if GPInputWarning is None:
471
- y_estimate = np.asarray(
472
- self._mode_impl.estimate_y(x_unit, y), dtype=float
473
- )
474
- else:
475
- with warnings.catch_warnings():
476
- warnings.filterwarnings(
477
- "ignore",
478
- message=r"The input matches the stored training data\..*",
479
- category=GPInputWarning,
480
- )
481
- y_estimate = np.asarray(
482
- self._mode_impl.estimate_y(x_unit, y), dtype=float
483
- )
484
- else:
485
- y_estimate = np.asarray(
486
- self._mode_impl.estimate_y(x_unit, y), dtype=float
487
- )
488
- else:
489
- y_estimate = self._mode_impl.estimate_y(x_unit, y)
490
- self._y_tr_list.extend(np.asarray(y_estimate, dtype=float).tolist())
491
-
492
- if self._trailing_obs is not None:
493
- self._trim_trailing_obs()
494
- prev_n = int(getattr(self._tr_state, "prev_num_obs", 0))
495
- if prev_n > 0 and prev_n <= len(self._y_tr_list):
496
- if hasattr(self._tr_state, "best_value"):
497
- self._tr_state.best_value = float(
498
- np.max(np.asarray(self._y_tr_list, dtype=float)[:prev_n])
499
- )
500
- self._mode_impl.update_trust_region(
501
- self._tr_state, self._x_obs_list, self._y_tr_list, k=self._k
502
- )
503
-
504
- return y_estimate
505
-
506
- def _draw_initial(self, num_arms: int) -> np.ndarray:
507
- unit = latin_hypercube(num_arms, self._num_dim, rng=self._rng)
508
- return from_unit(unit, self._bounds)
509
-
510
- def _get_init_lhd_points(
511
- self, num_arms: int, fallback_fn: Callable[[int], np.ndarray] | None = None
512
- ) -> np.ndarray:
513
- import numpy as np
514
-
515
- remaining_init = self._num_init - self._init_idx
516
- num_to_return = min(num_arms, remaining_init)
517
- result = self._init_lhd[self._init_idx : self._init_idx + num_to_return]
518
- self._init_idx += num_to_return
519
- if num_to_return < num_arms:
520
- num_remaining = num_arms - num_to_return
521
- if fallback_fn is not None:
522
- result = np.vstack([result, fallback_fn(num_remaining)])
523
- else:
524
- result = np.vstack([result, self._draw_initial(num_remaining)])
525
- return result
@@ -1,29 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import TYPE_CHECKING, Callable
4
-
5
- if TYPE_CHECKING:
6
- import numpy as np
7
- from numpy.random import Generator
8
-
9
- from .base_turbo_impl import BaseTurboImpl
10
- from .turbo_config import TurboZeroConfig
11
-
12
-
13
- class TurboZeroImpl(BaseTurboImpl):
14
- def __init__(self, config: TurboZeroConfig) -> None:
15
- super().__init__(config)
16
-
17
- def select_candidates(
18
- self,
19
- x_cand: np.ndarray,
20
- num_arms: int,
21
- num_dim: int,
22
- rng: Generator,
23
- fallback_fn: Callable[[np.ndarray, int], np.ndarray],
24
- from_unit_fn: Callable[[np.ndarray], np.ndarray],
25
- tr_state: object | None = None, # noqa: ARG002
26
- ) -> np.ndarray:
27
- from .proposal import select_uniform
28
-
29
- return select_uniform(x_cand, num_arms, num_dim, rng, from_unit_fn)
@@ -1,29 +0,0 @@
1
- enn/__init__.py,sha256=VYIuOTCjhUFIJm78IoJv0WXtvA_IuZhY1sSMJJM3dx8,507
2
- enn/enn/__init__.py,sha256=K3rntg_ZkITStmXMTBcEhxeS1kel1bb7wB_C7-2WE5Y,135
3
- enn/enn/enn.py,sha256=HfdrK2gXoI1JvvARsh4NdGOpVpCY2qY_A2RpK4JFVZ4,9310
4
- enn/enn/enn_fit.py,sha256=RkyFYX4-nUteGivNNS195M2mdRWiGOrLzypVL2b_FsE,4450
5
- enn/enn/enn_normal.py,sha256=Lm9n-eW5WRn33nb3b9xTGv44Dfn9xAhjys5UJZm2xlc,662
6
- enn/enn/enn_params.py,sha256=v53qHKwUxnZFNBlcSWI5WqpgpoyFeOzvWs0BxcDqt4o,747
7
- enn/enn/enn_util.py,sha256=PSeYmxZHz4xLJ6pMr9n22MLxvVmmgBqmQl1ckPrLzDo,4142
8
- enn/turbo/__init__.py,sha256=utnD3CLZgjCvw-46AAu5Tv2M2Vbg5YXK-_TycGk5BU4,197
9
- enn/turbo/base_turbo_impl.py,sha256=y4SP9FDT8OaNDOBhKvx1WuTksBd0d-vD8NQxj91QAuA,4396
10
- enn/turbo/lhd_only_impl.py,sha256=czqLTwhb8d-pKq6jy7g28JcDWMSoopLjNrDeP1dd-3A,1254
11
- enn/turbo/morbo_trust_region.py,sha256=9z_DgXHEEUfgajRnZ5ieJdeoVAUt7BlLNGu-Thi5Tg4,5973
12
- enn/turbo/no_trust_region.py,sha256=IxZB1nvmLFgb6GjAXrNpBe1TzgzwJDPYEB8wa_ZPX3k,1813
13
- enn/turbo/proposal.py,sha256=obFqVyXtZ49veqwnktTJ0_F0nqCERvgFnInstgqhllM,4252
14
- enn/turbo/turbo_config.py,sha256=tci_GODIED3UHE63XiB9XyxPu_0J5_8R90lxBVxidOQ,2500
15
- enn/turbo/turbo_enn_impl.py,sha256=qcUssC4xaoMu7usIlp5oJtXrsLY1Rn-mvHUW95zlWZQ,7249
16
- enn/turbo/turbo_gp.py,sha256=Hi11t0nw5YEG4WM6DeoOW4X-w-M5KGG-P3Zc1sPPx1k,1069
17
- enn/turbo/turbo_gp_base.py,sha256=tnE5uX_eAt1Db-gemyy83ZvKpdNbMg_tsWkh6sG7zaM,638
18
- enn/turbo/turbo_gp_noisy.py,sha256=itTL9jUCjE566jwDODT0P36fozsfU_bXACyuKqxYMXs,1080
19
- enn/turbo/turbo_mode.py,sha256=JMP1jkFCRwPtOzU95MWWd04Sgze7eKF0xNkiPqtQ8SI,181
20
- enn/turbo/turbo_mode_impl.py,sha256=ubUkV4reOPJH3jbAh6R65cutEHOF23z7Uw1bBE3s9T0,1923
21
- enn/turbo/turbo_one_impl.py,sha256=PXaBNdLKCgtsLC8Q2z4pHoHYu0edPqKAXb6Tmr4Guvs,11098
22
- enn/turbo/turbo_optimizer.py,sha256=h6Mu3Pqb2yQRrrqEh6ODxhKIt4Nnt-C241XrNluqxtk,20274
23
- enn/turbo/turbo_trust_region.py,sha256=0wlN_LhsfMeLqqjhq3xhkJpXtTYIzUon0zbvUjRm2_Q,3797
24
- enn/turbo/turbo_utils.py,sha256=bEe1F3hBUOUVmodF1WJQv_EKRmp4X74kMxWLNnuV-z0,10733
25
- enn/turbo/turbo_zero_impl.py,sha256=SvPexeUTQzDPbAwPdZib5lRcIHPOmwD3-ewMxED-nlQ,832
26
- ennbo-0.1.2.dist-info/METADATA,sha256=-Twds_sAT4LLkcN00vzmS9_a5jUSWRsqi1bO2_RKtHw,5960
27
- ennbo-0.1.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
28
- ennbo-0.1.2.dist-info/licenses/LICENSE,sha256=KTA0NjGalsl_JGrjT_x6SSq9ZYVO3gQ-hLVMEaekc5w,1070
29
- ennbo-0.1.2.dist-info/RECORD,,
File without changes