ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enn/__init__.py +25 -13
- enn/benchmarks/__init__.py +3 -0
- enn/benchmarks/ackley.py +5 -0
- enn/benchmarks/ackley_class.py +17 -0
- enn/benchmarks/ackley_core.py +12 -0
- enn/benchmarks/double_ackley.py +24 -0
- enn/enn/candidates.py +14 -0
- enn/enn/conditional_posterior_draw_internals.py +15 -0
- enn/enn/draw_internals.py +15 -0
- enn/enn/enn.py +16 -269
- enn/enn/enn_class.py +423 -0
- enn/enn/enn_conditional.py +325 -0
- enn/enn/enn_fit.py +69 -70
- enn/enn/enn_hash.py +79 -0
- enn/enn/enn_index.py +92 -0
- enn/enn/enn_like_protocol.py +35 -0
- enn/enn/enn_normal.py +0 -1
- enn/enn/enn_params.py +3 -22
- enn/enn/enn_params_class.py +24 -0
- enn/enn/enn_util.py +60 -46
- enn/enn/neighbor_data.py +14 -0
- enn/enn/neighbors.py +14 -0
- enn/enn/posterior_flags.py +8 -0
- enn/enn/weighted_stats.py +14 -0
- enn/turbo/components/__init__.py +41 -0
- enn/turbo/components/acquisition.py +13 -0
- enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
- enn/turbo/components/builder.py +22 -0
- enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
- enn/turbo/components/enn_surrogate.py +115 -0
- enn/turbo/components/gp_surrogate.py +144 -0
- enn/turbo/components/hnr_acq_optimizer.py +83 -0
- enn/turbo/components/incumbent_selector.py +11 -0
- enn/turbo/components/incumbent_selector_protocol.py +16 -0
- enn/turbo/components/no_incumbent_selector.py +21 -0
- enn/turbo/components/no_surrogate.py +49 -0
- enn/turbo/components/pareto_acq_optimizer.py +49 -0
- enn/turbo/components/posterior_result.py +12 -0
- enn/turbo/components/protocols.py +13 -0
- enn/turbo/components/random_acq_optimizer.py +21 -0
- enn/turbo/components/scalar_incumbent_selector.py +39 -0
- enn/turbo/components/surrogate_protocol.py +32 -0
- enn/turbo/components/surrogate_result.py +12 -0
- enn/turbo/components/surrogates.py +5 -0
- enn/turbo/components/thompson_acq_optimizer.py +49 -0
- enn/turbo/components/trust_region_protocol.py +24 -0
- enn/turbo/components/ucb_acq_optimizer.py +49 -0
- enn/turbo/config/__init__.py +87 -0
- enn/turbo/config/acq_type.py +8 -0
- enn/turbo/config/acquisition.py +26 -0
- enn/turbo/config/base.py +4 -0
- enn/turbo/config/candidate_gen_config.py +49 -0
- enn/turbo/config/candidate_rv.py +7 -0
- enn/turbo/config/draw_acquisition_config.py +14 -0
- enn/turbo/config/enn_index_driver.py +6 -0
- enn/turbo/config/enn_surrogate_config.py +42 -0
- enn/turbo/config/enums.py +7 -0
- enn/turbo/config/factory.py +118 -0
- enn/turbo/config/gp_surrogate_config.py +14 -0
- enn/turbo/config/hnr_optimizer_config.py +7 -0
- enn/turbo/config/init_config.py +17 -0
- enn/turbo/config/init_strategies/__init__.py +9 -0
- enn/turbo/config/init_strategies/hybrid_init.py +23 -0
- enn/turbo/config/init_strategies/init_strategy.py +19 -0
- enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
- enn/turbo/config/morbo_tr_config.py +82 -0
- enn/turbo/config/nds_optimizer_config.py +7 -0
- enn/turbo/config/no_surrogate_config.py +14 -0
- enn/turbo/config/no_tr_config.py +31 -0
- enn/turbo/config/optimizer_config.py +72 -0
- enn/turbo/config/pareto_acquisition_config.py +14 -0
- enn/turbo/config/raasp_driver.py +6 -0
- enn/turbo/config/raasp_optimizer_config.py +7 -0
- enn/turbo/config/random_acquisition_config.py +14 -0
- enn/turbo/config/rescalarize.py +7 -0
- enn/turbo/config/surrogate.py +12 -0
- enn/turbo/config/trust_region.py +34 -0
- enn/turbo/config/turbo_tr_config.py +71 -0
- enn/turbo/config/ucb_acquisition_config.py +14 -0
- enn/turbo/config/validation.py +45 -0
- enn/turbo/hypervolume.py +30 -0
- enn/turbo/impl_helpers.py +68 -0
- enn/turbo/morbo_trust_region.py +131 -70
- enn/turbo/no_trust_region.py +32 -39
- enn/turbo/optimizer.py +300 -0
- enn/turbo/optimizer_config.py +8 -0
- enn/turbo/proposal.py +36 -38
- enn/turbo/sampling.py +21 -0
- enn/turbo/strategies/__init__.py +9 -0
- enn/turbo/strategies/lhd_only_strategy.py +36 -0
- enn/turbo/strategies/optimization_strategy.py +19 -0
- enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
- enn/turbo/tr_helpers.py +202 -0
- enn/turbo/turbo_gp.py +0 -1
- enn/turbo/turbo_gp_base.py +0 -1
- enn/turbo/turbo_gp_fit.py +187 -0
- enn/turbo/turbo_gp_noisy.py +0 -1
- enn/turbo/turbo_optimizer_utils.py +98 -0
- enn/turbo/turbo_trust_region.py +126 -58
- enn/turbo/turbo_utils.py +98 -161
- enn/turbo/types/__init__.py +7 -0
- enn/turbo/types/appendable_array.py +85 -0
- enn/turbo/types/gp_data_prep.py +13 -0
- enn/turbo/types/gp_fit_result.py +11 -0
- enn/turbo/types/obs_lists.py +10 -0
- enn/turbo/types/prepare_ask_result.py +14 -0
- enn/turbo/types/tell_inputs.py +14 -0
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
- ennbo-0.1.7.dist-info/RECORD +111 -0
- enn/enn/__init__.py +0 -4
- enn/turbo/__init__.py +0 -11
- enn/turbo/base_turbo_impl.py +0 -144
- enn/turbo/lhd_only_impl.py +0 -49
- enn/turbo/turbo_config.py +0 -72
- enn/turbo/turbo_enn_impl.py +0 -201
- enn/turbo/turbo_mode.py +0 -10
- enn/turbo/turbo_mode_impl.py +0 -76
- enn/turbo/turbo_one_impl.py +0 -302
- enn/turbo/turbo_optimizer.py +0 -525
- enn/turbo/turbo_zero_impl.py +0 -29
- ennbo-0.1.2.dist-info/RECORD +0 -29
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
from .candidate_gen_config import CandidateGenConfig
|
|
5
|
+
from .init_config import InitConfig
|
|
6
|
+
from .surrogate import NoSurrogateConfig, SurrogateConfig
|
|
7
|
+
from .trust_region import TrustRegionConfig, TurboTRConfig
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from .acquisition import AcqOptimizerConfig, AcquisitionConfig
|
|
11
|
+
from .enums import CandidateRV, RAASPDriver
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class ObservationHistoryConfig:
|
|
16
|
+
trailing_obs: int | None = None
|
|
17
|
+
|
|
18
|
+
def __post_init__(self) -> None:
|
|
19
|
+
if self.trailing_obs is not None and self.trailing_obs <= 0:
|
|
20
|
+
raise ValueError(f"trailing_obs must be > 0, got {self.trailing_obs}")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _default_acquisition():
|
|
24
|
+
from .acquisition import RandomAcquisitionConfig
|
|
25
|
+
|
|
26
|
+
return RandomAcquisitionConfig()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _default_acq_optimizer():
|
|
30
|
+
from .acquisition import RAASPOptimizerConfig
|
|
31
|
+
|
|
32
|
+
return RAASPOptimizerConfig()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(frozen=True)
|
|
36
|
+
class OptimizerConfig:
|
|
37
|
+
trust_region: TrustRegionConfig = TurboTRConfig()
|
|
38
|
+
candidates: CandidateGenConfig = CandidateGenConfig()
|
|
39
|
+
init: InitConfig = InitConfig()
|
|
40
|
+
surrogate: SurrogateConfig = NoSurrogateConfig()
|
|
41
|
+
acquisition: AcquisitionConfig = field(default_factory=_default_acquisition)
|
|
42
|
+
acq_optimizer: AcqOptimizerConfig = field(default_factory=_default_acq_optimizer)
|
|
43
|
+
observation_history: ObservationHistoryConfig = ObservationHistoryConfig()
|
|
44
|
+
|
|
45
|
+
def __post_init__(self) -> None:
|
|
46
|
+
from .validation import validate_optimizer_config
|
|
47
|
+
|
|
48
|
+
validate_optimizer_config(self)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def num_metrics(self) -> int | None:
|
|
52
|
+
from .morbo_tr_config import MorboTRConfig
|
|
53
|
+
|
|
54
|
+
if isinstance(self.trust_region, MorboTRConfig):
|
|
55
|
+
return self.trust_region.num_metrics
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def candidate_rv(self) -> CandidateRV:
|
|
60
|
+
return self.candidates.candidate_rv
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def raasp_driver(self) -> RAASPDriver:
|
|
64
|
+
return self.candidates.raasp_driver
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def num_candidates(self):
|
|
68
|
+
return self.candidates.num_candidates
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def trailing_obs(self) -> int | None:
|
|
72
|
+
return self.observation_history.trailing_obs
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from ..components.protocols import AcquisitionOptimizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class ParetoAcquisitionConfig:
|
|
11
|
+
def build(self) -> AcquisitionOptimizer:
|
|
12
|
+
from ..components.acquisition import ParetoAcqOptimizer
|
|
13
|
+
|
|
14
|
+
return ParetoAcqOptimizer()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from ..components.protocols import AcquisitionOptimizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class RandomAcquisitionConfig:
|
|
11
|
+
def build(self) -> AcquisitionOptimizer:
|
|
12
|
+
from ..components.acquisition import RandomAcqOptimizer
|
|
13
|
+
|
|
14
|
+
return RandomAcqOptimizer()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .enn_surrogate_config import ENNFitConfig, ENNSurrogateConfig
|
|
2
|
+
from .gp_surrogate_config import GPSurrogateConfig
|
|
3
|
+
from .no_surrogate_config import NoSurrogateConfig
|
|
4
|
+
|
|
5
|
+
SurrogateConfig = NoSurrogateConfig | GPSurrogateConfig | ENNSurrogateConfig
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ENNFitConfig",
|
|
8
|
+
"ENNSurrogateConfig",
|
|
9
|
+
"GPSurrogateConfig",
|
|
10
|
+
"NoSurrogateConfig",
|
|
11
|
+
"SurrogateConfig",
|
|
12
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Protocol
|
|
4
|
+
|
|
5
|
+
from .morbo_tr_config import MorboTRConfig, MultiObjectiveConfig, RescalePolicyConfig
|
|
6
|
+
from .no_tr_config import NoTRConfig
|
|
7
|
+
from .turbo_tr_config import TRLengthConfig, TurboTRConfig
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from numpy.random import Generator
|
|
11
|
+
|
|
12
|
+
from ..components.protocols import TrustRegion
|
|
13
|
+
from .enums import CandidateRV
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TrustRegionConfig(Protocol):
|
|
17
|
+
def build(
|
|
18
|
+
self,
|
|
19
|
+
*,
|
|
20
|
+
num_dim: int,
|
|
21
|
+
rng: Generator,
|
|
22
|
+
candidate_rv: CandidateRV,
|
|
23
|
+
) -> TrustRegion: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"MorboTRConfig",
|
|
28
|
+
"MultiObjectiveConfig",
|
|
29
|
+
"NoTRConfig",
|
|
30
|
+
"RescalePolicyConfig",
|
|
31
|
+
"TRLengthConfig",
|
|
32
|
+
"TrustRegionConfig",
|
|
33
|
+
"TurboTRConfig",
|
|
34
|
+
]
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from numpy.random import Generator
|
|
8
|
+
|
|
9
|
+
from ..components.protocols import TrustRegion
|
|
10
|
+
from .enums import CandidateRV
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class TRLengthConfig:
|
|
15
|
+
length_init: float = 0.8
|
|
16
|
+
length_min: float = 0.5**7
|
|
17
|
+
length_max: float = 1.6
|
|
18
|
+
|
|
19
|
+
def __post_init__(self) -> None:
|
|
20
|
+
if self.length_init <= 0:
|
|
21
|
+
raise ValueError(f"length_init must be > 0, got {self.length_init}")
|
|
22
|
+
if self.length_min <= 0:
|
|
23
|
+
raise ValueError(f"length_min must be > 0, got {self.length_min}")
|
|
24
|
+
if self.length_max <= 0:
|
|
25
|
+
raise ValueError(f"length_max must be > 0, got {self.length_max}")
|
|
26
|
+
if self.length_min >= self.length_max:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"length_min must be < length_max, got {self.length_min} >= {self.length_max}"
|
|
29
|
+
)
|
|
30
|
+
if self.length_init > self.length_max:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"length_init must be <= length_max, got {self.length_init} > {self.length_max}"
|
|
33
|
+
)
|
|
34
|
+
if self.length_min > self.length_init:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"length_min must be <= length_init, got {self.length_min} > {self.length_init}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class TurboTRConfig:
|
|
42
|
+
length: TRLengthConfig = TRLengthConfig()
|
|
43
|
+
noise_aware: bool = False
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def length_init(self) -> float:
|
|
47
|
+
return self.length.length_init
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def length_min(self) -> float:
|
|
51
|
+
return self.length.length_min
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def length_max(self) -> float:
|
|
55
|
+
return self.length.length_max
|
|
56
|
+
|
|
57
|
+
def build(
|
|
58
|
+
self,
|
|
59
|
+
*,
|
|
60
|
+
num_dim: int,
|
|
61
|
+
rng: Generator,
|
|
62
|
+
candidate_rv: CandidateRV | None = None,
|
|
63
|
+
) -> TrustRegion:
|
|
64
|
+
from ..components.incumbent_selector import ScalarIncumbentSelector
|
|
65
|
+
from ..turbo_trust_region import TurboTrustRegion
|
|
66
|
+
|
|
67
|
+
return TurboTrustRegion(
|
|
68
|
+
config=self,
|
|
69
|
+
num_dim=num_dim,
|
|
70
|
+
incumbent_selector=ScalarIncumbentSelector(noise_aware=self.noise_aware),
|
|
71
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from ..components.protocols import AcquisitionOptimizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class UCBAcquisitionConfig:
|
|
11
|
+
def build(self) -> AcquisitionOptimizer:
|
|
12
|
+
from ..components.acquisition import UCBAcqOptimizer
|
|
13
|
+
|
|
14
|
+
return UCBAcqOptimizer()
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def validate_optimizer_config(cfg: Any) -> None:
|
|
6
|
+
from .acquisition import (
|
|
7
|
+
DrawAcquisitionConfig,
|
|
8
|
+
HnROptimizerConfig,
|
|
9
|
+
NDSOptimizerConfig,
|
|
10
|
+
ParetoAcquisitionConfig,
|
|
11
|
+
UCBAcquisitionConfig,
|
|
12
|
+
)
|
|
13
|
+
from .init_strategies import LHDOnlyInit
|
|
14
|
+
from .surrogate import GPSurrogateConfig, NoSurrogateConfig
|
|
15
|
+
|
|
16
|
+
if isinstance(cfg.init.init_strategy, LHDOnlyInit):
|
|
17
|
+
if not isinstance(cfg.surrogate, NoSurrogateConfig):
|
|
18
|
+
raise ValueError(
|
|
19
|
+
"init_strategy='lhd_only' requires NoSurrogateConfig surrogate"
|
|
20
|
+
)
|
|
21
|
+
if isinstance(cfg.surrogate, NoSurrogateConfig):
|
|
22
|
+
if isinstance(cfg.acquisition, DrawAcquisitionConfig):
|
|
23
|
+
raise ValueError(
|
|
24
|
+
"DrawAcquisitionConfig (Thompson sampling) requires a surrogate. "
|
|
25
|
+
"NoSurrogateConfig is not compatible with DrawAcquisitionConfig."
|
|
26
|
+
)
|
|
27
|
+
if isinstance(cfg.acquisition, UCBAcquisitionConfig):
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"UCBAcquisitionConfig requires a surrogate. "
|
|
30
|
+
"NoSurrogateConfig is not compatible with UCBAcquisitionConfig."
|
|
31
|
+
)
|
|
32
|
+
if isinstance(cfg.acquisition, ParetoAcquisitionConfig):
|
|
33
|
+
if not isinstance(cfg.acq_optimizer, NDSOptimizerConfig):
|
|
34
|
+
raise ValueError("ParetoAcquisitionConfig requires NDSOptimizerConfig")
|
|
35
|
+
if isinstance(cfg.acq_optimizer, HnROptimizerConfig):
|
|
36
|
+
if isinstance(cfg.acquisition, ParetoAcquisitionConfig):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"HnROptimizerConfig is not compatible with ParetoAcquisitionConfig"
|
|
39
|
+
)
|
|
40
|
+
if isinstance(cfg.surrogate, GPSurrogateConfig) and isinstance(
|
|
41
|
+
cfg.acquisition, DrawAcquisitionConfig
|
|
42
|
+
):
|
|
43
|
+
raise NotImplementedError(
|
|
44
|
+
"GP surrogate with DrawAcquisitionConfig and HnROptimizerConfig is not yet implemented"
|
|
45
|
+
)
|
enn/turbo/hypervolume.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def hypervolume_2d_max(y: np.ndarray, ref_point: np.ndarray) -> float:
|
|
6
|
+
y = np.asarray(y, dtype=float)
|
|
7
|
+
ref_point = np.asarray(ref_point, dtype=float)
|
|
8
|
+
if y.size == 0:
|
|
9
|
+
return 0.0
|
|
10
|
+
if y.ndim != 2:
|
|
11
|
+
raise ValueError(y.shape)
|
|
12
|
+
if y.shape[1] != 2:
|
|
13
|
+
raise ValueError(y.shape)
|
|
14
|
+
if ref_point.shape != (2,):
|
|
15
|
+
raise ValueError(ref_point.shape)
|
|
16
|
+
mask = (y[:, 0] > ref_point[0]) & (y[:, 1] > ref_point[1])
|
|
17
|
+
y = y[mask]
|
|
18
|
+
if y.size == 0:
|
|
19
|
+
return 0.0
|
|
20
|
+
order = np.argsort(y[:, 0], kind="mergesort")[::-1]
|
|
21
|
+
y = y[order]
|
|
22
|
+
hv = 0.0
|
|
23
|
+
best_y1 = ref_point[1]
|
|
24
|
+
for i in range(len(y)):
|
|
25
|
+
x0, y1 = y[i]
|
|
26
|
+
if y1 > best_y1:
|
|
27
|
+
best_y1 = y1
|
|
28
|
+
x_next = y[i + 1, 0] if i + 1 < len(y) else ref_point[0]
|
|
29
|
+
hv += (x0 - x_next) * (best_y1 - ref_point[1])
|
|
30
|
+
return float(hv)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.random import Generator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_x_center_fallback(
|
|
10
|
+
config: Any,
|
|
11
|
+
x_obs_list: list,
|
|
12
|
+
y_obs_list: list,
|
|
13
|
+
rng: Generator,
|
|
14
|
+
tr_state: Any = None,
|
|
15
|
+
) -> np.ndarray | None:
|
|
16
|
+
import numpy as np
|
|
17
|
+
from .components.incumbent_selector import ScalarIncumbentSelector
|
|
18
|
+
|
|
19
|
+
y_array = np.asarray(y_obs_list, dtype=float)
|
|
20
|
+
if y_array.size == 0:
|
|
21
|
+
return None
|
|
22
|
+
x_array = np.asarray(x_obs_list, dtype=float)
|
|
23
|
+
if tr_state is not None and hasattr(tr_state, "incumbent_selector"):
|
|
24
|
+
selector = tr_state.incumbent_selector
|
|
25
|
+
else:
|
|
26
|
+
selector = ScalarIncumbentSelector(noise_aware=False)
|
|
27
|
+
idx = selector.select(y_array, None, rng)
|
|
28
|
+
return x_array[idx]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def handle_restart_clear_always(
|
|
32
|
+
x_obs_list: list,
|
|
33
|
+
y_obs_list: list,
|
|
34
|
+
yvar_obs_list: list,
|
|
35
|
+
) -> tuple[bool, int]:
|
|
36
|
+
x_obs_list.clear()
|
|
37
|
+
y_obs_list.clear()
|
|
38
|
+
yvar_obs_list.clear()
|
|
39
|
+
return True, 0
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def handle_restart_check_multi_objective(
|
|
43
|
+
tr_state: Any,
|
|
44
|
+
x_obs_list: list,
|
|
45
|
+
y_obs_list: list,
|
|
46
|
+
yvar_obs_list: list,
|
|
47
|
+
init_idx: int,
|
|
48
|
+
) -> tuple[bool, int]:
|
|
49
|
+
is_multi = (
|
|
50
|
+
tr_state is not None
|
|
51
|
+
and hasattr(tr_state, "num_metrics")
|
|
52
|
+
and tr_state.num_metrics > 1
|
|
53
|
+
)
|
|
54
|
+
if is_multi:
|
|
55
|
+
x_obs_list.clear()
|
|
56
|
+
y_obs_list.clear()
|
|
57
|
+
yvar_obs_list.clear()
|
|
58
|
+
return True, 0
|
|
59
|
+
return False, init_idx
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def estimate_y_passthrough(y_observed: np.ndarray) -> np.ndarray:
|
|
63
|
+
import numpy as np
|
|
64
|
+
|
|
65
|
+
y = np.asarray(y_observed, dtype=float)
|
|
66
|
+
if y.ndim == 1:
|
|
67
|
+
return y.reshape(-1, 1)
|
|
68
|
+
return y
|