ennbo 0.1.2__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. enn/__init__.py +25 -13
  2. enn/benchmarks/__init__.py +3 -0
  3. enn/benchmarks/ackley.py +5 -0
  4. enn/benchmarks/ackley_class.py +17 -0
  5. enn/benchmarks/ackley_core.py +12 -0
  6. enn/benchmarks/double_ackley.py +24 -0
  7. enn/enn/candidates.py +14 -0
  8. enn/enn/conditional_posterior_draw_internals.py +15 -0
  9. enn/enn/draw_internals.py +15 -0
  10. enn/enn/enn.py +16 -269
  11. enn/enn/enn_class.py +423 -0
  12. enn/enn/enn_conditional.py +325 -0
  13. enn/enn/enn_fit.py +69 -70
  14. enn/enn/enn_hash.py +79 -0
  15. enn/enn/enn_index.py +92 -0
  16. enn/enn/enn_like_protocol.py +35 -0
  17. enn/enn/enn_normal.py +0 -1
  18. enn/enn/enn_params.py +3 -22
  19. enn/enn/enn_params_class.py +24 -0
  20. enn/enn/enn_util.py +60 -46
  21. enn/enn/neighbor_data.py +14 -0
  22. enn/enn/neighbors.py +14 -0
  23. enn/enn/posterior_flags.py +8 -0
  24. enn/enn/weighted_stats.py +14 -0
  25. enn/turbo/components/__init__.py +41 -0
  26. enn/turbo/components/acquisition.py +13 -0
  27. enn/turbo/components/acquisition_optimizer_protocol.py +19 -0
  28. enn/turbo/components/builder.py +22 -0
  29. enn/turbo/components/chebyshev_incumbent_selector.py +76 -0
  30. enn/turbo/components/enn_surrogate.py +115 -0
  31. enn/turbo/components/gp_surrogate.py +144 -0
  32. enn/turbo/components/hnr_acq_optimizer.py +83 -0
  33. enn/turbo/components/incumbent_selector.py +11 -0
  34. enn/turbo/components/incumbent_selector_protocol.py +16 -0
  35. enn/turbo/components/no_incumbent_selector.py +21 -0
  36. enn/turbo/components/no_surrogate.py +49 -0
  37. enn/turbo/components/pareto_acq_optimizer.py +49 -0
  38. enn/turbo/components/posterior_result.py +12 -0
  39. enn/turbo/components/protocols.py +13 -0
  40. enn/turbo/components/random_acq_optimizer.py +21 -0
  41. enn/turbo/components/scalar_incumbent_selector.py +39 -0
  42. enn/turbo/components/surrogate_protocol.py +32 -0
  43. enn/turbo/components/surrogate_result.py +12 -0
  44. enn/turbo/components/surrogates.py +5 -0
  45. enn/turbo/components/thompson_acq_optimizer.py +49 -0
  46. enn/turbo/components/trust_region_protocol.py +24 -0
  47. enn/turbo/components/ucb_acq_optimizer.py +49 -0
  48. enn/turbo/config/__init__.py +87 -0
  49. enn/turbo/config/acq_type.py +8 -0
  50. enn/turbo/config/acquisition.py +26 -0
  51. enn/turbo/config/base.py +4 -0
  52. enn/turbo/config/candidate_gen_config.py +49 -0
  53. enn/turbo/config/candidate_rv.py +7 -0
  54. enn/turbo/config/draw_acquisition_config.py +14 -0
  55. enn/turbo/config/enn_index_driver.py +6 -0
  56. enn/turbo/config/enn_surrogate_config.py +42 -0
  57. enn/turbo/config/enums.py +7 -0
  58. enn/turbo/config/factory.py +118 -0
  59. enn/turbo/config/gp_surrogate_config.py +14 -0
  60. enn/turbo/config/hnr_optimizer_config.py +7 -0
  61. enn/turbo/config/init_config.py +17 -0
  62. enn/turbo/config/init_strategies/__init__.py +9 -0
  63. enn/turbo/config/init_strategies/hybrid_init.py +23 -0
  64. enn/turbo/config/init_strategies/init_strategy.py +19 -0
  65. enn/turbo/config/init_strategies/lhd_only_init.py +24 -0
  66. enn/turbo/config/morbo_tr_config.py +82 -0
  67. enn/turbo/config/nds_optimizer_config.py +7 -0
  68. enn/turbo/config/no_surrogate_config.py +14 -0
  69. enn/turbo/config/no_tr_config.py +31 -0
  70. enn/turbo/config/optimizer_config.py +72 -0
  71. enn/turbo/config/pareto_acquisition_config.py +14 -0
  72. enn/turbo/config/raasp_driver.py +6 -0
  73. enn/turbo/config/raasp_optimizer_config.py +7 -0
  74. enn/turbo/config/random_acquisition_config.py +14 -0
  75. enn/turbo/config/rescalarize.py +7 -0
  76. enn/turbo/config/surrogate.py +12 -0
  77. enn/turbo/config/trust_region.py +34 -0
  78. enn/turbo/config/turbo_tr_config.py +71 -0
  79. enn/turbo/config/ucb_acquisition_config.py +14 -0
  80. enn/turbo/config/validation.py +45 -0
  81. enn/turbo/hypervolume.py +30 -0
  82. enn/turbo/impl_helpers.py +68 -0
  83. enn/turbo/morbo_trust_region.py +131 -70
  84. enn/turbo/no_trust_region.py +32 -39
  85. enn/turbo/optimizer.py +300 -0
  86. enn/turbo/optimizer_config.py +8 -0
  87. enn/turbo/proposal.py +36 -38
  88. enn/turbo/sampling.py +21 -0
  89. enn/turbo/strategies/__init__.py +9 -0
  90. enn/turbo/strategies/lhd_only_strategy.py +36 -0
  91. enn/turbo/strategies/optimization_strategy.py +19 -0
  92. enn/turbo/strategies/turbo_hybrid_strategy.py +124 -0
  93. enn/turbo/tr_helpers.py +202 -0
  94. enn/turbo/turbo_gp.py +0 -1
  95. enn/turbo/turbo_gp_base.py +0 -1
  96. enn/turbo/turbo_gp_fit.py +187 -0
  97. enn/turbo/turbo_gp_noisy.py +0 -1
  98. enn/turbo/turbo_optimizer_utils.py +98 -0
  99. enn/turbo/turbo_trust_region.py +126 -58
  100. enn/turbo/turbo_utils.py +98 -161
  101. enn/turbo/types/__init__.py +7 -0
  102. enn/turbo/types/appendable_array.py +85 -0
  103. enn/turbo/types/gp_data_prep.py +13 -0
  104. enn/turbo/types/gp_fit_result.py +11 -0
  105. enn/turbo/types/obs_lists.py +10 -0
  106. enn/turbo/types/prepare_ask_result.py +14 -0
  107. enn/turbo/types/tell_inputs.py +14 -0
  108. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/METADATA +18 -11
  109. ennbo-0.1.7.dist-info/RECORD +111 -0
  110. enn/enn/__init__.py +0 -4
  111. enn/turbo/__init__.py +0 -11
  112. enn/turbo/base_turbo_impl.py +0 -144
  113. enn/turbo/lhd_only_impl.py +0 -49
  114. enn/turbo/turbo_config.py +0 -72
  115. enn/turbo/turbo_enn_impl.py +0 -201
  116. enn/turbo/turbo_mode.py +0 -10
  117. enn/turbo/turbo_mode_impl.py +0 -76
  118. enn/turbo/turbo_one_impl.py +0 -302
  119. enn/turbo/turbo_optimizer.py +0 -525
  120. enn/turbo/turbo_zero_impl.py +0 -29
  121. ennbo-0.1.2.dist-info/RECORD +0 -29
  122. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/WHEEL +0 -0
  123. {ennbo-0.1.2.dist-info → ennbo-0.1.7.dist-info}/licenses/LICENSE +0 -0
enn/__init__.py CHANGED
@@ -1,24 +1,36 @@
1
1
  from __future__ import annotations
2
-
3
- from .enn import EpistemicNearestNeighbors, enn_fit
4
-
5
- _LAZY_IMPORTS = ("TurboMode", "TurboOptimizer", "Turbo", "Telemetry")
6
-
7
-
8
- def _lazy_load(name: str):
9
- from . import turbo
10
-
11
- return getattr(turbo, name)
2
+ import importlib
3
+ from .enn.enn_class import EpistemicNearestNeighbors
4
+ from .enn.enn_fit import enn_fit
5
+
6
+ _LAZY_ATTRS: dict[str, tuple[str, str]] = {
7
+ "create_optimizer": (".turbo.optimizer", "create_optimizer"),
8
+ "Telemetry": (".turbo.turbo_utils", "Telemetry"),
9
+ "OptimizerConfig": (".turbo.optimizer_config", "OptimizerConfig"),
10
+ "turbo_one_config": (".turbo.optimizer_config", "turbo_one_config"),
11
+ "turbo_zero_config": (".turbo.optimizer_config", "turbo_zero_config"),
12
+ "turbo_enn_config": (".turbo.optimizer_config", "turbo_enn_config"),
13
+ "lhd_only_config": (".turbo.optimizer_config", "lhd_only_config"),
14
+ "TurboTRConfig": (".turbo.config.trust_region", "TurboTRConfig"),
15
+ "MorboTRConfig": (".turbo.config.trust_region", "MorboTRConfig"),
16
+ "NoTRConfig": (".turbo.config.trust_region", "NoTRConfig"),
17
+ "CandidateRV": (".turbo.optimizer_config", "CandidateRV"),
18
+ "InitStrategy": (".turbo.optimizer_config", "InitStrategy"),
19
+ "AcqType": (".turbo.optimizer_config", "AcqType"),
20
+ }
12
21
 
13
22
 
14
23
  def __getattr__(name: str):
15
- if name in _LAZY_IMPORTS:
16
- return _lazy_load(name)
24
+ spec = _LAZY_ATTRS.get(name)
25
+ if spec is not None:
26
+ module_name, attr_name = spec
27
+ module = importlib.import_module(module_name, __package__)
28
+ return getattr(module, attr_name)
17
29
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
18
30
 
19
31
 
20
32
  __all__: list[str] = [
21
33
  "EpistemicNearestNeighbors",
22
34
  "enn_fit",
23
- *_LAZY_IMPORTS,
35
+ *_LAZY_ATTRS.keys(),
24
36
  ]
@@ -0,0 +1,3 @@
1
+ from .ackley import Ackley, DoubleAckley
2
+
3
+ __all__ = ["Ackley", "DoubleAckley"]
@@ -0,0 +1,5 @@
1
+ from .ackley_class import Ackley
2
+ from .ackley_core import ackley_core
3
+ from .double_ackley import DoubleAckley
4
+
5
+ __all__ = ["Ackley", "DoubleAckley", "ackley_core"]
@@ -0,0 +1,17 @@
1
+ import numpy as np
2
+ from numpy.random import Generator
3
+ from .ackley_core import ackley_core
4
+
5
+
6
+ class Ackley:
7
+ def __init__(self, noise: float, rng: Generator):
8
+ self.noise = noise
9
+ self.rng = rng
10
+ self.bounds = [-32.768, 32.768]
11
+
12
+ def __call__(self, x: np.ndarray) -> np.ndarray:
13
+ x = np.asarray(x, dtype=float)
14
+ if x.ndim == 1:
15
+ x = x[None, :]
16
+ y = -ackley_core(x) + self.noise * self.rng.normal(size=(x.shape[0],))
17
+ return y if y.ndim > 0 else float(y)
@@ -0,0 +1,12 @@
1
+ import numpy as np
2
+
3
+
4
+ def ackley_core(
5
+ x: np.ndarray, a: float = 20.0, b: float = 0.2, c: float = 2 * np.pi
6
+ ) -> np.ndarray:
7
+ if x.ndim == 1:
8
+ x = x[None, :]
9
+ x = x - 1
10
+ term1 = -a * np.exp(-b * np.sqrt((x**2).mean(axis=1)))
11
+ term2 = -np.exp(np.cos(c * x).mean(axis=1))
12
+ return term1 + term2 + a + np.e
@@ -0,0 +1,24 @@
1
+ import numpy as np
2
+ from numpy.random import Generator
3
+ from .ackley_core import ackley_core
4
+
5
+
6
+ class DoubleAckley:
7
+ def __init__(self, noise: float, rng: Generator):
8
+ self.noise = noise
9
+ self.rng = rng
10
+ self.bounds = [-32.768, 32.768]
11
+
12
+ def __call__(self, x: np.ndarray) -> np.ndarray:
13
+ x = np.asarray(x, dtype=float)
14
+ if x.ndim == 1:
15
+ x = x[None, :]
16
+ n, d = x.shape
17
+ if d % 2 != 0:
18
+ raise ValueError("num_dim must be even for DoubleAckley")
19
+ mid = d // 2
20
+ x1 = x[:, :mid]
21
+ x2 = x[:, mid:]
22
+ y1 = -ackley_core(x1) + self.noise * self.rng.normal(size=n)
23
+ y2 = -ackley_core(x2) + self.noise * self.rng.normal(size=n)
24
+ return np.stack([y1, y2], axis=1)
enn/enn/candidates.py ADDED
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class Candidates:
11
+ dist2: np.ndarray
12
+ ids: np.ndarray
13
+ y: np.ndarray
14
+ yvar: np.ndarray | None
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class ConditionalPosteriorDrawInternals:
11
+ idx: np.ndarray
12
+ w_normalized: np.ndarray
13
+ l2: np.ndarray
14
+ mu: np.ndarray
15
+ se: np.ndarray
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class DrawInternals:
11
+ idx: np.ndarray
12
+ w_normalized: np.ndarray
13
+ l2: np.ndarray
14
+ mu: np.ndarray
15
+ se: np.ndarray
enn/enn/enn.py CHANGED
@@ -1,269 +1,16 @@
1
- from __future__ import annotations
2
-
3
- from typing import TYPE_CHECKING, Any
4
-
5
- if TYPE_CHECKING:
6
- import numpy as np
7
-
8
- from .enn_normal import ENNNormal
9
- from .enn_params import ENNParams
10
-
11
-
12
- class EpistemicNearestNeighbors:
13
- def __init__(
14
- self,
15
- train_x: np.ndarray,
16
- train_y: np.ndarray,
17
- train_yvar: np.ndarray | None = None,
18
- *,
19
- scale_x: bool = False,
20
- ) -> None:
21
- import numpy as np
22
-
23
- train_x = np.asarray(train_x, dtype=float)
24
- train_y = np.asarray(train_y, dtype=float)
25
- if train_x.ndim != 2:
26
- raise ValueError(train_x.shape)
27
- if train_y.ndim != 2:
28
- raise ValueError(train_y.shape)
29
- if train_x.shape[0] != train_y.shape[0]:
30
- raise ValueError((train_x.shape, train_y.shape))
31
- if train_yvar is not None:
32
- train_yvar = np.asarray(train_yvar, dtype=float)
33
- if train_yvar.ndim != 2:
34
- raise ValueError(train_yvar.shape)
35
- if train_y.shape != train_yvar.shape:
36
- raise ValueError((train_y.shape, train_yvar.shape))
37
-
38
- self._train_x = train_x
39
- self._train_y = train_y
40
- self._train_yvar = train_yvar
41
- self._num_obs, self._num_dim = self._train_x.shape
42
- _, self._num_metrics = self._train_y.shape
43
- self._eps_var = 1e-9
44
- self._scale_x = bool(scale_x)
45
- if self._scale_x:
46
- if len(self._train_x) < 2:
47
- x_scale = np.ones((1, self._num_dim), dtype=float)
48
- else:
49
- x_scale = np.std(self._train_x, axis=0, keepdims=True).astype(float)
50
- x_scale = np.where(
51
- np.isfinite(x_scale) & (x_scale > 1e-12),
52
- x_scale,
53
- 1.0,
54
- )
55
- self._x_scale = x_scale
56
- self._train_x_scaled = self._train_x / self._x_scale
57
- else:
58
- self._x_scale = np.ones((1, self._num_dim), dtype=float)
59
- self._train_x_scaled = self._train_x
60
- if len(self._train_y) < 2:
61
- self._y_scale = np.ones(shape=(1, self._num_metrics), dtype=float)
62
- else:
63
- y_scale = np.std(self._train_y, axis=0, keepdims=True).astype(float)
64
- self._y_scale = np.where(
65
- np.isfinite(y_scale) & (y_scale > 0.0), y_scale, 1.0
66
- )
67
-
68
- self._index: Any | None = None
69
- self._build_index()
70
-
71
- @property
72
- def train_x(self) -> np.ndarray:
73
- return self._train_x
74
-
75
- @property
76
- def train_y(self) -> np.ndarray:
77
- return self._train_y
78
-
79
- @property
80
- def train_yvar(self) -> np.ndarray | None:
81
- return self._train_yvar
82
-
83
- @property
84
- def num_outputs(self) -> int:
85
- return self._num_metrics
86
-
87
- def __len__(self) -> int:
88
- return self._num_obs
89
-
90
- def _build_index(self) -> None:
91
- import faiss
92
- import numpy as np
93
-
94
- if self._num_obs == 0:
95
- return
96
- x_f32 = self._train_x_scaled.astype(np.float32, copy=False)
97
- index = faiss.IndexFlatL2(self._num_dim)
98
- index.add(x_f32)
99
- self._index = index
100
-
101
- def _search_index(
102
- self,
103
- x: np.ndarray,
104
- *,
105
- search_k: int,
106
- exclude_nearest: bool,
107
- ) -> tuple[np.ndarray, np.ndarray]:
108
- import numpy as np
109
-
110
- search_k = int(search_k)
111
- if search_k <= 0:
112
- raise ValueError(search_k)
113
- x = np.asarray(x, dtype=float)
114
- if x.ndim != 2 or x.shape[1] != self._num_dim:
115
- raise ValueError(x.shape)
116
- if self._index is None:
117
- raise RuntimeError("index is not initialized")
118
-
119
- x_scaled = x / self._x_scale if self._scale_x else x
120
- x_f32 = x_scaled.astype(np.float32, copy=False)
121
- dist2s_full, idx_full = self._index.search(x_f32, search_k)
122
- dist2s_full = dist2s_full.astype(float)
123
- idx_full = idx_full.astype(int)
124
- if exclude_nearest:
125
- dist2s_full = dist2s_full[:, 1:]
126
- idx_full = idx_full[:, 1:]
127
- return dist2s_full, idx_full
128
-
129
- def posterior(
130
- self,
131
- x: np.ndarray,
132
- *,
133
- params: ENNParams,
134
- exclude_nearest: bool = False,
135
- observation_noise: bool = False,
136
- ) -> ENNNormal:
137
- from .enn_normal import ENNNormal
138
-
139
- post_batch = self.batch_posterior(
140
- x,
141
- [params],
142
- exclude_nearest=exclude_nearest,
143
- observation_noise=observation_noise,
144
- )
145
- mu = post_batch.mu[0]
146
- se = post_batch.se[0]
147
- return ENNNormal(mu, se)
148
-
149
- def batch_posterior(
150
- self,
151
- x: np.ndarray,
152
- paramss: list[ENNParams],
153
- *,
154
- exclude_nearest: bool = False,
155
- observation_noise: bool = False,
156
- ) -> ENNNormal:
157
- import numpy as np
158
-
159
- from .enn_normal import ENNNormal
160
-
161
- x = np.asarray(x, dtype=float)
162
- if x.ndim != 2:
163
- raise ValueError(x.shape)
164
- if x.shape[1] != self._num_dim:
165
- raise ValueError(x.shape)
166
- if len(paramss) == 0:
167
- raise ValueError("paramss must be non-empty")
168
- batch_size = x.shape[0]
169
- num_params = len(paramss)
170
- if len(self) == 0:
171
- mu = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
172
- se = np.ones((num_params, batch_size, self._num_metrics), dtype=float)
173
- return ENNNormal(mu, se)
174
- max_k = max(params.k for params in paramss)
175
- if exclude_nearest:
176
- if len(self) <= 1:
177
- raise ValueError(len(self))
178
- search_k = int(min(max_k + 1, len(self)))
179
- else:
180
- search_k = int(min(max_k, len(self)))
181
- dist2s_full, idx_full = self._search_index(
182
- x, search_k=search_k, exclude_nearest=exclude_nearest
183
- )
184
- mu_all = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
185
- se_all = np.zeros((num_params, batch_size, self._num_metrics), dtype=float)
186
- available_k = search_k - 1 if exclude_nearest else search_k
187
- for i, params in enumerate(paramss):
188
- k = min(params.k, available_k)
189
- if k > dist2s_full.shape[1]:
190
- raise RuntimeError(
191
- f"k={k} exceeds available columns={dist2s_full.shape[1]}"
192
- )
193
- if k == 0:
194
- mu_all[i] = np.zeros((batch_size, self._num_metrics), dtype=float)
195
- se_all[i] = np.ones((batch_size, self._num_metrics), dtype=float)
196
- continue
197
- dist2s = dist2s_full[:, :k]
198
- idx = idx_full[:, :k]
199
- y_neighbors = self._train_y[idx]
200
-
201
- dist2s_expanded = dist2s[..., np.newaxis]
202
- var_component = (
203
- params.ale_homoscedastic_scale + params.epi_var_scale * dist2s_expanded
204
- )
205
- if self._train_yvar is not None:
206
- yvar_neighbors = self._train_yvar[idx] / self._y_scale**2
207
- var_component = var_component + yvar_neighbors
208
- else:
209
- yvar_neighbors = None
210
-
211
- w = 1.0 / (self._eps_var + var_component)
212
- norm = np.sum(w, axis=1)
213
- mu_all[i] = np.sum(w * y_neighbors, axis=1) / norm
214
- epistemic_var = 1.0 / norm
215
- vvar = epistemic_var
216
- if observation_noise:
217
- vvar = vvar + params.ale_homoscedastic_scale
218
- if yvar_neighbors is not None:
219
- ale_heteroscedastic = np.sum(w * yvar_neighbors, axis=1) / norm
220
- vvar = vvar + ale_heteroscedastic
221
- vvar = np.maximum(vvar, self._eps_var)
222
- se_all[i] = np.sqrt(vvar) * self._y_scale
223
- return ENNNormal(mu_all, se_all)
224
-
225
- def neighbors(
226
- self,
227
- x: np.ndarray,
228
- k: int,
229
- *,
230
- exclude_nearest: bool = False,
231
- ) -> list[tuple[np.ndarray, np.ndarray]]:
232
- import numpy as np
233
-
234
- x = np.asarray(x, dtype=float)
235
- if x.ndim == 1:
236
- x = x[np.newaxis, :]
237
- if x.ndim != 2:
238
- raise ValueError(f"x must be 1D or 2D, got shape {x.shape}")
239
- if x.shape[0] != 1:
240
- raise ValueError(f"x must be a single point, got shape {x.shape}")
241
- if x.shape[1] != self._num_dim:
242
- raise ValueError(
243
- f"x must have {self._num_dim} dimensions, got {x.shape[1]}"
244
- )
245
- if k < 0:
246
- raise ValueError(f"k must be non-negative, got {k}")
247
- if len(self) == 0:
248
- return []
249
- if exclude_nearest:
250
- if len(self) <= 1:
251
- raise ValueError(
252
- f"exclude_nearest=True requires at least 2 observations, got {len(self)}"
253
- )
254
- search_k = int(min(k + 1, len(self)))
255
- else:
256
- search_k = int(min(k, len(self)))
257
- if search_k == 0:
258
- return []
259
- dist2s_full, idx_full = self._search_index(
260
- x, search_k=search_k, exclude_nearest=exclude_nearest
261
- )
262
- actual_k = min(k, len(idx_full[0]))
263
- idx = idx_full[0, :actual_k]
264
- result = []
265
- for i in idx:
266
- x_neighbor = self._train_x[i].copy()
267
- y_neighbor = self._train_y[i].copy()
268
- result.append((x_neighbor, y_neighbor))
269
- return result
1
+ from .draw_internals import DrawInternals
2
+ from .neighbor_data import NeighborData
3
+ from .weighted_stats import WeightedStats
4
+
5
+ _DrawInternals = DrawInternals
6
+ _NeighborData = NeighborData
7
+ _WeightedStats = WeightedStats
8
+
9
+ __all__ = [
10
+ "DrawInternals",
11
+ "NeighborData",
12
+ "WeightedStats",
13
+ "_DrawInternals",
14
+ "_NeighborData",
15
+ "_WeightedStats",
16
+ ]