ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -269
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +69 -70
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +0 -1
  18. enn/enn/enn_params.py +3 -22
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +60 -46
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +131 -70
  84. enn/turbo/no_trust_region.py +32 -39
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +36 -38
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +0 -1
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +126 -58
  100. enn/turbo/turbo_utils.py +98 -161
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -144
  113. enn/turbo/lhd_only_impl.py +0 -49
  114. enn/turbo/turbo_config.py +0 -72
  115. enn/turbo/turbo_enn_impl.py +0 -201
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -76
  118. enn/turbo/turbo_one_impl.py +0 -302
  119. enn/turbo/turbo_optimizer.py +0 -525
  120. enn/turbo/turbo_zero_impl.py +0 -29
  121. ennbo-0.1.2.dist-info/RECORD +0 -29
  122. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,187 @@
1
+ from __future__ import annotations
2
+ from typing import Any
3
+ from enn.enn.enn_util import standardize_y
4
+ from .types import GPDataPrep, GPFitResult
5
+
6
+
7
+ def _prepare_gp_data(
8
+ x_obs_list: list, y_obs_list: list, yvar_obs_list: list | None
9
+ ) -> GPDataPrep:
10
+ import numpy as np
11
+ import torch
12
+
13
+ x = np.asarray(x_obs_list, dtype=float)
14
+ y = np.asarray(y_obs_list, dtype=float)
15
+ if y.ndim not in (1, 2):
16
+ raise ValueError(y.shape)
17
+ is_multi = y.ndim == 2 and y.shape[1] > 1
18
+ if yvar_obs_list is not None:
19
+ if len(yvar_obs_list) != len(y_obs_list):
20
+ raise ValueError(
21
+ f"yvar_obs_list length {len(yvar_obs_list)} != y_obs_list length {len(y_obs_list)}"
22
+ )
23
+ if is_multi:
24
+ raise ValueError("yvar_obs_list not supported for multi-output GP")
25
+ if is_multi:
26
+ y_mean, y_std = y.mean(axis=0), y.std(axis=0)
27
+ y_std = np.where(y_std < 1e-6, 1.0, y_std)
28
+ z = (y - y_mean) / y_std
29
+ train_y = torch.as_tensor(z.T, dtype=torch.float64)
30
+ else:
31
+ y_mean, y_std = standardize_y(y)
32
+ z = (y - y_mean) / y_std
33
+ train_y = torch.as_tensor(z, dtype=torch.float64)
34
+ return GPDataPrep(
35
+ train_x=torch.as_tensor(x, dtype=torch.float64),
36
+ train_y=train_y,
37
+ is_multi=is_multi,
38
+ y_mean=y_mean,
39
+ y_std=y_std,
40
+ y_raw=y,
41
+ )
42
+
43
+
44
+ def _build_gp_model(
45
+ train_x: Any,
46
+ train_y: Any,
47
+ is_multi: bool,
48
+ num_dim: int,
49
+ *,
50
+ yvar_obs_list: list | None,
51
+ gp_y_std: Any,
52
+ y: Any,
53
+ ) -> tuple[Any, Any]:
54
+ import numpy as np
55
+ import torch
56
+ from gpytorch.constraints import Interval
57
+ from gpytorch.likelihoods import GaussianLikelihood
58
+ from .turbo_gp import TurboGP
59
+ from .turbo_gp_noisy import TurboGPNoisy
60
+
61
+ ls_constr, os_constr = Interval(0.005, 2.0), Interval(0.05, 20.0)
62
+ if yvar_obs_list is not None:
63
+ y_var = np.asarray(yvar_obs_list, dtype=float)
64
+ train_y_var = torch.as_tensor(y_var / (gp_y_std**2), dtype=torch.float64)
65
+ model = TurboGPNoisy(
66
+ train_x=train_x,
67
+ train_y=train_y,
68
+ train_y_var=train_y_var,
69
+ lengthscale_constraint=ls_constr,
70
+ outputscale_constraint=os_constr,
71
+ ard_dims=num_dim,
72
+ ).to(dtype=train_x.dtype)
73
+ return model, model.likelihood
74
+ noise_constr = Interval(5e-4, 0.2)
75
+ num_out = int(y.shape[1]) if is_multi else None
76
+ if is_multi:
77
+ likelihood = GaussianLikelihood(
78
+ noise_constraint=noise_constr, batch_shape=torch.Size([num_out])
79
+ ).to(dtype=train_y.dtype)
80
+ else:
81
+ likelihood = GaussianLikelihood(noise_constraint=noise_constr).to(
82
+ dtype=train_y.dtype
83
+ )
84
+ model = TurboGP(
85
+ train_x=train_x,
86
+ train_y=train_y,
87
+ likelihood=likelihood,
88
+ lengthscale_constraint=ls_constr,
89
+ outputscale_constraint=os_constr,
90
+ ard_dims=num_dim,
91
+ ).to(dtype=train_x.dtype)
92
+ likelihood.noise = (
93
+ torch.full((num_out,), 0.005, dtype=train_y.dtype)
94
+ if is_multi
95
+ else torch.tensor(0.005, dtype=train_y.dtype)
96
+ )
97
+ return model, likelihood
98
+
99
+
100
+ def _init_gp_hyperparams(
101
+ model: Any, is_multi: bool, num_dim: int, num_out: int | None, dtype: Any
102
+ ) -> None:
103
+ import torch
104
+
105
+ if is_multi:
106
+ model.covar_module.outputscale = torch.ones(num_out, dtype=dtype)
107
+ model.covar_module.base_kernel.lengthscale = torch.full(
108
+ (num_out, 1, num_dim), 0.5, dtype=dtype
109
+ )
110
+ else:
111
+ model.covar_module.outputscale = torch.tensor(1.0, dtype=dtype)
112
+ model.covar_module.base_kernel.lengthscale = torch.full(
113
+ (num_dim,), 0.5, dtype=dtype
114
+ )
115
+
116
+
117
+ def _train_gp(
118
+ model: Any, likelihood: Any, train_x: Any, train_y: Any, num_steps: int
119
+ ) -> None:
120
+ import torch
121
+ from gpytorch.mlls import ExactMarginalLogLikelihood
122
+
123
+ model.train()
124
+ likelihood.train()
125
+ mll = ExactMarginalLogLikelihood(likelihood, model)
126
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
127
+ for _ in range(num_steps):
128
+ optimizer.zero_grad()
129
+ loss = -mll(model(train_x), train_y)
130
+ (loss.sum() if loss.ndim != 0 else loss).backward()
131
+ optimizer.step()
132
+ model.eval()
133
+ likelihood.eval()
134
+
135
+
136
+ def fit_gp(
137
+ x_obs_list: list[float] | list[list[float]],
138
+ y_obs_list: list[float] | list[list[float]],
139
+ num_dim: int,
140
+ *,
141
+ yvar_obs_list: list[float] | None = None,
142
+ num_steps: int = 50,
143
+ ) -> GPFitResult:
144
+ import numpy as np
145
+
146
+ x = np.asarray(x_obs_list, dtype=float)
147
+ y = np.asarray(y_obs_list, dtype=float)
148
+ n, is_multi = x.shape[0], y.ndim == 2 and y.shape[1] > 1
149
+ if n == 0:
150
+ return (
151
+ GPFitResult(
152
+ model=None,
153
+ likelihood=None,
154
+ y_mean=np.zeros(y.shape[1]),
155
+ y_std=np.ones(y.shape[1]),
156
+ )
157
+ if is_multi
158
+ else GPFitResult(model=None, likelihood=None, y_mean=0.0, y_std=1.0)
159
+ )
160
+ if n == 1 and is_multi:
161
+ return GPFitResult(
162
+ model=None,
163
+ likelihood=None,
164
+ y_mean=y[0].copy(),
165
+ y_std=np.ones(int(y.shape[1]), dtype=float),
166
+ )
167
+ gp_data = _prepare_gp_data(x_obs_list, y_obs_list, yvar_obs_list)
168
+ model, likelihood = _build_gp_model(
169
+ gp_data.train_x,
170
+ gp_data.train_y,
171
+ gp_data.is_multi,
172
+ num_dim,
173
+ yvar_obs_list=yvar_obs_list,
174
+ gp_y_std=gp_data.y_std,
175
+ y=gp_data.y_raw,
176
+ )
177
+ _init_gp_hyperparams(
178
+ model,
179
+ gp_data.is_multi,
180
+ num_dim,
181
+ int(gp_data.y_raw.shape[1]) if gp_data.is_multi else None,
182
+ gp_data.train_x.dtype,
183
+ )
184
+ _train_gp(model, likelihood, gp_data.train_x, gp_data.train_y, num_steps)
185
+ return GPFitResult(
186
+ model=model, likelihood=likelihood, y_mean=gp_data.y_mean, y_std=gp_data.y_std
187
+ )
@@ -1,5 +1,4 @@
1
1
  from __future__ import annotations
2
-
3
2
  from .turbo_gp_base import TurboGPBase
4
3
 
5
4
 
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+ from .types import ObsLists, TellInputs
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+
8
+
9
+ def sobol_seed_for_state(
10
+ seed_base: int, *, restart_generation: int, n_obs: int, num_arms: int
11
+ ) -> int:
12
+ mask64 = (1 << 64) - 1
13
+ x = int(seed_base) & mask64
14
+ x ^= (int(restart_generation) + 1) * 0xD1342543DE82EF95 & mask64
15
+ x ^= (int(n_obs) + 1) * 0x9E3779B97F4A7C15 & mask64
16
+ x ^= (int(num_arms) + 1) * 0xBF58476D1CE4E5B9 & mask64
17
+ x = (x + 0x9E3779B97F4A7C15) & mask64
18
+ z = x
19
+ z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9 & mask64
20
+ z = (z ^ (z >> 27)) * 0x94D049BB133111EB & mask64
21
+ z = z ^ (z >> 31)
22
+ return int(z & 0xFFFFFFFF)
23
+
24
+
25
+ def reset_timing(opt: object) -> None:
26
+ setattr(opt, "_dt_fit", 0.0)
27
+ setattr(opt, "_dt_gen", 0.0)
28
+ setattr(opt, "_dt_sel", 0.0)
29
+
30
+
31
+ def validate_tell_inputs(
32
+ x: np.ndarray, y: np.ndarray, y_var: np.ndarray | None, num_dim: int
33
+ ) -> TellInputs:
34
+ import numpy as np
35
+
36
+ x = np.asarray(x, dtype=float)
37
+ y = np.asarray(y, dtype=float)
38
+ if x.ndim != 2 or x.shape[1] != num_dim:
39
+ raise ValueError(x.shape)
40
+ if y.ndim == 2:
41
+ if y.shape[0] != x.shape[0]:
42
+ raise ValueError((x.shape, y.shape))
43
+ num_metrics = y.shape[1]
44
+ elif y.ndim == 1:
45
+ if y.shape[0] != x.shape[0]:
46
+ raise ValueError((x.shape, y.shape))
47
+ num_metrics = 1
48
+ else:
49
+ raise ValueError(y.shape)
50
+ if y_var is not None:
51
+ y_var = np.asarray(y_var, dtype=float)
52
+ if y_var.shape != y.shape:
53
+ raise ValueError((y.shape, y_var.shape))
54
+ return TellInputs(x=x, y=y, y_var=y_var, num_metrics=num_metrics)
55
+
56
+
57
+ def trim_trailing_observations(
58
+ x_obs_list: list,
59
+ y_obs_list: list,
60
+ y_tr_list: list,
61
+ yvar_obs_list: list,
62
+ *,
63
+ trailing_obs: int,
64
+ incumbent_indices: np.ndarray,
65
+ ) -> ObsLists:
66
+ import numpy as np
67
+
68
+ num_total = len(x_obs_list)
69
+ if num_total <= trailing_obs:
70
+ return ObsLists(
71
+ x_obs=x_obs_list,
72
+ y_obs=y_obs_list,
73
+ y_tr=y_tr_list,
74
+ yvar_obs=yvar_obs_list,
75
+ )
76
+ start_idx = max(0, num_total - trailing_obs)
77
+ recent_indices = set(range(start_idx, num_total))
78
+ keep_indices = set(incumbent_indices.tolist()) | recent_indices
79
+ if len(keep_indices) > trailing_obs:
80
+ keep_indices = set(incumbent_indices.tolist())
81
+ remaining_slots = trailing_obs - len(keep_indices)
82
+ if remaining_slots > 0:
83
+ recent_non_incumbent = [
84
+ i for i in range(num_total - 1, -1, -1) if i not in keep_indices
85
+ ][:remaining_slots]
86
+ keep_indices.update(recent_non_incumbent)
87
+ indices = np.array(sorted(keep_indices), dtype=int)
88
+ x_array = np.asarray(x_obs_list, dtype=float)
89
+ y_obs_array = np.asarray(y_obs_list, dtype=float)
90
+ y_tr_array = np.asarray(y_tr_list, dtype=float)
91
+ new_x = x_array[indices].tolist()
92
+ new_y_obs = y_obs_array[indices].tolist()
93
+ new_y_tr = y_tr_array[indices].tolist() if y_tr_array.size > 0 else []
94
+ new_yvar = yvar_obs_list
95
+ if len(yvar_obs_list) == len(y_obs_array):
96
+ yvar_array = np.asarray(yvar_obs_list, dtype=float)
97
+ new_yvar = yvar_array[indices].tolist()
98
+ return ObsLists(x_obs=new_x, y_obs=new_y_obs, y_tr=new_y_tr, yvar_obs=new_yvar)
@@ -1,55 +1,107 @@
1
1
  from __future__ import annotations
2
-
3
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
4
3
  from typing import TYPE_CHECKING, Any
4
+ from .tr_helpers import ScalarIncumbentMixin
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  import numpy as np
8
8
  from numpy.random import Generator
9
- from scipy.stats._qmc import QMCEngine
9
+ from .components.incumbent_selector import IncumbentSelector
10
+ from .config.turbo_tr_config import TurboTRConfig
10
11
 
11
12
 
12
13
  @dataclass
13
- class TurboTrustRegion:
14
+ class TurboTrustRegion(ScalarIncumbentMixin):
15
+ config: TurboTRConfig
14
16
  num_dim: int
15
- num_arms: int
16
- length: float = 0.8
17
- length_init: float = 0.8
18
- length_min: float = 0.5**7
19
- length_max: float = 1.6
17
+ length: float = field(init=False)
20
18
  failure_counter: int = 0
21
19
  success_counter: int = 0
22
20
  best_value: float = -float("inf")
23
21
  prev_num_obs: int = 0
22
+ incumbent_selector: IncumbentSelector = field(default=None, repr=False)
23
+ _num_arms: int | None = field(default=None, repr=False)
24
+ _failure_tolerance: int | None = field(default=None, repr=False)
24
25
 
25
26
  def __post_init__(self) -> None:
27
+ from .components.incumbent_selector import ScalarIncumbentSelector
28
+
29
+ self.length = self.config.length_init
30
+ self.success_tolerance = 3
31
+ if self.incumbent_selector is None:
32
+ self.incumbent_selector = ScalarIncumbentSelector(noise_aware=False)
33
+
34
+ @property
35
+ def length_init(self) -> float:
36
+ return self.config.length_init
37
+
38
+ @property
39
+ def length_min(self) -> float:
40
+ return self.config.length_min
41
+
42
+ @property
43
+ def length_max(self) -> float:
44
+ return self.config.length_max
45
+
46
+ @property
47
+ def num_metrics(self) -> int:
48
+ return 1
49
+
50
+ def _ensure_initialized(self, num_arms: int) -> None:
26
51
  import numpy as np
27
52
 
28
- self.failure_tolerance = int(
29
- np.ceil(
30
- max(
31
- 4.0 / float(self.num_arms),
32
- float(self.num_dim) / float(self.num_arms),
53
+ if self._num_arms is None:
54
+ self._num_arms = num_arms
55
+ self._failure_tolerance = int(
56
+ np.ceil(
57
+ max(
58
+ 4.0 / float(num_arms),
59
+ float(self.num_dim) / float(num_arms),
60
+ )
33
61
  )
34
62
  )
35
- )
36
- self.success_tolerance = 3
63
+ elif num_arms != self._num_arms:
64
+ raise ValueError(
65
+ f"num_arms changed from {self._num_arms} to {num_arms}; "
66
+ "must be consistent across ask() calls"
67
+ )
68
+
69
+ @property
70
+ def failure_tolerance(self) -> int:
71
+ if self._failure_tolerance is None:
72
+ raise RuntimeError("failure_tolerance not initialized; call ask() first")
73
+ return self._failure_tolerance
37
74
 
38
- def update(self, values: np.ndarray | Any) -> None:
75
+ def _coerce_y_obs_1d(self, y_obs: np.ndarray | Any) -> np.ndarray:
39
76
  import numpy as np
40
77
 
41
- if values.ndim != 1:
42
- raise ValueError(values.shape)
43
- if values.size == 0:
44
- return
45
- new_values = values[self.prev_num_obs :]
46
- if new_values.size == 0:
47
- return
48
- if not np.isfinite(self.best_value):
49
- self.best_value = float(np.max(new_values))
50
- self.prev_num_obs = values.size
51
- return
52
- improved = np.max(new_values) > self.best_value + 1e-3 * np.abs(self.best_value)
78
+ y_obs = np.asarray(y_obs, dtype=float)
79
+ if y_obs.ndim == 2:
80
+ if y_obs.shape[1] != 1:
81
+ raise ValueError(f"TurboTrustRegion expects m=1, got {y_obs.shape}")
82
+ return y_obs[:, 0]
83
+ if y_obs.ndim != 1:
84
+ raise ValueError(y_obs.shape)
85
+ return y_obs
86
+
87
+ def _coerce_y_incumbent_value(self, y_incumbent: np.ndarray | Any) -> float:
88
+ import numpy as np
89
+
90
+ y_incumbent = np.asarray(y_incumbent, dtype=float).reshape(-1)
91
+ if y_incumbent.shape != (self.num_metrics,):
92
+ raise ValueError(
93
+ f"y_incumbent must have shape ({self.num_metrics},), got {y_incumbent.shape}"
94
+ )
95
+ return float(y_incumbent[0])
96
+
97
+ def _improvement_scale(self, prev_values: np.ndarray) -> float:
98
+ import numpy as np
99
+
100
+ if prev_values.size == 0:
101
+ return 0.0
102
+ return float(np.max(prev_values) - np.min(prev_values))
103
+
104
+ def _update_counters_and_length(self, *, improved: bool) -> None:
53
105
  if improved:
54
106
  self.success_counter += 1
55
107
  self.failure_counter = 0
@@ -59,34 +111,52 @@ class TurboTrustRegion:
59
111
  if self.success_counter >= self.success_tolerance:
60
112
  self.length = min(2.0 * self.length, self.length_max)
61
113
  self.success_counter = 0
62
- elif self.failure_counter >= self.failure_tolerance:
114
+ elif (
115
+ self._failure_tolerance is not None
116
+ and self.failure_counter >= self._failure_tolerance
117
+ ):
63
118
  self.length = 0.5 * self.length
64
119
  self.failure_counter = 0
65
120
 
66
- self.best_value = max(self.best_value, float(np.max(new_values)))
67
- self.prev_num_obs = values.size
121
+ def update(self, y_obs: np.ndarray | Any, y_incumbent: np.ndarray | Any) -> None:
122
+ if self._failure_tolerance is None:
123
+ return
124
+ y_obs = self._coerce_y_obs_1d(y_obs)
125
+ n = int(y_obs.size)
126
+ if n <= 0:
127
+ return
128
+ if n < self.prev_num_obs:
129
+ raise ValueError((n, self.prev_num_obs))
130
+ if n == self.prev_num_obs:
131
+ return
132
+ y_incumbent_value = self._coerce_y_incumbent_value(y_incumbent)
133
+ import math
134
+
135
+ if not math.isfinite(self.best_value):
136
+ self.best_value = y_incumbent_value
137
+ self.prev_num_obs = n
138
+ return
139
+ prev_values = y_obs[: self.prev_num_obs]
140
+ scale = self._improvement_scale(prev_values)
141
+ improved = y_incumbent_value > self.best_value + 1e-3 * scale
142
+ self._update_counters_and_length(improved=improved)
143
+ self.best_value = max(self.best_value, y_incumbent_value)
144
+ self.prev_num_obs = n
68
145
 
69
146
  def needs_restart(self) -> bool:
70
147
  return self.length < self.length_min
71
148
 
72
- def restart(self) -> None:
149
+ def restart(self, rng: Any | None = None) -> None:
73
150
  self.length = self.length_init
74
151
  self.failure_counter = 0
75
152
  self.success_counter = 0
76
153
  self.best_value = -float("inf")
77
154
  self.prev_num_obs = 0
155
+ self._num_arms = None
156
+ self._failure_tolerance = None
78
157
 
79
158
  def validate_request(self, num_arms: int, *, is_fallback: bool = False) -> None:
80
- if is_fallback:
81
- if num_arms > self.num_arms:
82
- raise ValueError(
83
- f"num_arms {num_arms} > configured num_arms {self.num_arms}"
84
- )
85
- else:
86
- if num_arms != self.num_arms:
87
- raise ValueError(
88
- f"num_arms {num_arms} != configured num_arms {self.num_arms}"
89
- )
159
+ self._ensure_initialized(num_arms)
90
160
 
91
161
  def compute_bounds_1d(
92
162
  self, x_center: np.ndarray | Any, lengthscales: np.ndarray | None = None
@@ -96,26 +166,24 @@ class TurboTrustRegion:
96
166
  if lengthscales is None:
97
167
  half_length = 0.5 * self.length
98
168
  else:
169
+ lengthscales = np.asarray(lengthscales, dtype=float).reshape(-1)
170
+ if lengthscales.shape != (self.num_dim,):
171
+ raise ValueError(
172
+ f"lengthscales must have shape ({self.num_dim},), got {lengthscales.shape}"
173
+ )
174
+ if not np.all(np.isfinite(lengthscales)):
175
+ raise ValueError("lengthscales must be finite")
99
176
  half_length = lengthscales * self.length / 2.0
100
177
  lb = np.clip(x_center - half_length, 0.0, 1.0)
101
178
  ub = np.clip(x_center + half_length, 0.0, 1.0)
102
179
  return lb, ub
103
180
 
104
- def generate_candidates(
181
+ def get_incumbent_indices(
105
182
  self,
106
- x_center: np.ndarray,
107
- lengthscales: np.ndarray | None,
108
- num_candidates: int,
183
+ y: np.ndarray | Any,
109
184
  rng: Generator,
110
- sobol_engine: QMCEngine,
185
+ mu: np.ndarray | None = None,
111
186
  ) -> np.ndarray:
112
- from .turbo_utils import generate_trust_region_candidates
113
-
114
- return generate_trust_region_candidates(
115
- x_center,
116
- lengthscales,
117
- num_candidates,
118
- compute_bounds_1d=self.compute_bounds_1d,
119
- rng=rng,
120
- sobol_engine=sobol_engine,
121
- )
187
+ import numpy as np
188
+
189
+ return np.array([self.get_incumbent_index(y, rng, mu=mu)])