ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -269
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +69 -70
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +0 -1
  18. enn/enn/enn_params.py +3 -22
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +60 -46
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +131 -70
  84. enn/turbo/no_trust_region.py +32 -39
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +36 -38
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +0 -1
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +126 -58
  100. enn/turbo/turbo_utils.py +98 -161
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -144
  113. enn/turbo/lhd_only_impl.py +0 -49
  114. enn/turbo/turbo_config.py +0 -72
  115. enn/turbo/turbo_enn_impl.py +0 -201
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -76
  118. enn/turbo/turbo_one_impl.py +0 -302
  119. enn/turbo/turbo_optimizer.py +0 -525
  120. enn/turbo/turbo_zero_impl.py +0 -29
  121. ennbo-0.1.2.dist-info/RECORD +0 -29
  122. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,87 @@
1
+ from .acquisition import (
2
+ AcqOptimizerConfig,
3
+ AcquisitionConfig,
4
+ DrawAcquisitionConfig,
5
+ HnROptimizerConfig,
6
+ NDSOptimizerConfig,
7
+ ParetoAcquisitionConfig,
8
+ RAASPOptimizerConfig,
9
+ RandomAcquisitionConfig,
10
+ UCBAcquisitionConfig,
11
+ )
12
+ from .base import (
13
+ CandidateGenConfig,
14
+ InitConfig,
15
+ )
16
+ from .enums import (
17
+ AcqType,
18
+ CandidateRV,
19
+ )
20
+ from .init_strategies import HybridInit, InitStrategy, LHDOnlyInit
21
+ from .optimizer_config import OptimizerConfig
22
+ from .surrogate import (
23
+ ENNFitConfig,
24
+ ENNSurrogateConfig,
25
+ GPSurrogateConfig,
26
+ NoSurrogateConfig,
27
+ SurrogateConfig,
28
+ )
29
+ from .trust_region import (
30
+ MorboTRConfig,
31
+ MultiObjectiveConfig,
32
+ NoTRConfig,
33
+ RescalePolicyConfig,
34
+ TRLengthConfig,
35
+ TrustRegionConfig,
36
+ TurboTRConfig,
37
+ )
38
+
39
+
40
+ def __getattr__(name: str) -> object:
41
+ if name in (
42
+ "lhd_only_config",
43
+ "turbo_enn_config",
44
+ "turbo_one_config",
45
+ "turbo_zero_config",
46
+ ):
47
+ from . import factory
48
+
49
+ return getattr(factory, name)
50
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
51
+
52
+
53
+ __all__ = [
54
+ "AcqOptimizerConfig",
55
+ "AcqType",
56
+ "AcquisitionConfig",
57
+ "CandidateGenConfig",
58
+ "CandidateRV",
59
+ "DrawAcquisitionConfig",
60
+ "ENNFitConfig",
61
+ "ENNSurrogateConfig",
62
+ "GPSurrogateConfig",
63
+ "HnROptimizerConfig",
64
+ "InitConfig",
65
+ "InitStrategy",
66
+ "HybridInit",
67
+ "LHDOnlyInit",
68
+ "lhd_only_config",
69
+ "MorboTRConfig",
70
+ "MultiObjectiveConfig",
71
+ "NDSOptimizerConfig",
72
+ "NoSurrogateConfig",
73
+ "NoTRConfig",
74
+ "OptimizerConfig",
75
+ "ParetoAcquisitionConfig",
76
+ "RAASPOptimizerConfig",
77
+ "RandomAcquisitionConfig",
78
+ "RescalePolicyConfig",
79
+ "SurrogateConfig",
80
+ "TRLengthConfig",
81
+ "TrustRegionConfig",
82
+ "turbo_enn_config",
83
+ "turbo_one_config",
84
+ "TurboTRConfig",
85
+ "turbo_zero_config",
86
+ "UCBAcquisitionConfig",
87
+ ]
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations
2
+ from enum import Enum
3
+
4
+
5
+ class AcqType(Enum):
6
+ THOMPSON = "thompson"
7
+ PARETO = "pareto"
8
+ UCB = "ucb"
@@ -0,0 +1,26 @@
1
+ from .draw_acquisition_config import DrawAcquisitionConfig
2
+ from .hnr_optimizer_config import HnROptimizerConfig
3
+ from .nds_optimizer_config import NDSOptimizerConfig
4
+ from .pareto_acquisition_config import ParetoAcquisitionConfig
5
+ from .raasp_optimizer_config import RAASPOptimizerConfig
6
+ from .random_acquisition_config import RandomAcquisitionConfig
7
+ from .ucb_acquisition_config import UCBAcquisitionConfig
8
+
9
+ AcquisitionConfig = (
10
+ UCBAcquisitionConfig
11
+ | DrawAcquisitionConfig
12
+ | ParetoAcquisitionConfig
13
+ | RandomAcquisitionConfig
14
+ )
15
+ AcqOptimizerConfig = RAASPOptimizerConfig | HnROptimizerConfig | NDSOptimizerConfig
16
+ __all__ = [
17
+ "AcqOptimizerConfig",
18
+ "AcquisitionConfig",
19
+ "DrawAcquisitionConfig",
20
+ "HnROptimizerConfig",
21
+ "NDSOptimizerConfig",
22
+ "ParetoAcquisitionConfig",
23
+ "RAASPOptimizerConfig",
24
+ "RandomAcquisitionConfig",
25
+ "UCBAcquisitionConfig",
26
+ ]
@@ -0,0 +1,4 @@
1
+ from .candidate_gen_config import CandidateGenConfig
2
+ from .init_config import InitConfig
3
+
4
+ __all__ = ["CandidateGenConfig", "InitConfig"]
@@ -0,0 +1,49 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING, Any
4
+ from .candidate_rv import CandidateRV
5
+ from .raasp_driver import RAASPDriver
6
+
7
+ if TYPE_CHECKING:
8
+
9
+ class NumCandidatesFn:
10
+ def __call__(self, *, num_dim: int, num_arms: int) -> int: ...
11
+ else:
12
+ NumCandidatesFn = Any
13
+
14
+
15
+ def default_num_candidates(*, num_dim: int, num_arms: int) -> int:
16
+ return min(5000, 100 * int(num_dim))
17
+
18
+
19
+ def const_num_candidates(n: int) -> NumCandidatesFn:
20
+ n = int(n)
21
+ if n <= 0:
22
+ raise ValueError(f"num_candidates must be > 0, got {n}")
23
+
24
+ def fn(*, num_dim: int, num_arms: int) -> int:
25
+ return n
26
+
27
+ return fn
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class CandidateGenConfig:
32
+ candidate_rv: CandidateRV = CandidateRV.SOBOL
33
+ num_candidates: NumCandidatesFn = field(
34
+ default_factory=lambda: default_num_candidates
35
+ )
36
+ raasp_driver: RAASPDriver = RAASPDriver.ORIG
37
+
38
+ def __post_init__(self) -> None:
39
+ if not isinstance(self.candidate_rv, CandidateRV):
40
+ raise ValueError(
41
+ f"candidate_rv must be a CandidateRV enum, got {self.candidate_rv!r}"
42
+ )
43
+ if not callable(self.num_candidates):
44
+ raise ValueError(
45
+ f"num_candidates must be callable, got {type(self.num_candidates)!r}"
46
+ )
47
+ test_n = int(self.num_candidates(num_dim=1, num_arms=1))
48
+ if test_n <= 0:
49
+ raise ValueError(f"num_candidates must be > 0, got {test_n}")
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+ from enum import Enum
3
+
4
+
5
+ class CandidateRV(Enum):
6
+ SOBOL = "sobol"
7
+ UNIFORM = "uniform"
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import AcquisitionOptimizer
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class DrawAcquisitionConfig:
11
+ def build(self) -> AcquisitionOptimizer:
12
+ from ..components.acquisition import ThompsonAcqOptimizer
13
+
14
+ return ThompsonAcqOptimizer()
@@ -0,0 +1,6 @@
1
+ from enum import Enum, auto
2
+
3
+
4
+ class ENNIndexDriver(Enum):
5
+ FLAT = auto()
6
+ HNSW = auto()
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+ from .enums import ENNIndexDriver
5
+
6
+ if TYPE_CHECKING:
7
+ from ..components.protocols import Surrogate
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class ENNFitConfig:
12
+ num_fit_samples: int | None = None
13
+ num_fit_candidates: int | None = None
14
+
15
+ def __post_init__(self) -> None:
16
+ if self.num_fit_samples is not None and self.num_fit_samples <= 0:
17
+ raise ValueError(f"num_fit_samples must be > 0, got {self.num_fit_samples}")
18
+ if self.num_fit_candidates is not None and self.num_fit_candidates <= 0:
19
+ raise ValueError(
20
+ f"num_fit_candidates must be > 0, got {self.num_fit_candidates}"
21
+ )
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class ENNSurrogateConfig:
26
+ k: int | None = None
27
+ fit: ENNFitConfig = ENNFitConfig()
28
+ scale_x: bool = False
29
+ index_driver: ENNIndexDriver = ENNIndexDriver.FLAT
30
+
31
+ @property
32
+ def num_fit_samples(self) -> int | None:
33
+ return self.fit.num_fit_samples
34
+
35
+ @property
36
+ def num_fit_candidates(self) -> int | None:
37
+ return self.fit.num_fit_candidates
38
+
39
+ def build(self) -> Surrogate:
40
+ from ..components.surrogates import ENNSurrogate
41
+
42
+ return ENNSurrogate(self)
@@ -0,0 +1,7 @@
1
+ from .acq_type import AcqType
2
+ from .candidate_rv import CandidateRV
3
+ from .enn_index_driver import ENNIndexDriver
4
+ from .raasp_driver import RAASPDriver
5
+ from .rescalarize import Rescalarize
6
+
7
+ __all__ = ["AcqType", "CandidateRV", "ENNIndexDriver", "RAASPDriver", "Rescalarize"]
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+ from . import acquisition as acq
3
+ from . import surrogate as sur
4
+ from . import trust_region as tr
5
+ from .candidate_gen_config import (
6
+ CandidateGenConfig,
7
+ NumCandidatesFn,
8
+ const_num_candidates,
9
+ )
10
+ from .enums import AcqType, CandidateRV
11
+ from .init_config import InitConfig
12
+ from .optimizer_config import ObservationHistoryConfig, OptimizerConfig
13
+
14
+
15
+ def _make_candidate_gen_config(
16
+ candidate_rv: CandidateRV,
17
+ num_candidates: NumCandidatesFn | int | None,
18
+ ) -> CandidateGenConfig:
19
+ if num_candidates is None:
20
+ return CandidateGenConfig(candidate_rv=candidate_rv)
21
+ if isinstance(num_candidates, int):
22
+ num_candidates = const_num_candidates(num_candidates)
23
+ return CandidateGenConfig(candidate_rv=candidate_rv, num_candidates=num_candidates)
24
+
25
+
26
+ def turbo_one_config(
27
+ *,
28
+ num_candidates: int | None = None,
29
+ num_init: int | None = None,
30
+ trailing_obs: int | None = None,
31
+ trust_region: tr.TrustRegionConfig | None = None,
32
+ candidate_rv: CandidateRV = CandidateRV.SOBOL,
33
+ ) -> OptimizerConfig:
34
+ return OptimizerConfig(
35
+ trust_region=trust_region or tr.TurboTRConfig(),
36
+ candidates=_make_candidate_gen_config(candidate_rv, num_candidates),
37
+ init=InitConfig(num_init=num_init),
38
+ surrogate=sur.GPSurrogateConfig(),
39
+ acquisition=acq.DrawAcquisitionConfig(),
40
+ acq_optimizer=acq.RAASPOptimizerConfig(),
41
+ observation_history=ObservationHistoryConfig(trailing_obs=trailing_obs),
42
+ )
43
+
44
+
45
+ def turbo_zero_config(
46
+ *,
47
+ num_candidates: int | None = None,
48
+ num_init: int | None = None,
49
+ trailing_obs: int | None = None,
50
+ trust_region: tr.TrustRegionConfig | None = None,
51
+ candidate_rv: CandidateRV = CandidateRV.SOBOL,
52
+ ) -> OptimizerConfig:
53
+ return OptimizerConfig(
54
+ trust_region=trust_region or tr.TurboTRConfig(),
55
+ candidates=_make_candidate_gen_config(candidate_rv, num_candidates),
56
+ init=InitConfig(num_init=num_init),
57
+ surrogate=sur.NoSurrogateConfig(),
58
+ acquisition=acq.RandomAcquisitionConfig(),
59
+ acq_optimizer=acq.RAASPOptimizerConfig(),
60
+ observation_history=ObservationHistoryConfig(trailing_obs=trailing_obs),
61
+ )
62
+
63
+
64
+ def turbo_enn_config(
65
+ *,
66
+ enn: sur.ENNSurrogateConfig | None = None,
67
+ trust_region: tr.TrustRegionConfig | None = None,
68
+ candidates: CandidateGenConfig | None = None,
69
+ num_init: int | None = None,
70
+ trailing_obs: int | None = None,
71
+ acq_type: AcqType = AcqType.PARETO,
72
+ ) -> OptimizerConfig:
73
+ if acq_type == AcqType.PARETO:
74
+ acquisition = acq.ParetoAcquisitionConfig()
75
+ acq_optimizer = acq.NDSOptimizerConfig()
76
+ elif acq_type == AcqType.UCB:
77
+ acquisition = acq.UCBAcquisitionConfig()
78
+ acq_optimizer = acq.RAASPOptimizerConfig()
79
+ elif acq_type == AcqType.THOMPSON:
80
+ acquisition = acq.DrawAcquisitionConfig()
81
+ acq_optimizer = acq.RAASPOptimizerConfig()
82
+ else:
83
+ raise ValueError(
84
+ f"acq_type must be AcqType.THOMPSON, AcqType.PARETO, or AcqType.UCB, got {acq_type!r}"
85
+ )
86
+ surrogate = enn if enn is not None else sur.ENNSurrogateConfig()
87
+ if surrogate.num_fit_samples is None and acq_type != AcqType.PARETO:
88
+ raise ValueError(f"enn.num_fit_samples required for acq_type={acq_type!r}")
89
+ return OptimizerConfig(
90
+ trust_region=trust_region or tr.TurboTRConfig(),
91
+ candidates=candidates or CandidateGenConfig(),
92
+ init=InitConfig(num_init=num_init),
93
+ surrogate=surrogate,
94
+ acquisition=acquisition,
95
+ acq_optimizer=acq_optimizer,
96
+ observation_history=ObservationHistoryConfig(trailing_obs=trailing_obs),
97
+ )
98
+
99
+
100
+ def lhd_only_config(
101
+ *,
102
+ num_candidates: int | None = None,
103
+ num_init: int | None = None,
104
+ trailing_obs: int | None = None,
105
+ trust_region: tr.TrustRegionConfig | None = None,
106
+ candidate_rv: CandidateRV = CandidateRV.SOBOL,
107
+ ) -> OptimizerConfig:
108
+ from .init_strategies import LHDOnlyInit
109
+
110
+ return OptimizerConfig(
111
+ trust_region=trust_region or tr.NoTRConfig(),
112
+ candidates=_make_candidate_gen_config(candidate_rv, num_candidates),
113
+ init=InitConfig(init_strategy=LHDOnlyInit(), num_init=num_init),
114
+ surrogate=sur.NoSurrogateConfig(),
115
+ acquisition=acq.RandomAcquisitionConfig(),
116
+ acq_optimizer=acq.RAASPOptimizerConfig(),
117
+ observation_history=ObservationHistoryConfig(trailing_obs=trailing_obs),
118
+ )
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import Surrogate
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class GPSurrogateConfig:
11
+ def build(self) -> Surrogate:
12
+ from ..components.surrogates import GPSurrogate
13
+
14
+ return GPSurrogate()
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class HnROptimizerConfig:
7
+ pass
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from .init_strategies import HybridInit, InitStrategy
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class InitConfig:
8
+ init_strategy: InitStrategy = HybridInit()
9
+ num_init: int | None = None
10
+
11
+ def __post_init__(self) -> None:
12
+ if not isinstance(self.init_strategy, InitStrategy):
13
+ raise ValueError(
14
+ f"init_strategy must be an InitStrategy, got {self.init_strategy!r}"
15
+ )
16
+ if self.num_init is not None and self.num_init <= 0:
17
+ raise ValueError(f"num_init must be > 0, got {self.num_init}")
@@ -0,0 +1,9 @@
1
+ from .hybrid_init import HybridInit
2
+ from .init_strategy import InitStrategy
3
+ from .lhd_only_init import LHDOnlyInit
4
+
5
+ __all__ = [
6
+ "HybridInit",
7
+ "InitStrategy",
8
+ "LHDOnlyInit",
9
+ ]
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+ from .init_strategy import InitStrategy
5
+
6
+ if TYPE_CHECKING:
7
+ import numpy as np
8
+ from numpy.random import Generator
9
+ from ...strategies import OptimizationStrategy
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class HybridInit(InitStrategy):
14
+ def create_runtime_strategy(
15
+ self,
16
+ *,
17
+ bounds: np.ndarray,
18
+ rng: Generator,
19
+ num_init: int | None,
20
+ ) -> OptimizationStrategy:
21
+ from ...strategies import TurboHybridStrategy
22
+
23
+ return TurboHybridStrategy.create(bounds=bounds, rng=rng, num_init=num_init)
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+ from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+ from numpy.random import Generator
8
+ from ...strategies import OptimizationStrategy
9
+
10
+
11
+ class InitStrategy(ABC):
12
+ @abstractmethod
13
+ def create_runtime_strategy(
14
+ self,
15
+ *,
16
+ bounds: np.ndarray,
17
+ rng: Generator,
18
+ num_init: int | None,
19
+ ) -> OptimizationStrategy: ...
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+ from .init_strategy import InitStrategy
5
+
6
+ if TYPE_CHECKING:
7
+ import numpy as np
8
+ from numpy.random import Generator
9
+ from ...strategies import OptimizationStrategy
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class LHDOnlyInit(InitStrategy):
14
+ def create_runtime_strategy(
15
+ self,
16
+ *,
17
+ bounds: np.ndarray,
18
+ rng: Generator,
19
+ num_init: int | None,
20
+ ) -> OptimizationStrategy:
21
+ from ...strategies import LHDOnlyStrategy
22
+
23
+ del num_init
24
+ return LHDOnlyStrategy.create(bounds=bounds, rng=rng)
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ from .rescalarize import Rescalarize
7
+ from .turbo_tr_config import TRLengthConfig
8
+
9
+ if TYPE_CHECKING:
10
+ from numpy.random import Generator
11
+
12
+ from ..components.protocols import TrustRegion
13
+ from .enums import CandidateRV
14
+
15
+ from .enums import CandidateRV
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class MultiObjectiveConfig:
20
+ num_metrics: int
21
+ alpha: float = 0.05
22
+
23
+ def __post_init__(self) -> None:
24
+ if self.num_metrics < 2:
25
+ raise ValueError(
26
+ f"num_metrics must be >= 2 for MORBO, got {self.num_metrics}"
27
+ )
28
+ if self.alpha <= 0:
29
+ raise ValueError(f"alpha must be > 0, got {self.alpha}")
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class RescalePolicyConfig:
34
+ rescalarize: Rescalarize = Rescalarize.ON_PROPOSE
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class MorboTRConfig:
39
+ multi_objective: MultiObjectiveConfig
40
+ length: TRLengthConfig = TRLengthConfig()
41
+ rescale_policy: RescalePolicyConfig = RescalePolicyConfig()
42
+ noise_aware: bool = False
43
+
44
+ @property
45
+ def rescalarize(self) -> Rescalarize:
46
+ return self.rescale_policy.rescalarize
47
+
48
+ @property
49
+ def num_metrics(self) -> int:
50
+ return self.multi_objective.num_metrics
51
+
52
+ @property
53
+ def alpha(self) -> float:
54
+ return self.multi_objective.alpha
55
+
56
+ @property
57
+ def length_init(self) -> float:
58
+ return self.length.length_init
59
+
60
+ @property
61
+ def length_min(self) -> float:
62
+ return self.length.length_min
63
+
64
+ @property
65
+ def length_max(self) -> float:
66
+ return self.length.length_max
67
+
68
+ def build(
69
+ self,
70
+ *,
71
+ num_dim: int,
72
+ rng: Generator,
73
+ candidate_rv: CandidateRV = CandidateRV.SOBOL,
74
+ ) -> TrustRegion:
75
+ from ..morbo_trust_region import MorboTrustRegion
76
+
77
+ return MorboTrustRegion(
78
+ config=self,
79
+ num_dim=num_dim,
80
+ rng=rng,
81
+ candidate_rv=candidate_rv,
82
+ )
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class NDSOptimizerConfig:
7
+ pass
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ..components.protocols import Surrogate
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class NoSurrogateConfig:
11
+ def build(self) -> Surrogate:
12
+ from ..components.surrogates import NoSurrogate
13
+
14
+ return NoSurrogate()
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from numpy.random import Generator
8
+
9
+ from ..components.protocols import TrustRegion
10
+ from .enums import CandidateRV
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class NoTRConfig:
15
+ noise_aware: bool = False
16
+
17
+ def build(
18
+ self,
19
+ *,
20
+ num_dim: int,
21
+ rng: Generator,
22
+ candidate_rv: CandidateRV | None = None,
23
+ ) -> TrustRegion:
24
+ from ..components.incumbent_selector import ScalarIncumbentSelector
25
+ from ..no_trust_region import NoTrustRegion
26
+
27
+ return NoTrustRegion(
28
+ config=self,
29
+ num_dim=num_dim,
30
+ incumbent_selector=ScalarIncumbentSelector(noise_aware=self.noise_aware),
31
+ )