ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -269
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +69 -70
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +0 -1
  18. enn/enn/enn_params.py +3 -22
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +60 -46
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +131 -70
  84. enn/turbo/no_trust_region.py +32 -39
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +36 -38
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +0 -1
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +126 -58
  100. enn/turbo/turbo_utils.py +98 -161
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -144
  113. enn/turbo/lhd_only_impl.py +0 -49
  114. enn/turbo/turbo_config.py +0 -72
  115. enn/turbo/turbo_enn_impl.py +0 -201
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -76
  118. enn/turbo/turbo_one_impl.py +0 -302
  119. enn/turbo/turbo_optimizer.py +0 -525
  120. enn/turbo/turbo_zero_impl.py +0 -29
  121. ennbo-0.1.2.dist-info/RECORD +0 -29
  122. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any, Protocol
3
+
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
+ from .enn_params import ENNParams, PosteriorFlags
7
+
8
+
9
+ class ENNLike(Protocol):
10
+ _num_dim: int
11
+ _num_metrics: int
12
+ _x_scale: np.ndarray
13
+ _scale_x: bool
14
+ _enn_index: Any
15
+ _train_y: np.ndarray
16
+ _train_yvar: np.ndarray | None
17
+
18
+ def __len__(self) -> int: ...
19
+ def posterior(self, x: np.ndarray, *, params: ENNParams, flags: PosteriorFlags):
20
+ raise NotImplementedError
21
+
22
+ def _empty_posterior_internals(self, batch_size: int):
23
+ raise NotImplementedError
24
+
25
+ def _compute_weighted_stats(
26
+ self,
27
+ dist2s: np.ndarray,
28
+ y_neighbors: np.ndarray,
29
+ *,
30
+ yvar_neighbors: np.ndarray | None,
31
+ params: ENNParams,
32
+ observation_noise: bool,
33
+ y_scale: np.ndarray | None = None,
34
+ ):
35
+ raise NotImplementedError
enn/enn/enn_normal.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from __future__ import annotations
2
-
3
2
  from dataclasses import dataclass
4
3
  from typing import TYPE_CHECKING
5
4
 
enn/enn/enn_params.py CHANGED
@@ -1,23 +1,4 @@
1
- from __future__ import annotations
1
+ from .enn_params_class import ENNParams
2
+ from .posterior_flags import PosteriorFlags
2
3
 
3
- from dataclasses import dataclass
4
-
5
-
6
- @dataclass(frozen=True)
7
- class ENNParams:
8
- k: int
9
- epi_var_scale: float
10
- ale_homoscedastic_scale: float
11
-
12
- def __post_init__(self) -> None:
13
- import numpy as np
14
-
15
- k = int(self.k)
16
- if k <= 0:
17
- raise ValueError(f"k must be > 0, got {k}")
18
- epi_var_scale = float(self.epi_var_scale)
19
- if not np.isfinite(epi_var_scale) or epi_var_scale < 0.0:
20
- raise ValueError(f"epi_var_scale must be >= 0, got {epi_var_scale}")
21
- ale_scale = float(self.ale_homoscedastic_scale)
22
- if not np.isfinite(ale_scale) or ale_scale < 0.0:
23
- raise ValueError(f"ale_homoscedastic_scale must be >= 0, got {ale_scale}")
4
+ __all__ = ["ENNParams", "PosteriorFlags"]
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class ENNParams:
7
+ k_num_neighbors: int
8
+ epistemic_variance_scale: float
9
+ aleatoric_variance_scale: float
10
+
11
+ def __post_init__(self) -> None:
12
+ import numpy as np
13
+
14
+ k = int(self.k_num_neighbors)
15
+ if k <= 0:
16
+ raise ValueError(f"k_num_neighbors must be > 0, got {k}")
17
+ epi_var_scale = float(self.epistemic_variance_scale)
18
+ if not np.isfinite(epi_var_scale) or epi_var_scale < 0.0:
19
+ raise ValueError(
20
+ f"epistemic_variance_scale must be >= 0, got {epi_var_scale}"
21
+ )
22
+ ale_scale = float(self.aleatoric_variance_scale)
23
+ if not np.isfinite(ale_scale) or ale_scale < 0.0:
24
+ raise ValueError(f"aleatoric_variance_scale must be >= 0, got {ale_scale}")
enn/enn/enn_util.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from __future__ import annotations
2
-
3
2
  from typing import TYPE_CHECKING, Any
4
3
 
5
4
  if TYPE_CHECKING:
@@ -18,9 +17,7 @@ def standardize_y(y: np.ndarray | list[float] | Any) -> tuple[float, float]:
18
17
  return center, scale
19
18
 
20
19
 
21
- def calculate_sobol_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
22
- import numpy as np
23
-
20
+ def _validate_sobol_inputs(x, y):
24
21
  if x.ndim != 2:
25
22
  raise ValueError(f"x must be 2D, got shape {x.shape}")
26
23
  n, d = x.shape
@@ -28,16 +25,14 @@ def calculate_sobol_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
28
25
  raise ValueError(f"x must have at least 1 dimension, got {d}")
29
26
  if y.ndim == 2 and y.shape[1] == 1:
30
27
  y = y.reshape(-1)
31
- if y.ndim != 1:
32
- raise ValueError(f"y must be 1D, got shape {y.shape}")
33
- if y.shape[0] != n:
34
- raise ValueError(f"y length {y.shape[0]} != x rows {n}")
35
- if n < 9:
36
- return np.ones(d, dtype=x.dtype)
37
- mu = y.mean()
38
- vy = y.var(ddof=0)
39
- if not np.isfinite(vy) or vy <= 0:
40
- return np.ones(d, dtype=x.dtype)
28
+ if y.ndim != 1 or y.shape[0] != n:
29
+ raise ValueError(f"y shape {y.shape} incompatible with x rows {n}")
30
+ return n, d, y
31
+
32
+
33
+ def _compute_sobol_bins(x, y, n, d):
34
+ import numpy as np
35
+
41
36
  B = 10 if n >= 30 else 3
42
37
  order = np.argsort(x, axis=0)
43
38
  row_idx = np.arange(n).reshape(n, 1).repeat(d, axis=1)
@@ -46,17 +41,59 @@ def calculate_sobol_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
46
41
  idx = (ranks * B) // n
47
42
  oh = np.zeros((n, d, B), dtype=x.dtype)
48
43
  oh[np.arange(n)[:, None], np.arange(d)[None, :], idx] = 1.0
49
- counts = oh.sum(axis=0)
50
- sums = (oh * y.reshape(n, 1, 1)).sum(axis=0)
44
+ counts, sums = oh.sum(axis=0), (oh * y.reshape(n, 1, 1)).sum(axis=0)
51
45
  mu_b = np.zeros_like(sums)
52
46
  mask = counts > 0
53
47
  mu_b[mask] = sums[mask] / counts[mask]
54
- p_b = counts / float(n)
55
- diff = mu_b - mu
56
- S = (p_b * (diff * diff)).sum(axis=1) / vy
57
- var_x = x.var(axis=0, ddof=0)
58
- S = np.where(var_x <= 1e-12, np.zeros_like(S), S)
59
- return S
48
+ return counts / float(n), mu_b
49
+
50
+
51
+ def calculate_sobol_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
52
+ import numpy as np
53
+
54
+ n, d, y = _validate_sobol_inputs(x, y)
55
+ if n < 9:
56
+ return np.ones(d, dtype=x.dtype)
57
+ mu, vy = y.mean(), y.var(ddof=0)
58
+ if not np.isfinite(vy) or vy <= 0:
59
+ return np.ones(d, dtype=x.dtype)
60
+ p_b, mu_b = _compute_sobol_bins(x, y, n, d)
61
+ S = (p_b * (mu_b - mu) ** 2).sum(axis=1) / vy
62
+ return np.where(x.var(axis=0, ddof=0) <= 1e-12, np.zeros_like(S), S)
63
+
64
+
65
+ def pareto_front_2d_maximize(
66
+ a: np.ndarray | Any, b: np.ndarray | Any, idx: np.ndarray | Any | None = None
67
+ ) -> np.ndarray:
68
+ import numpy as np
69
+
70
+ a = np.asarray(a, dtype=float)
71
+ b = np.asarray(b, dtype=float)
72
+ if a.shape != b.shape or a.ndim != 1:
73
+ raise ValueError((a.shape, b.shape))
74
+ if idx is None:
75
+ idx = np.arange(a.size, dtype=int)
76
+ else:
77
+ idx = np.asarray(idx, dtype=int)
78
+ if idx.ndim != 1:
79
+ raise ValueError(idx.shape)
80
+ order = np.lexsort((-b[idx], -a[idx]))
81
+ sorted_idx = idx[order]
82
+ keep: list[int] = []
83
+ best_b = -float("inf")
84
+ last_a = float("nan")
85
+ last_b = float("nan")
86
+ for i in sorted_idx.tolist():
87
+ bi = float(b[i])
88
+ ai = float(a[i])
89
+ if bi > best_b:
90
+ keep.append(i)
91
+ best_b = bi
92
+ last_a = ai
93
+ last_b = bi
94
+ elif bi == best_b and ai == last_a and bi == last_b:
95
+ keep.append(i)
96
+ return np.asarray(keep, dtype=int)
60
97
 
61
98
 
62
99
  def arms_from_pareto_fronts(
@@ -79,32 +116,10 @@ def arms_from_pareto_fronts(
79
116
  raise ValueError(num_arms)
80
117
  if not np.all(np.isfinite(mu)) or not np.all(np.isfinite(se)):
81
118
  raise ValueError("mu and se must be finite")
82
-
83
- def _pareto_front_2d_maximize(
84
- mu_: np.ndarray, se_: np.ndarray, idx: np.ndarray
85
- ) -> np.ndarray:
86
- order = np.lexsort((-se_[idx], -mu_[idx]))
87
- sorted_idx = idx[order]
88
- keep: list[int] = []
89
- best_se = -float("inf")
90
- last_mu = float("nan")
91
- last_se = float("nan")
92
- for i in sorted_idx.tolist():
93
- s = float(se_[i])
94
- m = float(mu_[i])
95
- if s > best_se:
96
- keep.append(i)
97
- best_se = s
98
- last_mu = m
99
- last_se = s
100
- elif s == best_se and m == last_mu and s == last_se:
101
- keep.append(i)
102
- return np.asarray(keep, dtype=int)
103
-
104
119
  i_keep: list[int] = []
105
120
  remaining = np.arange(mu.size, dtype=int)
106
121
  while remaining.size > 0 and len(i_keep) < num_arms:
107
- front_indices = _pareto_front_2d_maximize(mu, se, remaining)
122
+ front_indices = pareto_front_2d_maximize(mu, se, remaining)
108
123
  if front_indices.size == 0:
109
124
  raise RuntimeError("pareto front extraction failed")
110
125
  front_indices = front_indices[np.argsort(-mu[front_indices])]
@@ -119,6 +134,5 @@ def arms_from_pareto_fronts(
119
134
  rng.choice(front_indices, size=remaining_arms, replace=False).tolist()
120
135
  )
121
136
  break
122
-
123
137
  i_keep = np.array(i_keep)
124
138
  return x_cand[i_keep[np.argsort(-mu[i_keep])]]
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class NeighborData:
11
+ dist2s: np.ndarray
12
+ idx: np.ndarray
13
+ y_neighbors: np.ndarray
14
+ k: int
enn/enn/neighbors.py ADDED
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class Neighbors:
11
+ dist2: np.ndarray
12
+ ids: np.ndarray
13
+ y: np.ndarray
14
+ yvar: np.ndarray | None
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class PosteriorFlags:
7
+ exclude_nearest: bool = False
8
+ observation_noise: bool = False
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class WeightedStats:
11
+ w_normalized: np.ndarray
12
+ l2: np.ndarray
13
+ mu: np.ndarray
14
+ se: np.ndarray
@@ -0,0 +1,41 @@
1
+ from .acquisition import (
2
+ HnRAcqOptimizer,
3
+ ParetoAcqOptimizer,
4
+ RandomAcqOptimizer,
5
+ ThompsonAcqOptimizer,
6
+ UCBAcqOptimizer,
7
+ )
8
+ from .incumbent_selector import (
9
+ ChebyshevIncumbentSelector,
10
+ IncumbentSelector,
11
+ NoIncumbentSelector,
12
+ ScalarIncumbentSelector,
13
+ )
14
+ from .protocols import (
15
+ AcquisitionOptimizer,
16
+ PosteriorResult,
17
+ Surrogate,
18
+ SurrogateResult,
19
+ TrustRegion,
20
+ )
21
+ from .surrogates import ENNSurrogate, GPSurrogate, NoSurrogate
22
+
23
+ __all__ = [
24
+ "AcquisitionOptimizer",
25
+ "ChebyshevIncumbentSelector",
26
+ "ENNSurrogate",
27
+ "GPSurrogate",
28
+ "HnRAcqOptimizer",
29
+ "IncumbentSelector",
30
+ "NoIncumbentSelector",
31
+ "NoSurrogate",
32
+ "ParetoAcqOptimizer",
33
+ "PosteriorResult",
34
+ "RandomAcqOptimizer",
35
+ "ScalarIncumbentSelector",
36
+ "Surrogate",
37
+ "SurrogateResult",
38
+ "ThompsonAcqOptimizer",
39
+ "TrustRegion",
40
+ "UCBAcqOptimizer",
41
+ ]
@@ -0,0 +1,13 @@
1
+ from .hnr_acq_optimizer import HnRAcqOptimizer
2
+ from .pareto_acq_optimizer import ParetoAcqOptimizer
3
+ from .random_acq_optimizer import RandomAcqOptimizer
4
+ from .thompson_acq_optimizer import ThompsonAcqOptimizer
5
+ from .ucb_acq_optimizer import UCBAcqOptimizer
6
+
7
+ __all__ = [
8
+ "HnRAcqOptimizer",
9
+ "ParetoAcqOptimizer",
10
+ "RandomAcqOptimizer",
11
+ "ThompsonAcqOptimizer",
12
+ "UCBAcqOptimizer",
13
+ ]
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any, Protocol
3
+ from .surrogate_protocol import Surrogate
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+ from numpy.random import Generator
8
+
9
+
10
+ class AcquisitionOptimizer(Protocol):
11
+ def select(
12
+ self,
13
+ x_cand: np.ndarray,
14
+ num_arms: int,
15
+ surrogate: Surrogate,
16
+ rng: Generator,
17
+ *,
18
+ tr_state: Any | None = None,
19
+ ) -> np.ndarray: ...
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+ from .acquisition import HnRAcqOptimizer, ThompsonAcqOptimizer, UCBAcqOptimizer
4
+
5
+ if TYPE_CHECKING:
6
+ from ..config.optimizer_config import OptimizerConfig
7
+ from .protocols import AcquisitionOptimizer, Surrogate
8
+
9
+
10
+ def build_surrogate(config: OptimizerConfig) -> Surrogate:
11
+ return config.surrogate.build()
12
+
13
+
14
+ def build_acquisition_optimizer(config: OptimizerConfig) -> AcquisitionOptimizer:
15
+ from ..config.acquisition import HnROptimizerConfig
16
+
17
+ base = config.acquisition.build()
18
+ if isinstance(config.acq_optimizer, HnROptimizerConfig):
19
+ if isinstance(base, (ThompsonAcqOptimizer, UCBAcqOptimizer)):
20
+ return HnRAcqOptimizer(base)
21
+ raise ValueError(f"HnR not supported with {type(base).__name__}")
22
+ return base
@@ -0,0 +1,76 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+ from numpy.random import Generator
8
+
9
+
10
+ @dataclass
11
+ class ChebyshevIncumbentSelector:
12
+ num_metrics: int
13
+ noise_aware: bool
14
+ alpha: float
15
+ _weights: np.ndarray | None = None
16
+
17
+ def __post_init__(self) -> None:
18
+ if self.num_metrics < 1:
19
+ raise ValueError(f"num_metrics must be >= 1, got {self.num_metrics}")
20
+
21
+ @property
22
+ def weights(self) -> np.ndarray | None:
23
+ return self._weights
24
+
25
+ def _sample_weights(self, rng: Generator) -> None:
26
+ import numpy as np
27
+
28
+ alpha = np.ones(self.num_metrics, dtype=float)
29
+ self._weights = np.asarray(rng.dirichlet(alpha), dtype=float)
30
+
31
+ def reset(self, rng: Generator) -> None:
32
+ self._sample_weights(rng)
33
+
34
+ def select(
35
+ self,
36
+ y_obs: np.ndarray,
37
+ mu_obs: np.ndarray | None,
38
+ rng: Generator,
39
+ ) -> int:
40
+ import numpy as np
41
+ from ..turbo_utils import argmax_random_tie
42
+
43
+ if self._weights is None:
44
+ self._sample_weights(rng)
45
+ y = np.asarray(y_obs, dtype=float)
46
+ if y.ndim != 2 or y.shape[1] != self.num_metrics:
47
+ raise ValueError(
48
+ f"Expected y with {self.num_metrics} metrics, got {y.shape}"
49
+ )
50
+ if self.noise_aware:
51
+ if mu_obs is None:
52
+ raise ValueError(
53
+ "noise_aware=True requires a surrogate that provides mu. "
54
+ "Either use a GP/ENN surrogate or set noise_aware=False."
55
+ )
56
+ values = np.asarray(mu_obs, dtype=float)
57
+ else:
58
+ values = y
59
+ scores = self._scalarize(values)
60
+ return int(argmax_random_tie(scores, rng=rng))
61
+
62
+ def _scalarize(self, values: np.ndarray) -> np.ndarray:
63
+ import numpy as np
64
+
65
+ if self._weights is None:
66
+ raise RuntimeError("Weights not initialized; call reset() first")
67
+ v_min = values.min(axis=0)
68
+ v_max = values.max(axis=0)
69
+ denom = v_max - v_min
70
+ is_deg = denom <= 0.0
71
+ denom_safe = np.where(is_deg, 1.0, denom)
72
+ z = (values - v_min.reshape(1, -1)) / denom_safe.reshape(1, -1)
73
+ z = np.where(is_deg, 0.5, z)
74
+ z = np.clip(z, 0.0, 1.0)
75
+ t = z * self._weights.reshape(1, -1)
76
+ return np.min(t, axis=1) + self.alpha * np.sum(t, axis=1)
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any
3
+ import numpy as np
4
+ from .posterior_result import PosteriorResult
5
+ from .surrogate_result import SurrogateResult
6
+
7
+ if TYPE_CHECKING:
8
+ from numpy.random import Generator
9
+ from ..config.surrogate import ENNSurrogateConfig
10
+
11
+
12
+ class ENNSurrogate:
13
+ def __init__(self, config: ENNSurrogateConfig) -> None:
14
+ self._config = config
15
+ self._enn: Any | None = None
16
+ self._params: Any | None = None
17
+
18
+ @property
19
+ def lengthscales(self) -> np.ndarray | None:
20
+ return None
21
+
22
+ def fit(
23
+ self,
24
+ x_obs: np.ndarray,
25
+ y_obs: np.ndarray,
26
+ y_var: np.ndarray | None = None,
27
+ *,
28
+ num_steps: int = 0,
29
+ rng: Generator | None = None,
30
+ ) -> SurrogateResult:
31
+ from ..proposal import mk_enn
32
+ from ..config.enums import ENNIndexDriver
33
+
34
+ k = self._config.k if self._config.k is not None else 10
35
+ if (
36
+ self._config.index_driver == ENNIndexDriver.HNSW
37
+ and self._enn is not None
38
+ and len(x_obs) > len(self._enn)
39
+ ):
40
+ n_old = len(self._enn)
41
+ new_x = x_obs[n_old:]
42
+ new_y = y_obs[n_old:]
43
+ new_yvar = y_var[n_old:] if y_var is not None else None
44
+ self._enn.add(new_x, new_y, new_yvar)
45
+ if self._config.num_fit_samples is not None and rng is not None:
46
+ from ...enn.enn_fit import enn_fit
47
+
48
+ self._params = enn_fit(
49
+ self._enn,
50
+ k=k,
51
+ num_fit_candidates=self._config.num_fit_candidates
52
+ if self._config.num_fit_candidates is not None
53
+ else 30,
54
+ num_fit_samples=self._config.num_fit_samples,
55
+ rng=rng,
56
+ params_warm_start=self._params,
57
+ )
58
+ else:
59
+ self._enn, self._params = mk_enn(
60
+ list(x_obs),
61
+ list(y_obs),
62
+ k,
63
+ list(y_var) if y_var is not None else [],
64
+ num_fit_samples=self._config.num_fit_samples,
65
+ num_fit_candidates=self._config.num_fit_candidates,
66
+ scale_x=self._config.scale_x,
67
+ index_driver=self._config.index_driver,
68
+ rng=rng,
69
+ params_warm_start=self._params,
70
+ )
71
+ return SurrogateResult(model=self._enn, lengthscales=None)
72
+
73
+ def get_incumbent_candidate_indices(self, y_obs: np.ndarray) -> np.ndarray:
74
+ y_array = np.asarray(y_obs, dtype=float)
75
+ k = self._config.k
76
+ if k is None:
77
+ num_fit_candidates = (
78
+ self._config.num_fit_candidates
79
+ if self._config.num_fit_candidates is not None
80
+ else 100
81
+ )
82
+ k = min(len(y_array), max(10, 2 * num_fit_candidates))
83
+ if y_array.ndim == 2 and y_array.shape[1] > 1:
84
+ num_top = min(k, len(y_array))
85
+ union_indices: set[int] = set()
86
+ for m in range(y_array.shape[1]):
87
+ top_m = np.argpartition(-y_array[:, m], num_top - 1)[:num_top]
88
+ union_indices.update(top_m.tolist())
89
+ return np.array(sorted(union_indices), dtype=int)
90
+ else:
91
+ y_flat = y_array[:, 0] if y_array.ndim == 2 else y_array
92
+ num_top = min(k, len(y_flat))
93
+ return np.argpartition(-y_flat, num_top - 1)[:num_top]
94
+
95
+ def predict(self, x: np.ndarray) -> PosteriorResult:
96
+ if self._enn is None or self._params is None:
97
+ raise RuntimeError("ENNSurrogate.predict requires a fitted model")
98
+ posterior = self._enn.posterior(x, params=self._params)
99
+ return PosteriorResult(mu=posterior.mu, sigma=posterior.se)
100
+
101
+ def sample(self, x: np.ndarray, num_samples: int, rng: Generator) -> np.ndarray:
102
+ if self._enn is None or self._params is None:
103
+ raise RuntimeError("ENNSurrogate.sample requires a fitted model")
104
+ num_candidates = len(x)
105
+ num_metrics = self._enn.num_outputs
106
+ base_seed = rng.integers(0, 2**31)
107
+ function_seeds = np.arange(base_seed, base_seed + num_samples, dtype=np.int64)
108
+ samples = self._enn.posterior_function_draw(
109
+ x, self._params, function_seeds=function_seeds
110
+ )
111
+ assert samples.shape == (num_samples, num_candidates, num_metrics), (
112
+ f"ENN samples shape mismatch: got {samples.shape}, "
113
+ f"expected ({num_samples}, {num_candidates}, {num_metrics})"
114
+ )
115
+ return samples