ennbo 0.1.0__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -229
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +77 -76
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +3 -3
  18. enn/enn/enn_params.py +3 -9
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +79 -37
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +250 -0
  84. enn/turbo/no_trust_region.py +58 -0
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +46 -39
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +9 -2
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +129 -63
  100. enn/turbo/turbo_utils.py +144 -117
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/METADATA +22 -14
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -98
  113. enn/turbo/lhd_only_impl.py +0 -42
  114. enn/turbo/turbo_config.py +0 -28
  115. enn/turbo/turbo_enn_impl.py +0 -176
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -67
  118. enn/turbo/turbo_one_impl.py +0 -163
  119. enn/turbo/turbo_optimizer.py +0 -337
  120. enn/turbo/turbo_zero_impl.py +0 -24
  121. ennbo-0.1.0.dist-info/RECORD +0 -27
  122. {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.0.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,325 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ from .candidates import Candidates
4
+ from .conditional_posterior_draw_internals import ConditionalPosteriorDrawInternals
5
+ from .enn_like_protocol import ENNLike
6
+ from .enn_params import ENNParams, PosteriorFlags
7
+ from .neighbors import Neighbors
8
+
9
+ _ENNLike = ENNLike
10
+ _Candidates = Candidates
11
+ _Neighbors = Neighbors
12
+
13
+
14
+ def _pairwise_sq_l2(a: np.ndarray, b: np.ndarray) -> np.ndarray:
15
+ a = np.asarray(a, dtype=float)
16
+ b = np.asarray(b, dtype=float)
17
+ aa = np.sum(a * a, axis=1, keepdims=True)
18
+ bb = np.sum(b * b, axis=1, keepdims=True).T
19
+ dist2 = aa + bb - 2.0 * (a @ b.T)
20
+ return np.maximum(dist2, 0.0)
21
+
22
+
23
+ def _validate_x(enn: ENNLike, x: np.ndarray) -> np.ndarray:
24
+ x = np.asarray(x, dtype=float)
25
+ if x.ndim != 2 or x.shape[1] != enn._num_dim:
26
+ raise ValueError(x.shape)
27
+ return x
28
+
29
+
30
+ def _validate_whatif(
31
+ enn: ENNLike, x_whatif: np.ndarray, y_whatif: np.ndarray
32
+ ) -> tuple[np.ndarray, np.ndarray]:
33
+ x_whatif = np.asarray(x_whatif, dtype=float)
34
+ y_whatif = np.asarray(y_whatif, dtype=float)
35
+ if x_whatif.ndim != 2 or x_whatif.shape[1] != enn._num_dim:
36
+ raise ValueError(x_whatif.shape)
37
+ if y_whatif.ndim != 2 or y_whatif.shape[1] != enn._num_metrics:
38
+ raise ValueError(y_whatif.shape)
39
+ if x_whatif.shape[0] != y_whatif.shape[0]:
40
+ raise ValueError((x_whatif.shape, y_whatif.shape))
41
+ return x_whatif, y_whatif
42
+
43
+
44
+ def _scale_x_if_needed(enn: ENNLike, x: np.ndarray) -> np.ndarray:
45
+ return x / enn._x_scale if enn._scale_x else x
46
+
47
+
48
+ def _compute_total_n(enn: ENNLike, num_whatif: int, flags: PosteriorFlags) -> int:
49
+ total_n = len(enn) + int(num_whatif)
50
+ if flags.exclude_nearest and total_n <= 1:
51
+ raise ValueError(total_n)
52
+ return total_n
53
+
54
+
55
+ def _compute_search_k(params: ENNParams, flags: PosteriorFlags, total_n: int) -> int:
56
+ return int(
57
+ min(params.k_num_neighbors + (1 if flags.exclude_nearest else 0), total_n)
58
+ )
59
+
60
+
61
+ def _get_train_candidates(enn: ENNLike, x: np.ndarray, *, search_k: int) -> Candidates:
62
+ batch_size = x.shape[0]
63
+ if len(enn) == 0 or search_k == 0:
64
+ return Candidates(
65
+ dist2=np.zeros((batch_size, 0), dtype=float),
66
+ ids=np.zeros((batch_size, 0), dtype=int),
67
+ y=np.zeros((batch_size, 0, enn._num_metrics), dtype=float),
68
+ yvar=(
69
+ np.zeros((batch_size, 0, enn._num_metrics), dtype=float)
70
+ if enn._train_yvar is not None
71
+ else None
72
+ ),
73
+ )
74
+ train_search_k = int(min(search_k, len(enn)))
75
+ dist2_train, idx_train = enn._enn_index.search(
76
+ x, search_k=train_search_k, exclude_nearest=False
77
+ )
78
+ y_train = enn._train_y[idx_train]
79
+ yvar_train = enn._train_yvar[idx_train] if enn._train_yvar is not None else None
80
+ return Candidates(dist2=dist2_train, ids=idx_train, y=y_train, yvar=yvar_train)
81
+
82
+
83
+ def _get_whatif_candidates(
84
+ enn: ENNLike,
85
+ x: np.ndarray,
86
+ x_whatif: np.ndarray,
87
+ y_whatif: np.ndarray,
88
+ ) -> tuple[np.ndarray, np.ndarray]:
89
+ x_scaled = _scale_x_if_needed(enn, x)
90
+ x_whatif_scaled = _scale_x_if_needed(enn, x_whatif)
91
+ dist2_whatif = _pairwise_sq_l2(x_scaled, x_whatif_scaled)
92
+ batch_size = x.shape[0]
93
+ y_whatif_batched = np.broadcast_to(
94
+ y_whatif[np.newaxis, :, :], (batch_size, y_whatif.shape[0], y_whatif.shape[1])
95
+ )
96
+ return dist2_whatif, y_whatif_batched
97
+
98
+
99
+ _WhatifCandidateTuple = tuple[np.ndarray, np.ndarray, np.ndarray]
100
+
101
+
102
+ def _merge_candidates(
103
+ enn: ENNLike,
104
+ *,
105
+ train: Candidates,
106
+ whatif: _WhatifCandidateTuple,
107
+ ) -> Candidates:
108
+ dist2_whatif, ids_whatif, y_whatif_batched = whatif
109
+ dist2_all = np.concatenate([train.dist2, dist2_whatif], axis=1)
110
+ ids_all = np.concatenate([train.ids, ids_whatif], axis=1)
111
+ y_all = np.concatenate([train.y, y_whatif_batched], axis=1)
112
+ if train.yvar is None:
113
+ return Candidates(dist2=dist2_all, ids=ids_all, y=y_all, yvar=None)
114
+ batch_size = dist2_all.shape[0]
115
+ num_whatif = dist2_whatif.shape[1]
116
+ yvar_whatif = np.zeros((batch_size, num_whatif, enn._num_metrics))
117
+ yvar_all = np.concatenate([train.yvar, yvar_whatif], axis=1)
118
+ return Candidates(dist2=dist2_all, ids=ids_all, y=y_all, yvar=yvar_all)
119
+
120
+
121
+ def _select_sorted_candidates(dist2_all: np.ndarray, *, search_k: int) -> np.ndarray:
122
+ batch_size, num_candidates = dist2_all.shape
123
+ if search_k < num_candidates:
124
+ sel = np.argpartition(dist2_all, kth=search_k - 1, axis=1)[:, :search_k]
125
+ else:
126
+ sel = np.broadcast_to(np.arange(num_candidates), (batch_size, num_candidates))
127
+ sel_dist2 = np.take_along_axis(dist2_all, sel, axis=1)
128
+ sel_order = np.argsort(sel_dist2, axis=1)
129
+ return np.take_along_axis(sel, sel_order, axis=1)
130
+
131
+
132
+ def _take_along_axis_3d(a: np.ndarray, idx_2d: np.ndarray) -> np.ndarray:
133
+ return np.take_along_axis(a, idx_2d[:, :, np.newaxis], axis=1)
134
+
135
+
136
+ def _make_empty_normal(enn: ENNLike, batch_size: int):
137
+ from .enn_normal import ENNNormal
138
+
139
+ internals = enn._empty_posterior_internals(batch_size)
140
+ return ENNNormal(internals.mu, internals.se)
141
+
142
+
143
+ def _build_candidates(
144
+ enn: ENNLike,
145
+ x: np.ndarray,
146
+ x_whatif: np.ndarray,
147
+ y_whatif: np.ndarray,
148
+ *,
149
+ search_k: int,
150
+ ) -> Candidates:
151
+ train_candidates = _get_train_candidates(enn, x, search_k=search_k)
152
+ dist2_whatif, y_whatif_batched = _get_whatif_candidates(enn, x, x_whatif, y_whatif)
153
+ n_train = int(len(enn))
154
+ ids_whatif = np.broadcast_to(
155
+ n_train + np.arange(x_whatif.shape[0], dtype=int), dist2_whatif.shape
156
+ )
157
+ return _merge_candidates(
158
+ enn,
159
+ train=train_candidates,
160
+ whatif=(dist2_whatif, ids_whatif, y_whatif_batched),
161
+ )
162
+
163
+
164
+ def _select_effective_neighbors(
165
+ candidates: Candidates,
166
+ *,
167
+ search_k: int,
168
+ k: int,
169
+ exclude_nearest: bool,
170
+ ) -> Neighbors | None:
171
+ sel = _select_sorted_candidates(candidates.dist2, search_k=search_k)
172
+ if exclude_nearest:
173
+ sel = sel[:, 1:]
174
+ sel = sel[:, : int(min(k, sel.shape[1]))]
175
+ if sel.shape[1] == 0:
176
+ return None
177
+ dist2s = np.take_along_axis(candidates.dist2, sel, axis=1)
178
+ ids = np.take_along_axis(candidates.ids, sel, axis=1)
179
+ y_neighbors = _take_along_axis_3d(candidates.y, sel)
180
+ yvar_neighbors = (
181
+ _take_along_axis_3d(candidates.yvar, sel)
182
+ if candidates.yvar is not None
183
+ else None
184
+ )
185
+ return Neighbors(dist2=dist2s, ids=ids, y=y_neighbors, yvar=yvar_neighbors)
186
+
187
+
188
+ def _compute_mu_se(
189
+ enn: ENNLike,
190
+ neighbors: Neighbors,
191
+ *,
192
+ params: ENNParams,
193
+ flags: PosteriorFlags,
194
+ y_scale: np.ndarray,
195
+ ) -> tuple[np.ndarray, np.ndarray]:
196
+ stats = enn._compute_weighted_stats(
197
+ neighbors.dist2,
198
+ neighbors.y,
199
+ yvar_neighbors=neighbors.yvar,
200
+ params=params,
201
+ observation_noise=flags.observation_noise,
202
+ y_scale=y_scale,
203
+ )
204
+ return stats.mu, stats.se
205
+
206
+
207
+ def _compute_draw_internals(
208
+ enn: ENNLike,
209
+ neighbors: Neighbors,
210
+ *,
211
+ params: ENNParams,
212
+ flags: PosteriorFlags,
213
+ y_scale: np.ndarray,
214
+ ) -> ConditionalPosteriorDrawInternals:
215
+ stats = enn._compute_weighted_stats(
216
+ neighbors.dist2,
217
+ neighbors.y,
218
+ yvar_neighbors=neighbors.yvar,
219
+ params=params,
220
+ observation_noise=flags.observation_noise,
221
+ y_scale=y_scale,
222
+ )
223
+ return ConditionalPosteriorDrawInternals(
224
+ idx=neighbors.ids.astype(int, copy=False),
225
+ w_normalized=stats.w_normalized,
226
+ l2=stats.l2,
227
+ mu=stats.mu,
228
+ se=stats.se,
229
+ )
230
+
231
+
232
+ def _conditional_neighbors_nonempty_whatif(
233
+ enn: ENNLike,
234
+ x_whatif: np.ndarray,
235
+ y_whatif: np.ndarray,
236
+ x: np.ndarray,
237
+ *,
238
+ params: ENNParams,
239
+ flags: PosteriorFlags,
240
+ ) -> tuple[int, int, Neighbors | None]:
241
+ batch_size = x.shape[0]
242
+ search_k = _compute_search_k(
243
+ params, flags, _compute_total_n(enn, x_whatif.shape[0], flags)
244
+ )
245
+ if search_k == 0:
246
+ return batch_size, search_k, None
247
+ candidates = _build_candidates(enn, x, x_whatif, y_whatif, search_k=search_k)
248
+ neighbors = _select_effective_neighbors(
249
+ candidates,
250
+ search_k=search_k,
251
+ k=params.k_num_neighbors,
252
+ exclude_nearest=flags.exclude_nearest,
253
+ )
254
+ return batch_size, search_k, neighbors
255
+
256
+
257
+ def _compute_conditional_posterior_impl(
258
+ enn: ENNLike,
259
+ x_whatif: np.ndarray,
260
+ y_whatif: np.ndarray,
261
+ x: np.ndarray,
262
+ *,
263
+ params: ENNParams,
264
+ flags: PosteriorFlags,
265
+ y_scale: np.ndarray,
266
+ ):
267
+ from .enn_normal import ENNNormal
268
+
269
+ x = _validate_x(enn, x)
270
+ x_whatif, y_whatif = _validate_whatif(enn, x_whatif, y_whatif)
271
+ if x_whatif.shape[0] == 0:
272
+ return enn.posterior(x, params=params, flags=flags)
273
+ batch_size, search_k, neighbors = _conditional_neighbors_nonempty_whatif(
274
+ enn, x_whatif, y_whatif, x, params=params, flags=flags
275
+ )
276
+ if search_k == 0 or neighbors is None:
277
+ return _make_empty_normal(enn, batch_size)
278
+ mu, se = _compute_mu_se(enn, neighbors, params=params, flags=flags, y_scale=y_scale)
279
+ return ENNNormal(mu, se)
280
+
281
+
282
+ def compute_conditional_posterior(
283
+ enn: ENNLike,
284
+ x_whatif: np.ndarray,
285
+ y_whatif: np.ndarray,
286
+ x: np.ndarray,
287
+ *,
288
+ params: ENNParams,
289
+ flags: PosteriorFlags,
290
+ y_scale: np.ndarray,
291
+ ):
292
+ return _compute_conditional_posterior_impl(
293
+ enn, x_whatif, y_whatif, x, params=params, flags=flags, y_scale=y_scale
294
+ )
295
+
296
+
297
+ def compute_conditional_posterior_draw_internals(
298
+ enn: ENNLike,
299
+ x_whatif: np.ndarray,
300
+ y_whatif: np.ndarray,
301
+ x: np.ndarray,
302
+ *,
303
+ params: ENNParams,
304
+ flags: PosteriorFlags,
305
+ y_scale: np.ndarray,
306
+ ) -> ConditionalPosteriorDrawInternals:
307
+ x = _validate_x(enn, x)
308
+ x_whatif, y_whatif = _validate_whatif(enn, x_whatif, y_whatif)
309
+ if x_whatif.shape[0] == 0:
310
+ raise ValueError("x_whatif must be non-empty for conditional draw internals")
311
+ batch_size, search_k, neighbors = _conditional_neighbors_nonempty_whatif(
312
+ enn, x_whatif, y_whatif, x, params=params, flags=flags
313
+ )
314
+ if search_k == 0 or neighbors is None:
315
+ empty_internals = enn._empty_posterior_internals(batch_size)
316
+ return ConditionalPosteriorDrawInternals(
317
+ idx=empty_internals.idx,
318
+ w_normalized=empty_internals.w_normalized,
319
+ l2=empty_internals.l2,
320
+ mu=empty_internals.mu,
321
+ se=empty_internals.se,
322
+ )
323
+ return _compute_draw_internals(
324
+ enn, neighbors, params=params, flags=flags, y_scale=y_scale
325
+ )
enn/enn/enn_fit.py CHANGED
@@ -1,15 +1,49 @@
1
1
  from __future__ import annotations
2
-
3
2
  from typing import TYPE_CHECKING, Any
4
3
 
5
4
  if TYPE_CHECKING:
6
5
  import numpy as np
7
6
  from numpy.random import Generator
8
-
9
- from .enn import EpistemicNearestNeighbors
7
+ from .enn_class import EpistemicNearestNeighbors
10
8
  from .enn_params import ENNParams
11
9
 
12
- from .enn_util import standardize_y
10
+
11
+ def _validate_subsample_inputs(
12
+ x: np.ndarray | Any, y: np.ndarray | Any, P: int, paramss: list
13
+ ) -> tuple[np.ndarray, np.ndarray]:
14
+ import numpy as np
15
+
16
+ x_array = np.asarray(x, dtype=float)
17
+ if x_array.ndim != 2:
18
+ raise ValueError(x_array.shape)
19
+ y_array = np.asarray(y, dtype=float)
20
+ if y_array.ndim == 1:
21
+ y_array = y_array.reshape(-1, 1)
22
+ if y_array.ndim != 2:
23
+ raise ValueError(y_array.shape)
24
+ if x_array.shape[0] != y_array.shape[0]:
25
+ raise ValueError((x_array.shape, y_array.shape))
26
+ if P <= 0:
27
+ raise ValueError(P)
28
+ if len(paramss) == 0:
29
+ raise ValueError("paramss must be non-empty")
30
+ return x_array, y_array
31
+
32
+
33
+ def _compute_single_loglik(
34
+ y_scaled: np.ndarray, mu_i: np.ndarray, se_i: np.ndarray
35
+ ) -> float:
36
+ import numpy as np
37
+
38
+ if not np.isfinite(mu_i).all() or not np.isfinite(se_i).all():
39
+ return 0.0
40
+ if np.any(se_i <= 0.0):
41
+ return 0.0
42
+ var_scaled = se_i**2
43
+ loglik = -0.5 * np.sum(
44
+ np.log(2.0 * np.pi * var_scaled) + (y_scaled - mu_i) ** 2 / var_scaled
45
+ )
46
+ return float(loglik) if np.isfinite(loglik) else 0.0
13
47
 
14
48
 
15
49
  def subsample_loglik(
@@ -23,68 +57,37 @@ def subsample_loglik(
23
57
  ) -> list[float]:
24
58
  import numpy as np
25
59
 
26
- if x.ndim != 2:
27
- raise ValueError(x.shape)
28
- if y.ndim != 1:
29
- raise ValueError(y.shape)
30
- if x.shape[0] != y.shape[0]:
31
- raise ValueError((x.shape, y.shape))
32
- if P <= 0:
33
- raise ValueError(P)
34
- if len(paramss) == 0:
35
- raise ValueError("paramss must be non-empty")
36
- n = x.shape[0]
37
- if n == 0:
38
- return [0.0] * len(paramss)
39
- if len(model) <= 1:
60
+ x_array, y_array = _validate_subsample_inputs(x, y, P, paramss)
61
+ n = x_array.shape[0]
62
+ if n == 0 or len(model) <= 1:
40
63
  return [0.0] * len(paramss)
41
64
  P_actual = min(P, n)
42
- if P_actual == n:
43
- indices = np.arange(n, dtype=int)
44
- else:
45
- indices = rng.permutation(n)[:P_actual]
46
- x_selected = x[indices]
47
- y_selected = y[indices]
48
- if not np.isfinite(y_selected).all():
65
+ indices = (
66
+ np.arange(n, dtype=int) if P_actual == n else rng.permutation(n)[:P_actual]
67
+ )
68
+ x_sel, y_sel = x_array[indices], y_array[indices]
69
+ if not np.isfinite(y_sel).all():
49
70
  return [0.0] * len(paramss)
50
- post_batch = model.batch_posterior(
51
- x_selected, paramss, exclude_nearest=True, observation_noise=True
71
+ from .enn_params import PosteriorFlags
72
+
73
+ post = model.batch_posterior(
74
+ x_sel,
75
+ paramss,
76
+ flags=PosteriorFlags(exclude_nearest=True, observation_noise=True),
52
77
  )
53
- mu_batch = post_batch.mu
54
- se_batch = post_batch.se
55
- if mu_batch.shape[2] == 1:
56
- mu_batch = mu_batch[:, :, 0]
57
- se_batch = se_batch[:, :, 0]
58
- num_params = len(paramss)
59
- if mu_batch.shape != (num_params, P_actual) or se_batch.shape != (
60
- num_params,
61
- P_actual,
62
- ):
63
- raise ValueError((mu_batch.shape, se_batch.shape, (num_params, P_actual)))
64
- _, y_std = standardize_y(y)
65
- y_scaled = y_selected / y_std
66
- mu_scaled = mu_batch / y_std
67
- se_scaled = se_batch / y_std
68
- result = []
69
- for i in range(num_params):
70
- mu_i = mu_scaled[i]
71
- se_i = se_scaled[i]
72
- if not np.isfinite(mu_i).all() or not np.isfinite(se_i).all():
73
- result.append(0.0)
74
- continue
75
- if np.any(se_i <= 0.0):
76
- result.append(0.0)
77
- continue
78
- diff = y_scaled - mu_i
79
- var_scaled = se_i**2
80
- log_term = np.log(2.0 * np.pi * var_scaled)
81
- quad = diff**2 / var_scaled
82
- loglik = -0.5 * np.sum(log_term + quad)
83
- if not np.isfinite(loglik):
84
- result.append(0.0)
85
- continue
86
- result.append(float(loglik))
87
- return result
78
+ num_params, num_outputs = len(paramss), y_sel.shape[1]
79
+ expected_shape = (num_params, P_actual, num_outputs)
80
+ if post.mu.shape != expected_shape or post.se.shape != expected_shape:
81
+ raise ValueError((post.mu.shape, post.se.shape, expected_shape))
82
+ y_std = np.std(y_array, axis=0, keepdims=True).astype(float)
83
+ y_std = np.where(np.isfinite(y_std) & (y_std > 0.0), y_std, 1.0)
84
+ y_scaled = y_sel / y_std
85
+ mu_scaled = post.mu / y_std
86
+ se_scaled = post.se / y_std
87
+ return [
88
+ _compute_single_loglik(y_scaled, mu_scaled[i], se_scaled[i])
89
+ for i in range(num_params)
90
+ ]
88
91
 
89
92
 
90
93
  def enn_fit(
@@ -100,12 +103,6 @@ def enn_fit(
100
103
 
101
104
  train_x = model.train_x
102
105
  train_y = model.train_y
103
- train_yvar = model.train_yvar
104
- if train_y.shape[1] != 1:
105
- raise ValueError(train_y.shape)
106
- if train_yvar is not None and train_yvar.shape[1] != 1:
107
- raise ValueError(train_yvar.shape)
108
- y = train_y[:, 0]
109
106
  log_min = -3.0
110
107
  log_max = 3.0
111
108
  epi_var_scale_log_values = rng.uniform(log_min, log_max, size=num_fit_candidates)
@@ -116,26 +113,30 @@ def enn_fit(
116
113
  ale_homoscedastic_values = 10**ale_homoscedastic_log_values
117
114
  paramss = [
118
115
  ENNParams(
119
- k=k,
120
- epi_var_scale=float(epi_val),
121
- ale_homoscedastic_scale=float(ale_val),
116
+ k_num_neighbors=k,
117
+ epistemic_variance_scale=float(epi_val),
118
+ aleatoric_variance_scale=float(ale_val),
122
119
  )
123
120
  for epi_val, ale_val in zip(epi_var_scale_values, ale_homoscedastic_values)
124
121
  ]
125
122
  if params_warm_start is not None:
126
123
  paramss.append(
127
124
  ENNParams(
128
- k=k,
129
- epi_var_scale=params_warm_start.epi_var_scale,
130
- ale_homoscedastic_scale=params_warm_start.ale_homoscedastic_scale,
125
+ k_num_neighbors=k,
126
+ epistemic_variance_scale=params_warm_start.epistemic_variance_scale,
127
+ aleatoric_variance_scale=params_warm_start.aleatoric_variance_scale,
131
128
  )
132
129
  )
133
130
  if len(paramss) == 0:
134
- return ENNParams(k=k, epi_var_scale=1.0, ale_homoscedastic_scale=0.0)
131
+ return ENNParams(
132
+ k_num_neighbors=k,
133
+ epistemic_variance_scale=1.0,
134
+ aleatoric_variance_scale=0.0,
135
+ )
135
136
  import numpy as np
136
137
 
137
138
  logliks = subsample_loglik(
138
- model, train_x, y, paramss=paramss, P=num_fit_samples, rng=rng
139
+ model, train_x, train_y, paramss=paramss, P=num_fit_samples, rng=rng
139
140
  )
140
141
  if len(logliks) == 0:
141
142
  return paramss[0]
enn/enn/enn_hash.py ADDED
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
+
7
+
8
+ def normal_hash_batch_multi_seed(
9
+ function_seeds: np.ndarray, data_indices: np.ndarray, num_metrics: int
10
+ ) -> np.ndarray:
11
+ import numpy as np
12
+ from scipy.special import ndtri
13
+
14
+ num_seeds = len(function_seeds)
15
+ unique_indices, inverse = np.unique(data_indices, return_inverse=True)
16
+ num_unique = len(unique_indices)
17
+ seed_grid, idx_grid, metric_grid = np.meshgrid(
18
+ function_seeds.astype(np.uint64),
19
+ unique_indices.astype(np.uint64),
20
+ np.arange(num_metrics, dtype=np.uint64),
21
+ indexing="ij",
22
+ )
23
+ seed_flat = seed_grid.ravel()
24
+ idx_flat = idx_grid.ravel()
25
+ metric_flat = metric_grid.ravel()
26
+ combined_seeds = (seed_flat * np.uint64(1_000_003) + idx_flat) * np.uint64(
27
+ 1_000_003
28
+ ) + metric_flat
29
+ uniform_vals = np.empty(len(combined_seeds), dtype=float)
30
+ for i, seed in enumerate(combined_seeds):
31
+ rng = np.random.Generator(np.random.Philox(int(seed)))
32
+ uniform_vals[i] = rng.random()
33
+ uniform_vals = np.clip(uniform_vals, 1e-10, 1.0 - 1e-10)
34
+ normal_vals = ndtri(uniform_vals).reshape(num_seeds, num_unique, num_metrics)
35
+ return normal_vals[:, inverse.ravel(), :].reshape(
36
+ num_seeds, *data_indices.shape, num_metrics
37
+ )
38
+
39
+
40
+ def normal_hash_batch_multi_seed_fast(
41
+ function_seeds: np.ndarray, data_indices: np.ndarray, num_metrics: int
42
+ ) -> np.ndarray:
43
+ import numpy as np
44
+
45
+ function_seeds = np.asarray(function_seeds, dtype=np.int64)
46
+ data_indices = np.asarray(data_indices)
47
+ if num_metrics <= 0:
48
+ raise ValueError(num_metrics)
49
+ num_seeds = len(function_seeds)
50
+ unique_indices, inverse = np.unique(data_indices, return_inverse=True)
51
+
52
+ def _splitmix64(x: np.ndarray) -> np.ndarray:
53
+ with np.errstate(over="ignore"):
54
+ x = x + np.uint64(0x9E3779B97F4A7C15)
55
+ z = x
56
+ z = (z ^ (z >> np.uint64(30))) * np.uint64(0xBF58476D1CE4E5B9)
57
+ z = (z ^ (z >> np.uint64(27))) * np.uint64(0x94D049BB133111EB)
58
+ z = z ^ (z >> np.uint64(31))
59
+ return z
60
+
61
+ seeds_u64 = function_seeds.astype(np.uint64, copy=False)
62
+ unique_u64 = unique_indices.astype(np.uint64, copy=False)
63
+ metric_u64 = np.arange(num_metrics, dtype=np.uint64)
64
+ normal_vals = np.empty((num_seeds, unique_indices.size, num_metrics), dtype=float)
65
+ p = np.uint64(1_000_003)
66
+ inv_2p53 = 1.0 / 9007199254740992.0
67
+ for si, s in enumerate(seeds_u64):
68
+ with np.errstate(over="ignore"):
69
+ base = (s * p + unique_u64) * p
70
+ combined = base[:, None] + metric_u64[None, :]
71
+ r1 = _splitmix64(combined)
72
+ r2 = _splitmix64(combined ^ np.uint64(0xD2B74407B1CE6E93))
73
+ u1 = (r1 >> np.uint64(11)).astype(np.float64) * inv_2p53
74
+ u2 = (r2 >> np.uint64(11)).astype(np.float64) * inv_2p53
75
+ u1 = np.clip(u1, 1e-12, 1.0 - 1e-12)
76
+ normal_vals[si, :, :] = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2)
77
+ return normal_vals[:, inverse.ravel(), :].reshape(
78
+ num_seeds, *data_indices.shape, num_metrics
79
+ )
enn/enn/enn_index.py ADDED
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ if TYPE_CHECKING:
5
+ import numpy as np
6
+
7
+
8
+ class ENNIndex:
9
+ def __init__(
10
+ self,
11
+ train_x_scaled: np.ndarray,
12
+ num_dim: int,
13
+ x_scale: np.ndarray,
14
+ scale_x: bool,
15
+ driver: Any = None,
16
+ ) -> None:
17
+ from enn.turbo.config.enums import ENNIndexDriver
18
+
19
+ if driver is None:
20
+ driver = ENNIndexDriver.FLAT
21
+ self._train_x_scaled = train_x_scaled
22
+ self._num_dim = num_dim
23
+ self._x_scale = x_scale
24
+ self._scale_x = scale_x
25
+ self._driver = driver
26
+ self._index: Any | None = None
27
+ self._build_index()
28
+
29
+ def _build_index(self) -> None:
30
+ import faiss
31
+ import numpy as np
32
+ from enn.turbo.config.enums import ENNIndexDriver
33
+
34
+ if len(self._train_x_scaled) == 0:
35
+ return
36
+ x_f32 = self._train_x_scaled.astype(np.float32, copy=False)
37
+ if self._driver == ENNIndexDriver.FLAT:
38
+ index = faiss.IndexFlatL2(self._num_dim)
39
+ elif self._driver == ENNIndexDriver.HNSW:
40
+ # TODO: Make M configurable
41
+ index = faiss.IndexHNSWFlat(self._num_dim, 32)
42
+ else:
43
+ raise ValueError(f"Unknown driver: {self._driver}")
44
+ index.add(x_f32)
45
+ self._index = index
46
+
47
+ def add(self, x: np.ndarray) -> None:
48
+ import numpy as np
49
+ from enn.turbo.config.enums import ENNIndexDriver
50
+
51
+ x = np.asarray(x, dtype=float)
52
+ if x.ndim != 2 or x.shape[1] != self._num_dim:
53
+ raise ValueError(x.shape)
54
+ x_scaled = x / self._x_scale if self._scale_x else x
55
+ x_f32 = x_scaled.astype(np.float32, copy=False)
56
+ if self._index is None:
57
+ import faiss
58
+
59
+ if self._driver == ENNIndexDriver.FLAT:
60
+ self._index = faiss.IndexFlatL2(self._num_dim)
61
+ elif self._driver == ENNIndexDriver.HNSW:
62
+ self._index = faiss.IndexHNSWFlat(self._num_dim, 32)
63
+ else:
64
+ raise ValueError(f"Unknown driver: {self._driver}")
65
+ self._index.add(x_f32)
66
+
67
+ def search(
68
+ self,
69
+ x: np.ndarray,
70
+ *,
71
+ search_k: int,
72
+ exclude_nearest: bool,
73
+ ) -> tuple[np.ndarray, np.ndarray]:
74
+ import numpy as np
75
+
76
+ search_k = int(search_k)
77
+ if search_k <= 0:
78
+ raise ValueError(search_k)
79
+ x = np.asarray(x, dtype=float)
80
+ if x.ndim != 2 or x.shape[1] != self._num_dim:
81
+ raise ValueError(x.shape)
82
+ if self._index is None:
83
+ raise RuntimeError("index is not initialized")
84
+ x_scaled = x / self._x_scale if self._scale_x else x
85
+ x_f32 = x_scaled.astype(np.float32, copy=False)
86
+ dist2s_full, idx_full = self._index.search(x_f32, search_k)
87
+ dist2s_full = dist2s_full.astype(float)
88
+ idx_full = idx_full.astype(int)
89
+ if exclude_nearest:
90
+ dist2s_full = dist2s_full[:, 1:]
91
+ idx_full = idx_full[:, 1:]
92
+ return dist2s_full, idx_full