ennbo 0.1.0__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -229
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +77 -76
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +3 -3
  18. enn/enn/enn_params.py +3 -9
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +79 -37
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +250 -0
  84. enn/turbo/no_trust_region.py +58 -0
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +46 -39
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +9 -2
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +129 -63
  100. enn/turbo/turbo_utils.py +144 -117
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/METADATA +22 -14
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -98
  113. enn/turbo/lhd_only_impl.py +0 -42
  114. enn/turbo/turbo_config.py +0 -28
  115. enn/turbo/turbo_enn_impl.py +0 -176
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -67
  118. enn/turbo/turbo_one_impl.py +0 -163
  119. enn/turbo/turbo_optimizer.py +0 -337
  120. enn/turbo/turbo_zero_impl.py +0 -24
  121. ennbo-0.1.0.dist-info/RECORD +0 -27
  122. {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING
4
+ from .candidate_gen_config import CandidateGenConfig
5
+ from .init_config import InitConfig
6
+ from .surrogate import NoSurrogateConfig, SurrogateConfig
7
+ from .trust_region import TrustRegionConfig, TurboTRConfig
8
+
9
+ if TYPE_CHECKING:
10
+ from .acquisition import AcqOptimizerConfig, AcquisitionConfig
11
+ from .enums import CandidateRV, RAASPDriver
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ObservationHistoryConfig:
16
+ trailing_obs: int | None = None
17
+
18
+ def __post_init__(self) -> None:
19
+ if self.trailing_obs is not None and self.trailing_obs <= 0:
20
+ raise ValueError(f"trailing_obs must be > 0, got {self.trailing_obs}")
21
+
22
+
23
+ def _default_acquisition():
24
+ from .acquisition import RandomAcquisitionConfig
25
+
26
+ return RandomAcquisitionConfig()
27
+
28
+
29
+ def _default_acq_optimizer():
30
+ from .acquisition import RAASPOptimizerConfig
31
+
32
+ return RAASPOptimizerConfig()
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class OptimizerConfig:
37
+ trust_region: TrustRegionConfig = TurboTRConfig()
38
+ candidates: CandidateGenConfig = CandidateGenConfig()
39
+ init: InitConfig = InitConfig()
40
+ surrogate: SurrogateConfig = NoSurrogateConfig()
41
+ acquisition: AcquisitionConfig = field(default_factory=_default_acquisition)
42
+ acq_optimizer: AcqOptimizerConfig = field(default_factory=_default_acq_optimizer)
43
+ observation_history: ObservationHistoryConfig = ObservationHistoryConfig()
44
+
45
+ def __post_init__(self) -> None:
46
+ from .validation import validate_optimizer_config
47
+
48
+ validate_optimizer_config(self)
49
+
50
+ @property
51
+ def num_metrics(self) -> int | None:
52
+ from .morbo_tr_config import MorboTRConfig
53
+
54
+ if isinstance(self.trust_region, MorboTRConfig):
55
+ return self.trust_region.num_metrics
56
+ return None
57
+
58
+ @property
59
+ def candidate_rv(self) -> CandidateRV:
60
+ return self.candidates.candidate_rv
61
+
62
+ @property
63
+ def raasp_driver(self) -> RAASPDriver:
64
+ return self.candidates.raasp_driver
65
+
66
+ @property
67
+ def num_candidates(self):
68
+ return self.candidates.num_candidates
69
+
70
+ @property
71
+ def trailing_obs(self) -> int | None:
72
+ return self.observation_history.trailing_obs
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import AcquisitionOptimizer
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class ParetoAcquisitionConfig:
11
+ def build(self) -> AcquisitionOptimizer:
12
+ from ..components.acquisition import ParetoAcqOptimizer
13
+
14
+ return ParetoAcqOptimizer()
@@ -0,0 +1,6 @@
1
+ from enum import Enum, auto
2
+
3
+
4
+ class RAASPDriver(Enum):
5
+ ORIG = auto()
6
+ FAST = auto()
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class RAASPOptimizerConfig:
7
+ pass
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import AcquisitionOptimizer
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class RandomAcquisitionConfig:
11
+ def build(self) -> AcquisitionOptimizer:
12
+ from ..components.acquisition import RandomAcqOptimizer
13
+
14
+ return RandomAcqOptimizer()
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+ from enum import Enum
3
+
4
+
5
+ class Rescalarize(Enum):
6
+ ON_RESTART = "on_restart"
7
+ ON_PROPOSE = "on_propose"
@@ -0,0 +1,12 @@
1
+ from .enn_surrogate_config import ENNFitConfig, ENNSurrogateConfig
2
+ from .gp_surrogate_config import GPSurrogateConfig
3
+ from .no_surrogate_config import NoSurrogateConfig
4
+
5
+ SurrogateConfig = NoSurrogateConfig | GPSurrogateConfig | ENNSurrogateConfig
6
+ __all__ = [
7
+ "ENNFitConfig",
8
+ "ENNSurrogateConfig",
9
+ "GPSurrogateConfig",
10
+ "NoSurrogateConfig",
11
+ "SurrogateConfig",
12
+ ]
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Protocol
4
+
5
+ from .morbo_tr_config import MorboTRConfig, MultiObjectiveConfig, RescalePolicyConfig
6
+ from .no_tr_config import NoTRConfig
7
+ from .turbo_tr_config import TRLengthConfig, TurboTRConfig
8
+
9
+ if TYPE_CHECKING:
10
+ from numpy.random import Generator
11
+
12
+ from ..components.protocols import TrustRegion
13
+ from .enums import CandidateRV
14
+
15
+
16
+ class TrustRegionConfig(Protocol):
17
+ def build(
18
+ self,
19
+ *,
20
+ num_dim: int,
21
+ rng: Generator,
22
+ candidate_rv: CandidateRV,
23
+ ) -> TrustRegion: ...
24
+
25
+
26
+ __all__ = [
27
+ "MorboTRConfig",
28
+ "MultiObjectiveConfig",
29
+ "NoTRConfig",
30
+ "RescalePolicyConfig",
31
+ "TRLengthConfig",
32
+ "TrustRegionConfig",
33
+ "TurboTRConfig",
34
+ ]
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from numpy.random import Generator
8
+
9
+ from ..components.protocols import TrustRegion
10
+ from .enums import CandidateRV
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TRLengthConfig:
15
+ length_init: float = 0.8
16
+ length_min: float = 0.5**7
17
+ length_max: float = 1.6
18
+
19
+ def __post_init__(self) -> None:
20
+ if self.length_init <= 0:
21
+ raise ValueError(f"length_init must be > 0, got {self.length_init}")
22
+ if self.length_min <= 0:
23
+ raise ValueError(f"length_min must be > 0, got {self.length_min}")
24
+ if self.length_max <= 0:
25
+ raise ValueError(f"length_max must be > 0, got {self.length_max}")
26
+ if self.length_min >= self.length_max:
27
+ raise ValueError(
28
+ f"length_min must be < length_max, got {self.length_min} >= {self.length_max}"
29
+ )
30
+ if self.length_init > self.length_max:
31
+ raise ValueError(
32
+ f"length_init must be <= length_max, got {self.length_init} > {self.length_max}"
33
+ )
34
+ if self.length_min > self.length_init:
35
+ raise ValueError(
36
+ f"length_min must be <= length_init, got {self.length_min} > {self.length_init}"
37
+ )
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class TurboTRConfig:
42
+ length: TRLengthConfig = TRLengthConfig()
43
+ noise_aware: bool = False
44
+
45
+ @property
46
+ def length_init(self) -> float:
47
+ return self.length.length_init
48
+
49
+ @property
50
+ def length_min(self) -> float:
51
+ return self.length.length_min
52
+
53
+ @property
54
+ def length_max(self) -> float:
55
+ return self.length.length_max
56
+
57
+ def build(
58
+ self,
59
+ *,
60
+ num_dim: int,
61
+ rng: Generator,
62
+ candidate_rv: CandidateRV | None = None,
63
+ ) -> TrustRegion:
64
+ from ..components.incumbent_selector import ScalarIncumbentSelector
65
+ from ..turbo_trust_region import TurboTrustRegion
66
+
67
+ return TurboTrustRegion(
68
+ config=self,
69
+ num_dim=num_dim,
70
+ incumbent_selector=ScalarIncumbentSelector(noise_aware=self.noise_aware),
71
+ )
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import AcquisitionOptimizer
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class UCBAcquisitionConfig:
11
+ def build(self) -> AcquisitionOptimizer:
12
+ from ..components.acquisition import UCBAcqOptimizer
13
+
14
+ return UCBAcqOptimizer()
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+ from typing import Any
3
+
4
+
5
+ def validate_optimizer_config(cfg: Any) -> None:
6
+ from .acquisition import (
7
+ DrawAcquisitionConfig,
8
+ HnROptimizerConfig,
9
+ NDSOptimizerConfig,
10
+ ParetoAcquisitionConfig,
11
+ UCBAcquisitionConfig,
12
+ )
13
+ from .init_strategies import LHDOnlyInit
14
+ from .surrogate import GPSurrogateConfig, NoSurrogateConfig
15
+
16
+ if isinstance(cfg.init.init_strategy, LHDOnlyInit):
17
+ if not isinstance(cfg.surrogate, NoSurrogateConfig):
18
+ raise ValueError(
19
+ "init_strategy='lhd_only' requires NoSurrogateConfig surrogate"
20
+ )
21
+ if isinstance(cfg.surrogate, NoSurrogateConfig):
22
+ if isinstance(cfg.acquisition, DrawAcquisitionConfig):
23
+ raise ValueError(
24
+ "DrawAcquisitionConfig (Thompson sampling) requires a surrogate. "
25
+ "NoSurrogateConfig is not compatible with DrawAcquisitionConfig."
26
+ )
27
+ if isinstance(cfg.acquisition, UCBAcquisitionConfig):
28
+ raise ValueError(
29
+ "UCBAcquisitionConfig requires a surrogate. "
30
+ "NoSurrogateConfig is not compatible with UCBAcquisitionConfig."
31
+ )
32
+ if isinstance(cfg.acquisition, ParetoAcquisitionConfig):
33
+ if not isinstance(cfg.acq_optimizer, NDSOptimizerConfig):
34
+ raise ValueError("ParetoAcquisitionConfig requires NDSOptimizerConfig")
35
+ if isinstance(cfg.acq_optimizer, HnROptimizerConfig):
36
+ if isinstance(cfg.acquisition, ParetoAcquisitionConfig):
37
+ raise ValueError(
38
+ "HnROptimizerConfig is not compatible with ParetoAcquisitionConfig"
39
+ )
40
+ if isinstance(cfg.surrogate, GPSurrogateConfig) and isinstance(
41
+ cfg.acquisition, DrawAcquisitionConfig
42
+ ):
43
+ raise NotImplementedError(
44
+ "GP surrogate with DrawAcquisitionConfig and HnROptimizerConfig is not yet implemented"
45
+ )
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+
4
+
5
+ def hypervolume_2d_max(y: np.ndarray, ref_point: np.ndarray) -> float:
6
+ y = np.asarray(y, dtype=float)
7
+ ref_point = np.asarray(ref_point, dtype=float)
8
+ if y.size == 0:
9
+ return 0.0
10
+ if y.ndim != 2:
11
+ raise ValueError(y.shape)
12
+ if y.shape[1] != 2:
13
+ raise ValueError(y.shape)
14
+ if ref_point.shape != (2,):
15
+ raise ValueError(ref_point.shape)
16
+ mask = (y[:, 0] > ref_point[0]) & (y[:, 1] > ref_point[1])
17
+ y = y[mask]
18
+ if y.size == 0:
19
+ return 0.0
20
+ order = np.argsort(y[:, 0], kind="mergesort")[::-1]
21
+ y = y[order]
22
+ hv = 0.0
23
+ best_y1 = ref_point[1]
24
+ for i in range(len(y)):
25
+ x0, y1 = y[i]
26
+ if y1 > best_y1:
27
+ best_y1 = y1
28
+ x_next = y[i + 1, 0] if i + 1 < len(y) else ref_point[0]
29
+ hv += (x0 - x_next) * (best_y1 - ref_point[1])
30
+ return float(hv)
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
+ from numpy.random import Generator
7
+
8
+
9
+ def get_x_center_fallback(
10
+ config: Any,
11
+ x_obs_list: list,
12
+ y_obs_list: list,
13
+ rng: Generator,
14
+ tr_state: Any = None,
15
+ ) -> np.ndarray | None:
16
+ import numpy as np
17
+ from .components.incumbent_selector import ScalarIncumbentSelector
18
+
19
+ y_array = np.asarray(y_obs_list, dtype=float)
20
+ if y_array.size == 0:
21
+ return None
22
+ x_array = np.asarray(x_obs_list, dtype=float)
23
+ if tr_state is not None and hasattr(tr_state, "incumbent_selector"):
24
+ selector = tr_state.incumbent_selector
25
+ else:
26
+ selector = ScalarIncumbentSelector(noise_aware=False)
27
+ idx = selector.select(y_array, None, rng)
28
+ return x_array[idx]
29
+
30
+
31
+ def handle_restart_clear_always(
32
+ x_obs_list: list,
33
+ y_obs_list: list,
34
+ yvar_obs_list: list,
35
+ ) -> tuple[bool, int]:
36
+ x_obs_list.clear()
37
+ y_obs_list.clear()
38
+ yvar_obs_list.clear()
39
+ return True, 0
40
+
41
+
42
+ def handle_restart_check_multi_objective(
43
+ tr_state: Any,
44
+ x_obs_list: list,
45
+ y_obs_list: list,
46
+ yvar_obs_list: list,
47
+ init_idx: int,
48
+ ) -> tuple[bool, int]:
49
+ is_multi = (
50
+ tr_state is not None
51
+ and hasattr(tr_state, "num_metrics")
52
+ and tr_state.num_metrics > 1
53
+ )
54
+ if is_multi:
55
+ x_obs_list.clear()
56
+ y_obs_list.clear()
57
+ yvar_obs_list.clear()
58
+ return True, 0
59
+ return False, init_idx
60
+
61
+
62
+ def estimate_y_passthrough(y_observed: np.ndarray) -> np.ndarray:
63
+ import numpy as np
64
+
65
+ y = np.asarray(y_observed, dtype=float)
66
+ if y.ndim == 1:
67
+ return y.reshape(-1, 1)
68
+ return y
@@ -0,0 +1,250 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from .tr_helpers import ScalarIncumbentMixin
6
+
7
+ if TYPE_CHECKING:
8
+ import numpy as np
9
+ from numpy.random import Generator
10
+ from scipy.stats._qmc import QMCEngine
11
+
12
+ from .config.morbo_tr_config import MorboTRConfig
13
+ from .config.rescalarize import Rescalarize
14
+
15
+ from .config.enums import CandidateRV, RAASPDriver
16
+
17
+
18
+ class MorboTrustRegion(ScalarIncumbentMixin):
19
+ def __init__(
20
+ self,
21
+ config: MorboTRConfig,
22
+ num_dim: int,
23
+ *,
24
+ rng: Generator,
25
+ candidate_rv: CandidateRV = CandidateRV.SOBOL,
26
+ ) -> None:
27
+ from .components.incumbent_selector import ChebyshevIncumbentSelector
28
+ from .config.turbo_tr_config import TurboTRConfig
29
+ from .turbo_trust_region import TurboTrustRegion
30
+
31
+ self._config = config
32
+ self._candidate_rv = candidate_rv
33
+ inner_config = TurboTRConfig(length=config.length)
34
+ self._tr = TurboTrustRegion(
35
+ config=inner_config,
36
+ num_dim=num_dim,
37
+ )
38
+ self._num_dim = int(num_dim)
39
+ self._num_metrics = int(config.num_metrics)
40
+ if self._num_metrics <= 0:
41
+ raise ValueError(self._num_metrics)
42
+ self._alpha = float(config.alpha)
43
+ self._rescalarize = config.rescalarize
44
+ self.incumbent_selector = ChebyshevIncumbentSelector(
45
+ num_metrics=self._num_metrics,
46
+ alpha=self._alpha,
47
+ noise_aware=config.noise_aware,
48
+ )
49
+ self.incumbent_selector.reset(rng)
50
+ self._weights = self.incumbent_selector.weights
51
+ self._y_min: np.ndarray | Any | None = None
52
+ self._y_max: np.ndarray | Any | None = None
53
+ self._incumbent_y_raw: np.ndarray | None = None
54
+
55
+ @property
56
+ def num_dim(self) -> int:
57
+ return self._num_dim
58
+
59
+ @property
60
+ def num_metrics(self) -> int:
61
+ return self._num_metrics
62
+
63
+ @property
64
+ def weights(self) -> np.ndarray:
65
+ return self._weights
66
+
67
+ @property
68
+ def length(self) -> float:
69
+ return float(self._tr.length)
70
+
71
+ @property
72
+ def rescalarize(self) -> Rescalarize:
73
+ return self._rescalarize
74
+
75
+ def resample_weights(self, rng: Generator) -> None:
76
+ self.incumbent_selector.reset(rng)
77
+ self._weights = self.incumbent_selector.weights
78
+
79
+ def _update_ranges(self, y_obs):
80
+ self._y_min, self._y_max = y_obs.min(axis=0), y_obs.max(axis=0)
81
+
82
+ def update(self, y_obs: np.ndarray | Any, y_incumbent: np.ndarray | Any) -> None:
83
+ import numpy as np
84
+
85
+ y_obs = np.asarray(y_obs, dtype=float)
86
+ if y_obs.ndim != 2 or y_obs.shape[1] != self._num_metrics:
87
+ raise ValueError((y_obs.shape, self._num_metrics))
88
+ n = int(y_obs.shape[0])
89
+ if n == 0:
90
+ self._y_min, self._y_max = None, None
91
+ self._incumbent_y_raw = None
92
+ self._tr.restart()
93
+ return
94
+ prev_n = int(self._tr.prev_num_obs)
95
+ if n < prev_n:
96
+ raise ValueError((n, prev_n))
97
+ self._y_min, self._y_max = y_obs.min(axis=0), y_obs.max(axis=0)
98
+ y_incumbent = np.asarray(y_incumbent, dtype=float).reshape(1, -1)
99
+ if y_incumbent.shape != (1, self._num_metrics):
100
+ raise ValueError(
101
+ f"y_incumbent must have shape (1, {self._num_metrics}), got {y_incumbent.shape}"
102
+ )
103
+ if prev_n == 0:
104
+ self._handle_initial_update(y_incumbent, n)
105
+ return
106
+ if self._incumbent_y_raw is None:
107
+ self._handle_initial_update(y_incumbent, n)
108
+ return
109
+ scores = self.scalarize(
110
+ np.vstack([self._incumbent_y_raw, y_incumbent]), clip=True
111
+ )
112
+ old_score = float(scores[0])
113
+ new_score = float(scores[1])
114
+ self._tr.best_value = old_score
115
+ dummy_y_obs = np.zeros((n, 1))
116
+ self._tr.update(dummy_y_obs, np.array([new_score]))
117
+ if new_score > old_score:
118
+ self._incumbent_y_raw = y_incumbent.copy()
119
+
120
+ def _handle_initial_update(self, y_incumbent: np.ndarray, n: int) -> None:
121
+ import numpy as np
122
+
123
+ self._incumbent_y_raw = y_incumbent.copy()
124
+ score = self.scalarize(y_incumbent, clip=True)
125
+ dummy_y_obs = np.zeros((n, 1))
126
+ self._tr.update(dummy_y_obs, score)
127
+
128
+ def scalarize(self, y: np.ndarray | Any, *, clip: bool) -> np.ndarray:
129
+ import numpy as np
130
+
131
+ y = np.asarray(y, dtype=float)
132
+ if y.ndim != 2 or y.shape[1] != self._num_metrics:
133
+ raise ValueError(y.shape)
134
+ if self._y_min is None or self._y_max is None:
135
+ raise RuntimeError("scalarize called before any observations")
136
+ return self._scalarize_with_ranges(
137
+ y, y_min=self._y_min, y_max=self._y_max, clip=clip
138
+ )
139
+
140
+ def _scalarize_with_ranges(
141
+ self,
142
+ y: np.ndarray | Any,
143
+ *,
144
+ y_min: np.ndarray,
145
+ y_max: np.ndarray,
146
+ clip: bool,
147
+ ) -> np.ndarray:
148
+ import numpy as np
149
+
150
+ y = np.asarray(y, dtype=float)
151
+ if y.ndim != 2 or y.shape[1] != self._num_metrics:
152
+ raise ValueError(y.shape)
153
+ y_min = np.asarray(y_min, dtype=float).reshape(-1)
154
+ y_max = np.asarray(y_max, dtype=float).reshape(-1)
155
+ if y_min.shape != (self._num_metrics,) or y_max.shape != (self._num_metrics,):
156
+ raise ValueError((y_min.shape, y_max.shape, self._num_metrics))
157
+ denom = y_max - y_min
158
+ is_deg = denom <= 0.0
159
+ denom_safe = np.where(is_deg, 1.0, denom)
160
+ z = (y - y_min.reshape(1, -1)) / denom_safe.reshape(1, -1)
161
+ z = np.where(is_deg, 0.5, z)
162
+ if clip:
163
+ z = np.clip(z, 0.0, 1.0)
164
+ t = z * self._weights.reshape(1, -1)
165
+ scores = np.min(t, axis=1) + self._alpha * np.sum(t, axis=1)
166
+ return scores
167
+
168
+ def needs_restart(self) -> bool:
169
+ return self._tr.needs_restart()
170
+
171
+ def restart(self, rng: Generator | None = None) -> None:
172
+ from .config.rescalarize import Rescalarize
173
+
174
+ self._y_min = None
175
+ self._y_max = None
176
+ self._incumbent_y_raw = None
177
+ self._tr.restart()
178
+ if rng is not None and self._rescalarize == Rescalarize.ON_RESTART:
179
+ self.resample_weights(rng)
180
+
181
+ def validate_request(self, num_arms: int, *, is_fallback: bool = False) -> None:
182
+ return self._tr.validate_request(num_arms, is_fallback=is_fallback)
183
+
184
+ def compute_bounds_1d(
185
+ self, x_center: np.ndarray | Any, lengthscales: np.ndarray | None = None
186
+ ) -> tuple[np.ndarray, np.ndarray]:
187
+ return self._tr.compute_bounds_1d(x_center, lengthscales)
188
+
189
+ def generate_candidates(
190
+ self,
191
+ x_center: np.ndarray,
192
+ lengthscales: np.ndarray | None,
193
+ num_candidates: int,
194
+ rng: Generator,
195
+ sobol_engine: QMCEngine,
196
+ raasp_driver: RAASPDriver = RAASPDriver.ORIG,
197
+ num_pert: int = 20,
198
+ ) -> np.ndarray:
199
+ from .tr_helpers import generate_tr_candidates
200
+
201
+ return generate_tr_candidates(
202
+ self._tr.compute_bounds_1d,
203
+ x_center,
204
+ lengthscales,
205
+ num_candidates,
206
+ rng=rng,
207
+ candidate_rv=self._candidate_rv,
208
+ sobol_engine=sobol_engine,
209
+ raasp_driver=raasp_driver,
210
+ num_pert=num_pert,
211
+ )
212
+
213
+ def get_incumbent_indices(
214
+ self,
215
+ y: np.ndarray | Any,
216
+ rng: Generator,
217
+ ) -> np.ndarray:
218
+ import numpy as np
219
+
220
+ y = np.asarray(y, dtype=float)
221
+ if y.ndim != 2:
222
+ raise ValueError(y.shape)
223
+ n = y.shape[0]
224
+ if n == 0:
225
+ return np.array([], dtype=int)
226
+ from nds import ndomsort
227
+
228
+ idx_front = np.array(ndomsort.non_domin_sort(-y, only_front_indices=True))
229
+ return np.where(idx_front == 0)[0]
230
+
231
+ def get_incumbent_value(
232
+ self,
233
+ y_obs: np.ndarray | Any,
234
+ rng: Generator,
235
+ mu_obs: np.ndarray | None = None,
236
+ ) -> np.ndarray:
237
+ import numpy as np
238
+
239
+ y_obs = np.asarray(y_obs, dtype=float)
240
+ if y_obs.ndim != 2 or y_obs.shape[1] != self._num_metrics:
241
+ raise ValueError((y_obs.shape, self._num_metrics))
242
+ n = int(y_obs.shape[0])
243
+ if n == 0:
244
+ return np.array([], dtype=float)
245
+ idx = self.get_incumbent_index(y_obs, rng, mu=mu_obs)
246
+ use_mu = bool(getattr(self.incumbent_selector, "noise_aware", False))
247
+ values = np.asarray(mu_obs if use_mu else y_obs, dtype=float)
248
+ if values.ndim != 2 or values.shape[1] != self._num_metrics:
249
+ raise ValueError((values.shape, self._num_metrics))
250
+ return values[idx : idx + 1].copy()