ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -269
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +69 -70
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +0 -1
  18. enn/enn/enn_params.py +3 -22
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +60 -46
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +131 -70
  84. enn/turbo/no_trust_region.py +32 -39
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +36 -38
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +0 -1
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +126 -58
  100. enn/turbo/turbo_utils.py +98 -161
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -144
  113. enn/turbo/lhd_only_impl.py +0 -49
  114. enn/turbo/turbo_config.py +0 -72
  115. enn/turbo/turbo_enn_impl.py +0 -201
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -76
  118. enn/turbo/turbo_one_impl.py +0 -302
  119. enn/turbo/turbo_optimizer.py +0 -525
  120. enn/turbo/turbo_zero_impl.py +0 -29
  121. ennbo-0.1.2.dist-info/RECORD +0 -29
  122. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING
4
+ from .candidate_gen_config import CandidateGenConfig
5
+ from .init_config import InitConfig
6
+ from .surrogate import NoSurrogateConfig, SurrogateConfig
7
+ from .trust_region import TrustRegionConfig, TurboTRConfig
8
+
9
+ if TYPE_CHECKING:
10
+ from .acquisition import AcqOptimizerConfig, AcquisitionConfig
11
+ from .enums import CandidateRV, RAASPDriver
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ObservationHistoryConfig:
16
+ trailing_obs: int | None = None
17
+
18
+ def __post_init__(self) -> None:
19
+ if self.trailing_obs is not None and self.trailing_obs <= 0:
20
+ raise ValueError(f"trailing_obs must be > 0, got {self.trailing_obs}")
21
+
22
+
23
+ def _default_acquisition():
24
+ from .acquisition import RandomAcquisitionConfig
25
+
26
+ return RandomAcquisitionConfig()
27
+
28
+
29
+ def _default_acq_optimizer():
30
+ from .acquisition import RAASPOptimizerConfig
31
+
32
+ return RAASPOptimizerConfig()
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class OptimizerConfig:
37
+ trust_region: TrustRegionConfig = TurboTRConfig()
38
+ candidates: CandidateGenConfig = CandidateGenConfig()
39
+ init: InitConfig = InitConfig()
40
+ surrogate: SurrogateConfig = NoSurrogateConfig()
41
+ acquisition: AcquisitionConfig = field(default_factory=_default_acquisition)
42
+ acq_optimizer: AcqOptimizerConfig = field(default_factory=_default_acq_optimizer)
43
+ observation_history: ObservationHistoryConfig = ObservationHistoryConfig()
44
+
45
+ def __post_init__(self) -> None:
46
+ from .validation import validate_optimizer_config
47
+
48
+ validate_optimizer_config(self)
49
+
50
+ @property
51
+ def num_metrics(self) -> int | None:
52
+ from .morbo_tr_config import MorboTRConfig
53
+
54
+ if isinstance(self.trust_region, MorboTRConfig):
55
+ return self.trust_region.num_metrics
56
+ return None
57
+
58
+ @property
59
+ def candidate_rv(self) -> CandidateRV:
60
+ return self.candidates.candidate_rv
61
+
62
+ @property
63
+ def raasp_driver(self) -> RAASPDriver:
64
+ return self.candidates.raasp_driver
65
+
66
+ @property
67
+ def num_candidates(self):
68
+ return self.candidates.num_candidates
69
+
70
+ @property
71
+ def trailing_obs(self) -> int | None:
72
+ return self.observation_history.trailing_obs
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import AcquisitionOptimizer
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class ParetoAcquisitionConfig:
11
+ def build(self) -> AcquisitionOptimizer:
12
+ from ..components.acquisition import ParetoAcqOptimizer
13
+
14
+ return ParetoAcqOptimizer()
@@ -0,0 +1,6 @@
1
+ from enum import Enum, auto
2
+
3
+
4
+ class RAASPDriver(Enum):
5
+ ORIG = auto()
6
+ FAST = auto()
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class RAASPOptimizerConfig:
7
+ pass
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import AcquisitionOptimizer
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class RandomAcquisitionConfig:
11
+ def build(self) -> AcquisitionOptimizer:
12
+ from ..components.acquisition import RandomAcqOptimizer
13
+
14
+ return RandomAcqOptimizer()
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+ from enum import Enum
3
+
4
+
5
+ class Rescalarize(Enum):
6
+ ON_RESTART = "on_restart"
7
+ ON_PROPOSE = "on_propose"
@@ -0,0 +1,12 @@
1
+ from .enn_surrogate_config import ENNFitConfig, ENNSurrogateConfig
2
+ from .gp_surrogate_config import GPSurrogateConfig
3
+ from .no_surrogate_config import NoSurrogateConfig
4
+
5
+ SurrogateConfig = NoSurrogateConfig | GPSurrogateConfig | ENNSurrogateConfig
6
+ __all__ = [
7
+ "ENNFitConfig",
8
+ "ENNSurrogateConfig",
9
+ "GPSurrogateConfig",
10
+ "NoSurrogateConfig",
11
+ "SurrogateConfig",
12
+ ]
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Protocol
4
+
5
+ from .morbo_tr_config import MorboTRConfig, MultiObjectiveConfig, RescalePolicyConfig
6
+ from .no_tr_config import NoTRConfig
7
+ from .turbo_tr_config import TRLengthConfig, TurboTRConfig
8
+
9
+ if TYPE_CHECKING:
10
+ from numpy.random import Generator
11
+
12
+ from ..components.protocols import TrustRegion
13
+ from .enums import CandidateRV
14
+
15
+
16
+ class TrustRegionConfig(Protocol):
17
+ def build(
18
+ self,
19
+ *,
20
+ num_dim: int,
21
+ rng: Generator,
22
+ candidate_rv: CandidateRV,
23
+ ) -> TrustRegion: ...
24
+
25
+
26
+ __all__ = [
27
+ "MorboTRConfig",
28
+ "MultiObjectiveConfig",
29
+ "NoTRConfig",
30
+ "RescalePolicyConfig",
31
+ "TRLengthConfig",
32
+ "TrustRegionConfig",
33
+ "TurboTRConfig",
34
+ ]
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from numpy.random import Generator
8
+
9
+ from ..components.protocols import TrustRegion
10
+ from .enums import CandidateRV
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TRLengthConfig:
15
+ length_init: float = 0.8
16
+ length_min: float = 0.5**7
17
+ length_max: float = 1.6
18
+
19
+ def __post_init__(self) -> None:
20
+ if self.length_init <= 0:
21
+ raise ValueError(f"length_init must be > 0, got {self.length_init}")
22
+ if self.length_min <= 0:
23
+ raise ValueError(f"length_min must be > 0, got {self.length_min}")
24
+ if self.length_max <= 0:
25
+ raise ValueError(f"length_max must be > 0, got {self.length_max}")
26
+ if self.length_min >= self.length_max:
27
+ raise ValueError(
28
+ f"length_min must be < length_max, got {self.length_min} >= {self.length_max}"
29
+ )
30
+ if self.length_init > self.length_max:
31
+ raise ValueError(
32
+ f"length_init must be <= length_max, got {self.length_init} > {self.length_max}"
33
+ )
34
+ if self.length_min > self.length_init:
35
+ raise ValueError(
36
+ f"length_min must be <= length_init, got {self.length_min} > {self.length_init}"
37
+ )
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class TurboTRConfig:
42
+ length: TRLengthConfig = TRLengthConfig()
43
+ noise_aware: bool = False
44
+
45
+ @property
46
+ def length_init(self) -> float:
47
+ return self.length.length_init
48
+
49
+ @property
50
+ def length_min(self) -> float:
51
+ return self.length.length_min
52
+
53
+ @property
54
+ def length_max(self) -> float:
55
+ return self.length.length_max
56
+
57
+ def build(
58
+ self,
59
+ *,
60
+ num_dim: int,
61
+ rng: Generator,
62
+ candidate_rv: CandidateRV | None = None,
63
+ ) -> TrustRegion:
64
+ from ..components.incumbent_selector import ScalarIncumbentSelector
65
+ from ..turbo_trust_region import TurboTrustRegion
66
+
67
+ return TurboTrustRegion(
68
+ config=self,
69
+ num_dim=num_dim,
70
+ incumbent_selector=ScalarIncumbentSelector(noise_aware=self.noise_aware),
71
+ )
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import AcquisitionOptimizer
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class UCBAcquisitionConfig:
11
+ def build(self) -> AcquisitionOptimizer:
12
+ from ..components.acquisition import UCBAcqOptimizer
13
+
14
+ return UCBAcqOptimizer()
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+ from typing import Any
3
+
4
+
5
+ def validate_optimizer_config(cfg: Any) -> None:
6
+ from .acquisition import (
7
+ DrawAcquisitionConfig,
8
+ HnROptimizerConfig,
9
+ NDSOptimizerConfig,
10
+ ParetoAcquisitionConfig,
11
+ UCBAcquisitionConfig,
12
+ )
13
+ from .init_strategies import LHDOnlyInit
14
+ from .surrogate import GPSurrogateConfig, NoSurrogateConfig
15
+
16
+ if isinstance(cfg.init.init_strategy, LHDOnlyInit):
17
+ if not isinstance(cfg.surrogate, NoSurrogateConfig):
18
+ raise ValueError(
19
+ "init_strategy='lhd_only' requires NoSurrogateConfig surrogate"
20
+ )
21
+ if isinstance(cfg.surrogate, NoSurrogateConfig):
22
+ if isinstance(cfg.acquisition, DrawAcquisitionConfig):
23
+ raise ValueError(
24
+ "DrawAcquisitionConfig (Thompson sampling) requires a surrogate. "
25
+ "NoSurrogateConfig is not compatible with DrawAcquisitionConfig."
26
+ )
27
+ if isinstance(cfg.acquisition, UCBAcquisitionConfig):
28
+ raise ValueError(
29
+ "UCBAcquisitionConfig requires a surrogate. "
30
+ "NoSurrogateConfig is not compatible with UCBAcquisitionConfig."
31
+ )
32
+ if isinstance(cfg.acquisition, ParetoAcquisitionConfig):
33
+ if not isinstance(cfg.acq_optimizer, NDSOptimizerConfig):
34
+ raise ValueError("ParetoAcquisitionConfig requires NDSOptimizerConfig")
35
+ if isinstance(cfg.acq_optimizer, HnROptimizerConfig):
36
+ if isinstance(cfg.acquisition, ParetoAcquisitionConfig):
37
+ raise ValueError(
38
+ "HnROptimizerConfig is not compatible with ParetoAcquisitionConfig"
39
+ )
40
+ if isinstance(cfg.surrogate, GPSurrogateConfig) and isinstance(
41
+ cfg.acquisition, DrawAcquisitionConfig
42
+ ):
43
+ raise NotImplementedError(
44
+ "GP surrogate with DrawAcquisitionConfig and HnROptimizerConfig is not yet implemented"
45
+ )
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+
4
+
5
+ def hypervolume_2d_max(y: np.ndarray, ref_point: np.ndarray) -> float:
6
+ y = np.asarray(y, dtype=float)
7
+ ref_point = np.asarray(ref_point, dtype=float)
8
+ if y.size == 0:
9
+ return 0.0
10
+ if y.ndim != 2:
11
+ raise ValueError(y.shape)
12
+ if y.shape[1] != 2:
13
+ raise ValueError(y.shape)
14
+ if ref_point.shape != (2,):
15
+ raise ValueError(ref_point.shape)
16
+ mask = (y[:, 0] > ref_point[0]) & (y[:, 1] > ref_point[1])
17
+ y = y[mask]
18
+ if y.size == 0:
19
+ return 0.0
20
+ order = np.argsort(y[:, 0], kind="mergesort")[::-1]
21
+ y = y[order]
22
+ hv = 0.0
23
+ best_y1 = ref_point[1]
24
+ for i in range(len(y)):
25
+ x0, y1 = y[i]
26
+ if y1 > best_y1:
27
+ best_y1 = y1
28
+ x_next = y[i + 1, 0] if i + 1 < len(y) else ref_point[0]
29
+ hv += (x0 - x_next) * (best_y1 - ref_point[1])
30
+ return float(hv)
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
+ from numpy.random import Generator
7
+
8
+
9
+ def get_x_center_fallback(
10
+ config: Any,
11
+ x_obs_list: list,
12
+ y_obs_list: list,
13
+ rng: Generator,
14
+ tr_state: Any = None,
15
+ ) -> np.ndarray | None:
16
+ import numpy as np
17
+ from .components.incumbent_selector import ScalarIncumbentSelector
18
+
19
+ y_array = np.asarray(y_obs_list, dtype=float)
20
+ if y_array.size == 0:
21
+ return None
22
+ x_array = np.asarray(x_obs_list, dtype=float)
23
+ if tr_state is not None and hasattr(tr_state, "incumbent_selector"):
24
+ selector = tr_state.incumbent_selector
25
+ else:
26
+ selector = ScalarIncumbentSelector(noise_aware=False)
27
+ idx = selector.select(y_array, None, rng)
28
+ return x_array[idx]
29
+
30
+
31
+ def handle_restart_clear_always(
32
+ x_obs_list: list,
33
+ y_obs_list: list,
34
+ yvar_obs_list: list,
35
+ ) -> tuple[bool, int]:
36
+ x_obs_list.clear()
37
+ y_obs_list.clear()
38
+ yvar_obs_list.clear()
39
+ return True, 0
40
+
41
+
42
+ def handle_restart_check_multi_objective(
43
+ tr_state: Any,
44
+ x_obs_list: list,
45
+ y_obs_list: list,
46
+ yvar_obs_list: list,
47
+ init_idx: int,
48
+ ) -> tuple[bool, int]:
49
+ is_multi = (
50
+ tr_state is not None
51
+ and hasattr(tr_state, "num_metrics")
52
+ and tr_state.num_metrics > 1
53
+ )
54
+ if is_multi:
55
+ x_obs_list.clear()
56
+ y_obs_list.clear()
57
+ yvar_obs_list.clear()
58
+ return True, 0
59
+ return False, init_idx
60
+
61
+
62
+ def estimate_y_passthrough(y_observed: np.ndarray) -> np.ndarray:
63
+ import numpy as np
64
+
65
+ y = np.asarray(y_observed, dtype=float)
66
+ if y.ndim == 1:
67
+ return y.reshape(-1, 1)
68
+ return y