ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -269
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +69 -70
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +0 -1
  18. enn/enn/enn_params.py +3 -22
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +60 -46
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +131 -70
  84. enn/turbo/no_trust_region.py +32 -39
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +36 -38
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +0 -1
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +126 -58
  100. enn/turbo/turbo_utils.py +98 -161
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -144
  113. enn/turbo/lhd_only_impl.py +0 -49
  114. enn/turbo/turbo_config.py +0 -72
  115. enn/turbo/turbo_enn_impl.py +0 -201
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -76
  118. enn/turbo/turbo_one_impl.py +0 -302
  119. enn/turbo/turbo_optimizer.py +0 -525
  120. enn/turbo/turbo_zero_impl.py +0 -29
  121. ennbo-0.1.2.dist-info/RECORD +0 -29
  122. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
enn/enn/enn_class.py ADDED
@@ -0,0 +1,423 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+ import numpy as np
4
+ from .draw_internals import DrawInternals
5
+ from .neighbor_data import NeighborData
6
+ from .weighted_stats import WeightedStats
7
+ from enn.turbo.config.enums import ENNIndexDriver
8
+
9
+ if TYPE_CHECKING:
10
+ from .enn_normal import ENNNormal
11
+ from .enn_params import ENNParams, PosteriorFlags
12
+
13
+
14
+ def _compute_conditional_y_scale(
15
+ model: EpistemicNearestNeighbors, y_whatif: np.ndarray
16
+ ):
17
+ y_whatif = np.asarray(y_whatif, dtype=float)
18
+ return model._compute_scale(
19
+ np.concatenate([model.train_y, y_whatif], axis=0),
20
+ 0.0,
21
+ )
22
+
23
+
24
+ def _draw_from_internals(
25
+ model: EpistemicNearestNeighbors,
26
+ internals: DrawInternals,
27
+ *,
28
+ function_seeds: np.ndarray | list[int],
29
+ ) -> np.ndarray:
30
+ from .enn_hash import normal_hash_batch_multi_seed_fast
31
+
32
+ function_seeds = np.asarray(function_seeds, dtype=np.int64)
33
+ n, k, m = internals.idx.shape[0], internals.idx.shape[1], model.num_outputs
34
+ if k == 0:
35
+ return np.broadcast_to(internals.mu, (len(function_seeds), n, m)).copy()
36
+ u = normal_hash_batch_multi_seed_fast(function_seeds, internals.idx, m)
37
+ weighted_u = np.sum(internals.w_normalized[np.newaxis, :, :, :] * u, axis=2)
38
+ l2_safe = np.maximum(internals.l2, 1e-12)
39
+ return (
40
+ internals.mu[np.newaxis, :, :]
41
+ + internals.se[np.newaxis, :, :] * weighted_u / l2_safe[np.newaxis, :, :]
42
+ )
43
+
44
+
45
+ class EpistemicNearestNeighbors:
46
+ _EPS_VAR = 1e-9
47
+
48
+ @staticmethod
49
+ def _validate_inputs(train_x, train_y, train_yvar):
50
+ train_x, train_y = (
51
+ np.asarray(train_x, dtype=float),
52
+ np.asarray(train_y, dtype=float),
53
+ )
54
+ if (
55
+ train_x.ndim != 2
56
+ or train_y.ndim != 2
57
+ or train_x.shape[0] != train_y.shape[0]
58
+ ):
59
+ raise ValueError((train_x.shape, train_y.shape))
60
+ if train_yvar is not None:
61
+ train_yvar = np.asarray(train_yvar, dtype=float)
62
+ if train_yvar.ndim != 2 or train_y.shape != train_yvar.shape:
63
+ raise ValueError((train_y.shape, train_yvar.shape))
64
+ return train_x, train_y, train_yvar
65
+
66
+ @staticmethod
67
+ def _compute_scale(data, min_val=0.0):
68
+ if len(data) < 2:
69
+ return np.ones((1, data.shape[1]), dtype=float)
70
+ scale = np.std(data, axis=0, keepdims=True).astype(float)
71
+ return np.where(np.isfinite(scale) & (scale > min_val), scale, 1.0)
72
+
73
+ def __init__(
74
+ self,
75
+ train_x: np.ndarray,
76
+ train_y: np.ndarray,
77
+ train_yvar: np.ndarray | None = None,
78
+ *,
79
+ scale_x: bool = False,
80
+ index_driver: ENNIndexDriver = ENNIndexDriver.FLAT,
81
+ ) -> None:
82
+ self._train_x, self._train_y, self._train_yvar = self._validate_inputs(
83
+ train_x, train_y, train_yvar
84
+ )
85
+ self._num_obs, self._num_dim = self._train_x.shape
86
+ _, self._num_metrics = self._train_y.shape
87
+ self._scale_x = bool(scale_x)
88
+ self._x_scale = (
89
+ self._compute_scale(self._train_x, 1e-12)
90
+ if scale_x
91
+ else np.ones((1, self._num_dim), dtype=float)
92
+ )
93
+ self._train_x_scaled = (
94
+ self._train_x / self._x_scale if scale_x else self._train_x
95
+ )
96
+ self._y_scale = self._compute_scale(self._train_y, 0.0)
97
+ from .enn_index import ENNIndex
98
+
99
+ self._enn_index = ENNIndex(
100
+ self._train_x_scaled,
101
+ self._num_dim,
102
+ self._x_scale,
103
+ self._scale_x,
104
+ driver=index_driver,
105
+ )
106
+
107
+ def add(
108
+ self,
109
+ x: np.ndarray,
110
+ y: np.ndarray,
111
+ yvar: np.ndarray | None = None,
112
+ ) -> None:
113
+ x, y, yvar = self._validate_inputs(x, y, yvar)
114
+ self._train_x = np.concatenate([self._train_x, x], axis=0)
115
+ self._train_y = np.concatenate([self._train_y, y], axis=0)
116
+ if yvar is not None:
117
+ if self._train_yvar is None:
118
+ self._train_yvar = yvar
119
+ else:
120
+ self._train_yvar = np.concatenate([self._train_yvar, yvar], axis=0)
121
+ elif self._train_yvar is not None:
122
+ # If we have some yvar but not for the new points, we need to handle it.
123
+ # For now, we'll just use zeros or raise if inconsistent.
124
+ # Following the existing pattern, we assume consistency.
125
+ raise ValueError("yvar must be provided if model has existing yvar")
126
+
127
+ self._num_obs = self._train_x.shape[0]
128
+ self._enn_index.add(x)
129
+
130
+ @property
131
+ def train_x(self) -> np.ndarray:
132
+ return self._train_x
133
+
134
+ @property
135
+ def train_y(self) -> np.ndarray:
136
+ return self._train_y
137
+
138
+ @property
139
+ def train_yvar(self) -> np.ndarray | None:
140
+ return self._train_yvar
141
+
142
+ @property
143
+ def num_outputs(self) -> int:
144
+ return self._num_metrics
145
+
146
+ def __len__(self) -> int:
147
+ return self._num_obs
148
+
149
+ def posterior(
150
+ self,
151
+ x: np.ndarray,
152
+ *,
153
+ params: ENNParams,
154
+ flags: PosteriorFlags | None = None,
155
+ ) -> ENNNormal:
156
+ from .enn_normal import ENNNormal
157
+ from .enn_params import PosteriorFlags
158
+
159
+ if flags is None:
160
+ flags = PosteriorFlags()
161
+ post_batch = self.batch_posterior(x, [params], flags=flags)
162
+ return ENNNormal(post_batch.mu[0], post_batch.se[0])
163
+
164
+ def _empty_posterior_internals(self, batch_size: int) -> DrawInternals:
165
+ m = self._num_metrics
166
+ return DrawInternals(
167
+ idx=np.zeros((batch_size, 0), dtype=int),
168
+ w_normalized=np.zeros((batch_size, 0, m), dtype=float),
169
+ l2=np.ones((batch_size, m), dtype=float),
170
+ mu=np.zeros((batch_size, m), dtype=float),
171
+ se=np.ones((batch_size, m), dtype=float),
172
+ )
173
+
174
+ def _get_neighbor_data(
175
+ self, x: np.ndarray, params: ENNParams, exclude_nearest: bool
176
+ ) -> NeighborData | None:
177
+ if exclude_nearest:
178
+ if len(self) <= 1:
179
+ raise ValueError(len(self))
180
+ search_k = int(min(params.k_num_neighbors + 1, len(self)))
181
+ else:
182
+ search_k = int(min(params.k_num_neighbors, len(self)))
183
+ dist2s_full, idx_full = self._enn_index.search(
184
+ x, search_k=search_k, exclude_nearest=exclude_nearest
185
+ )
186
+ available_k = search_k - 1 if exclude_nearest else search_k
187
+ k = min(params.k_num_neighbors, available_k)
188
+ if k > dist2s_full.shape[1]:
189
+ raise RuntimeError(
190
+ f"k={k} exceeds available columns={dist2s_full.shape[1]}"
191
+ )
192
+ if k == 0:
193
+ return None
194
+ return NeighborData(
195
+ dist2s=dist2s_full[:, :k],
196
+ idx=idx_full[:, :k],
197
+ y_neighbors=self._train_y[idx_full[:, :k]],
198
+ k=k,
199
+ )
200
+
201
+ def _compute_weighted_posterior(
202
+ self,
203
+ dist2s: np.ndarray,
204
+ idx: np.ndarray,
205
+ y_neighbors: np.ndarray,
206
+ params: ENNParams,
207
+ observation_noise: bool,
208
+ ) -> DrawInternals:
209
+ yvar_neighbors = None
210
+ if self._train_yvar is not None:
211
+ yvar_neighbors = self._train_yvar[idx]
212
+ stats = self._compute_weighted_stats(
213
+ dist2s,
214
+ y_neighbors,
215
+ yvar_neighbors=yvar_neighbors,
216
+ params=params,
217
+ observation_noise=observation_noise,
218
+ )
219
+ return DrawInternals(
220
+ idx=idx,
221
+ w_normalized=stats.w_normalized,
222
+ l2=stats.l2,
223
+ mu=stats.mu,
224
+ se=stats.se,
225
+ )
226
+
227
+ def _compute_weighted_stats(
228
+ self,
229
+ dist2s: np.ndarray,
230
+ y_neighbors: np.ndarray,
231
+ *,
232
+ yvar_neighbors: np.ndarray | None,
233
+ params: ENNParams,
234
+ observation_noise: bool,
235
+ y_scale: np.ndarray | None = None,
236
+ ) -> WeightedStats:
237
+ if y_scale is None:
238
+ y_scale = self._y_scale
239
+ dist2s_expanded = dist2s[..., np.newaxis]
240
+ var_epi = params.epistemic_variance_scale * dist2s_expanded
241
+ var_ale = params.aleatoric_variance_scale
242
+ if yvar_neighbors is not None:
243
+ var_ale = var_ale + yvar_neighbors / y_scale**2
244
+ w = 1.0 / (self._EPS_VAR + var_epi + var_ale)
245
+ norm = np.sum(w, axis=1, keepdims=True)
246
+ w_normalized = w / norm
247
+ l2 = np.sqrt(np.sum(w_normalized**2, axis=1))
248
+ mu = np.sum(w_normalized * y_neighbors, axis=1)
249
+ epistemic_var = 1.0 / norm.squeeze(axis=1)
250
+ if observation_noise:
251
+ if np.isscalar(var_ale):
252
+ aleatoric_var = np.full_like(epistemic_var, var_ale)
253
+ else:
254
+ aleatoric_var = np.sum(w_normalized * var_ale, axis=1)
255
+ else:
256
+ aleatoric_var = 0.0
257
+ se = np.sqrt(np.maximum(epistemic_var + aleatoric_var, self._EPS_VAR)) * y_scale
258
+ return WeightedStats(w_normalized=w_normalized, l2=l2, mu=mu, se=se)
259
+
260
+ def conditional_posterior(
261
+ self,
262
+ x_whatif: np.ndarray,
263
+ y_whatif: np.ndarray,
264
+ x: np.ndarray,
265
+ *,
266
+ params: ENNParams,
267
+ flags: PosteriorFlags | None = None,
268
+ ) -> ENNNormal:
269
+ from .enn_conditional import compute_conditional_posterior
270
+ from .enn_params import PosteriorFlags
271
+
272
+ if flags is None:
273
+ flags = PosteriorFlags()
274
+ y_scale = _compute_conditional_y_scale(self, y_whatif)
275
+ return compute_conditional_posterior(
276
+ self, x_whatif, y_whatif, x, params=params, flags=flags, y_scale=y_scale
277
+ )
278
+
279
+ def _compute_posterior_internals(
280
+ self,
281
+ x: np.ndarray,
282
+ params: ENNParams,
283
+ flags: PosteriorFlags,
284
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
285
+ x = np.asarray(x, dtype=float)
286
+ if x.ndim != 2 or x.shape[1] != self._num_dim:
287
+ raise ValueError(x.shape)
288
+ batch_size = x.shape[0]
289
+ if len(self) == 0:
290
+ return self._empty_posterior_internals(batch_size)
291
+ neighbor_data = self._get_neighbor_data(x, params, flags.exclude_nearest)
292
+ if neighbor_data is None:
293
+ return self._empty_posterior_internals(batch_size)
294
+ return self._compute_weighted_posterior(
295
+ neighbor_data.dist2s,
296
+ neighbor_data.idx,
297
+ neighbor_data.y_neighbors,
298
+ params,
299
+ flags.observation_noise,
300
+ )
301
+
302
+ def batch_posterior(
303
+ self,
304
+ x: np.ndarray,
305
+ paramss: list[ENNParams],
306
+ *,
307
+ flags: PosteriorFlags | None = None,
308
+ ) -> ENNNormal:
309
+ from .enn_normal import ENNNormal
310
+ from .enn_params import PosteriorFlags
311
+
312
+ if flags is None:
313
+ flags = PosteriorFlags()
314
+ x = np.asarray(x, dtype=float)
315
+ if x.ndim != 2 or x.shape[1] != self._num_dim:
316
+ raise ValueError(x.shape)
317
+ if not paramss:
318
+ raise ValueError("paramss must be non-empty")
319
+ batch_size, num_params = x.shape[0], len(paramss)
320
+ mu_all = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
321
+ se_all = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
322
+ k_values = {p.k_num_neighbors for p in paramss}
323
+ if len(k_values) == 1 and len(self) > 0:
324
+ neighbor_data = self._get_neighbor_data(
325
+ x, paramss[0], flags.exclude_nearest
326
+ )
327
+ if neighbor_data is None:
328
+ return ENNNormal(mu_all, se_all)
329
+ for i, params in enumerate(paramss):
330
+ internals = self._compute_weighted_posterior(
331
+ neighbor_data.dist2s,
332
+ neighbor_data.idx,
333
+ neighbor_data.y_neighbors,
334
+ params,
335
+ flags.observation_noise,
336
+ )
337
+ mu_all[i], se_all[i] = internals.mu, internals.se
338
+ else:
339
+ for i, params in enumerate(paramss):
340
+ internals = self._compute_posterior_internals(x, params, flags)
341
+ mu_all[i], se_all[i] = internals.mu, internals.se
342
+ return ENNNormal(mu_all, se_all)
343
+
344
+ def neighbors(self, x: np.ndarray, k: int, *, exclude_nearest: bool = False):
345
+ x = np.asarray(x, dtype=float)
346
+ if x.ndim == 1:
347
+ x = x[np.newaxis, :]
348
+ if x.ndim != 2 or x.shape[0] != 1 or x.shape[1] != self._num_dim:
349
+ raise ValueError(
350
+ f"x must be single point with {self._num_dim} dims, got {x.shape}"
351
+ )
352
+ if k < 0:
353
+ raise ValueError(f"k must be non-negative, got {k}")
354
+ if len(self) == 0:
355
+ return []
356
+ if exclude_nearest and len(self) <= 1:
357
+ raise ValueError(
358
+ f"exclude_nearest=True requires at least 2 observations, got {len(self)}"
359
+ )
360
+ search_k = int(min(k + 1 if exclude_nearest else k, len(self)))
361
+ if search_k == 0:
362
+ return []
363
+ _, idx_full = self._enn_index.search(
364
+ x, search_k=search_k, exclude_nearest=exclude_nearest
365
+ )
366
+ idx = idx_full[0, : min(k, len(idx_full[0]))]
367
+ return [(self._train_x[i].copy(), self._train_y[i].copy()) for i in idx]
368
+
369
+ def posterior_function_draw(
370
+ self,
371
+ x: np.ndarray,
372
+ params: ENNParams,
373
+ *,
374
+ function_seeds: np.ndarray | list[int],
375
+ flags: PosteriorFlags | None = None,
376
+ ) -> np.ndarray:
377
+ from .enn_params import PosteriorFlags
378
+
379
+ if flags is None:
380
+ flags = PosteriorFlags()
381
+ internals = self._compute_posterior_internals(x, params, flags)
382
+ return _draw_from_internals(self, internals, function_seeds=function_seeds)
383
+
384
+ def conditional_posterior_function_draw(
385
+ self,
386
+ x_whatif: np.ndarray,
387
+ y_whatif: np.ndarray,
388
+ x: np.ndarray,
389
+ *,
390
+ params: ENNParams,
391
+ function_seeds: np.ndarray | list[int],
392
+ flags: PosteriorFlags | None = None,
393
+ ) -> np.ndarray:
394
+ from .enn_conditional import compute_conditional_posterior_draw_internals
395
+ from .enn_params import PosteriorFlags
396
+
397
+ if flags is None:
398
+ flags = PosteriorFlags()
399
+ x_whatif = np.asarray(x_whatif, dtype=float)
400
+ if x_whatif.ndim != 2 or x_whatif.shape[1] != self._num_dim:
401
+ raise ValueError(x_whatif.shape)
402
+ if x_whatif.shape[0] == 0:
403
+ return self.posterior_function_draw(
404
+ x,
405
+ params,
406
+ function_seeds=function_seeds,
407
+ flags=flags,
408
+ )
409
+ y_scale = _compute_conditional_y_scale(self, y_whatif)
410
+ internals = compute_conditional_posterior_draw_internals(
411
+ self, x_whatif, y_whatif, x, params=params, flags=flags, y_scale=y_scale
412
+ )
413
+ return _draw_from_internals(
414
+ self,
415
+ DrawInternals(
416
+ idx=internals.idx,
417
+ w_normalized=internals.w_normalized,
418
+ l2=internals.l2,
419
+ mu=internals.mu,
420
+ se=internals.se,
421
+ ),
422
+ function_seeds=function_seeds,
423
+ )