ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enn/__init__.py +25 -13
- enn/benchmarks/__init__.py +3 -0
- enn/benchmarks/ackley.py +5 -0
- enn/benchmarks/ackley_class.py +17 -0
- enn/benchmarks/ackley_core.py +12 -0
- enn/benchmarks/double_ackley.py +24 -0
- enn/enn/candidates.py +14 -0
- enn/enn/conditional_posterior_draw_internals.py +15 -0
- enn/enn/draw_internals.py +15 -0
- enn/enn/enn.py +16 -269
- enn/enn/enn_class.py +423 -0
- enn/enn/enn_conditional.py +325 -0
- enn/enn/enn_fit.py +69 -70
- enn/enn/enn_hash.py +79 -0
- enn/enn/enn_index.py +92 -0
- enn/enn/enn_like_protocol.py +35 -0
- enn/enn/enn_normal.py +0 -1
- enn/enn/enn_params.py +3 -22
- enn/enn/enn_params_class.py +24 -0
- enn/enn/enn_util.py +60 -46
- enn/enn/neighbor_data.py +14 -0
- enn/enn/neighbors.py +14 -0
- enn/enn/posterior_flags.py +8 -0
- enn/enn/weighted_stats.py +14 -0
- enn/turbo/components/__init__.py +41 -0
- enn/turbo/components/acquisition.py +13 -0
- enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
- enn/turbo/components/builder.py +22 -0
- enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
- enn/turbo/components/enn_surrogate.py +115 -0
- enn/turbo/components/gp_surrogate.py +144 -0
- enn/turbo/components/hnr_acq_optimizer.py +83 -0
- enn/turbo/components/incumbent_selector.py +11 -0
- enn/turbo/components/incumbent_selector_protocol.py +16 -0
- enn/turbo/components/no_incumbent_selector.py +21 -0
- enn/turbo/components/no_surrogate.py +49 -0
- enn/turbo/components/pareto_acq_optimizer.py +49 -0
- enn/turbo/components/posterior_result.py +12 -0
- enn/turbo/components/protocols.py +13 -0
- enn/turbo/components/random_acq_optimizer.py +21 -0
- enn/turbo/components/scalar_incumbent_selector.py +39 -0
- enn/turbo/components/surrogate_protocol.py +32 -0
- enn/turbo/components/surrogate_result.py +12 -0
- enn/turbo/components/surrogates.py +5 -0
- enn/turbo/components/thompson_acq_optimizer.py +49 -0
- enn/turbo/components/trust_region_protocol.py +24 -0
- enn/turbo/components/ucb_acq_optimizer.py +49 -0
- enn/turbo/config/__init__.py +87 -0
- enn/turbo/config/acq_type.py +8 -0
- enn/turbo/config/acquisition.py +26 -0
- enn/turbo/config/base.py +4 -0
- enn/turbo/config/candidate_gen_config.py +49 -0
- enn/turbo/config/candidate_rv.py +7 -0
- enn/turbo/config/draw_acquisition_config.py +14 -0
- enn/turbo/config/enn_index_driver.py +6 -0
- enn/turbo/config/enn_surrogate_config.py +42 -0
- enn/turbo/config/enums.py +7 -0
- enn/turbo/config/factory.py +118 -0
- enn/turbo/config/gp_surrogate_config.py +14 -0
- enn/turbo/config/hnr_optimizer_config.py +7 -0
- enn/turbo/config/init_config.py +17 -0
- enn/turbo/config/init_strategies/__init__.py +9 -0
- enn/turbo/config/init_strategies/hybrid_init.py +23 -0
- enn/turbo/config/init_strategies/init_strategy.py +19 -0
- enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
- enn/turbo/config/morbo_tr_config.py +82 -0
- enn/turbo/config/nds_optimizer_config.py +7 -0
- enn/turbo/config/no_surrogate_config.py +14 -0
- enn/turbo/config/no_tr_config.py +31 -0
- enn/turbo/config/optimizer_config.py +72 -0
- enn/turbo/config/pareto_acquisition_config.py +14 -0
- enn/turbo/config/raasp_driver.py +6 -0
- enn/turbo/config/raasp_optimizer_config.py +7 -0
- enn/turbo/config/random_acquisition_config.py +14 -0
- enn/turbo/config/rescalarize.py +7 -0
- enn/turbo/config/surrogate.py +12 -0
- enn/turbo/config/trust_region.py +34 -0
- enn/turbo/config/turbo_tr_config.py +71 -0
- enn/turbo/config/ucb_acquisition_config.py +14 -0
- enn/turbo/config/validation.py +45 -0
- enn/turbo/hypervolume.py +30 -0
- enn/turbo/impl_helpers.py +68 -0
- enn/turbo/morbo_trust_region.py +131 -70
- enn/turbo/no_trust_region.py +32 -39
- enn/turbo/optimizer.py +300 -0
- enn/turbo/optimizer_config.py +8 -0
- enn/turbo/proposal.py +36 -38
- enn/turbo/sampling.py +21 -0
- enn/turbo/strategies/__init__.py +9 -0
- enn/turbo/strategies/lhd_only_strategy.py +36 -0
- enn/turbo/strategies/optimization_strategy.py +19 -0
- enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
- enn/turbo/tr_helpers.py +202 -0
- enn/turbo/turbo_gp.py +0 -1
- enn/turbo/turbo_gp_base.py +0 -1
- enn/turbo/turbo_gp_fit.py +187 -0
- enn/turbo/turbo_gp_noisy.py +0 -1
- enn/turbo/turbo_optimizer_utils.py +98 -0
- enn/turbo/turbo_trust_region.py +126 -58
- enn/turbo/turbo_utils.py +98 -161
- enn/turbo/types/__init__.py +7 -0
- enn/turbo/types/appendable_array.py +85 -0
- enn/turbo/types/gp_data_prep.py +13 -0
- enn/turbo/types/gp_fit_result.py +11 -0
- enn/turbo/types/obs_lists.py +10 -0
- enn/turbo/types/prepare_ask_result.py +14 -0
- enn/turbo/types/tell_inputs.py +14 -0
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
- ennbo-0.1.7.dist-info/RECORD +111 -0
- enn/enn/__init__.py +0 -4
- enn/turbo/__init__.py +0 -11
- enn/turbo/base_turbo_impl.py +0 -144
- enn/turbo/lhd_only_impl.py +0 -49
- enn/turbo/turbo_config.py +0 -72
- enn/turbo/turbo_enn_impl.py +0 -201
- enn/turbo/turbo_mode.py +0 -10
- enn/turbo/turbo_mode_impl.py +0 -76
- enn/turbo/turbo_one_impl.py +0 -302
- enn/turbo/turbo_optimizer.py +0 -525
- enn/turbo/turbo_zero_impl.py +0 -29
- ennbo-0.1.2.dist-info/RECORD +0 -29
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
enn/enn/enn_class.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
import numpy as np
|
|
4
|
+
from .draw_internals import DrawInternals
|
|
5
|
+
from .neighbor_data import NeighborData
|
|
6
|
+
from .weighted_stats import WeightedStats
|
|
7
|
+
from enn.turbo.config.enums import ENNIndexDriver
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from .enn_normal import ENNNormal
|
|
11
|
+
from .enn_params import ENNParams, PosteriorFlags
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _compute_conditional_y_scale(
|
|
15
|
+
model: EpistemicNearestNeighbors, y_whatif: np.ndarray
|
|
16
|
+
):
|
|
17
|
+
y_whatif = np.asarray(y_whatif, dtype=float)
|
|
18
|
+
return model._compute_scale(
|
|
19
|
+
np.concatenate([model.train_y, y_whatif], axis=0),
|
|
20
|
+
0.0,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _draw_from_internals(
|
|
25
|
+
model: EpistemicNearestNeighbors,
|
|
26
|
+
internals: DrawInternals,
|
|
27
|
+
*,
|
|
28
|
+
function_seeds: np.ndarray | list[int],
|
|
29
|
+
) -> np.ndarray:
|
|
30
|
+
from .enn_hash import normal_hash_batch_multi_seed_fast
|
|
31
|
+
|
|
32
|
+
function_seeds = np.asarray(function_seeds, dtype=np.int64)
|
|
33
|
+
n, k, m = internals.idx.shape[0], internals.idx.shape[1], model.num_outputs
|
|
34
|
+
if k == 0:
|
|
35
|
+
return np.broadcast_to(internals.mu, (len(function_seeds), n, m)).copy()
|
|
36
|
+
u = normal_hash_batch_multi_seed_fast(function_seeds, internals.idx, m)
|
|
37
|
+
weighted_u = np.sum(internals.w_normalized[np.newaxis, :, :, :] * u, axis=2)
|
|
38
|
+
l2_safe = np.maximum(internals.l2, 1e-12)
|
|
39
|
+
return (
|
|
40
|
+
internals.mu[np.newaxis, :, :]
|
|
41
|
+
+ internals.se[np.newaxis, :, :] * weighted_u / l2_safe[np.newaxis, :, :]
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EpistemicNearestNeighbors:
|
|
46
|
+
_EPS_VAR = 1e-9
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def _validate_inputs(train_x, train_y, train_yvar):
|
|
50
|
+
train_x, train_y = (
|
|
51
|
+
np.asarray(train_x, dtype=float),
|
|
52
|
+
np.asarray(train_y, dtype=float),
|
|
53
|
+
)
|
|
54
|
+
if (
|
|
55
|
+
train_x.ndim != 2
|
|
56
|
+
or train_y.ndim != 2
|
|
57
|
+
or train_x.shape[0] != train_y.shape[0]
|
|
58
|
+
):
|
|
59
|
+
raise ValueError((train_x.shape, train_y.shape))
|
|
60
|
+
if train_yvar is not None:
|
|
61
|
+
train_yvar = np.asarray(train_yvar, dtype=float)
|
|
62
|
+
if train_yvar.ndim != 2 or train_y.shape != train_yvar.shape:
|
|
63
|
+
raise ValueError((train_y.shape, train_yvar.shape))
|
|
64
|
+
return train_x, train_y, train_yvar
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _compute_scale(data, min_val=0.0):
|
|
68
|
+
if len(data) < 2:
|
|
69
|
+
return np.ones((1, data.shape[1]), dtype=float)
|
|
70
|
+
scale = np.std(data, axis=0, keepdims=True).astype(float)
|
|
71
|
+
return np.where(np.isfinite(scale) & (scale > min_val), scale, 1.0)
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
train_x: np.ndarray,
|
|
76
|
+
train_y: np.ndarray,
|
|
77
|
+
train_yvar: np.ndarray | None = None,
|
|
78
|
+
*,
|
|
79
|
+
scale_x: bool = False,
|
|
80
|
+
index_driver: ENNIndexDriver = ENNIndexDriver.FLAT,
|
|
81
|
+
) -> None:
|
|
82
|
+
self._train_x, self._train_y, self._train_yvar = self._validate_inputs(
|
|
83
|
+
train_x, train_y, train_yvar
|
|
84
|
+
)
|
|
85
|
+
self._num_obs, self._num_dim = self._train_x.shape
|
|
86
|
+
_, self._num_metrics = self._train_y.shape
|
|
87
|
+
self._scale_x = bool(scale_x)
|
|
88
|
+
self._x_scale = (
|
|
89
|
+
self._compute_scale(self._train_x, 1e-12)
|
|
90
|
+
if scale_x
|
|
91
|
+
else np.ones((1, self._num_dim), dtype=float)
|
|
92
|
+
)
|
|
93
|
+
self._train_x_scaled = (
|
|
94
|
+
self._train_x / self._x_scale if scale_x else self._train_x
|
|
95
|
+
)
|
|
96
|
+
self._y_scale = self._compute_scale(self._train_y, 0.0)
|
|
97
|
+
from .enn_index import ENNIndex
|
|
98
|
+
|
|
99
|
+
self._enn_index = ENNIndex(
|
|
100
|
+
self._train_x_scaled,
|
|
101
|
+
self._num_dim,
|
|
102
|
+
self._x_scale,
|
|
103
|
+
self._scale_x,
|
|
104
|
+
driver=index_driver,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def add(
|
|
108
|
+
self,
|
|
109
|
+
x: np.ndarray,
|
|
110
|
+
y: np.ndarray,
|
|
111
|
+
yvar: np.ndarray | None = None,
|
|
112
|
+
) -> None:
|
|
113
|
+
x, y, yvar = self._validate_inputs(x, y, yvar)
|
|
114
|
+
self._train_x = np.concatenate([self._train_x, x], axis=0)
|
|
115
|
+
self._train_y = np.concatenate([self._train_y, y], axis=0)
|
|
116
|
+
if yvar is not None:
|
|
117
|
+
if self._train_yvar is None:
|
|
118
|
+
self._train_yvar = yvar
|
|
119
|
+
else:
|
|
120
|
+
self._train_yvar = np.concatenate([self._train_yvar, yvar], axis=0)
|
|
121
|
+
elif self._train_yvar is not None:
|
|
122
|
+
# If we have some yvar but not for the new points, we need to handle it.
|
|
123
|
+
# For now, we'll just use zeros or raise if inconsistent.
|
|
124
|
+
# Following the existing pattern, we assume consistency.
|
|
125
|
+
raise ValueError("yvar must be provided if model has existing yvar")
|
|
126
|
+
|
|
127
|
+
self._num_obs = self._train_x.shape[0]
|
|
128
|
+
self._enn_index.add(x)
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def train_x(self) -> np.ndarray:
|
|
132
|
+
return self._train_x
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def train_y(self) -> np.ndarray:
|
|
136
|
+
return self._train_y
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def train_yvar(self) -> np.ndarray | None:
|
|
140
|
+
return self._train_yvar
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def num_outputs(self) -> int:
|
|
144
|
+
return self._num_metrics
|
|
145
|
+
|
|
146
|
+
def __len__(self) -> int:
|
|
147
|
+
return self._num_obs
|
|
148
|
+
|
|
149
|
+
def posterior(
|
|
150
|
+
self,
|
|
151
|
+
x: np.ndarray,
|
|
152
|
+
*,
|
|
153
|
+
params: ENNParams,
|
|
154
|
+
flags: PosteriorFlags | None = None,
|
|
155
|
+
) -> ENNNormal:
|
|
156
|
+
from .enn_normal import ENNNormal
|
|
157
|
+
from .enn_params import PosteriorFlags
|
|
158
|
+
|
|
159
|
+
if flags is None:
|
|
160
|
+
flags = PosteriorFlags()
|
|
161
|
+
post_batch = self.batch_posterior(x, [params], flags=flags)
|
|
162
|
+
return ENNNormal(post_batch.mu[0], post_batch.se[0])
|
|
163
|
+
|
|
164
|
+
def _empty_posterior_internals(self, batch_size: int) -> DrawInternals:
|
|
165
|
+
m = self._num_metrics
|
|
166
|
+
return DrawInternals(
|
|
167
|
+
idx=np.zeros((batch_size, 0), dtype=int),
|
|
168
|
+
w_normalized=np.zeros((batch_size, 0, m), dtype=float),
|
|
169
|
+
l2=np.ones((batch_size, m), dtype=float),
|
|
170
|
+
mu=np.zeros((batch_size, m), dtype=float),
|
|
171
|
+
se=np.ones((batch_size, m), dtype=float),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _get_neighbor_data(
|
|
175
|
+
self, x: np.ndarray, params: ENNParams, exclude_nearest: bool
|
|
176
|
+
) -> NeighborData | None:
|
|
177
|
+
if exclude_nearest:
|
|
178
|
+
if len(self) <= 1:
|
|
179
|
+
raise ValueError(len(self))
|
|
180
|
+
search_k = int(min(params.k_num_neighbors + 1, len(self)))
|
|
181
|
+
else:
|
|
182
|
+
search_k = int(min(params.k_num_neighbors, len(self)))
|
|
183
|
+
dist2s_full, idx_full = self._enn_index.search(
|
|
184
|
+
x, search_k=search_k, exclude_nearest=exclude_nearest
|
|
185
|
+
)
|
|
186
|
+
available_k = search_k - 1 if exclude_nearest else search_k
|
|
187
|
+
k = min(params.k_num_neighbors, available_k)
|
|
188
|
+
if k > dist2s_full.shape[1]:
|
|
189
|
+
raise RuntimeError(
|
|
190
|
+
f"k={k} exceeds available columns={dist2s_full.shape[1]}"
|
|
191
|
+
)
|
|
192
|
+
if k == 0:
|
|
193
|
+
return None
|
|
194
|
+
return NeighborData(
|
|
195
|
+
dist2s=dist2s_full[:, :k],
|
|
196
|
+
idx=idx_full[:, :k],
|
|
197
|
+
y_neighbors=self._train_y[idx_full[:, :k]],
|
|
198
|
+
k=k,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def _compute_weighted_posterior(
|
|
202
|
+
self,
|
|
203
|
+
dist2s: np.ndarray,
|
|
204
|
+
idx: np.ndarray,
|
|
205
|
+
y_neighbors: np.ndarray,
|
|
206
|
+
params: ENNParams,
|
|
207
|
+
observation_noise: bool,
|
|
208
|
+
) -> DrawInternals:
|
|
209
|
+
yvar_neighbors = None
|
|
210
|
+
if self._train_yvar is not None:
|
|
211
|
+
yvar_neighbors = self._train_yvar[idx]
|
|
212
|
+
stats = self._compute_weighted_stats(
|
|
213
|
+
dist2s,
|
|
214
|
+
y_neighbors,
|
|
215
|
+
yvar_neighbors=yvar_neighbors,
|
|
216
|
+
params=params,
|
|
217
|
+
observation_noise=observation_noise,
|
|
218
|
+
)
|
|
219
|
+
return DrawInternals(
|
|
220
|
+
idx=idx,
|
|
221
|
+
w_normalized=stats.w_normalized,
|
|
222
|
+
l2=stats.l2,
|
|
223
|
+
mu=stats.mu,
|
|
224
|
+
se=stats.se,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
def _compute_weighted_stats(
|
|
228
|
+
self,
|
|
229
|
+
dist2s: np.ndarray,
|
|
230
|
+
y_neighbors: np.ndarray,
|
|
231
|
+
*,
|
|
232
|
+
yvar_neighbors: np.ndarray | None,
|
|
233
|
+
params: ENNParams,
|
|
234
|
+
observation_noise: bool,
|
|
235
|
+
y_scale: np.ndarray | None = None,
|
|
236
|
+
) -> WeightedStats:
|
|
237
|
+
if y_scale is None:
|
|
238
|
+
y_scale = self._y_scale
|
|
239
|
+
dist2s_expanded = dist2s[..., np.newaxis]
|
|
240
|
+
var_epi = params.epistemic_variance_scale * dist2s_expanded
|
|
241
|
+
var_ale = params.aleatoric_variance_scale
|
|
242
|
+
if yvar_neighbors is not None:
|
|
243
|
+
var_ale = var_ale + yvar_neighbors / y_scale**2
|
|
244
|
+
w = 1.0 / (self._EPS_VAR + var_epi + var_ale)
|
|
245
|
+
norm = np.sum(w, axis=1, keepdims=True)
|
|
246
|
+
w_normalized = w / norm
|
|
247
|
+
l2 = np.sqrt(np.sum(w_normalized**2, axis=1))
|
|
248
|
+
mu = np.sum(w_normalized * y_neighbors, axis=1)
|
|
249
|
+
epistemic_var = 1.0 / norm.squeeze(axis=1)
|
|
250
|
+
if observation_noise:
|
|
251
|
+
if np.isscalar(var_ale):
|
|
252
|
+
aleatoric_var = np.full_like(epistemic_var, var_ale)
|
|
253
|
+
else:
|
|
254
|
+
aleatoric_var = np.sum(w_normalized * var_ale, axis=1)
|
|
255
|
+
else:
|
|
256
|
+
aleatoric_var = 0.0
|
|
257
|
+
se = np.sqrt(np.maximum(epistemic_var + aleatoric_var, self._EPS_VAR)) * y_scale
|
|
258
|
+
return WeightedStats(w_normalized=w_normalized, l2=l2, mu=mu, se=se)
|
|
259
|
+
|
|
260
|
+
def conditional_posterior(
|
|
261
|
+
self,
|
|
262
|
+
x_whatif: np.ndarray,
|
|
263
|
+
y_whatif: np.ndarray,
|
|
264
|
+
x: np.ndarray,
|
|
265
|
+
*,
|
|
266
|
+
params: ENNParams,
|
|
267
|
+
flags: PosteriorFlags | None = None,
|
|
268
|
+
) -> ENNNormal:
|
|
269
|
+
from .enn_conditional import compute_conditional_posterior
|
|
270
|
+
from .enn_params import PosteriorFlags
|
|
271
|
+
|
|
272
|
+
if flags is None:
|
|
273
|
+
flags = PosteriorFlags()
|
|
274
|
+
y_scale = _compute_conditional_y_scale(self, y_whatif)
|
|
275
|
+
return compute_conditional_posterior(
|
|
276
|
+
self, x_whatif, y_whatif, x, params=params, flags=flags, y_scale=y_scale
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def _compute_posterior_internals(
|
|
280
|
+
self,
|
|
281
|
+
x: np.ndarray,
|
|
282
|
+
params: ENNParams,
|
|
283
|
+
flags: PosteriorFlags,
|
|
284
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
285
|
+
x = np.asarray(x, dtype=float)
|
|
286
|
+
if x.ndim != 2 or x.shape[1] != self._num_dim:
|
|
287
|
+
raise ValueError(x.shape)
|
|
288
|
+
batch_size = x.shape[0]
|
|
289
|
+
if len(self) == 0:
|
|
290
|
+
return self._empty_posterior_internals(batch_size)
|
|
291
|
+
neighbor_data = self._get_neighbor_data(x, params, flags.exclude_nearest)
|
|
292
|
+
if neighbor_data is None:
|
|
293
|
+
return self._empty_posterior_internals(batch_size)
|
|
294
|
+
return self._compute_weighted_posterior(
|
|
295
|
+
neighbor_data.dist2s,
|
|
296
|
+
neighbor_data.idx,
|
|
297
|
+
neighbor_data.y_neighbors,
|
|
298
|
+
params,
|
|
299
|
+
flags.observation_noise,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def batch_posterior(
|
|
303
|
+
self,
|
|
304
|
+
x: np.ndarray,
|
|
305
|
+
paramss: list[ENNParams],
|
|
306
|
+
*,
|
|
307
|
+
flags: PosteriorFlags | None = None,
|
|
308
|
+
) -> ENNNormal:
|
|
309
|
+
from .enn_normal import ENNNormal
|
|
310
|
+
from .enn_params import PosteriorFlags
|
|
311
|
+
|
|
312
|
+
if flags is None:
|
|
313
|
+
flags = PosteriorFlags()
|
|
314
|
+
x = np.asarray(x, dtype=float)
|
|
315
|
+
if x.ndim != 2 or x.shape[1] != self._num_dim:
|
|
316
|
+
raise ValueError(x.shape)
|
|
317
|
+
if not paramss:
|
|
318
|
+
raise ValueError("paramss must be non-empty")
|
|
319
|
+
batch_size, num_params = x.shape[0], len(paramss)
|
|
320
|
+
mu_all = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
|
|
321
|
+
se_all = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
|
|
322
|
+
k_values = {p.k_num_neighbors for p in paramss}
|
|
323
|
+
if len(k_values) == 1 and len(self) > 0:
|
|
324
|
+
neighbor_data = self._get_neighbor_data(
|
|
325
|
+
x, paramss[0], flags.exclude_nearest
|
|
326
|
+
)
|
|
327
|
+
if neighbor_data is None:
|
|
328
|
+
return ENNNormal(mu_all, se_all)
|
|
329
|
+
for i, params in enumerate(paramss):
|
|
330
|
+
internals = self._compute_weighted_posterior(
|
|
331
|
+
neighbor_data.dist2s,
|
|
332
|
+
neighbor_data.idx,
|
|
333
|
+
neighbor_data.y_neighbors,
|
|
334
|
+
params,
|
|
335
|
+
flags.observation_noise,
|
|
336
|
+
)
|
|
337
|
+
mu_all[i], se_all[i] = internals.mu, internals.se
|
|
338
|
+
else:
|
|
339
|
+
for i, params in enumerate(paramss):
|
|
340
|
+
internals = self._compute_posterior_internals(x, params, flags)
|
|
341
|
+
mu_all[i], se_all[i] = internals.mu, internals.se
|
|
342
|
+
return ENNNormal(mu_all, se_all)
|
|
343
|
+
|
|
344
|
+
def neighbors(self, x: np.ndarray, k: int, *, exclude_nearest: bool = False):
|
|
345
|
+
x = np.asarray(x, dtype=float)
|
|
346
|
+
if x.ndim == 1:
|
|
347
|
+
x = x[np.newaxis, :]
|
|
348
|
+
if x.ndim != 2 or x.shape[0] != 1 or x.shape[1] != self._num_dim:
|
|
349
|
+
raise ValueError(
|
|
350
|
+
f"x must be single point with {self._num_dim} dims, got {x.shape}"
|
|
351
|
+
)
|
|
352
|
+
if k < 0:
|
|
353
|
+
raise ValueError(f"k must be non-negative, got {k}")
|
|
354
|
+
if len(self) == 0:
|
|
355
|
+
return []
|
|
356
|
+
if exclude_nearest and len(self) <= 1:
|
|
357
|
+
raise ValueError(
|
|
358
|
+
f"exclude_nearest=True requires at least 2 observations, got {len(self)}"
|
|
359
|
+
)
|
|
360
|
+
search_k = int(min(k + 1 if exclude_nearest else k, len(self)))
|
|
361
|
+
if search_k == 0:
|
|
362
|
+
return []
|
|
363
|
+
_, idx_full = self._enn_index.search(
|
|
364
|
+
x, search_k=search_k, exclude_nearest=exclude_nearest
|
|
365
|
+
)
|
|
366
|
+
idx = idx_full[0, : min(k, len(idx_full[0]))]
|
|
367
|
+
return [(self._train_x[i].copy(), self._train_y[i].copy()) for i in idx]
|
|
368
|
+
|
|
369
|
+
def posterior_function_draw(
|
|
370
|
+
self,
|
|
371
|
+
x: np.ndarray,
|
|
372
|
+
params: ENNParams,
|
|
373
|
+
*,
|
|
374
|
+
function_seeds: np.ndarray | list[int],
|
|
375
|
+
flags: PosteriorFlags | None = None,
|
|
376
|
+
) -> np.ndarray:
|
|
377
|
+
from .enn_params import PosteriorFlags
|
|
378
|
+
|
|
379
|
+
if flags is None:
|
|
380
|
+
flags = PosteriorFlags()
|
|
381
|
+
internals = self._compute_posterior_internals(x, params, flags)
|
|
382
|
+
return _draw_from_internals(self, internals, function_seeds=function_seeds)
|
|
383
|
+
|
|
384
|
+
def conditional_posterior_function_draw(
|
|
385
|
+
self,
|
|
386
|
+
x_whatif: np.ndarray,
|
|
387
|
+
y_whatif: np.ndarray,
|
|
388
|
+
x: np.ndarray,
|
|
389
|
+
*,
|
|
390
|
+
params: ENNParams,
|
|
391
|
+
function_seeds: np.ndarray | list[int],
|
|
392
|
+
flags: PosteriorFlags | None = None,
|
|
393
|
+
) -> np.ndarray:
|
|
394
|
+
from .enn_conditional import compute_conditional_posterior_draw_internals
|
|
395
|
+
from .enn_params import PosteriorFlags
|
|
396
|
+
|
|
397
|
+
if flags is None:
|
|
398
|
+
flags = PosteriorFlags()
|
|
399
|
+
x_whatif = np.asarray(x_whatif, dtype=float)
|
|
400
|
+
if x_whatif.ndim != 2 or x_whatif.shape[1] != self._num_dim:
|
|
401
|
+
raise ValueError(x_whatif.shape)
|
|
402
|
+
if x_whatif.shape[0] == 0:
|
|
403
|
+
return self.posterior_function_draw(
|
|
404
|
+
x,
|
|
405
|
+
params,
|
|
406
|
+
function_seeds=function_seeds,
|
|
407
|
+
flags=flags,
|
|
408
|
+
)
|
|
409
|
+
y_scale = _compute_conditional_y_scale(self, y_whatif)
|
|
410
|
+
internals = compute_conditional_posterior_draw_internals(
|
|
411
|
+
self, x_whatif, y_whatif, x, params=params, flags=flags, y_scale=y_scale
|
|
412
|
+
)
|
|
413
|
+
return _draw_from_internals(
|
|
414
|
+
self,
|
|
415
|
+
DrawInternals(
|
|
416
|
+
idx=internals.idx,
|
|
417
|
+
w_normalized=internals.w_normalized,
|
|
418
|
+
l2=internals.l2,
|
|
419
|
+
mu=internals.mu,
|
|
420
|
+
se=internals.se,
|
|
421
|
+
),
|
|
422
|
+
function_seeds=function_seeds,
|
|
423
|
+
)
|