ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -269
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +69 -70
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +0 -1
  18. enn/enn/enn_params.py +3 -22
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +60 -46
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +131 -70
  84. enn/turbo/no_trust_region.py +32 -39
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +36 -38
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +0 -1
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +126 -58
  100. enn/turbo/turbo_utils.py +98 -161
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -144
  113. enn/turbo/lhd_only_impl.py +0 -49
  114. enn/turbo/turbo_config.py +0 -72
  115. enn/turbo/turbo_enn_impl.py +0 -201
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -76
  118. enn/turbo/turbo_one_impl.py +0 -302
  119. enn/turbo/turbo_optimizer.py +0 -525
  120. enn/turbo/turbo_zero_impl.py +0 -29
  121. ennbo-0.1.2.dist-info/RECORD +0 -29
  122. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,325 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ from .candidates import Candidates
4
+ from .conditional_posterior_draw_internals import ConditionalPosteriorDrawInternals
5
+ from .enn_like_protocol import ENNLike
6
+ from .enn_params import ENNParams, PosteriorFlags
7
+ from .neighbors import Neighbors
8
+
9
+ _ENNLike = ENNLike
10
+ _Candidates = Candidates
11
+ _Neighbors = Neighbors
12
+
13
+
14
+ def _pairwise_sq_l2(a: np.ndarray, b: np.ndarray) -> np.ndarray:
15
+ a = np.asarray(a, dtype=float)
16
+ b = np.asarray(b, dtype=float)
17
+ aa = np.sum(a * a, axis=1, keepdims=True)
18
+ bb = np.sum(b * b, axis=1, keepdims=True).T
19
+ dist2 = aa + bb - 2.0 * (a @ b.T)
20
+ return np.maximum(dist2, 0.0)
21
+
22
+
23
+ def _validate_x(enn: ENNLike, x: np.ndarray) -> np.ndarray:
24
+ x = np.asarray(x, dtype=float)
25
+ if x.ndim != 2 or x.shape[1] != enn._num_dim:
26
+ raise ValueError(x.shape)
27
+ return x
28
+
29
+
30
+ def _validate_whatif(
31
+ enn: ENNLike, x_whatif: np.ndarray, y_whatif: np.ndarray
32
+ ) -> tuple[np.ndarray, np.ndarray]:
33
+ x_whatif = np.asarray(x_whatif, dtype=float)
34
+ y_whatif = np.asarray(y_whatif, dtype=float)
35
+ if x_whatif.ndim != 2 or x_whatif.shape[1] != enn._num_dim:
36
+ raise ValueError(x_whatif.shape)
37
+ if y_whatif.ndim != 2 or y_whatif.shape[1] != enn._num_metrics:
38
+ raise ValueError(y_whatif.shape)
39
+ if x_whatif.shape[0] != y_whatif.shape[0]:
40
+ raise ValueError((x_whatif.shape, y_whatif.shape))
41
+ return x_whatif, y_whatif
42
+
43
+
44
+ def _scale_x_if_needed(enn: ENNLike, x: np.ndarray) -> np.ndarray:
45
+ return x / enn._x_scale if enn._scale_x else x
46
+
47
+
48
+ def _compute_total_n(enn: ENNLike, num_whatif: int, flags: PosteriorFlags) -> int:
49
+ total_n = len(enn) + int(num_whatif)
50
+ if flags.exclude_nearest and total_n <= 1:
51
+ raise ValueError(total_n)
52
+ return total_n
53
+
54
+
55
+ def _compute_search_k(params: ENNParams, flags: PosteriorFlags, total_n: int) -> int:
56
+ return int(
57
+ min(params.k_num_neighbors + (1 if flags.exclude_nearest else 0), total_n)
58
+ )
59
+
60
+
61
+ def _get_train_candidates(enn: ENNLike, x: np.ndarray, *, search_k: int) -> Candidates:
62
+ batch_size = x.shape[0]
63
+ if len(enn) == 0 or search_k == 0:
64
+ return Candidates(
65
+ dist2=np.zeros((batch_size, 0), dtype=float),
66
+ ids=np.zeros((batch_size, 0), dtype=int),
67
+ y=np.zeros((batch_size, 0, enn._num_metrics), dtype=float),
68
+ yvar=(
69
+ np.zeros((batch_size, 0, enn._num_metrics), dtype=float)
70
+ if enn._train_yvar is not None
71
+ else None
72
+ ),
73
+ )
74
+ train_search_k = int(min(search_k, len(enn)))
75
+ dist2_train, idx_train = enn._enn_index.search(
76
+ x, search_k=train_search_k, exclude_nearest=False
77
+ )
78
+ y_train = enn._train_y[idx_train]
79
+ yvar_train = enn._train_yvar[idx_train] if enn._train_yvar is not None else None
80
+ return Candidates(dist2=dist2_train, ids=idx_train, y=y_train, yvar=yvar_train)
81
+
82
+
83
+ def _get_whatif_candidates(
84
+ enn: ENNLike,
85
+ x: np.ndarray,
86
+ x_whatif: np.ndarray,
87
+ y_whatif: np.ndarray,
88
+ ) -> tuple[np.ndarray, np.ndarray]:
89
+ x_scaled = _scale_x_if_needed(enn, x)
90
+ x_whatif_scaled = _scale_x_if_needed(enn, x_whatif)
91
+ dist2_whatif = _pairwise_sq_l2(x_scaled, x_whatif_scaled)
92
+ batch_size = x.shape[0]
93
+ y_whatif_batched = np.broadcast_to(
94
+ y_whatif[np.newaxis, :, :], (batch_size, y_whatif.shape[0], y_whatif.shape[1])
95
+ )
96
+ return dist2_whatif, y_whatif_batched
97
+
98
+
99
+ _WhatifCandidateTuple = tuple[np.ndarray, np.ndarray, np.ndarray]
100
+
101
+
102
+ def _merge_candidates(
103
+ enn: ENNLike,
104
+ *,
105
+ train: Candidates,
106
+ whatif: _WhatifCandidateTuple,
107
+ ) -> Candidates:
108
+ dist2_whatif, ids_whatif, y_whatif_batched = whatif
109
+ dist2_all = np.concatenate([train.dist2, dist2_whatif], axis=1)
110
+ ids_all = np.concatenate([train.ids, ids_whatif], axis=1)
111
+ y_all = np.concatenate([train.y, y_whatif_batched], axis=1)
112
+ if train.yvar is None:
113
+ return Candidates(dist2=dist2_all, ids=ids_all, y=y_all, yvar=None)
114
+ batch_size = dist2_all.shape[0]
115
+ num_whatif = dist2_whatif.shape[1]
116
+ yvar_whatif = np.zeros((batch_size, num_whatif, enn._num_metrics))
117
+ yvar_all = np.concatenate([train.yvar, yvar_whatif], axis=1)
118
+ return Candidates(dist2=dist2_all, ids=ids_all, y=y_all, yvar=yvar_all)
119
+
120
+
121
+ def _select_sorted_candidates(dist2_all: np.ndarray, *, search_k: int) -> np.ndarray:
122
+ batch_size, num_candidates = dist2_all.shape
123
+ if search_k < num_candidates:
124
+ sel = np.argpartition(dist2_all, kth=search_k - 1, axis=1)[:, :search_k]
125
+ else:
126
+ sel = np.broadcast_to(np.arange(num_candidates), (batch_size, num_candidates))
127
+ sel_dist2 = np.take_along_axis(dist2_all, sel, axis=1)
128
+ sel_order = np.argsort(sel_dist2, axis=1)
129
+ return np.take_along_axis(sel, sel_order, axis=1)
130
+
131
+
132
+ def _take_along_axis_3d(a: np.ndarray, idx_2d: np.ndarray) -> np.ndarray:
133
+ return np.take_along_axis(a, idx_2d[:, :, np.newaxis], axis=1)
134
+
135
+
136
+ def _make_empty_normal(enn: ENNLike, batch_size: int):
137
+ from .enn_normal import ENNNormal
138
+
139
+ internals = enn._empty_posterior_internals(batch_size)
140
+ return ENNNormal(internals.mu, internals.se)
141
+
142
+
143
+ def _build_candidates(
144
+ enn: ENNLike,
145
+ x: np.ndarray,
146
+ x_whatif: np.ndarray,
147
+ y_whatif: np.ndarray,
148
+ *,
149
+ search_k: int,
150
+ ) -> Candidates:
151
+ train_candidates = _get_train_candidates(enn, x, search_k=search_k)
152
+ dist2_whatif, y_whatif_batched = _get_whatif_candidates(enn, x, x_whatif, y_whatif)
153
+ n_train = int(len(enn))
154
+ ids_whatif = np.broadcast_to(
155
+ n_train + np.arange(x_whatif.shape[0], dtype=int), dist2_whatif.shape
156
+ )
157
+ return _merge_candidates(
158
+ enn,
159
+ train=train_candidates,
160
+ whatif=(dist2_whatif, ids_whatif, y_whatif_batched),
161
+ )
162
+
163
+
164
+ def _select_effective_neighbors(
165
+ candidates: Candidates,
166
+ *,
167
+ search_k: int,
168
+ k: int,
169
+ exclude_nearest: bool,
170
+ ) -> Neighbors | None:
171
+ sel = _select_sorted_candidates(candidates.dist2, search_k=search_k)
172
+ if exclude_nearest:
173
+ sel = sel[:, 1:]
174
+ sel = sel[:, : int(min(k, sel.shape[1]))]
175
+ if sel.shape[1] == 0:
176
+ return None
177
+ dist2s = np.take_along_axis(candidates.dist2, sel, axis=1)
178
+ ids = np.take_along_axis(candidates.ids, sel, axis=1)
179
+ y_neighbors = _take_along_axis_3d(candidates.y, sel)
180
+ yvar_neighbors = (
181
+ _take_along_axis_3d(candidates.yvar, sel)
182
+ if candidates.yvar is not None
183
+ else None
184
+ )
185
+ return Neighbors(dist2=dist2s, ids=ids, y=y_neighbors, yvar=yvar_neighbors)
186
+
187
+
188
+ def _compute_mu_se(
189
+ enn: ENNLike,
190
+ neighbors: Neighbors,
191
+ *,
192
+ params: ENNParams,
193
+ flags: PosteriorFlags,
194
+ y_scale: np.ndarray,
195
+ ) -> tuple[np.ndarray, np.ndarray]:
196
+ stats = enn._compute_weighted_stats(
197
+ neighbors.dist2,
198
+ neighbors.y,
199
+ yvar_neighbors=neighbors.yvar,
200
+ params=params,
201
+ observation_noise=flags.observation_noise,
202
+ y_scale=y_scale,
203
+ )
204
+ return stats.mu, stats.se
205
+
206
+
207
+ def _compute_draw_internals(
208
+ enn: ENNLike,
209
+ neighbors: Neighbors,
210
+ *,
211
+ params: ENNParams,
212
+ flags: PosteriorFlags,
213
+ y_scale: np.ndarray,
214
+ ) -> ConditionalPosteriorDrawInternals:
215
+ stats = enn._compute_weighted_stats(
216
+ neighbors.dist2,
217
+ neighbors.y,
218
+ yvar_neighbors=neighbors.yvar,
219
+ params=params,
220
+ observation_noise=flags.observation_noise,
221
+ y_scale=y_scale,
222
+ )
223
+ return ConditionalPosteriorDrawInternals(
224
+ idx=neighbors.ids.astype(int, copy=False),
225
+ w_normalized=stats.w_normalized,
226
+ l2=stats.l2,
227
+ mu=stats.mu,
228
+ se=stats.se,
229
+ )
230
+
231
+
232
+ def _conditional_neighbors_nonempty_whatif(
233
+ enn: ENNLike,
234
+ x_whatif: np.ndarray,
235
+ y_whatif: np.ndarray,
236
+ x: np.ndarray,
237
+ *,
238
+ params: ENNParams,
239
+ flags: PosteriorFlags,
240
+ ) -> tuple[int, int, Neighbors | None]:
241
+ batch_size = x.shape[0]
242
+ search_k = _compute_search_k(
243
+ params, flags, _compute_total_n(enn, x_whatif.shape[0], flags)
244
+ )
245
+ if search_k == 0:
246
+ return batch_size, search_k, None
247
+ candidates = _build_candidates(enn, x, x_whatif, y_whatif, search_k=search_k)
248
+ neighbors = _select_effective_neighbors(
249
+ candidates,
250
+ search_k=search_k,
251
+ k=params.k_num_neighbors,
252
+ exclude_nearest=flags.exclude_nearest,
253
+ )
254
+ return batch_size, search_k, neighbors
255
+
256
+
257
+ def _compute_conditional_posterior_impl(
258
+ enn: ENNLike,
259
+ x_whatif: np.ndarray,
260
+ y_whatif: np.ndarray,
261
+ x: np.ndarray,
262
+ *,
263
+ params: ENNParams,
264
+ flags: PosteriorFlags,
265
+ y_scale: np.ndarray,
266
+ ):
267
+ from .enn_normal import ENNNormal
268
+
269
+ x = _validate_x(enn, x)
270
+ x_whatif, y_whatif = _validate_whatif(enn, x_whatif, y_whatif)
271
+ if x_whatif.shape[0] == 0:
272
+ return enn.posterior(x, params=params, flags=flags)
273
+ batch_size, search_k, neighbors = _conditional_neighbors_nonempty_whatif(
274
+ enn, x_whatif, y_whatif, x, params=params, flags=flags
275
+ )
276
+ if search_k == 0 or neighbors is None:
277
+ return _make_empty_normal(enn, batch_size)
278
+ mu, se = _compute_mu_se(enn, neighbors, params=params, flags=flags, y_scale=y_scale)
279
+ return ENNNormal(mu, se)
280
+
281
+
282
+ def compute_conditional_posterior(
283
+ enn: ENNLike,
284
+ x_whatif: np.ndarray,
285
+ y_whatif: np.ndarray,
286
+ x: np.ndarray,
287
+ *,
288
+ params: ENNParams,
289
+ flags: PosteriorFlags,
290
+ y_scale: np.ndarray,
291
+ ):
292
+ return _compute_conditional_posterior_impl(
293
+ enn, x_whatif, y_whatif, x, params=params, flags=flags, y_scale=y_scale
294
+ )
295
+
296
+
297
+ def compute_conditional_posterior_draw_internals(
298
+ enn: ENNLike,
299
+ x_whatif: np.ndarray,
300
+ y_whatif: np.ndarray,
301
+ x: np.ndarray,
302
+ *,
303
+ params: ENNParams,
304
+ flags: PosteriorFlags,
305
+ y_scale: np.ndarray,
306
+ ) -> ConditionalPosteriorDrawInternals:
307
+ x = _validate_x(enn, x)
308
+ x_whatif, y_whatif = _validate_whatif(enn, x_whatif, y_whatif)
309
+ if x_whatif.shape[0] == 0:
310
+ raise ValueError("x_whatif must be non-empty for conditional draw internals")
311
+ batch_size, search_k, neighbors = _conditional_neighbors_nonempty_whatif(
312
+ enn, x_whatif, y_whatif, x, params=params, flags=flags
313
+ )
314
+ if search_k == 0 or neighbors is None:
315
+ empty_internals = enn._empty_posterior_internals(batch_size)
316
+ return ConditionalPosteriorDrawInternals(
317
+ idx=empty_internals.idx,
318
+ w_normalized=empty_internals.w_normalized,
319
+ l2=empty_internals.l2,
320
+ mu=empty_internals.mu,
321
+ se=empty_internals.se,
322
+ )
323
+ return _compute_draw_internals(
324
+ enn, neighbors, params=params, flags=flags, y_scale=y_scale
325
+ )
enn/enn/enn_fit.py CHANGED
@@ -1,24 +1,16 @@
1
1
  from __future__ import annotations
2
-
3
2
  from typing import TYPE_CHECKING, Any
4
3
 
5
4
  if TYPE_CHECKING:
6
5
  import numpy as np
7
6
  from numpy.random import Generator
8
-
9
- from .enn import EpistemicNearestNeighbors
7
+ from .enn_class import EpistemicNearestNeighbors
10
8
  from .enn_params import ENNParams
11
9
 
12
10
 
13
- def subsample_loglik(
14
- model: EpistemicNearestNeighbors | Any,
15
- x: np.ndarray | Any,
16
- y: np.ndarray | Any,
17
- *,
18
- paramss: list[ENNParams] | list[Any],
19
- P: int = 10,
20
- rng: Generator | Any,
21
- ) -> list[float]:
11
+ def _validate_subsample_inputs(
12
+ x: np.ndarray | Any, y: np.ndarray | Any, P: int, paramss: list
13
+ ) -> tuple[np.ndarray, np.ndarray]:
22
14
  import numpy as np
23
15
 
24
16
  x_array = np.asarray(x, dtype=float)
@@ -35,64 +27,67 @@ def subsample_loglik(
35
27
  raise ValueError(P)
36
28
  if len(paramss) == 0:
37
29
  raise ValueError("paramss must be non-empty")
30
+ return x_array, y_array
31
+
32
+
33
+ def _compute_single_loglik(
34
+ y_scaled: np.ndarray, mu_i: np.ndarray, se_i: np.ndarray
35
+ ) -> float:
36
+ import numpy as np
37
+
38
+ if not np.isfinite(mu_i).all() or not np.isfinite(se_i).all():
39
+ return 0.0
40
+ if np.any(se_i <= 0.0):
41
+ return 0.0
42
+ var_scaled = se_i**2
43
+ loglik = -0.5 * np.sum(
44
+ np.log(2.0 * np.pi * var_scaled) + (y_scaled - mu_i) ** 2 / var_scaled
45
+ )
46
+ return float(loglik) if np.isfinite(loglik) else 0.0
47
+
48
+
49
+ def subsample_loglik(
50
+ model: EpistemicNearestNeighbors | Any,
51
+ x: np.ndarray | Any,
52
+ y: np.ndarray | Any,
53
+ *,
54
+ paramss: list[ENNParams] | list[Any],
55
+ P: int = 10,
56
+ rng: Generator | Any,
57
+ ) -> list[float]:
58
+ import numpy as np
59
+
60
+ x_array, y_array = _validate_subsample_inputs(x, y, P, paramss)
38
61
  n = x_array.shape[0]
39
- if n == 0:
40
- return [0.0] * len(paramss)
41
- if len(model) <= 1:
62
+ if n == 0 or len(model) <= 1:
42
63
  return [0.0] * len(paramss)
43
64
  P_actual = min(P, n)
44
- if P_actual == n:
45
- indices = np.arange(n, dtype=int)
46
- else:
47
- indices = rng.permutation(n)[:P_actual]
48
- x_selected = x_array[indices]
49
- y_selected = y_array[indices]
50
- if not np.isfinite(y_selected).all():
65
+ indices = (
66
+ np.arange(n, dtype=int) if P_actual == n else rng.permutation(n)[:P_actual]
67
+ )
68
+ x_sel, y_sel = x_array[indices], y_array[indices]
69
+ if not np.isfinite(y_sel).all():
51
70
  return [0.0] * len(paramss)
52
- post_batch = model.batch_posterior(
53
- x_selected, paramss, exclude_nearest=True, observation_noise=True
71
+ from .enn_params import PosteriorFlags
72
+
73
+ post = model.batch_posterior(
74
+ x_sel,
75
+ paramss,
76
+ flags=PosteriorFlags(exclude_nearest=True, observation_noise=True),
54
77
  )
55
- mu_batch = post_batch.mu
56
- se_batch = post_batch.se
57
- num_params = len(paramss)
58
- num_outputs = y_selected.shape[1]
59
- if mu_batch.shape != (num_params, P_actual, num_outputs) or se_batch.shape != (
60
- num_params,
61
- P_actual,
62
- num_outputs,
63
- ):
64
- raise ValueError(
65
- (
66
- mu_batch.shape,
67
- se_batch.shape,
68
- (num_params, P_actual, num_outputs),
69
- )
70
- )
78
+ num_params, num_outputs = len(paramss), y_sel.shape[1]
79
+ expected_shape = (num_params, P_actual, num_outputs)
80
+ if post.mu.shape != expected_shape or post.se.shape != expected_shape:
81
+ raise ValueError((post.mu.shape, post.se.shape, expected_shape))
71
82
  y_std = np.std(y_array, axis=0, keepdims=True).astype(float)
72
83
  y_std = np.where(np.isfinite(y_std) & (y_std > 0.0), y_std, 1.0)
73
- y_scaled = y_selected / y_std
74
- mu_scaled = mu_batch / y_std
75
- se_scaled = se_batch / y_std
76
- result = []
77
- for i in range(num_params):
78
- mu_i = mu_scaled[i]
79
- se_i = se_scaled[i]
80
- if not np.isfinite(mu_i).all() or not np.isfinite(se_i).all():
81
- result.append(0.0)
82
- continue
83
- if np.any(se_i <= 0.0):
84
- result.append(0.0)
85
- continue
86
- diff = y_scaled - mu_i
87
- var_scaled = se_i**2
88
- log_term = np.log(2.0 * np.pi * var_scaled)
89
- quad = diff**2 / var_scaled
90
- loglik = -0.5 * np.sum(log_term + quad)
91
- if not np.isfinite(loglik):
92
- result.append(0.0)
93
- continue
94
- result.append(float(loglik))
95
- return result
84
+ y_scaled = y_sel / y_std
85
+ mu_scaled = post.mu / y_std
86
+ se_scaled = post.se / y_std
87
+ return [
88
+ _compute_single_loglik(y_scaled, mu_scaled[i], se_scaled[i])
89
+ for i in range(num_params)
90
+ ]
96
91
 
97
92
 
98
93
  def enn_fit(
@@ -118,22 +113,26 @@ def enn_fit(
118
113
  ale_homoscedastic_values = 10**ale_homoscedastic_log_values
119
114
  paramss = [
120
115
  ENNParams(
121
- k=k,
122
- epi_var_scale=float(epi_val),
123
- ale_homoscedastic_scale=float(ale_val),
116
+ k_num_neighbors=k,
117
+ epistemic_variance_scale=float(epi_val),
118
+ aleatoric_variance_scale=float(ale_val),
124
119
  )
125
120
  for epi_val, ale_val in zip(epi_var_scale_values, ale_homoscedastic_values)
126
121
  ]
127
122
  if params_warm_start is not None:
128
123
  paramss.append(
129
124
  ENNParams(
130
- k=k,
131
- epi_var_scale=params_warm_start.epi_var_scale,
132
- ale_homoscedastic_scale=params_warm_start.ale_homoscedastic_scale,
125
+ k_num_neighbors=k,
126
+ epistemic_variance_scale=params_warm_start.epistemic_variance_scale,
127
+ aleatoric_variance_scale=params_warm_start.aleatoric_variance_scale,
133
128
  )
134
129
  )
135
130
  if len(paramss) == 0:
136
- return ENNParams(k=k, epi_var_scale=1.0, ale_homoscedastic_scale=0.0)
131
+ return ENNParams(
132
+ k_num_neighbors=k,
133
+ epistemic_variance_scale=1.0,
134
+ aleatoric_variance_scale=0.0,
135
+ )
137
136
  import numpy as np
138
137
 
139
138
  logliks = subsample_loglik(
enn/enn/enn_hash.py ADDED
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
+
7
+
8
+ def normal_hash_batch_multi_seed(
9
+ function_seeds: np.ndarray, data_indices: np.ndarray, num_metrics: int
10
+ ) -> np.ndarray:
11
+ import numpy as np
12
+ from scipy.special import ndtri
13
+
14
+ num_seeds = len(function_seeds)
15
+ unique_indices, inverse = np.unique(data_indices, return_inverse=True)
16
+ num_unique = len(unique_indices)
17
+ seed_grid, idx_grid, metric_grid = np.meshgrid(
18
+ function_seeds.astype(np.uint64),
19
+ unique_indices.astype(np.uint64),
20
+ np.arange(num_metrics, dtype=np.uint64),
21
+ indexing="ij",
22
+ )
23
+ seed_flat = seed_grid.ravel()
24
+ idx_flat = idx_grid.ravel()
25
+ metric_flat = metric_grid.ravel()
26
+ combined_seeds = (seed_flat * np.uint64(1_000_003) + idx_flat) * np.uint64(
27
+ 1_000_003
28
+ ) + metric_flat
29
+ uniform_vals = np.empty(len(combined_seeds), dtype=float)
30
+ for i, seed in enumerate(combined_seeds):
31
+ rng = np.random.Generator(np.random.Philox(int(seed)))
32
+ uniform_vals[i] = rng.random()
33
+ uniform_vals = np.clip(uniform_vals, 1e-10, 1.0 - 1e-10)
34
+ normal_vals = ndtri(uniform_vals).reshape(num_seeds, num_unique, num_metrics)
35
+ return normal_vals[:, inverse.ravel(), :].reshape(
36
+ num_seeds, *data_indices.shape, num_metrics
37
+ )
38
+
39
+
40
+ def normal_hash_batch_multi_seed_fast(
41
+ function_seeds: np.ndarray, data_indices: np.ndarray, num_metrics: int
42
+ ) -> np.ndarray:
43
+ import numpy as np
44
+
45
+ function_seeds = np.asarray(function_seeds, dtype=np.int64)
46
+ data_indices = np.asarray(data_indices)
47
+ if num_metrics <= 0:
48
+ raise ValueError(num_metrics)
49
+ num_seeds = len(function_seeds)
50
+ unique_indices, inverse = np.unique(data_indices, return_inverse=True)
51
+
52
+ def _splitmix64(x: np.ndarray) -> np.ndarray:
53
+ with np.errstate(over="ignore"):
54
+ x = x + np.uint64(0x9E3779B97F4A7C15)
55
+ z = x
56
+ z = (z ^ (z >> np.uint64(30))) * np.uint64(0xBF58476D1CE4E5B9)
57
+ z = (z ^ (z >> np.uint64(27))) * np.uint64(0x94D049BB133111EB)
58
+ z = z ^ (z >> np.uint64(31))
59
+ return z
60
+
61
+ seeds_u64 = function_seeds.astype(np.uint64, copy=False)
62
+ unique_u64 = unique_indices.astype(np.uint64, copy=False)
63
+ metric_u64 = np.arange(num_metrics, dtype=np.uint64)
64
+ normal_vals = np.empty((num_seeds, unique_indices.size, num_metrics), dtype=float)
65
+ p = np.uint64(1_000_003)
66
+ inv_2p53 = 1.0 / 9007199254740992.0
67
+ for si, s in enumerate(seeds_u64):
68
+ with np.errstate(over="ignore"):
69
+ base = (s * p + unique_u64) * p
70
+ combined = base[:, None] + metric_u64[None, :]
71
+ r1 = _splitmix64(combined)
72
+ r2 = _splitmix64(combined ^ np.uint64(0xD2B74407B1CE6E93))
73
+ u1 = (r1 >> np.uint64(11)).astype(np.float64) * inv_2p53
74
+ u2 = (r2 >> np.uint64(11)).astype(np.float64) * inv_2p53
75
+ u1 = np.clip(u1, 1e-12, 1.0 - 1e-12)
76
+ normal_vals[si, :, :] = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2)
77
+ return normal_vals[:, inverse.ravel(), :].reshape(
78
+ num_seeds, *data_indices.shape, num_metrics
79
+ )
enn/enn/enn_index.py ADDED
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
+
7
+
8
+ class ENNIndex:
9
+ def __init__(
10
+ self,
11
+ train_x_scaled: np.ndarray,
12
+ num_dim: int,
13
+ x_scale: np.ndarray,
14
+ scale_x: bool,
15
+ driver: Any = None,
16
+ ) -> None:
17
+ from enn.turbo.config.enums import ENNIndexDriver
18
+
19
+ if driver is None:
20
+ driver = ENNIndexDriver.FLAT
21
+ self._train_x_scaled = train_x_scaled
22
+ self._num_dim = num_dim
23
+ self._x_scale = x_scale
24
+ self._scale_x = scale_x
25
+ self._driver = driver
26
+ self._index: Any | None = None
27
+ self._build_index()
28
+
29
+ def _build_index(self) -> None:
30
+ import faiss
31
+ import numpy as np
32
+ from enn.turbo.config.enums import ENNIndexDriver
33
+
34
+ if len(self._train_x_scaled) == 0:
35
+ return
36
+ x_f32 = self._train_x_scaled.astype(np.float32, copy=False)
37
+ if self._driver == ENNIndexDriver.FLAT:
38
+ index = faiss.IndexFlatL2(self._num_dim)
39
+ elif self._driver == ENNIndexDriver.HNSW:
40
+ # TODO: Make M configurable
41
+ index = faiss.IndexHNSWFlat(self._num_dim, 32)
42
+ else:
43
+ raise ValueError(f"Unknown driver: {self._driver}")
44
+ index.add(x_f32)
45
+ self._index = index
46
+
47
+ def add(self, x: np.ndarray) -> None:
48
+ import numpy as np
49
+ from enn.turbo.config.enums import ENNIndexDriver
50
+
51
+ x = np.asarray(x, dtype=float)
52
+ if x.ndim != 2 or x.shape[1] != self._num_dim:
53
+ raise ValueError(x.shape)
54
+ x_scaled = x / self._x_scale if self._scale_x else x
55
+ x_f32 = x_scaled.astype(np.float32, copy=False)
56
+ if self._index is None:
57
+ import faiss
58
+
59
+ if self._driver == ENNIndexDriver.FLAT:
60
+ self._index = faiss.IndexFlatL2(self._num_dim)
61
+ elif self._driver == ENNIndexDriver.HNSW:
62
+ self._index = faiss.IndexHNSWFlat(self._num_dim, 32)
63
+ else:
64
+ raise ValueError(f"Unknown driver: {self._driver}")
65
+ self._index.add(x_f32)
66
+
67
+ def search(
68
+ self,
69
+ x: np.ndarray,
70
+ *,
71
+ search_k: int,
72
+ exclude_nearest: bool,
73
+ ) -> tuple[np.ndarray, np.ndarray]:
74
+ import numpy as np
75
+
76
+ search_k = int(search_k)
77
+ if search_k <= 0:
78
+ raise ValueError(search_k)
79
+ x = np.asarray(x, dtype=float)
80
+ if x.ndim != 2 or x.shape[1] != self._num_dim:
81
+ raise ValueError(x.shape)
82
+ if self._index is None:
83
+ raise RuntimeError("index is not initialized")
84
+ x_scaled = x / self._x_scale if self._scale_x else x
85
+ x_f32 = x_scaled.astype(np.float32, copy=False)
86
+ dist2s_full, idx_full = self._index.search(x_f32, search_k)
87
+ dist2s_full = dist2s_full.astype(float)
88
+ idx_full = idx_full.astype(int)
89
+ if exclude_nearest:
90
+ dist2s_full = dist2s_full[:, 1:]
91
+ idx_full = idx_full[:, 1:]
92
+ return dist2s_full, idx_full