ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enn/__init__.py +25 -13
- enn/benchmarks/__init__.py +3 -0
- enn/benchmarks/ackley.py +5 -0
- enn/benchmarks/ackley_class.py +17 -0
- enn/benchmarks/ackley_core.py +12 -0
- enn/benchmarks/double_ackley.py +24 -0
- enn/enn/candidates.py +14 -0
- enn/enn/conditional_posterior_draw_internals.py +15 -0
- enn/enn/draw_internals.py +15 -0
- enn/enn/enn.py +16 -269
- enn/enn/enn_class.py +423 -0
- enn/enn/enn_conditional.py +325 -0
- enn/enn/enn_fit.py +69 -70
- enn/enn/enn_hash.py +79 -0
- enn/enn/enn_index.py +92 -0
- enn/enn/enn_like_protocol.py +35 -0
- enn/enn/enn_normal.py +0 -1
- enn/enn/enn_params.py +3 -22
- enn/enn/enn_params_class.py +24 -0
- enn/enn/enn_util.py +60 -46
- enn/enn/neighbor_data.py +14 -0
- enn/enn/neighbors.py +14 -0
- enn/enn/posterior_flags.py +8 -0
- enn/enn/weighted_stats.py +14 -0
- enn/turbo/components/__init__.py +41 -0
- enn/turbo/components/acquisition.py +13 -0
- enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
- enn/turbo/components/builder.py +22 -0
- enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
- enn/turbo/components/enn_surrogate.py +115 -0
- enn/turbo/components/gp_surrogate.py +144 -0
- enn/turbo/components/hnr_acq_optimizer.py +83 -0
- enn/turbo/components/incumbent_selector.py +11 -0
- enn/turbo/components/incumbent_selector_protocol.py +16 -0
- enn/turbo/components/no_incumbent_selector.py +21 -0
- enn/turbo/components/no_surrogate.py +49 -0
- enn/turbo/components/pareto_acq_optimizer.py +49 -0
- enn/turbo/components/posterior_result.py +12 -0
- enn/turbo/components/protocols.py +13 -0
- enn/turbo/components/random_acq_optimizer.py +21 -0
- enn/turbo/components/scalar_incumbent_selector.py +39 -0
- enn/turbo/components/surrogate_protocol.py +32 -0
- enn/turbo/components/surrogate_result.py +12 -0
- enn/turbo/components/surrogates.py +5 -0
- enn/turbo/components/thompson_acq_optimizer.py +49 -0
- enn/turbo/components/trust_region_protocol.py +24 -0
- enn/turbo/components/ucb_acq_optimizer.py +49 -0
- enn/turbo/config/__init__.py +87 -0
- enn/turbo/config/acq_type.py +8 -0
- enn/turbo/config/acquisition.py +26 -0
- enn/turbo/config/base.py +4 -0
- enn/turbo/config/candidate_gen_config.py +49 -0
- enn/turbo/config/candidate_rv.py +7 -0
- enn/turbo/config/draw_acquisition_config.py +14 -0
- enn/turbo/config/enn_index_driver.py +6 -0
- enn/turbo/config/enn_surrogate_config.py +42 -0
- enn/turbo/config/enums.py +7 -0
- enn/turbo/config/factory.py +118 -0
- enn/turbo/config/gp_surrogate_config.py +14 -0
- enn/turbo/config/hnr_optimizer_config.py +7 -0
- enn/turbo/config/init_config.py +17 -0
- enn/turbo/config/init_strategies/__init__.py +9 -0
- enn/turbo/config/init_strategies/hybrid_init.py +23 -0
- enn/turbo/config/init_strategies/init_strategy.py +19 -0
- enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
- enn/turbo/config/morbo_tr_config.py +82 -0
- enn/turbo/config/nds_optimizer_config.py +7 -0
- enn/turbo/config/no_surrogate_config.py +14 -0
- enn/turbo/config/no_tr_config.py +31 -0
- enn/turbo/config/optimizer_config.py +72 -0
- enn/turbo/config/pareto_acquisition_config.py +14 -0
- enn/turbo/config/raasp_driver.py +6 -0
- enn/turbo/config/raasp_optimizer_config.py +7 -0
- enn/turbo/config/random_acquisition_config.py +14 -0
- enn/turbo/config/rescalarize.py +7 -0
- enn/turbo/config/surrogate.py +12 -0
- enn/turbo/config/trust_region.py +34 -0
- enn/turbo/config/turbo_tr_config.py +71 -0
- enn/turbo/config/ucb_acquisition_config.py +14 -0
- enn/turbo/config/validation.py +45 -0
- enn/turbo/hypervolume.py +30 -0
- enn/turbo/impl_helpers.py +68 -0
- enn/turbo/morbo_trust_region.py +131 -70
- enn/turbo/no_trust_region.py +32 -39
- enn/turbo/optimizer.py +300 -0
- enn/turbo/optimizer_config.py +8 -0
- enn/turbo/proposal.py +36 -38
- enn/turbo/sampling.py +21 -0
- enn/turbo/strategies/__init__.py +9 -0
- enn/turbo/strategies/lhd_only_strategy.py +36 -0
- enn/turbo/strategies/optimization_strategy.py +19 -0
- enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
- enn/turbo/tr_helpers.py +202 -0
- enn/turbo/turbo_gp.py +0 -1
- enn/turbo/turbo_gp_base.py +0 -1
- enn/turbo/turbo_gp_fit.py +187 -0
- enn/turbo/turbo_gp_noisy.py +0 -1
- enn/turbo/turbo_optimizer_utils.py +98 -0
- enn/turbo/turbo_trust_region.py +126 -58
- enn/turbo/turbo_utils.py +98 -161
- enn/turbo/types/__init__.py +7 -0
- enn/turbo/types/appendable_array.py +85 -0
- enn/turbo/types/gp_data_prep.py +13 -0
- enn/turbo/types/gp_fit_result.py +11 -0
- enn/turbo/types/obs_lists.py +10 -0
- enn/turbo/types/prepare_ask_result.py +14 -0
- enn/turbo/types/tell_inputs.py +14 -0
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
- ennbo-0.1.7.dist-info/RECORD +111 -0
- enn/enn/__init__.py +0 -4
- enn/turbo/__init__.py +0 -11
- enn/turbo/base_turbo_impl.py +0 -144
- enn/turbo/lhd_only_impl.py +0 -49
- enn/turbo/turbo_config.py +0 -72
- enn/turbo/turbo_enn_impl.py +0 -201
- enn/turbo/turbo_mode.py +0 -10
- enn/turbo/turbo_mode_impl.py +0 -76
- enn/turbo/turbo_one_impl.py +0 -302
- enn/turbo/turbo_optimizer.py +0 -525
- enn/turbo/turbo_zero_impl.py +0 -29
- ennbo-0.1.2.dist-info/RECORD +0 -29
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
- {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
from .candidates import Candidates
|
|
4
|
+
from .conditional_posterior_draw_internals import ConditionalPosteriorDrawInternals
|
|
5
|
+
from .enn_like_protocol import ENNLike
|
|
6
|
+
from .enn_params import ENNParams, PosteriorFlags
|
|
7
|
+
from .neighbors import Neighbors
|
|
8
|
+
|
|
9
|
+
_ENNLike = ENNLike
|
|
10
|
+
_Candidates = Candidates
|
|
11
|
+
_Neighbors = Neighbors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _pairwise_sq_l2(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
15
|
+
a = np.asarray(a, dtype=float)
|
|
16
|
+
b = np.asarray(b, dtype=float)
|
|
17
|
+
aa = np.sum(a * a, axis=1, keepdims=True)
|
|
18
|
+
bb = np.sum(b * b, axis=1, keepdims=True).T
|
|
19
|
+
dist2 = aa + bb - 2.0 * (a @ b.T)
|
|
20
|
+
return np.maximum(dist2, 0.0)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _validate_x(enn: ENNLike, x: np.ndarray) -> np.ndarray:
|
|
24
|
+
x = np.asarray(x, dtype=float)
|
|
25
|
+
if x.ndim != 2 or x.shape[1] != enn._num_dim:
|
|
26
|
+
raise ValueError(x.shape)
|
|
27
|
+
return x
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _validate_whatif(
|
|
31
|
+
enn: ENNLike, x_whatif: np.ndarray, y_whatif: np.ndarray
|
|
32
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
33
|
+
x_whatif = np.asarray(x_whatif, dtype=float)
|
|
34
|
+
y_whatif = np.asarray(y_whatif, dtype=float)
|
|
35
|
+
if x_whatif.ndim != 2 or x_whatif.shape[1] != enn._num_dim:
|
|
36
|
+
raise ValueError(x_whatif.shape)
|
|
37
|
+
if y_whatif.ndim != 2 or y_whatif.shape[1] != enn._num_metrics:
|
|
38
|
+
raise ValueError(y_whatif.shape)
|
|
39
|
+
if x_whatif.shape[0] != y_whatif.shape[0]:
|
|
40
|
+
raise ValueError((x_whatif.shape, y_whatif.shape))
|
|
41
|
+
return x_whatif, y_whatif
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _scale_x_if_needed(enn: ENNLike, x: np.ndarray) -> np.ndarray:
|
|
45
|
+
return x / enn._x_scale if enn._scale_x else x
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _compute_total_n(enn: ENNLike, num_whatif: int, flags: PosteriorFlags) -> int:
|
|
49
|
+
total_n = len(enn) + int(num_whatif)
|
|
50
|
+
if flags.exclude_nearest and total_n <= 1:
|
|
51
|
+
raise ValueError(total_n)
|
|
52
|
+
return total_n
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _compute_search_k(params: ENNParams, flags: PosteriorFlags, total_n: int) -> int:
|
|
56
|
+
return int(
|
|
57
|
+
min(params.k_num_neighbors + (1 if flags.exclude_nearest else 0), total_n)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _get_train_candidates(enn: ENNLike, x: np.ndarray, *, search_k: int) -> Candidates:
|
|
62
|
+
batch_size = x.shape[0]
|
|
63
|
+
if len(enn) == 0 or search_k == 0:
|
|
64
|
+
return Candidates(
|
|
65
|
+
dist2=np.zeros((batch_size, 0), dtype=float),
|
|
66
|
+
ids=np.zeros((batch_size, 0), dtype=int),
|
|
67
|
+
y=np.zeros((batch_size, 0, enn._num_metrics), dtype=float),
|
|
68
|
+
yvar=(
|
|
69
|
+
np.zeros((batch_size, 0, enn._num_metrics), dtype=float)
|
|
70
|
+
if enn._train_yvar is not None
|
|
71
|
+
else None
|
|
72
|
+
),
|
|
73
|
+
)
|
|
74
|
+
train_search_k = int(min(search_k, len(enn)))
|
|
75
|
+
dist2_train, idx_train = enn._enn_index.search(
|
|
76
|
+
x, search_k=train_search_k, exclude_nearest=False
|
|
77
|
+
)
|
|
78
|
+
y_train = enn._train_y[idx_train]
|
|
79
|
+
yvar_train = enn._train_yvar[idx_train] if enn._train_yvar is not None else None
|
|
80
|
+
return Candidates(dist2=dist2_train, ids=idx_train, y=y_train, yvar=yvar_train)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_whatif_candidates(
|
|
84
|
+
enn: ENNLike,
|
|
85
|
+
x: np.ndarray,
|
|
86
|
+
x_whatif: np.ndarray,
|
|
87
|
+
y_whatif: np.ndarray,
|
|
88
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
89
|
+
x_scaled = _scale_x_if_needed(enn, x)
|
|
90
|
+
x_whatif_scaled = _scale_x_if_needed(enn, x_whatif)
|
|
91
|
+
dist2_whatif = _pairwise_sq_l2(x_scaled, x_whatif_scaled)
|
|
92
|
+
batch_size = x.shape[0]
|
|
93
|
+
y_whatif_batched = np.broadcast_to(
|
|
94
|
+
y_whatif[np.newaxis, :, :], (batch_size, y_whatif.shape[0], y_whatif.shape[1])
|
|
95
|
+
)
|
|
96
|
+
return dist2_whatif, y_whatif_batched
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
_WhatifCandidateTuple = tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _merge_candidates(
|
|
103
|
+
enn: ENNLike,
|
|
104
|
+
*,
|
|
105
|
+
train: Candidates,
|
|
106
|
+
whatif: _WhatifCandidateTuple,
|
|
107
|
+
) -> Candidates:
|
|
108
|
+
dist2_whatif, ids_whatif, y_whatif_batched = whatif
|
|
109
|
+
dist2_all = np.concatenate([train.dist2, dist2_whatif], axis=1)
|
|
110
|
+
ids_all = np.concatenate([train.ids, ids_whatif], axis=1)
|
|
111
|
+
y_all = np.concatenate([train.y, y_whatif_batched], axis=1)
|
|
112
|
+
if train.yvar is None:
|
|
113
|
+
return Candidates(dist2=dist2_all, ids=ids_all, y=y_all, yvar=None)
|
|
114
|
+
batch_size = dist2_all.shape[0]
|
|
115
|
+
num_whatif = dist2_whatif.shape[1]
|
|
116
|
+
yvar_whatif = np.zeros((batch_size, num_whatif, enn._num_metrics))
|
|
117
|
+
yvar_all = np.concatenate([train.yvar, yvar_whatif], axis=1)
|
|
118
|
+
return Candidates(dist2=dist2_all, ids=ids_all, y=y_all, yvar=yvar_all)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _select_sorted_candidates(dist2_all: np.ndarray, *, search_k: int) -> np.ndarray:
|
|
122
|
+
batch_size, num_candidates = dist2_all.shape
|
|
123
|
+
if search_k < num_candidates:
|
|
124
|
+
sel = np.argpartition(dist2_all, kth=search_k - 1, axis=1)[:, :search_k]
|
|
125
|
+
else:
|
|
126
|
+
sel = np.broadcast_to(np.arange(num_candidates), (batch_size, num_candidates))
|
|
127
|
+
sel_dist2 = np.take_along_axis(dist2_all, sel, axis=1)
|
|
128
|
+
sel_order = np.argsort(sel_dist2, axis=1)
|
|
129
|
+
return np.take_along_axis(sel, sel_order, axis=1)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _take_along_axis_3d(a: np.ndarray, idx_2d: np.ndarray) -> np.ndarray:
|
|
133
|
+
return np.take_along_axis(a, idx_2d[:, :, np.newaxis], axis=1)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _make_empty_normal(enn: ENNLike, batch_size: int):
|
|
137
|
+
from .enn_normal import ENNNormal
|
|
138
|
+
|
|
139
|
+
internals = enn._empty_posterior_internals(batch_size)
|
|
140
|
+
return ENNNormal(internals.mu, internals.se)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _build_candidates(
|
|
144
|
+
enn: ENNLike,
|
|
145
|
+
x: np.ndarray,
|
|
146
|
+
x_whatif: np.ndarray,
|
|
147
|
+
y_whatif: np.ndarray,
|
|
148
|
+
*,
|
|
149
|
+
search_k: int,
|
|
150
|
+
) -> Candidates:
|
|
151
|
+
train_candidates = _get_train_candidates(enn, x, search_k=search_k)
|
|
152
|
+
dist2_whatif, y_whatif_batched = _get_whatif_candidates(enn, x, x_whatif, y_whatif)
|
|
153
|
+
n_train = int(len(enn))
|
|
154
|
+
ids_whatif = np.broadcast_to(
|
|
155
|
+
n_train + np.arange(x_whatif.shape[0], dtype=int), dist2_whatif.shape
|
|
156
|
+
)
|
|
157
|
+
return _merge_candidates(
|
|
158
|
+
enn,
|
|
159
|
+
train=train_candidates,
|
|
160
|
+
whatif=(dist2_whatif, ids_whatif, y_whatif_batched),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _select_effective_neighbors(
|
|
165
|
+
candidates: Candidates,
|
|
166
|
+
*,
|
|
167
|
+
search_k: int,
|
|
168
|
+
k: int,
|
|
169
|
+
exclude_nearest: bool,
|
|
170
|
+
) -> Neighbors | None:
|
|
171
|
+
sel = _select_sorted_candidates(candidates.dist2, search_k=search_k)
|
|
172
|
+
if exclude_nearest:
|
|
173
|
+
sel = sel[:, 1:]
|
|
174
|
+
sel = sel[:, : int(min(k, sel.shape[1]))]
|
|
175
|
+
if sel.shape[1] == 0:
|
|
176
|
+
return None
|
|
177
|
+
dist2s = np.take_along_axis(candidates.dist2, sel, axis=1)
|
|
178
|
+
ids = np.take_along_axis(candidates.ids, sel, axis=1)
|
|
179
|
+
y_neighbors = _take_along_axis_3d(candidates.y, sel)
|
|
180
|
+
yvar_neighbors = (
|
|
181
|
+
_take_along_axis_3d(candidates.yvar, sel)
|
|
182
|
+
if candidates.yvar is not None
|
|
183
|
+
else None
|
|
184
|
+
)
|
|
185
|
+
return Neighbors(dist2=dist2s, ids=ids, y=y_neighbors, yvar=yvar_neighbors)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _compute_mu_se(
|
|
189
|
+
enn: ENNLike,
|
|
190
|
+
neighbors: Neighbors,
|
|
191
|
+
*,
|
|
192
|
+
params: ENNParams,
|
|
193
|
+
flags: PosteriorFlags,
|
|
194
|
+
y_scale: np.ndarray,
|
|
195
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
196
|
+
stats = enn._compute_weighted_stats(
|
|
197
|
+
neighbors.dist2,
|
|
198
|
+
neighbors.y,
|
|
199
|
+
yvar_neighbors=neighbors.yvar,
|
|
200
|
+
params=params,
|
|
201
|
+
observation_noise=flags.observation_noise,
|
|
202
|
+
y_scale=y_scale,
|
|
203
|
+
)
|
|
204
|
+
return stats.mu, stats.se
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _compute_draw_internals(
|
|
208
|
+
enn: ENNLike,
|
|
209
|
+
neighbors: Neighbors,
|
|
210
|
+
*,
|
|
211
|
+
params: ENNParams,
|
|
212
|
+
flags: PosteriorFlags,
|
|
213
|
+
y_scale: np.ndarray,
|
|
214
|
+
) -> ConditionalPosteriorDrawInternals:
|
|
215
|
+
stats = enn._compute_weighted_stats(
|
|
216
|
+
neighbors.dist2,
|
|
217
|
+
neighbors.y,
|
|
218
|
+
yvar_neighbors=neighbors.yvar,
|
|
219
|
+
params=params,
|
|
220
|
+
observation_noise=flags.observation_noise,
|
|
221
|
+
y_scale=y_scale,
|
|
222
|
+
)
|
|
223
|
+
return ConditionalPosteriorDrawInternals(
|
|
224
|
+
idx=neighbors.ids.astype(int, copy=False),
|
|
225
|
+
w_normalized=stats.w_normalized,
|
|
226
|
+
l2=stats.l2,
|
|
227
|
+
mu=stats.mu,
|
|
228
|
+
se=stats.se,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _conditional_neighbors_nonempty_whatif(
|
|
233
|
+
enn: ENNLike,
|
|
234
|
+
x_whatif: np.ndarray,
|
|
235
|
+
y_whatif: np.ndarray,
|
|
236
|
+
x: np.ndarray,
|
|
237
|
+
*,
|
|
238
|
+
params: ENNParams,
|
|
239
|
+
flags: PosteriorFlags,
|
|
240
|
+
) -> tuple[int, int, Neighbors | None]:
|
|
241
|
+
batch_size = x.shape[0]
|
|
242
|
+
search_k = _compute_search_k(
|
|
243
|
+
params, flags, _compute_total_n(enn, x_whatif.shape[0], flags)
|
|
244
|
+
)
|
|
245
|
+
if search_k == 0:
|
|
246
|
+
return batch_size, search_k, None
|
|
247
|
+
candidates = _build_candidates(enn, x, x_whatif, y_whatif, search_k=search_k)
|
|
248
|
+
neighbors = _select_effective_neighbors(
|
|
249
|
+
candidates,
|
|
250
|
+
search_k=search_k,
|
|
251
|
+
k=params.k_num_neighbors,
|
|
252
|
+
exclude_nearest=flags.exclude_nearest,
|
|
253
|
+
)
|
|
254
|
+
return batch_size, search_k, neighbors
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _compute_conditional_posterior_impl(
|
|
258
|
+
enn: ENNLike,
|
|
259
|
+
x_whatif: np.ndarray,
|
|
260
|
+
y_whatif: np.ndarray,
|
|
261
|
+
x: np.ndarray,
|
|
262
|
+
*,
|
|
263
|
+
params: ENNParams,
|
|
264
|
+
flags: PosteriorFlags,
|
|
265
|
+
y_scale: np.ndarray,
|
|
266
|
+
):
|
|
267
|
+
from .enn_normal import ENNNormal
|
|
268
|
+
|
|
269
|
+
x = _validate_x(enn, x)
|
|
270
|
+
x_whatif, y_whatif = _validate_whatif(enn, x_whatif, y_whatif)
|
|
271
|
+
if x_whatif.shape[0] == 0:
|
|
272
|
+
return enn.posterior(x, params=params, flags=flags)
|
|
273
|
+
batch_size, search_k, neighbors = _conditional_neighbors_nonempty_whatif(
|
|
274
|
+
enn, x_whatif, y_whatif, x, params=params, flags=flags
|
|
275
|
+
)
|
|
276
|
+
if search_k == 0 or neighbors is None:
|
|
277
|
+
return _make_empty_normal(enn, batch_size)
|
|
278
|
+
mu, se = _compute_mu_se(enn, neighbors, params=params, flags=flags, y_scale=y_scale)
|
|
279
|
+
return ENNNormal(mu, se)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def compute_conditional_posterior(
|
|
283
|
+
enn: ENNLike,
|
|
284
|
+
x_whatif: np.ndarray,
|
|
285
|
+
y_whatif: np.ndarray,
|
|
286
|
+
x: np.ndarray,
|
|
287
|
+
*,
|
|
288
|
+
params: ENNParams,
|
|
289
|
+
flags: PosteriorFlags,
|
|
290
|
+
y_scale: np.ndarray,
|
|
291
|
+
):
|
|
292
|
+
return _compute_conditional_posterior_impl(
|
|
293
|
+
enn, x_whatif, y_whatif, x, params=params, flags=flags, y_scale=y_scale
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def compute_conditional_posterior_draw_internals(
|
|
298
|
+
enn: ENNLike,
|
|
299
|
+
x_whatif: np.ndarray,
|
|
300
|
+
y_whatif: np.ndarray,
|
|
301
|
+
x: np.ndarray,
|
|
302
|
+
*,
|
|
303
|
+
params: ENNParams,
|
|
304
|
+
flags: PosteriorFlags,
|
|
305
|
+
y_scale: np.ndarray,
|
|
306
|
+
) -> ConditionalPosteriorDrawInternals:
|
|
307
|
+
x = _validate_x(enn, x)
|
|
308
|
+
x_whatif, y_whatif = _validate_whatif(enn, x_whatif, y_whatif)
|
|
309
|
+
if x_whatif.shape[0] == 0:
|
|
310
|
+
raise ValueError("x_whatif must be non-empty for conditional draw internals")
|
|
311
|
+
batch_size, search_k, neighbors = _conditional_neighbors_nonempty_whatif(
|
|
312
|
+
enn, x_whatif, y_whatif, x, params=params, flags=flags
|
|
313
|
+
)
|
|
314
|
+
if search_k == 0 or neighbors is None:
|
|
315
|
+
empty_internals = enn._empty_posterior_internals(batch_size)
|
|
316
|
+
return ConditionalPosteriorDrawInternals(
|
|
317
|
+
idx=empty_internals.idx,
|
|
318
|
+
w_normalized=empty_internals.w_normalized,
|
|
319
|
+
l2=empty_internals.l2,
|
|
320
|
+
mu=empty_internals.mu,
|
|
321
|
+
se=empty_internals.se,
|
|
322
|
+
)
|
|
323
|
+
return _compute_draw_internals(
|
|
324
|
+
enn, neighbors, params=params, flags=flags, y_scale=y_scale
|
|
325
|
+
)
|
enn/enn/enn_fit.py
CHANGED
|
@@ -1,24 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
from typing import TYPE_CHECKING, Any
|
|
4
3
|
|
|
5
4
|
if TYPE_CHECKING:
|
|
6
5
|
import numpy as np
|
|
7
6
|
from numpy.random import Generator
|
|
8
|
-
|
|
9
|
-
from .enn import EpistemicNearestNeighbors
|
|
7
|
+
from .enn_class import EpistemicNearestNeighbors
|
|
10
8
|
from .enn_params import ENNParams
|
|
11
9
|
|
|
12
10
|
|
|
13
|
-
def
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
y: np.ndarray | Any,
|
|
17
|
-
*,
|
|
18
|
-
paramss: list[ENNParams] | list[Any],
|
|
19
|
-
P: int = 10,
|
|
20
|
-
rng: Generator | Any,
|
|
21
|
-
) -> list[float]:
|
|
11
|
+
def _validate_subsample_inputs(
|
|
12
|
+
x: np.ndarray | Any, y: np.ndarray | Any, P: int, paramss: list
|
|
13
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
22
14
|
import numpy as np
|
|
23
15
|
|
|
24
16
|
x_array = np.asarray(x, dtype=float)
|
|
@@ -35,64 +27,67 @@ def subsample_loglik(
|
|
|
35
27
|
raise ValueError(P)
|
|
36
28
|
if len(paramss) == 0:
|
|
37
29
|
raise ValueError("paramss must be non-empty")
|
|
30
|
+
return x_array, y_array
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _compute_single_loglik(
|
|
34
|
+
y_scaled: np.ndarray, mu_i: np.ndarray, se_i: np.ndarray
|
|
35
|
+
) -> float:
|
|
36
|
+
import numpy as np
|
|
37
|
+
|
|
38
|
+
if not np.isfinite(mu_i).all() or not np.isfinite(se_i).all():
|
|
39
|
+
return 0.0
|
|
40
|
+
if np.any(se_i <= 0.0):
|
|
41
|
+
return 0.0
|
|
42
|
+
var_scaled = se_i**2
|
|
43
|
+
loglik = -0.5 * np.sum(
|
|
44
|
+
np.log(2.0 * np.pi * var_scaled) + (y_scaled - mu_i) ** 2 / var_scaled
|
|
45
|
+
)
|
|
46
|
+
return float(loglik) if np.isfinite(loglik) else 0.0
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def subsample_loglik(
|
|
50
|
+
model: EpistemicNearestNeighbors | Any,
|
|
51
|
+
x: np.ndarray | Any,
|
|
52
|
+
y: np.ndarray | Any,
|
|
53
|
+
*,
|
|
54
|
+
paramss: list[ENNParams] | list[Any],
|
|
55
|
+
P: int = 10,
|
|
56
|
+
rng: Generator | Any,
|
|
57
|
+
) -> list[float]:
|
|
58
|
+
import numpy as np
|
|
59
|
+
|
|
60
|
+
x_array, y_array = _validate_subsample_inputs(x, y, P, paramss)
|
|
38
61
|
n = x_array.shape[0]
|
|
39
|
-
if n == 0:
|
|
40
|
-
return [0.0] * len(paramss)
|
|
41
|
-
if len(model) <= 1:
|
|
62
|
+
if n == 0 or len(model) <= 1:
|
|
42
63
|
return [0.0] * len(paramss)
|
|
43
64
|
P_actual = min(P, n)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
y_selected = y_array[indices]
|
|
50
|
-
if not np.isfinite(y_selected).all():
|
|
65
|
+
indices = (
|
|
66
|
+
np.arange(n, dtype=int) if P_actual == n else rng.permutation(n)[:P_actual]
|
|
67
|
+
)
|
|
68
|
+
x_sel, y_sel = x_array[indices], y_array[indices]
|
|
69
|
+
if not np.isfinite(y_sel).all():
|
|
51
70
|
return [0.0] * len(paramss)
|
|
52
|
-
|
|
53
|
-
|
|
71
|
+
from .enn_params import PosteriorFlags
|
|
72
|
+
|
|
73
|
+
post = model.batch_posterior(
|
|
74
|
+
x_sel,
|
|
75
|
+
paramss,
|
|
76
|
+
flags=PosteriorFlags(exclude_nearest=True, observation_noise=True),
|
|
54
77
|
)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
if mu_batch.shape != (num_params, P_actual, num_outputs) or se_batch.shape != (
|
|
60
|
-
num_params,
|
|
61
|
-
P_actual,
|
|
62
|
-
num_outputs,
|
|
63
|
-
):
|
|
64
|
-
raise ValueError(
|
|
65
|
-
(
|
|
66
|
-
mu_batch.shape,
|
|
67
|
-
se_batch.shape,
|
|
68
|
-
(num_params, P_actual, num_outputs),
|
|
69
|
-
)
|
|
70
|
-
)
|
|
78
|
+
num_params, num_outputs = len(paramss), y_sel.shape[1]
|
|
79
|
+
expected_shape = (num_params, P_actual, num_outputs)
|
|
80
|
+
if post.mu.shape != expected_shape or post.se.shape != expected_shape:
|
|
81
|
+
raise ValueError((post.mu.shape, post.se.shape, expected_shape))
|
|
71
82
|
y_std = np.std(y_array, axis=0, keepdims=True).astype(float)
|
|
72
83
|
y_std = np.where(np.isfinite(y_std) & (y_std > 0.0), y_std, 1.0)
|
|
73
|
-
y_scaled =
|
|
74
|
-
mu_scaled =
|
|
75
|
-
se_scaled =
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
if not np.isfinite(mu_i).all() or not np.isfinite(se_i).all():
|
|
81
|
-
result.append(0.0)
|
|
82
|
-
continue
|
|
83
|
-
if np.any(se_i <= 0.0):
|
|
84
|
-
result.append(0.0)
|
|
85
|
-
continue
|
|
86
|
-
diff = y_scaled - mu_i
|
|
87
|
-
var_scaled = se_i**2
|
|
88
|
-
log_term = np.log(2.0 * np.pi * var_scaled)
|
|
89
|
-
quad = diff**2 / var_scaled
|
|
90
|
-
loglik = -0.5 * np.sum(log_term + quad)
|
|
91
|
-
if not np.isfinite(loglik):
|
|
92
|
-
result.append(0.0)
|
|
93
|
-
continue
|
|
94
|
-
result.append(float(loglik))
|
|
95
|
-
return result
|
|
84
|
+
y_scaled = y_sel / y_std
|
|
85
|
+
mu_scaled = post.mu / y_std
|
|
86
|
+
se_scaled = post.se / y_std
|
|
87
|
+
return [
|
|
88
|
+
_compute_single_loglik(y_scaled, mu_scaled[i], se_scaled[i])
|
|
89
|
+
for i in range(num_params)
|
|
90
|
+
]
|
|
96
91
|
|
|
97
92
|
|
|
98
93
|
def enn_fit(
|
|
@@ -118,22 +113,26 @@ def enn_fit(
|
|
|
118
113
|
ale_homoscedastic_values = 10**ale_homoscedastic_log_values
|
|
119
114
|
paramss = [
|
|
120
115
|
ENNParams(
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
116
|
+
k_num_neighbors=k,
|
|
117
|
+
epistemic_variance_scale=float(epi_val),
|
|
118
|
+
aleatoric_variance_scale=float(ale_val),
|
|
124
119
|
)
|
|
125
120
|
for epi_val, ale_val in zip(epi_var_scale_values, ale_homoscedastic_values)
|
|
126
121
|
]
|
|
127
122
|
if params_warm_start is not None:
|
|
128
123
|
paramss.append(
|
|
129
124
|
ENNParams(
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
125
|
+
k_num_neighbors=k,
|
|
126
|
+
epistemic_variance_scale=params_warm_start.epistemic_variance_scale,
|
|
127
|
+
aleatoric_variance_scale=params_warm_start.aleatoric_variance_scale,
|
|
133
128
|
)
|
|
134
129
|
)
|
|
135
130
|
if len(paramss) == 0:
|
|
136
|
-
return ENNParams(
|
|
131
|
+
return ENNParams(
|
|
132
|
+
k_num_neighbors=k,
|
|
133
|
+
epistemic_variance_scale=1.0,
|
|
134
|
+
aleatoric_variance_scale=0.0,
|
|
135
|
+
)
|
|
137
136
|
import numpy as np
|
|
138
137
|
|
|
139
138
|
logliks = subsample_loglik(
|
enn/enn/enn_hash.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def normal_hash_batch_multi_seed(
|
|
9
|
+
function_seeds: np.ndarray, data_indices: np.ndarray, num_metrics: int
|
|
10
|
+
) -> np.ndarray:
|
|
11
|
+
import numpy as np
|
|
12
|
+
from scipy.special import ndtri
|
|
13
|
+
|
|
14
|
+
num_seeds = len(function_seeds)
|
|
15
|
+
unique_indices, inverse = np.unique(data_indices, return_inverse=True)
|
|
16
|
+
num_unique = len(unique_indices)
|
|
17
|
+
seed_grid, idx_grid, metric_grid = np.meshgrid(
|
|
18
|
+
function_seeds.astype(np.uint64),
|
|
19
|
+
unique_indices.astype(np.uint64),
|
|
20
|
+
np.arange(num_metrics, dtype=np.uint64),
|
|
21
|
+
indexing="ij",
|
|
22
|
+
)
|
|
23
|
+
seed_flat = seed_grid.ravel()
|
|
24
|
+
idx_flat = idx_grid.ravel()
|
|
25
|
+
metric_flat = metric_grid.ravel()
|
|
26
|
+
combined_seeds = (seed_flat * np.uint64(1_000_003) + idx_flat) * np.uint64(
|
|
27
|
+
1_000_003
|
|
28
|
+
) + metric_flat
|
|
29
|
+
uniform_vals = np.empty(len(combined_seeds), dtype=float)
|
|
30
|
+
for i, seed in enumerate(combined_seeds):
|
|
31
|
+
rng = np.random.Generator(np.random.Philox(int(seed)))
|
|
32
|
+
uniform_vals[i] = rng.random()
|
|
33
|
+
uniform_vals = np.clip(uniform_vals, 1e-10, 1.0 - 1e-10)
|
|
34
|
+
normal_vals = ndtri(uniform_vals).reshape(num_seeds, num_unique, num_metrics)
|
|
35
|
+
return normal_vals[:, inverse.ravel(), :].reshape(
|
|
36
|
+
num_seeds, *data_indices.shape, num_metrics
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def normal_hash_batch_multi_seed_fast(
|
|
41
|
+
function_seeds: np.ndarray, data_indices: np.ndarray, num_metrics: int
|
|
42
|
+
) -> np.ndarray:
|
|
43
|
+
import numpy as np
|
|
44
|
+
|
|
45
|
+
function_seeds = np.asarray(function_seeds, dtype=np.int64)
|
|
46
|
+
data_indices = np.asarray(data_indices)
|
|
47
|
+
if num_metrics <= 0:
|
|
48
|
+
raise ValueError(num_metrics)
|
|
49
|
+
num_seeds = len(function_seeds)
|
|
50
|
+
unique_indices, inverse = np.unique(data_indices, return_inverse=True)
|
|
51
|
+
|
|
52
|
+
def _splitmix64(x: np.ndarray) -> np.ndarray:
|
|
53
|
+
with np.errstate(over="ignore"):
|
|
54
|
+
x = x + np.uint64(0x9E3779B97F4A7C15)
|
|
55
|
+
z = x
|
|
56
|
+
z = (z ^ (z >> np.uint64(30))) * np.uint64(0xBF58476D1CE4E5B9)
|
|
57
|
+
z = (z ^ (z >> np.uint64(27))) * np.uint64(0x94D049BB133111EB)
|
|
58
|
+
z = z ^ (z >> np.uint64(31))
|
|
59
|
+
return z
|
|
60
|
+
|
|
61
|
+
seeds_u64 = function_seeds.astype(np.uint64, copy=False)
|
|
62
|
+
unique_u64 = unique_indices.astype(np.uint64, copy=False)
|
|
63
|
+
metric_u64 = np.arange(num_metrics, dtype=np.uint64)
|
|
64
|
+
normal_vals = np.empty((num_seeds, unique_indices.size, num_metrics), dtype=float)
|
|
65
|
+
p = np.uint64(1_000_003)
|
|
66
|
+
inv_2p53 = 1.0 / 9007199254740992.0
|
|
67
|
+
for si, s in enumerate(seeds_u64):
|
|
68
|
+
with np.errstate(over="ignore"):
|
|
69
|
+
base = (s * p + unique_u64) * p
|
|
70
|
+
combined = base[:, None] + metric_u64[None, :]
|
|
71
|
+
r1 = _splitmix64(combined)
|
|
72
|
+
r2 = _splitmix64(combined ^ np.uint64(0xD2B74407B1CE6E93))
|
|
73
|
+
u1 = (r1 >> np.uint64(11)).astype(np.float64) * inv_2p53
|
|
74
|
+
u2 = (r2 >> np.uint64(11)).astype(np.float64) * inv_2p53
|
|
75
|
+
u1 = np.clip(u1, 1e-12, 1.0 - 1e-12)
|
|
76
|
+
normal_vals[si, :, :] = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2)
|
|
77
|
+
return normal_vals[:, inverse.ravel(), :].reshape(
|
|
78
|
+
num_seeds, *data_indices.shape, num_metrics
|
|
79
|
+
)
|
enn/enn/enn_index.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ENNIndex:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
train_x_scaled: np.ndarray,
|
|
12
|
+
num_dim: int,
|
|
13
|
+
x_scale: np.ndarray,
|
|
14
|
+
scale_x: bool,
|
|
15
|
+
driver: Any = None,
|
|
16
|
+
) -> None:
|
|
17
|
+
from enn.turbo.config.enums import ENNIndexDriver
|
|
18
|
+
|
|
19
|
+
if driver is None:
|
|
20
|
+
driver = ENNIndexDriver.FLAT
|
|
21
|
+
self._train_x_scaled = train_x_scaled
|
|
22
|
+
self._num_dim = num_dim
|
|
23
|
+
self._x_scale = x_scale
|
|
24
|
+
self._scale_x = scale_x
|
|
25
|
+
self._driver = driver
|
|
26
|
+
self._index: Any | None = None
|
|
27
|
+
self._build_index()
|
|
28
|
+
|
|
29
|
+
def _build_index(self) -> None:
|
|
30
|
+
import faiss
|
|
31
|
+
import numpy as np
|
|
32
|
+
from enn.turbo.config.enums import ENNIndexDriver
|
|
33
|
+
|
|
34
|
+
if len(self._train_x_scaled) == 0:
|
|
35
|
+
return
|
|
36
|
+
x_f32 = self._train_x_scaled.astype(np.float32, copy=False)
|
|
37
|
+
if self._driver == ENNIndexDriver.FLAT:
|
|
38
|
+
index = faiss.IndexFlatL2(self._num_dim)
|
|
39
|
+
elif self._driver == ENNIndexDriver.HNSW:
|
|
40
|
+
# TODO: Make M configurable
|
|
41
|
+
index = faiss.IndexHNSWFlat(self._num_dim, 32)
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError(f"Unknown driver: {self._driver}")
|
|
44
|
+
index.add(x_f32)
|
|
45
|
+
self._index = index
|
|
46
|
+
|
|
47
|
+
def add(self, x: np.ndarray) -> None:
|
|
48
|
+
import numpy as np
|
|
49
|
+
from enn.turbo.config.enums import ENNIndexDriver
|
|
50
|
+
|
|
51
|
+
x = np.asarray(x, dtype=float)
|
|
52
|
+
if x.ndim != 2 or x.shape[1] != self._num_dim:
|
|
53
|
+
raise ValueError(x.shape)
|
|
54
|
+
x_scaled = x / self._x_scale if self._scale_x else x
|
|
55
|
+
x_f32 = x_scaled.astype(np.float32, copy=False)
|
|
56
|
+
if self._index is None:
|
|
57
|
+
import faiss
|
|
58
|
+
|
|
59
|
+
if self._driver == ENNIndexDriver.FLAT:
|
|
60
|
+
self._index = faiss.IndexFlatL2(self._num_dim)
|
|
61
|
+
elif self._driver == ENNIndexDriver.HNSW:
|
|
62
|
+
self._index = faiss.IndexHNSWFlat(self._num_dim, 32)
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(f"Unknown driver: {self._driver}")
|
|
65
|
+
self._index.add(x_f32)
|
|
66
|
+
|
|
67
|
+
def search(
|
|
68
|
+
self,
|
|
69
|
+
x: np.ndarray,
|
|
70
|
+
*,
|
|
71
|
+
search_k: int,
|
|
72
|
+
exclude_nearest: bool,
|
|
73
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
74
|
+
import numpy as np
|
|
75
|
+
|
|
76
|
+
search_k = int(search_k)
|
|
77
|
+
if search_k <= 0:
|
|
78
|
+
raise ValueError(search_k)
|
|
79
|
+
x = np.asarray(x, dtype=float)
|
|
80
|
+
if x.ndim != 2 or x.shape[1] != self._num_dim:
|
|
81
|
+
raise ValueError(x.shape)
|
|
82
|
+
if self._index is None:
|
|
83
|
+
raise RuntimeError("index is not initialized")
|
|
84
|
+
x_scaled = x / self._x_scale if self._scale_x else x
|
|
85
|
+
x_f32 = x_scaled.astype(np.float32, copy=False)
|
|
86
|
+
dist2s_full, idx_full = self._index.search(x_f32, search_k)
|
|
87
|
+
dist2s_full = dist2s_full.astype(float)
|
|
88
|
+
idx_full = idx_full.astype(int)
|
|
89
|
+
if exclude_nearest:
|
|
90
|
+
dist2s_full = dist2s_full[:, 1:]
|
|
91
|
+
idx_full = idx_full[:, 1:]
|
|
92
|
+
return dist2s_full, idx_full
|