ennbo 0.1.0__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enn/__init__.py +25 -13
- enn/benchmarks/__init__.py +3 -0
- enn/benchmarks/ackley.py +5 -0
- enn/benchmarks/ackley_class.py +17 -0
- enn/benchmarks/ackley_core.py +12 -0
- enn/benchmarks/double_ackley.py +24 -0
- enn/enn/candidates.py +14 -0
- enn/enn/conditional_posterior_draw_internals.py +15 -0
- enn/enn/draw_internals.py +15 -0
- enn/enn/enn.py +16 -229
- enn/enn/enn_class.py +423 -0
- enn/enn/enn_conditional.py +325 -0
- enn/enn/enn_fit.py +77 -76
- enn/enn/enn_hash.py +79 -0
- enn/enn/enn_index.py +92 -0
- enn/enn/enn_like_protocol.py +35 -0
- enn/enn/enn_normal.py +3 -3
- enn/enn/enn_params.py +3 -9
- enn/enn/enn_params_class.py +24 -0
- enn/enn/enn_util.py +79 -37
- enn/enn/neighbor_data.py +14 -0
- enn/enn/neighbors.py +14 -0
- enn/enn/posterior_flags.py +8 -0
- enn/enn/weighted_stats.py +14 -0
- enn/turbo/components/__init__.py +41 -0
- enn/turbo/components/acquisition.py +13 -0
- enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
- enn/turbo/components/builder.py +22 -0
- enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
- enn/turbo/components/enn_surrogate.py +115 -0
- enn/turbo/components/gp_surrogate.py +144 -0
- enn/turbo/components/hnr_acq_optimizer.py +83 -0
- enn/turbo/components/incumbent_selector.py +11 -0
- enn/turbo/components/incumbent_selector_protocol.py +16 -0
- enn/turbo/components/no_incumbent_selector.py +21 -0
- enn/turbo/components/no_surrogate.py +49 -0
- enn/turbo/components/pareto_acq_optimizer.py +49 -0
- enn/turbo/components/posterior_result.py +12 -0
- enn/turbo/components/protocols.py +13 -0
- enn/turbo/components/random_acq_optimizer.py +21 -0
- enn/turbo/components/scalar_incumbent_selector.py +39 -0
- enn/turbo/components/surrogate_protocol.py +32 -0
- enn/turbo/components/surrogate_result.py +12 -0
- enn/turbo/components/surrogates.py +5 -0
- enn/turbo/components/thompson_acq_optimizer.py +49 -0
- enn/turbo/components/trust_region_protocol.py +24 -0
- enn/turbo/components/ucb_acq_optimizer.py +49 -0
- enn/turbo/config/__init__.py +87 -0
- enn/turbo/config/acq_type.py +8 -0
- enn/turbo/config/acquisition.py +26 -0
- enn/turbo/config/base.py +4 -0
- enn/turbo/config/candidate_gen_config.py +49 -0
- enn/turbo/config/candidate_rv.py +7 -0
- enn/turbo/config/draw_acquisition_config.py +14 -0
- enn/turbo/config/enn_index_driver.py +6 -0
- enn/turbo/config/enn_surrogate_config.py +42 -0
- enn/turbo/config/enums.py +7 -0
- enn/turbo/config/factory.py +118 -0
- enn/turbo/config/gp_surrogate_config.py +14 -0
- enn/turbo/config/hnr_optimizer_config.py +7 -0
- enn/turbo/config/init_config.py +17 -0
- enn/turbo/config/init_strategies/__init__.py +9 -0
- enn/turbo/config/init_strategies/hybrid_init.py +23 -0
- enn/turbo/config/init_strategies/init_strategy.py +19 -0
- enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
- enn/turbo/config/morbo_tr_config.py +82 -0
- enn/turbo/config/nds_optimizer_config.py +7 -0
- enn/turbo/config/no_surrogate_config.py +14 -0
- enn/turbo/config/no_tr_config.py +31 -0
- enn/turbo/config/optimizer_config.py +72 -0
- enn/turbo/config/pareto_acquisition_config.py +14 -0
- enn/turbo/config/raasp_driver.py +6 -0
- enn/turbo/config/raasp_optimizer_config.py +7 -0
- enn/turbo/config/random_acquisition_config.py +14 -0
- enn/turbo/config/rescalarize.py +7 -0
- enn/turbo/config/surrogate.py +12 -0
- enn/turbo/config/trust_region.py +34 -0
- enn/turbo/config/turbo_tr_config.py +71 -0
- enn/turbo/config/ucb_acquisition_config.py +14 -0
- enn/turbo/config/validation.py +45 -0
- enn/turbo/hypervolume.py +30 -0
- enn/turbo/impl_helpers.py +68 -0
- enn/turbo/morbo_trust_region.py +250 -0
- enn/turbo/no_trust_region.py +58 -0
- enn/turbo/optimizer.py +300 -0
- enn/turbo/optimizer_config.py +8 -0
- enn/turbo/proposal.py +46 -39
- enn/turbo/sampling.py +21 -0
- enn/turbo/strategies/__init__.py +9 -0
- enn/turbo/strategies/lhd_only_strategy.py +36 -0
- enn/turbo/strategies/optimization_strategy.py +19 -0
- enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
- enn/turbo/tr_helpers.py +202 -0
- enn/turbo/turbo_gp.py +9 -2
- enn/turbo/turbo_gp_base.py +0 -1
- enn/turbo/turbo_gp_fit.py +187 -0
- enn/turbo/turbo_gp_noisy.py +0 -1
- enn/turbo/turbo_optimizer_utils.py +98 -0
- enn/turbo/turbo_trust_region.py +129 -63
- enn/turbo/turbo_utils.py +144 -117
- enn/turbo/types/__init__.py +7 -0
- enn/turbo/types/appendable_array.py +85 -0
- enn/turbo/types/gp_data_prep.py +13 -0
- enn/turbo/types/gp_fit_result.py +11 -0
- enn/turbo/types/obs_lists.py +10 -0
- enn/turbo/types/prepare_ask_result.py +14 -0
- enn/turbo/types/tell_inputs.py +14 -0
- {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/METADATA +22 -14
- ennbo-0.1.7.dist-info/RECORD +111 -0
- enn/enn/__init__.py +0 -4
- enn/turbo/__init__.py +0 -11
- enn/turbo/base_turbo_impl.py +0 -98
- enn/turbo/lhd_only_impl.py +0 -42
- enn/turbo/turbo_config.py +0 -28
- enn/turbo/turbo_enn_impl.py +0 -176
- enn/turbo/turbo_mode.py +0 -10
- enn/turbo/turbo_mode_impl.py +0 -67
- enn/turbo/turbo_one_impl.py +0 -163
- enn/turbo/turbo_optimizer.py +0 -337
- enn/turbo/turbo_zero_impl.py +0 -24
- ennbo-0.1.0.dist-info/RECORD +0 -27
- {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
- {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
from .candidates import Candidates
|
|
4
|
+
from .conditional_posterior_draw_internals import ConditionalPosteriorDrawInternals
|
|
5
|
+
from .enn_like_protocol import ENNLike
|
|
6
|
+
from .enn_params import ENNParams, PosteriorFlags
|
|
7
|
+
from .neighbors import Neighbors
|
|
8
|
+
|
|
9
|
+
_ENNLike = ENNLike
|
|
10
|
+
_Candidates = Candidates
|
|
11
|
+
_Neighbors = Neighbors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _pairwise_sq_l2(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
15
|
+
a = np.asarray(a, dtype=float)
|
|
16
|
+
b = np.asarray(b, dtype=float)
|
|
17
|
+
aa = np.sum(a * a, axis=1, keepdims=True)
|
|
18
|
+
bb = np.sum(b * b, axis=1, keepdims=True).T
|
|
19
|
+
dist2 = aa + bb - 2.0 * (a @ b.T)
|
|
20
|
+
return np.maximum(dist2, 0.0)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _validate_x(enn: ENNLike, x: np.ndarray) -> np.ndarray:
|
|
24
|
+
x = np.asarray(x, dtype=float)
|
|
25
|
+
if x.ndim != 2 or x.shape[1] != enn._num_dim:
|
|
26
|
+
raise ValueError(x.shape)
|
|
27
|
+
return x
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _validate_whatif(
|
|
31
|
+
enn: ENNLike, x_whatif: np.ndarray, y_whatif: np.ndarray
|
|
32
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
33
|
+
x_whatif = np.asarray(x_whatif, dtype=float)
|
|
34
|
+
y_whatif = np.asarray(y_whatif, dtype=float)
|
|
35
|
+
if x_whatif.ndim != 2 or x_whatif.shape[1] != enn._num_dim:
|
|
36
|
+
raise ValueError(x_whatif.shape)
|
|
37
|
+
if y_whatif.ndim != 2 or y_whatif.shape[1] != enn._num_metrics:
|
|
38
|
+
raise ValueError(y_whatif.shape)
|
|
39
|
+
if x_whatif.shape[0] != y_whatif.shape[0]:
|
|
40
|
+
raise ValueError((x_whatif.shape, y_whatif.shape))
|
|
41
|
+
return x_whatif, y_whatif
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _scale_x_if_needed(enn: ENNLike, x: np.ndarray) -> np.ndarray:
|
|
45
|
+
return x / enn._x_scale if enn._scale_x else x
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _compute_total_n(enn: ENNLike, num_whatif: int, flags: PosteriorFlags) -> int:
|
|
49
|
+
total_n = len(enn) + int(num_whatif)
|
|
50
|
+
if flags.exclude_nearest and total_n <= 1:
|
|
51
|
+
raise ValueError(total_n)
|
|
52
|
+
return total_n
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _compute_search_k(params: ENNParams, flags: PosteriorFlags, total_n: int) -> int:
|
|
56
|
+
return int(
|
|
57
|
+
min(params.k_num_neighbors + (1 if flags.exclude_nearest else 0), total_n)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _get_train_candidates(enn: ENNLike, x: np.ndarray, *, search_k: int) -> Candidates:
|
|
62
|
+
batch_size = x.shape[0]
|
|
63
|
+
if len(enn) == 0 or search_k == 0:
|
|
64
|
+
return Candidates(
|
|
65
|
+
dist2=np.zeros((batch_size, 0), dtype=float),
|
|
66
|
+
ids=np.zeros((batch_size, 0), dtype=int),
|
|
67
|
+
y=np.zeros((batch_size, 0, enn._num_metrics), dtype=float),
|
|
68
|
+
yvar=(
|
|
69
|
+
np.zeros((batch_size, 0, enn._num_metrics), dtype=float)
|
|
70
|
+
if enn._train_yvar is not None
|
|
71
|
+
else None
|
|
72
|
+
),
|
|
73
|
+
)
|
|
74
|
+
train_search_k = int(min(search_k, len(enn)))
|
|
75
|
+
dist2_train, idx_train = enn._enn_index.search(
|
|
76
|
+
x, search_k=train_search_k, exclude_nearest=False
|
|
77
|
+
)
|
|
78
|
+
y_train = enn._train_y[idx_train]
|
|
79
|
+
yvar_train = enn._train_yvar[idx_train] if enn._train_yvar is not None else None
|
|
80
|
+
return Candidates(dist2=dist2_train, ids=idx_train, y=y_train, yvar=yvar_train)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_whatif_candidates(
|
|
84
|
+
enn: ENNLike,
|
|
85
|
+
x: np.ndarray,
|
|
86
|
+
x_whatif: np.ndarray,
|
|
87
|
+
y_whatif: np.ndarray,
|
|
88
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
89
|
+
x_scaled = _scale_x_if_needed(enn, x)
|
|
90
|
+
x_whatif_scaled = _scale_x_if_needed(enn, x_whatif)
|
|
91
|
+
dist2_whatif = _pairwise_sq_l2(x_scaled, x_whatif_scaled)
|
|
92
|
+
batch_size = x.shape[0]
|
|
93
|
+
y_whatif_batched = np.broadcast_to(
|
|
94
|
+
y_whatif[np.newaxis, :, :], (batch_size, y_whatif.shape[0], y_whatif.shape[1])
|
|
95
|
+
)
|
|
96
|
+
return dist2_whatif, y_whatif_batched
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
_WhatifCandidateTuple = tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _merge_candidates(
|
|
103
|
+
enn: ENNLike,
|
|
104
|
+
*,
|
|
105
|
+
train: Candidates,
|
|
106
|
+
whatif: _WhatifCandidateTuple,
|
|
107
|
+
) -> Candidates:
|
|
108
|
+
dist2_whatif, ids_whatif, y_whatif_batched = whatif
|
|
109
|
+
dist2_all = np.concatenate([train.dist2, dist2_whatif], axis=1)
|
|
110
|
+
ids_all = np.concatenate([train.ids, ids_whatif], axis=1)
|
|
111
|
+
y_all = np.concatenate([train.y, y_whatif_batched], axis=1)
|
|
112
|
+
if train.yvar is None:
|
|
113
|
+
return Candidates(dist2=dist2_all, ids=ids_all, y=y_all, yvar=None)
|
|
114
|
+
batch_size = dist2_all.shape[0]
|
|
115
|
+
num_whatif = dist2_whatif.shape[1]
|
|
116
|
+
yvar_whatif = np.zeros((batch_size, num_whatif, enn._num_metrics))
|
|
117
|
+
yvar_all = np.concatenate([train.yvar, yvar_whatif], axis=1)
|
|
118
|
+
return Candidates(dist2=dist2_all, ids=ids_all, y=y_all, yvar=yvar_all)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _select_sorted_candidates(dist2_all: np.ndarray, *, search_k: int) -> np.ndarray:
|
|
122
|
+
batch_size, num_candidates = dist2_all.shape
|
|
123
|
+
if search_k < num_candidates:
|
|
124
|
+
sel = np.argpartition(dist2_all, kth=search_k - 1, axis=1)[:, :search_k]
|
|
125
|
+
else:
|
|
126
|
+
sel = np.broadcast_to(np.arange(num_candidates), (batch_size, num_candidates))
|
|
127
|
+
sel_dist2 = np.take_along_axis(dist2_all, sel, axis=1)
|
|
128
|
+
sel_order = np.argsort(sel_dist2, axis=1)
|
|
129
|
+
return np.take_along_axis(sel, sel_order, axis=1)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _take_along_axis_3d(a: np.ndarray, idx_2d: np.ndarray) -> np.ndarray:
|
|
133
|
+
return np.take_along_axis(a, idx_2d[:, :, np.newaxis], axis=1)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _make_empty_normal(enn: ENNLike, batch_size: int):
|
|
137
|
+
from .enn_normal import ENNNormal
|
|
138
|
+
|
|
139
|
+
internals = enn._empty_posterior_internals(batch_size)
|
|
140
|
+
return ENNNormal(internals.mu, internals.se)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _build_candidates(
|
|
144
|
+
enn: ENNLike,
|
|
145
|
+
x: np.ndarray,
|
|
146
|
+
x_whatif: np.ndarray,
|
|
147
|
+
y_whatif: np.ndarray,
|
|
148
|
+
*,
|
|
149
|
+
search_k: int,
|
|
150
|
+
) -> Candidates:
|
|
151
|
+
train_candidates = _get_train_candidates(enn, x, search_k=search_k)
|
|
152
|
+
dist2_whatif, y_whatif_batched = _get_whatif_candidates(enn, x, x_whatif, y_whatif)
|
|
153
|
+
n_train = int(len(enn))
|
|
154
|
+
ids_whatif = np.broadcast_to(
|
|
155
|
+
n_train + np.arange(x_whatif.shape[0], dtype=int), dist2_whatif.shape
|
|
156
|
+
)
|
|
157
|
+
return _merge_candidates(
|
|
158
|
+
enn,
|
|
159
|
+
train=train_candidates,
|
|
160
|
+
whatif=(dist2_whatif, ids_whatif, y_whatif_batched),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _select_effective_neighbors(
|
|
165
|
+
candidates: Candidates,
|
|
166
|
+
*,
|
|
167
|
+
search_k: int,
|
|
168
|
+
k: int,
|
|
169
|
+
exclude_nearest: bool,
|
|
170
|
+
) -> Neighbors | None:
|
|
171
|
+
sel = _select_sorted_candidates(candidates.dist2, search_k=search_k)
|
|
172
|
+
if exclude_nearest:
|
|
173
|
+
sel = sel[:, 1:]
|
|
174
|
+
sel = sel[:, : int(min(k, sel.shape[1]))]
|
|
175
|
+
if sel.shape[1] == 0:
|
|
176
|
+
return None
|
|
177
|
+
dist2s = np.take_along_axis(candidates.dist2, sel, axis=1)
|
|
178
|
+
ids = np.take_along_axis(candidates.ids, sel, axis=1)
|
|
179
|
+
y_neighbors = _take_along_axis_3d(candidates.y, sel)
|
|
180
|
+
yvar_neighbors = (
|
|
181
|
+
_take_along_axis_3d(candidates.yvar, sel)
|
|
182
|
+
if candidates.yvar is not None
|
|
183
|
+
else None
|
|
184
|
+
)
|
|
185
|
+
return Neighbors(dist2=dist2s, ids=ids, y=y_neighbors, yvar=yvar_neighbors)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _compute_mu_se(
|
|
189
|
+
enn: ENNLike,
|
|
190
|
+
neighbors: Neighbors,
|
|
191
|
+
*,
|
|
192
|
+
params: ENNParams,
|
|
193
|
+
flags: PosteriorFlags,
|
|
194
|
+
y_scale: np.ndarray,
|
|
195
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
196
|
+
stats = enn._compute_weighted_stats(
|
|
197
|
+
neighbors.dist2,
|
|
198
|
+
neighbors.y,
|
|
199
|
+
yvar_neighbors=neighbors.yvar,
|
|
200
|
+
params=params,
|
|
201
|
+
observation_noise=flags.observation_noise,
|
|
202
|
+
y_scale=y_scale,
|
|
203
|
+
)
|
|
204
|
+
return stats.mu, stats.se
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _compute_draw_internals(
|
|
208
|
+
enn: ENNLike,
|
|
209
|
+
neighbors: Neighbors,
|
|
210
|
+
*,
|
|
211
|
+
params: ENNParams,
|
|
212
|
+
flags: PosteriorFlags,
|
|
213
|
+
y_scale: np.ndarray,
|
|
214
|
+
) -> ConditionalPosteriorDrawInternals:
|
|
215
|
+
stats = enn._compute_weighted_stats(
|
|
216
|
+
neighbors.dist2,
|
|
217
|
+
neighbors.y,
|
|
218
|
+
yvar_neighbors=neighbors.yvar,
|
|
219
|
+
params=params,
|
|
220
|
+
observation_noise=flags.observation_noise,
|
|
221
|
+
y_scale=y_scale,
|
|
222
|
+
)
|
|
223
|
+
return ConditionalPosteriorDrawInternals(
|
|
224
|
+
idx=neighbors.ids.astype(int, copy=False),
|
|
225
|
+
w_normalized=stats.w_normalized,
|
|
226
|
+
l2=stats.l2,
|
|
227
|
+
mu=stats.mu,
|
|
228
|
+
se=stats.se,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _conditional_neighbors_nonempty_whatif(
|
|
233
|
+
enn: ENNLike,
|
|
234
|
+
x_whatif: np.ndarray,
|
|
235
|
+
y_whatif: np.ndarray,
|
|
236
|
+
x: np.ndarray,
|
|
237
|
+
*,
|
|
238
|
+
params: ENNParams,
|
|
239
|
+
flags: PosteriorFlags,
|
|
240
|
+
) -> tuple[int, int, Neighbors | None]:
|
|
241
|
+
batch_size = x.shape[0]
|
|
242
|
+
search_k = _compute_search_k(
|
|
243
|
+
params, flags, _compute_total_n(enn, x_whatif.shape[0], flags)
|
|
244
|
+
)
|
|
245
|
+
if search_k == 0:
|
|
246
|
+
return batch_size, search_k, None
|
|
247
|
+
candidates = _build_candidates(enn, x, x_whatif, y_whatif, search_k=search_k)
|
|
248
|
+
neighbors = _select_effective_neighbors(
|
|
249
|
+
candidates,
|
|
250
|
+
search_k=search_k,
|
|
251
|
+
k=params.k_num_neighbors,
|
|
252
|
+
exclude_nearest=flags.exclude_nearest,
|
|
253
|
+
)
|
|
254
|
+
return batch_size, search_k, neighbors
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _compute_conditional_posterior_impl(
|
|
258
|
+
enn: ENNLike,
|
|
259
|
+
x_whatif: np.ndarray,
|
|
260
|
+
y_whatif: np.ndarray,
|
|
261
|
+
x: np.ndarray,
|
|
262
|
+
*,
|
|
263
|
+
params: ENNParams,
|
|
264
|
+
flags: PosteriorFlags,
|
|
265
|
+
y_scale: np.ndarray,
|
|
266
|
+
):
|
|
267
|
+
from .enn_normal import ENNNormal
|
|
268
|
+
|
|
269
|
+
x = _validate_x(enn, x)
|
|
270
|
+
x_whatif, y_whatif = _validate_whatif(enn, x_whatif, y_whatif)
|
|
271
|
+
if x_whatif.shape[0] == 0:
|
|
272
|
+
return enn.posterior(x, params=params, flags=flags)
|
|
273
|
+
batch_size, search_k, neighbors = _conditional_neighbors_nonempty_whatif(
|
|
274
|
+
enn, x_whatif, y_whatif, x, params=params, flags=flags
|
|
275
|
+
)
|
|
276
|
+
if search_k == 0 or neighbors is None:
|
|
277
|
+
return _make_empty_normal(enn, batch_size)
|
|
278
|
+
mu, se = _compute_mu_se(enn, neighbors, params=params, flags=flags, y_scale=y_scale)
|
|
279
|
+
return ENNNormal(mu, se)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def compute_conditional_posterior(
|
|
283
|
+
enn: ENNLike,
|
|
284
|
+
x_whatif: np.ndarray,
|
|
285
|
+
y_whatif: np.ndarray,
|
|
286
|
+
x: np.ndarray,
|
|
287
|
+
*,
|
|
288
|
+
params: ENNParams,
|
|
289
|
+
flags: PosteriorFlags,
|
|
290
|
+
y_scale: np.ndarray,
|
|
291
|
+
):
|
|
292
|
+
return _compute_conditional_posterior_impl(
|
|
293
|
+
enn, x_whatif, y_whatif, x, params=params, flags=flags, y_scale=y_scale
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def compute_conditional_posterior_draw_internals(
|
|
298
|
+
enn: ENNLike,
|
|
299
|
+
x_whatif: np.ndarray,
|
|
300
|
+
y_whatif: np.ndarray,
|
|
301
|
+
x: np.ndarray,
|
|
302
|
+
*,
|
|
303
|
+
params: ENNParams,
|
|
304
|
+
flags: PosteriorFlags,
|
|
305
|
+
y_scale: np.ndarray,
|
|
306
|
+
) -> ConditionalPosteriorDrawInternals:
|
|
307
|
+
x = _validate_x(enn, x)
|
|
308
|
+
x_whatif, y_whatif = _validate_whatif(enn, x_whatif, y_whatif)
|
|
309
|
+
if x_whatif.shape[0] == 0:
|
|
310
|
+
raise ValueError("x_whatif must be non-empty for conditional draw internals")
|
|
311
|
+
batch_size, search_k, neighbors = _conditional_neighbors_nonempty_whatif(
|
|
312
|
+
enn, x_whatif, y_whatif, x, params=params, flags=flags
|
|
313
|
+
)
|
|
314
|
+
if search_k == 0 or neighbors is None:
|
|
315
|
+
empty_internals = enn._empty_posterior_internals(batch_size)
|
|
316
|
+
return ConditionalPosteriorDrawInternals(
|
|
317
|
+
idx=empty_internals.idx,
|
|
318
|
+
w_normalized=empty_internals.w_normalized,
|
|
319
|
+
l2=empty_internals.l2,
|
|
320
|
+
mu=empty_internals.mu,
|
|
321
|
+
se=empty_internals.se,
|
|
322
|
+
)
|
|
323
|
+
return _compute_draw_internals(
|
|
324
|
+
enn, neighbors, params=params, flags=flags, y_scale=y_scale
|
|
325
|
+
)
|
enn/enn/enn_fit.py
CHANGED
|
@@ -1,15 +1,49 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
from typing import TYPE_CHECKING, Any
|
|
4
3
|
|
|
5
4
|
if TYPE_CHECKING:
|
|
6
5
|
import numpy as np
|
|
7
6
|
from numpy.random import Generator
|
|
8
|
-
|
|
9
|
-
from .enn import EpistemicNearestNeighbors
|
|
7
|
+
from .enn_class import EpistemicNearestNeighbors
|
|
10
8
|
from .enn_params import ENNParams
|
|
11
9
|
|
|
12
|
-
|
|
10
|
+
|
|
11
|
+
def _validate_subsample_inputs(
|
|
12
|
+
x: np.ndarray | Any, y: np.ndarray | Any, P: int, paramss: list
|
|
13
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
x_array = np.asarray(x, dtype=float)
|
|
17
|
+
if x_array.ndim != 2:
|
|
18
|
+
raise ValueError(x_array.shape)
|
|
19
|
+
y_array = np.asarray(y, dtype=float)
|
|
20
|
+
if y_array.ndim == 1:
|
|
21
|
+
y_array = y_array.reshape(-1, 1)
|
|
22
|
+
if y_array.ndim != 2:
|
|
23
|
+
raise ValueError(y_array.shape)
|
|
24
|
+
if x_array.shape[0] != y_array.shape[0]:
|
|
25
|
+
raise ValueError((x_array.shape, y_array.shape))
|
|
26
|
+
if P <= 0:
|
|
27
|
+
raise ValueError(P)
|
|
28
|
+
if len(paramss) == 0:
|
|
29
|
+
raise ValueError("paramss must be non-empty")
|
|
30
|
+
return x_array, y_array
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _compute_single_loglik(
|
|
34
|
+
y_scaled: np.ndarray, mu_i: np.ndarray, se_i: np.ndarray
|
|
35
|
+
) -> float:
|
|
36
|
+
import numpy as np
|
|
37
|
+
|
|
38
|
+
if not np.isfinite(mu_i).all() or not np.isfinite(se_i).all():
|
|
39
|
+
return 0.0
|
|
40
|
+
if np.any(se_i <= 0.0):
|
|
41
|
+
return 0.0
|
|
42
|
+
var_scaled = se_i**2
|
|
43
|
+
loglik = -0.5 * np.sum(
|
|
44
|
+
np.log(2.0 * np.pi * var_scaled) + (y_scaled - mu_i) ** 2 / var_scaled
|
|
45
|
+
)
|
|
46
|
+
return float(loglik) if np.isfinite(loglik) else 0.0
|
|
13
47
|
|
|
14
48
|
|
|
15
49
|
def subsample_loglik(
|
|
@@ -23,68 +57,37 @@ def subsample_loglik(
|
|
|
23
57
|
) -> list[float]:
|
|
24
58
|
import numpy as np
|
|
25
59
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
if
|
|
29
|
-
raise ValueError(y.shape)
|
|
30
|
-
if x.shape[0] != y.shape[0]:
|
|
31
|
-
raise ValueError((x.shape, y.shape))
|
|
32
|
-
if P <= 0:
|
|
33
|
-
raise ValueError(P)
|
|
34
|
-
if len(paramss) == 0:
|
|
35
|
-
raise ValueError("paramss must be non-empty")
|
|
36
|
-
n = x.shape[0]
|
|
37
|
-
if n == 0:
|
|
38
|
-
return [0.0] * len(paramss)
|
|
39
|
-
if len(model) <= 1:
|
|
60
|
+
x_array, y_array = _validate_subsample_inputs(x, y, P, paramss)
|
|
61
|
+
n = x_array.shape[0]
|
|
62
|
+
if n == 0 or len(model) <= 1:
|
|
40
63
|
return [0.0] * len(paramss)
|
|
41
64
|
P_actual = min(P, n)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
y_selected = y[indices]
|
|
48
|
-
if not np.isfinite(y_selected).all():
|
|
65
|
+
indices = (
|
|
66
|
+
np.arange(n, dtype=int) if P_actual == n else rng.permutation(n)[:P_actual]
|
|
67
|
+
)
|
|
68
|
+
x_sel, y_sel = x_array[indices], y_array[indices]
|
|
69
|
+
if not np.isfinite(y_sel).all():
|
|
49
70
|
return [0.0] * len(paramss)
|
|
50
|
-
|
|
51
|
-
|
|
71
|
+
from .enn_params import PosteriorFlags
|
|
72
|
+
|
|
73
|
+
post = model.batch_posterior(
|
|
74
|
+
x_sel,
|
|
75
|
+
paramss,
|
|
76
|
+
flags=PosteriorFlags(exclude_nearest=True, observation_noise=True),
|
|
52
77
|
)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
if
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
mu_scaled = mu_batch / y_std
|
|
67
|
-
se_scaled = se_batch / y_std
|
|
68
|
-
result = []
|
|
69
|
-
for i in range(num_params):
|
|
70
|
-
mu_i = mu_scaled[i]
|
|
71
|
-
se_i = se_scaled[i]
|
|
72
|
-
if not np.isfinite(mu_i).all() or not np.isfinite(se_i).all():
|
|
73
|
-
result.append(0.0)
|
|
74
|
-
continue
|
|
75
|
-
if np.any(se_i <= 0.0):
|
|
76
|
-
result.append(0.0)
|
|
77
|
-
continue
|
|
78
|
-
diff = y_scaled - mu_i
|
|
79
|
-
var_scaled = se_i**2
|
|
80
|
-
log_term = np.log(2.0 * np.pi * var_scaled)
|
|
81
|
-
quad = diff**2 / var_scaled
|
|
82
|
-
loglik = -0.5 * np.sum(log_term + quad)
|
|
83
|
-
if not np.isfinite(loglik):
|
|
84
|
-
result.append(0.0)
|
|
85
|
-
continue
|
|
86
|
-
result.append(float(loglik))
|
|
87
|
-
return result
|
|
78
|
+
num_params, num_outputs = len(paramss), y_sel.shape[1]
|
|
79
|
+
expected_shape = (num_params, P_actual, num_outputs)
|
|
80
|
+
if post.mu.shape != expected_shape or post.se.shape != expected_shape:
|
|
81
|
+
raise ValueError((post.mu.shape, post.se.shape, expected_shape))
|
|
82
|
+
y_std = np.std(y_array, axis=0, keepdims=True).astype(float)
|
|
83
|
+
y_std = np.where(np.isfinite(y_std) & (y_std > 0.0), y_std, 1.0)
|
|
84
|
+
y_scaled = y_sel / y_std
|
|
85
|
+
mu_scaled = post.mu / y_std
|
|
86
|
+
se_scaled = post.se / y_std
|
|
87
|
+
return [
|
|
88
|
+
_compute_single_loglik(y_scaled, mu_scaled[i], se_scaled[i])
|
|
89
|
+
for i in range(num_params)
|
|
90
|
+
]
|
|
88
91
|
|
|
89
92
|
|
|
90
93
|
def enn_fit(
|
|
@@ -100,12 +103,6 @@ def enn_fit(
|
|
|
100
103
|
|
|
101
104
|
train_x = model.train_x
|
|
102
105
|
train_y = model.train_y
|
|
103
|
-
train_yvar = model.train_yvar
|
|
104
|
-
if train_y.shape[1] != 1:
|
|
105
|
-
raise ValueError(train_y.shape)
|
|
106
|
-
if train_yvar is not None and train_yvar.shape[1] != 1:
|
|
107
|
-
raise ValueError(train_yvar.shape)
|
|
108
|
-
y = train_y[:, 0]
|
|
109
106
|
log_min = -3.0
|
|
110
107
|
log_max = 3.0
|
|
111
108
|
epi_var_scale_log_values = rng.uniform(log_min, log_max, size=num_fit_candidates)
|
|
@@ -116,26 +113,30 @@ def enn_fit(
|
|
|
116
113
|
ale_homoscedastic_values = 10**ale_homoscedastic_log_values
|
|
117
114
|
paramss = [
|
|
118
115
|
ENNParams(
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
116
|
+
k_num_neighbors=k,
|
|
117
|
+
epistemic_variance_scale=float(epi_val),
|
|
118
|
+
aleatoric_variance_scale=float(ale_val),
|
|
122
119
|
)
|
|
123
120
|
for epi_val, ale_val in zip(epi_var_scale_values, ale_homoscedastic_values)
|
|
124
121
|
]
|
|
125
122
|
if params_warm_start is not None:
|
|
126
123
|
paramss.append(
|
|
127
124
|
ENNParams(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
125
|
+
k_num_neighbors=k,
|
|
126
|
+
epistemic_variance_scale=params_warm_start.epistemic_variance_scale,
|
|
127
|
+
aleatoric_variance_scale=params_warm_start.aleatoric_variance_scale,
|
|
131
128
|
)
|
|
132
129
|
)
|
|
133
130
|
if len(paramss) == 0:
|
|
134
|
-
return ENNParams(
|
|
131
|
+
return ENNParams(
|
|
132
|
+
k_num_neighbors=k,
|
|
133
|
+
epistemic_variance_scale=1.0,
|
|
134
|
+
aleatoric_variance_scale=0.0,
|
|
135
|
+
)
|
|
135
136
|
import numpy as np
|
|
136
137
|
|
|
137
138
|
logliks = subsample_loglik(
|
|
138
|
-
model, train_x,
|
|
139
|
+
model, train_x, train_y, paramss=paramss, P=num_fit_samples, rng=rng
|
|
139
140
|
)
|
|
140
141
|
if len(logliks) == 0:
|
|
141
142
|
return paramss[0]
|
enn/enn/enn_hash.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def normal_hash_batch_multi_seed(
|
|
9
|
+
function_seeds: np.ndarray, data_indices: np.ndarray, num_metrics: int
|
|
10
|
+
) -> np.ndarray:
|
|
11
|
+
import numpy as np
|
|
12
|
+
from scipy.special import ndtri
|
|
13
|
+
|
|
14
|
+
num_seeds = len(function_seeds)
|
|
15
|
+
unique_indices, inverse = np.unique(data_indices, return_inverse=True)
|
|
16
|
+
num_unique = len(unique_indices)
|
|
17
|
+
seed_grid, idx_grid, metric_grid = np.meshgrid(
|
|
18
|
+
function_seeds.astype(np.uint64),
|
|
19
|
+
unique_indices.astype(np.uint64),
|
|
20
|
+
np.arange(num_metrics, dtype=np.uint64),
|
|
21
|
+
indexing="ij",
|
|
22
|
+
)
|
|
23
|
+
seed_flat = seed_grid.ravel()
|
|
24
|
+
idx_flat = idx_grid.ravel()
|
|
25
|
+
metric_flat = metric_grid.ravel()
|
|
26
|
+
combined_seeds = (seed_flat * np.uint64(1_000_003) + idx_flat) * np.uint64(
|
|
27
|
+
1_000_003
|
|
28
|
+
) + metric_flat
|
|
29
|
+
uniform_vals = np.empty(len(combined_seeds), dtype=float)
|
|
30
|
+
for i, seed in enumerate(combined_seeds):
|
|
31
|
+
rng = np.random.Generator(np.random.Philox(int(seed)))
|
|
32
|
+
uniform_vals[i] = rng.random()
|
|
33
|
+
uniform_vals = np.clip(uniform_vals, 1e-10, 1.0 - 1e-10)
|
|
34
|
+
normal_vals = ndtri(uniform_vals).reshape(num_seeds, num_unique, num_metrics)
|
|
35
|
+
return normal_vals[:, inverse.ravel(), :].reshape(
|
|
36
|
+
num_seeds, *data_indices.shape, num_metrics
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def normal_hash_batch_multi_seed_fast(
|
|
41
|
+
function_seeds: np.ndarray, data_indices: np.ndarray, num_metrics: int
|
|
42
|
+
) -> np.ndarray:
|
|
43
|
+
import numpy as np
|
|
44
|
+
|
|
45
|
+
function_seeds = np.asarray(function_seeds, dtype=np.int64)
|
|
46
|
+
data_indices = np.asarray(data_indices)
|
|
47
|
+
if num_metrics <= 0:
|
|
48
|
+
raise ValueError(num_metrics)
|
|
49
|
+
num_seeds = len(function_seeds)
|
|
50
|
+
unique_indices, inverse = np.unique(data_indices, return_inverse=True)
|
|
51
|
+
|
|
52
|
+
def _splitmix64(x: np.ndarray) -> np.ndarray:
|
|
53
|
+
with np.errstate(over="ignore"):
|
|
54
|
+
x = x + np.uint64(0x9E3779B97F4A7C15)
|
|
55
|
+
z = x
|
|
56
|
+
z = (z ^ (z >> np.uint64(30))) * np.uint64(0xBF58476D1CE4E5B9)
|
|
57
|
+
z = (z ^ (z >> np.uint64(27))) * np.uint64(0x94D049BB133111EB)
|
|
58
|
+
z = z ^ (z >> np.uint64(31))
|
|
59
|
+
return z
|
|
60
|
+
|
|
61
|
+
seeds_u64 = function_seeds.astype(np.uint64, copy=False)
|
|
62
|
+
unique_u64 = unique_indices.astype(np.uint64, copy=False)
|
|
63
|
+
metric_u64 = np.arange(num_metrics, dtype=np.uint64)
|
|
64
|
+
normal_vals = np.empty((num_seeds, unique_indices.size, num_metrics), dtype=float)
|
|
65
|
+
p = np.uint64(1_000_003)
|
|
66
|
+
inv_2p53 = 1.0 / 9007199254740992.0
|
|
67
|
+
for si, s in enumerate(seeds_u64):
|
|
68
|
+
with np.errstate(over="ignore"):
|
|
69
|
+
base = (s * p + unique_u64) * p
|
|
70
|
+
combined = base[:, None] + metric_u64[None, :]
|
|
71
|
+
r1 = _splitmix64(combined)
|
|
72
|
+
r2 = _splitmix64(combined ^ np.uint64(0xD2B74407B1CE6E93))
|
|
73
|
+
u1 = (r1 >> np.uint64(11)).astype(np.float64) * inv_2p53
|
|
74
|
+
u2 = (r2 >> np.uint64(11)).astype(np.float64) * inv_2p53
|
|
75
|
+
u1 = np.clip(u1, 1e-12, 1.0 - 1e-12)
|
|
76
|
+
normal_vals[si, :, :] = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2)
|
|
77
|
+
return normal_vals[:, inverse.ravel(), :].reshape(
|
|
78
|
+
num_seeds, *data_indices.shape, num_metrics
|
|
79
|
+
)
|
enn/enn/enn_index.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ENNIndex:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
train_x_scaled: np.ndarray,
|
|
12
|
+
num_dim: int,
|
|
13
|
+
x_scale: np.ndarray,
|
|
14
|
+
scale_x: bool,
|
|
15
|
+
driver: Any = None,
|
|
16
|
+
) -> None:
|
|
17
|
+
from enn.turbo.config.enums import ENNIndexDriver
|
|
18
|
+
|
|
19
|
+
if driver is None:
|
|
20
|
+
driver = ENNIndexDriver.FLAT
|
|
21
|
+
self._train_x_scaled = train_x_scaled
|
|
22
|
+
self._num_dim = num_dim
|
|
23
|
+
self._x_scale = x_scale
|
|
24
|
+
self._scale_x = scale_x
|
|
25
|
+
self._driver = driver
|
|
26
|
+
self._index: Any | None = None
|
|
27
|
+
self._build_index()
|
|
28
|
+
|
|
29
|
+
def _build_index(self) -> None:
|
|
30
|
+
import faiss
|
|
31
|
+
import numpy as np
|
|
32
|
+
from enn.turbo.config.enums import ENNIndexDriver
|
|
33
|
+
|
|
34
|
+
if len(self._train_x_scaled) == 0:
|
|
35
|
+
return
|
|
36
|
+
x_f32 = self._train_x_scaled.astype(np.float32, copy=False)
|
|
37
|
+
if self._driver == ENNIndexDriver.FLAT:
|
|
38
|
+
index = faiss.IndexFlatL2(self._num_dim)
|
|
39
|
+
elif self._driver == ENNIndexDriver.HNSW:
|
|
40
|
+
# TODO: Make M configurable
|
|
41
|
+
index = faiss.IndexHNSWFlat(self._num_dim, 32)
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError(f"Unknown driver: {self._driver}")
|
|
44
|
+
index.add(x_f32)
|
|
45
|
+
self._index = index
|
|
46
|
+
|
|
47
|
+
def add(self, x: np.ndarray) -> None:
|
|
48
|
+
import numpy as np
|
|
49
|
+
from enn.turbo.config.enums import ENNIndexDriver
|
|
50
|
+
|
|
51
|
+
x = np.asarray(x, dtype=float)
|
|
52
|
+
if x.ndim != 2 or x.shape[1] != self._num_dim:
|
|
53
|
+
raise ValueError(x.shape)
|
|
54
|
+
x_scaled = x / self._x_scale if self._scale_x else x
|
|
55
|
+
x_f32 = x_scaled.astype(np.float32, copy=False)
|
|
56
|
+
if self._index is None:
|
|
57
|
+
import faiss
|
|
58
|
+
|
|
59
|
+
if self._driver == ENNIndexDriver.FLAT:
|
|
60
|
+
self._index = faiss.IndexFlatL2(self._num_dim)
|
|
61
|
+
elif self._driver == ENNIndexDriver.HNSW:
|
|
62
|
+
self._index = faiss.IndexHNSWFlat(self._num_dim, 32)
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(f"Unknown driver: {self._driver}")
|
|
65
|
+
self._index.add(x_f32)
|
|
66
|
+
|
|
67
|
+
def search(
|
|
68
|
+
self,
|
|
69
|
+
x: np.ndarray,
|
|
70
|
+
*,
|
|
71
|
+
search_k: int,
|
|
72
|
+
exclude_nearest: bool,
|
|
73
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
74
|
+
import numpy as np
|
|
75
|
+
|
|
76
|
+
search_k = int(search_k)
|
|
77
|
+
if search_k <= 0:
|
|
78
|
+
raise ValueError(search_k)
|
|
79
|
+
x = np.asarray(x, dtype=float)
|
|
80
|
+
if x.ndim != 2 or x.shape[1] != self._num_dim:
|
|
81
|
+
raise ValueError(x.shape)
|
|
82
|
+
if self._index is None:
|
|
83
|
+
raise RuntimeError("index is not initialized")
|
|
84
|
+
x_scaled = x / self._x_scale if self._scale_x else x
|
|
85
|
+
x_f32 = x_scaled.astype(np.float32, copy=False)
|
|
86
|
+
dist2s_full, idx_full = self._index.search(x_f32, search_k)
|
|
87
|
+
dist2s_full = dist2s_full.astype(float)
|
|
88
|
+
idx_full = idx_full.astype(int)
|
|
89
|
+
if exclude_nearest:
|
|
90
|
+
dist2s_full = dist2s_full[:, 1:]
|
|
91
|
+
idx_full = idx_full[:, 1:]
|
|
92
|
+
return dist2s_full, idx_full
|