statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1083 @@
1
+ """Bandwidth selection helpers for kernel-based estimators."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ import math
7
+ from typing import Any, Dict, Optional, Union
8
+
9
+ import numpy as np
10
+
11
+ from statgpu.backends import _torch_dev, xp_eye, xp_arange, xp_asarray, xp_astype, xp_maximum
12
+
13
+ from statgpu.nonparametric.kernel_smoothing._kernel_common import (
14
+ _bandwidth_factor,
15
+ _bandwidth_factor_1d_nrd,
16
+ _kernel_values_from_quad,
17
+ _normalize_regression_name,
18
+ _to_float_scalar,
19
+ _to_numpy,
20
+ )
21
+
22
+ _BW_DELMAX = 1000.0
23
+ _SQRT_PI = math.sqrt(math.pi)
24
+ _SQRT_2PI = math.sqrt(2.0 * math.pi)
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class BandwidthSelectionResult:
29
+ """Diagnostic result for automatic bandwidth selection."""
30
+
31
+ factor: float
32
+ method: str
33
+ n_features: int
34
+ n_eff: float
35
+ used_r_selector: bool
36
+ weighted: bool
37
+ weighted_strategy: str
38
+ multivariate_strategy: str
39
+ selector_dimension: int
40
+ details: Dict[str, Any]
41
+
42
+ def to_dict(self) -> Dict[str, Any]:
43
+ return {
44
+ "factor": float(self.factor),
45
+ "method": str(self.method),
46
+ "n_features": int(self.n_features),
47
+ "n_eff": float(self.n_eff),
48
+ "used_r_selector": bool(self.used_r_selector),
49
+ "weighted": bool(self.weighted),
50
+ "weighted_strategy": str(self.weighted_strategy),
51
+ "multivariate_strategy": str(self.multivariate_strategy),
52
+ "selector_dimension": int(self.selector_dimension),
53
+ "details": dict(self.details),
54
+ }
55
+
56
+
57
+ def _normalize_weighted_strategy(strategy: str) -> str:
58
+ name = str(strategy).strip().lower()
59
+ aliases = {
60
+ "quantile_resample": "quantile_resample",
61
+ "quantile": "quantile_resample",
62
+ "resample": "quantile_resample",
63
+ }
64
+ out = aliases.get(name)
65
+ if out is None:
66
+ raise ValueError("weighted_r_selector_strategy must be 'quantile_resample'")
67
+ return out
68
+
69
+
70
+ def _normalize_multivariate_strategy(strategy: str) -> str:
71
+ name = str(strategy).strip().lower()
72
+ aliases = {
73
+ "projection_pca_1d": "projection_pca_1d",
74
+ "projection": "projection_pca_1d",
75
+ "pca": "projection_pca_1d",
76
+ }
77
+ out = aliases.get(name)
78
+ if out is None:
79
+ raise ValueError("multivariate_selector_strategy must be 'projection_pca_1d'")
80
+ return out
81
+
82
+
83
+ def _normalize_estimator_name(estimator: str) -> str:
84
+ name = str(estimator).strip().lower()
85
+ aliases = {
86
+ "kde": "kde",
87
+ "kernel_density": "kde",
88
+ "gaussian_kde": "kde",
89
+ "kernel_regression": "kernel_regression",
90
+ "kreg": "kernel_regression",
91
+ "regression": "kernel_regression",
92
+ }
93
+ out = aliases.get(name)
94
+ if out is None:
95
+ raise ValueError("estimator must be one of: 'kde', 'kernel_regression'")
96
+ return out
97
+
98
+
99
+ # Alias for backward compatibility - delegates to _kernel_common
100
+ _normalize_regression_mode = _normalize_regression_name
101
+
102
+
103
+ def _normalized_weights_numpy(weights: np.ndarray) -> np.ndarray:
104
+ w = np.asarray(weights, dtype=np.float64).reshape(-1)
105
+ if w.size == 0:
106
+ raise ValueError("weights must not be empty")
107
+ if not np.all(np.isfinite(w)):
108
+ raise ValueError("weights must be finite")
109
+ if np.min(w) < 0.0:
110
+ raise ValueError("weights must be non-negative")
111
+
112
+ w_sum = float(np.sum(w))
113
+ if (not np.isfinite(w_sum)) or w_sum <= 0.0:
114
+ raise ValueError("weights must sum to a positive value")
115
+
116
+ return w / w_sum
117
+
118
+
119
+ def _weighted_var_unbiased_1d(x: np.ndarray, w_norm: np.ndarray) -> float:
120
+ x_np = np.asarray(x, dtype=np.float64).reshape(-1)
121
+ w_np = np.asarray(w_norm, dtype=np.float64).reshape(-1)
122
+ if x_np.size != w_np.size:
123
+ raise ValueError("x and weights must have the same length")
124
+
125
+ mean = float(np.sum(w_np * x_np))
126
+ denom = 1.0 - float(np.sum(w_np * w_np))
127
+ if (not np.isfinite(denom)) or denom <= 1e-15:
128
+ return float("nan")
129
+
130
+ var = float(np.sum(w_np * (x_np - mean) ** 2) / denom)
131
+ if (not np.isfinite(var)) or var < 0.0:
132
+ return float("nan")
133
+ return var
134
+
135
+
136
+ def _weighted_quantile_resample_1d(x: np.ndarray, w_norm: np.ndarray, n_target: int) -> np.ndarray:
137
+ x_np = np.asarray(x, dtype=np.float64).reshape(-1)
138
+ w_np = np.asarray(w_norm, dtype=np.float64).reshape(-1)
139
+
140
+ n_t = int(n_target)
141
+ if n_t < 2:
142
+ raise ValueError("n_target must be at least 2 for weighted resampling")
143
+
144
+ idx = np.argsort(x_np)
145
+ x_s = x_np[idx]
146
+ w_s = w_np[idx]
147
+
148
+ cdf = np.cumsum(w_s)
149
+ cdf[-1] = 1.0
150
+
151
+ u = (np.arange(n_t, dtype=np.float64) + 0.5) / float(n_t)
152
+ x_rep = np.interp(u, cdf, x_s)
153
+ return np.asarray(x_rep, dtype=np.float64)
154
+
155
+
156
+ def _project_to_principal_axis(samples_2d, weights_1d) -> tuple[np.ndarray, np.ndarray, float]:
157
+ x = np.asarray(_to_numpy(samples_2d), dtype=np.float64)
158
+ if x.ndim != 2:
159
+ raise ValueError("samples_2d must be a 2D array")
160
+
161
+ w_norm = _normalized_weights_numpy(np.asarray(_to_numpy(weights_1d), dtype=np.float64))
162
+ if x.shape[0] != w_norm.size:
163
+ raise ValueError("weights shape is incompatible with samples")
164
+
165
+ mean = np.sum(x * w_norm[:, None], axis=0)
166
+ xc = x - mean[None, :]
167
+
168
+ cov = (xc.T * w_norm[None, :]) @ xc
169
+ cov = 0.5 * (cov + cov.T)
170
+
171
+ evals, evecs = np.linalg.eigh(cov)
172
+ idx_max = int(np.argmax(evals))
173
+ principal_vec = np.asarray(evecs[:, idx_max], dtype=np.float64).reshape(-1)
174
+ principal_eval = float(evals[idx_max])
175
+
176
+ if (not np.isfinite(principal_eval)) or principal_eval <= 0.0:
177
+ scales = np.sqrt(np.maximum(np.diag(cov), 0.0))
178
+ idx_fallback = int(np.argmax(scales))
179
+ principal_vec = np.zeros(x.shape[1], dtype=np.float64)
180
+ principal_vec[idx_fallback] = 1.0
181
+ principal_eval = float(np.max(scales) ** 2)
182
+
183
+ proj = xc @ principal_vec
184
+ total_var = float(np.sum(np.maximum(evals, 0.0)))
185
+ explained_ratio = float(principal_eval / total_var) if total_var > 0.0 else 1.0
186
+
187
+ return np.asarray(proj, dtype=np.float64), principal_vec, explained_ratio
188
+
189
+
190
+ # _bandwidth_factor, _bandwidth_factor_1d_nrd, _normalize_regression_mode
191
+ # are imported from _kernel_common (see top-level imports)
192
+
193
+
194
+ def _golden_section_minimize(func, lower: float, upper: float, tol: float) -> float:
195
+ a = float(lower)
196
+ b = float(upper)
197
+ if (not np.isfinite(a)) or (not np.isfinite(b)) or a <= 0.0 or b <= a:
198
+ raise ValueError("invalid optimization bounds for bandwidth selection")
199
+
200
+ tol_f = float(tol)
201
+ if (not np.isfinite(tol_f)) or tol_f <= 0.0:
202
+ tol_f = max(1e-8, (b - a) * 1e-4)
203
+
204
+ invphi = (math.sqrt(5.0) - 1.0) / 2.0
205
+ invphi2 = (3.0 - math.sqrt(5.0)) / 2.0
206
+
207
+ c = a + invphi2 * (b - a)
208
+ d = a + invphi * (b - a)
209
+
210
+ def _eval(x: float) -> float:
211
+ y = float(func(float(x)))
212
+ if not np.isfinite(y):
213
+ return float("inf")
214
+ return y
215
+
216
+ fc = _eval(c)
217
+ fd = _eval(d)
218
+
219
+ for _ in range(256):
220
+ if (b - a) <= tol_f:
221
+ break
222
+ if fc < fd:
223
+ b = d
224
+ d = c
225
+ fd = fc
226
+ c = a + invphi2 * (b - a)
227
+ fc = _eval(c)
228
+ else:
229
+ a = c
230
+ c = d
231
+ fc = fd
232
+ d = a + invphi * (b - a)
233
+ fd = _eval(d)
234
+
235
+ return float(c if fc <= fd else d)
236
+
237
+
238
+ def _bisection_root(func, lower: float, upper: float, tol: float) -> float:
239
+ a = float(lower)
240
+ b = float(upper)
241
+ if (not np.isfinite(a)) or (not np.isfinite(b)) or b <= a:
242
+ raise ValueError("invalid bisection interval")
243
+
244
+ fa = float(func(a))
245
+ fb = float(func(b))
246
+ if (not np.isfinite(fa)) or (not np.isfinite(fb)):
247
+ raise ValueError("non-finite values in bisection objective")
248
+ if fa == 0.0:
249
+ return a
250
+ if fb == 0.0:
251
+ return b
252
+ if fa * fb > 0.0:
253
+ raise ValueError("bisection interval does not bracket a root")
254
+
255
+ tol_f = float(tol)
256
+ if (not np.isfinite(tol_f)) or tol_f <= 0.0:
257
+ tol_f = max(1e-8, (b - a) * 1e-4)
258
+
259
+ for _ in range(256):
260
+ mid = 0.5 * (a + b)
261
+ if abs(b - a) <= tol_f:
262
+ return float(mid)
263
+ fm = float(func(mid))
264
+ if not np.isfinite(fm):
265
+ raise ValueError("non-finite values in bisection objective")
266
+ if fm == 0.0:
267
+ return float(mid)
268
+ if fa * fm < 0.0:
269
+ b = mid
270
+ fb = fm
271
+ else:
272
+ a = mid
273
+ fa = fm
274
+
275
+ return float(0.5 * (a + b))
276
+
277
+
278
+ def _bw_pair_distance_counts_1d(x: np.ndarray, nb: int = 1000) -> tuple[float, np.ndarray]:
279
+ x_np = np.asarray(x, dtype=np.float64).reshape(-1)
280
+ if x_np.size < 2:
281
+ raise ValueError("need at least 2 data points for automatic bandwidth")
282
+ if not np.all(np.isfinite(x_np)):
283
+ raise ValueError("samples must be finite for automatic bandwidth")
284
+
285
+ nb_i = int(nb)
286
+ if nb_i <= 1:
287
+ raise ValueError("nb must be greater than 1")
288
+
289
+ xmin = float(np.min(x_np))
290
+ xmax = float(np.max(x_np))
291
+ data_range = xmax - xmin
292
+ if (not np.isfinite(data_range)) or data_range <= 0.0:
293
+ raise ValueError("data are constant in automatic bandwidth calculation")
294
+
295
+ d = float(1.01 * data_range / float(nb_i))
296
+ if (not np.isfinite(d)) or d <= 0.0:
297
+ raise ValueError("invalid bin width in automatic bandwidth calculation")
298
+
299
+ idx = np.floor((x_np - xmin) / d).astype(np.int64)
300
+ idx = np.clip(idx, 0, nb_i - 1)
301
+ hist = np.bincount(idx, minlength=nb_i).astype(np.float64)
302
+
303
+ cnt = np.correlate(hist, hist, mode="full")[nb_i - 1 :].astype(np.float64)
304
+ cnt[0] = np.sum(hist * (hist - 1.0) * 0.5)
305
+ return d, cnt
306
+
307
+
308
+ def _bw_ucv_objective(n: int, d: float, cnt: np.ndarray, h: float) -> float:
309
+ h_f = float(h)
310
+ if (not np.isfinite(h_f)) or h_f <= 0.0:
311
+ return float("inf")
312
+
313
+ idx = np.arange(cnt.size, dtype=np.float64)
314
+ delta = (idx * float(d) / h_f) ** 2
315
+ mask = delta < _BW_DELMAX
316
+ delta_m = delta[mask]
317
+ cnt_m = cnt[mask]
318
+
319
+ term = np.exp(-delta_m / 4.0) - math.sqrt(8.0) * np.exp(-delta_m / 2.0)
320
+ sum_term = float(np.sum(term * cnt_m))
321
+ return float((0.5 + sum_term / float(n)) / (float(n) * h_f * _SQRT_PI))
322
+
323
+
324
+ def _bw_bcv_objective(n: int, d: float, cnt: np.ndarray, h: float) -> float:
325
+ h_f = float(h)
326
+ if (not np.isfinite(h_f)) or h_f <= 0.0:
327
+ return float("inf")
328
+
329
+ idx = np.arange(cnt.size, dtype=np.float64)
330
+ delta = (idx * float(d) / h_f) ** 2
331
+ mask = delta < _BW_DELMAX
332
+ delta_m = delta[mask]
333
+ cnt_m = cnt[mask]
334
+
335
+ term = np.exp(-delta_m / 4.0) * (delta_m * delta_m - 12.0 * delta_m + 12.0)
336
+ sum_term = float(np.sum(term * cnt_m))
337
+ return float((1.0 + sum_term / (32.0 * float(n))) / (2.0 * float(n) * h_f * _SQRT_PI))
338
+
339
+
340
+ def _bw_phi4(n: int, d: float, cnt: np.ndarray, h: float) -> float:
341
+ h_f = float(h)
342
+ if (not np.isfinite(h_f)) or h_f <= 0.0:
343
+ return float("nan")
344
+
345
+ idx = np.arange(cnt.size, dtype=np.float64)
346
+ delta = (idx * float(d) / h_f) ** 2
347
+ mask = delta < _BW_DELMAX
348
+ delta_m = delta[mask]
349
+ cnt_m = cnt[mask]
350
+
351
+ term = np.exp(-delta_m / 2.0) * (delta_m * delta_m - 6.0 * delta_m + 3.0)
352
+ sum_term = float(np.sum(term * cnt_m))
353
+ sum_term = 2.0 * sum_term + 3.0 * float(n)
354
+ denom = float(n * (n - 1)) * (h_f ** 5) * _SQRT_2PI
355
+ if denom <= 0.0:
356
+ return float("nan")
357
+ return float(sum_term / denom)
358
+
359
+
360
+ def _bw_phi6(n: int, d: float, cnt: np.ndarray, h: float) -> float:
361
+ h_f = float(h)
362
+ if (not np.isfinite(h_f)) or h_f <= 0.0:
363
+ return float("nan")
364
+
365
+ idx = np.arange(cnt.size, dtype=np.float64)
366
+ delta = (idx * float(d) / h_f) ** 2
367
+ mask = delta < _BW_DELMAX
368
+ delta_m = delta[mask]
369
+ cnt_m = cnt[mask]
370
+
371
+ term = np.exp(-delta_m / 2.0) * (
372
+ delta_m * delta_m * delta_m - 15.0 * delta_m * delta_m + 45.0 * delta_m - 15.0
373
+ )
374
+ sum_term = float(np.sum(term * cnt_m))
375
+ sum_term = 2.0 * sum_term - 15.0 * float(n)
376
+ denom = float(n * (n - 1)) * (h_f ** 7) * _SQRT_2PI
377
+ if denom <= 0.0:
378
+ return float("nan")
379
+ return float(sum_term / denom)
380
+
381
+
382
+ def _bandwidth_factor_1d_r_selectors(
383
+ method: str,
384
+ *,
385
+ samples_2d,
386
+ weights_1d,
387
+ data_cov,
388
+ weighted_strategy: str = "quantile_resample",
389
+ ) -> float:
390
+ method_n = str(method).strip().lower()
391
+ if method_n == "sj":
392
+ method_n = "sj-ste"
393
+ if method_n not in ("ucv", "bcv", "sj-ste", "sj-dpi"):
394
+ raise ValueError("method must be one of: 'ucv', 'bcv', 'sj', 'sj-ste', 'sj-dpi'")
395
+
396
+ x = np.asarray(_to_numpy(samples_2d[:, 0]), dtype=np.float64).reshape(-1)
397
+ if x.size < 2:
398
+ raise ValueError("need at least 2 samples for automatic bandwidth selection")
399
+ if not np.all(np.isfinite(x)):
400
+ raise ValueError("samples must be finite for automatic bandwidth selection")
401
+
402
+ w = _normalized_weights_numpy(np.asarray(_to_numpy(weights_1d), dtype=np.float64).reshape(-1))
403
+ if w.size != x.size:
404
+ raise ValueError("weights shape is incompatible with samples")
405
+
406
+ x_work = x
407
+ is_weighted = float(np.max(w) - np.min(w)) > 1e-12
408
+ if is_weighted:
409
+ strategy = _normalize_weighted_strategy(weighted_strategy)
410
+ if strategy == "quantile_resample":
411
+ n_eff = float(1.0 / np.sum(w * w))
412
+ n_rep = int(np.clip(round(max(128.0, min(8192.0, n_eff * 8.0))), 128, 8192))
413
+ x_work = _weighted_quantile_resample_1d(x, w, n_rep)
414
+ else:
415
+ raise ValueError("unsupported weighted strategy")
416
+
417
+ n = int(x_work.size)
418
+ d, cnt = _bw_pair_distance_counts_1d(x_work, nb=1000)
419
+
420
+ sample_sd = float(np.std(x_work, ddof=1))
421
+ if (not np.isfinite(sample_sd)) or sample_sd <= 0.0:
422
+ raise ValueError("data are constant in automatic bandwidth calculation")
423
+
424
+ q75, q25 = np.quantile(x_work, [0.75, 0.25])
425
+ robust = float((q75 - q25) / 1.349)
426
+ scale = min(sample_sd, robust) if np.isfinite(robust) and robust > 0.0 else sample_sd
427
+ if (not np.isfinite(scale)) or scale <= 0.0:
428
+ scale = sample_sd
429
+
430
+ if method_n in ("ucv", "bcv"):
431
+ hmax = float(1.144 * sample_sd * (float(n) ** (-1.0 / 5.0)))
432
+ lower = max(hmax * 0.1, float(np.finfo(np.float64).tiny))
433
+ upper = max(hmax, lower * 1.01)
434
+ tol = max(lower * 0.1, 1e-8)
435
+
436
+ obj = _bw_ucv_objective if method_n == "ucv" else _bw_bcv_objective
437
+ bw_abs = _golden_section_minimize(lambda h: obj(n, d, cnt, h), lower, upper, tol)
438
+ else:
439
+ hmax = float(1.144 * scale * (float(n) ** (-1.0 / 5.0)))
440
+ lower = max(hmax * 0.1, float(np.finfo(np.float64).tiny))
441
+ upper = max(hmax, lower * 1.01)
442
+ tol = max(lower * 0.1, 1e-8)
443
+
444
+ c1 = 1.0 / (2.0 * _SQRT_PI * float(n))
445
+ a = float(1.24 * scale * (float(n) ** (-1.0 / 7.0)))
446
+ b = float(1.23 * scale * (float(n) ** (-1.0 / 9.0)))
447
+
448
+ td = -_bw_phi6(n, d, cnt, b)
449
+ if (not np.isfinite(td)) or td <= 0.0:
450
+ raise ValueError("sample is too sparse to find TD for 'sj' bandwidth")
451
+
452
+ if method_n == "sj-dpi":
453
+ h_phi4 = float((2.394 / (float(n) * td)) ** (1.0 / 7.0))
454
+ sd_h = _bw_phi4(n, d, cnt, h_phi4)
455
+ if (not np.isfinite(sd_h)) or sd_h <= 0.0:
456
+ raise ValueError("sample is too sparse to find SD for 'sj-dpi' bandwidth")
457
+ bw_abs = float((c1 / sd_h) ** (1.0 / 5.0))
458
+ else:
459
+ sd_a = _bw_phi4(n, d, cnt, a)
460
+ if (not np.isfinite(sd_a)) or sd_a <= 0.0:
461
+ raise ValueError("sample is too sparse to find SD for 'sj-ste' bandwidth")
462
+
463
+ alph2 = float(1.357 * ((sd_a / td) ** (1.0 / 7.0)))
464
+ if (not np.isfinite(alph2)) or alph2 <= 0.0:
465
+ raise ValueError("sample is too sparse to find alph2 for 'sj-ste' bandwidth")
466
+
467
+ def f_sd(h: float) -> float:
468
+ h_f = float(h)
469
+ sd_term = _bw_phi4(n, d, cnt, alph2 * (h_f ** (5.0 / 7.0)))
470
+ if (not np.isfinite(sd_term)) or sd_term <= 0.0:
471
+ return float("nan")
472
+ return float((c1 / sd_term) ** (1.0 / 5.0) - h_f)
473
+
474
+ fl = float(f_sd(lower))
475
+ fu = float(f_sd(upper))
476
+ itry = 1
477
+ while (not np.isfinite(fl) or not np.isfinite(fu) or (fl * fu > 0.0)) and itry <= 99:
478
+ if itry % 2 == 1:
479
+ upper *= 1.2
480
+ else:
481
+ lower /= 1.2
482
+ lower = max(lower, float(np.finfo(np.float64).tiny))
483
+ fl = float(f_sd(lower))
484
+ fu = float(f_sd(upper))
485
+ itry += 1
486
+
487
+ if (not np.isfinite(fl)) or (not np.isfinite(fu)) or (fl * fu > 0.0):
488
+ raise ValueError("no solution found for 'sj-ste' bandwidth in the search range")
489
+
490
+ bw_abs = _bisection_root(f_sd, lower, upper, tol)
491
+
492
+ if (not np.isfinite(bw_abs)) or bw_abs <= 0.0:
493
+ raise ValueError("automatic bandwidth rule produced a non-positive value")
494
+
495
+ data_sd = math.sqrt(max(_to_float_scalar(data_cov[0, 0]), 0.0))
496
+ if data_sd <= 0.0 or (not np.isfinite(data_sd)):
497
+ data_sd = sample_sd
498
+
499
+ factor = float(bw_abs / data_sd)
500
+ if (not np.isfinite(factor)) or factor <= 0.0:
501
+ raise ValueError("bandwidth factor must be a finite positive scalar")
502
+ return factor
503
+
504
+
505
+ def _multivariate_factor_from_projected_1d(
506
+ method: str,
507
+ *,
508
+ samples_2d,
509
+ weights_1d,
510
+ data_cov,
511
+ n_eff: float,
512
+ rule_kind: str,
513
+ weighted_r_selector_strategy: str,
514
+ ) -> tuple[float, Dict[str, Any]]:
515
+ proj, principal_vec, explained_ratio = _project_to_principal_axis(samples_2d, weights_1d)
516
+ w_norm = _normalized_weights_numpy(np.asarray(_to_numpy(weights_1d), dtype=np.float64))
517
+
518
+ proj_2d = np.asarray(proj, dtype=np.float64).reshape(-1, 1)
519
+
520
+ var_proj = _weighted_var_unbiased_1d(proj, w_norm)
521
+ if (not np.isfinite(var_proj)) or var_proj <= 0.0:
522
+ var_proj = float(np.var(proj, ddof=1)) if proj.size >= 2 else float("nan")
523
+ if (not np.isfinite(var_proj)) or var_proj <= 0.0:
524
+ var_proj = float(np.finfo(np.float64).tiny)
525
+
526
+ proj_cov = np.asarray([[var_proj]], dtype=np.float64)
527
+
528
+ if rule_kind == "nrd":
529
+ factor = _bandwidth_factor_1d_nrd(
530
+ method,
531
+ n_eff=n_eff,
532
+ samples_2d=proj_2d,
533
+ data_cov=proj_cov,
534
+ xp=np,
535
+ )
536
+ elif rule_kind == "r_selector":
537
+ factor = _bandwidth_factor_1d_r_selectors(
538
+ method,
539
+ samples_2d=proj_2d,
540
+ weights_1d=w_norm,
541
+ data_cov=proj_cov,
542
+ weighted_strategy=weighted_r_selector_strategy,
543
+ )
544
+ else:
545
+ raise ValueError("rule_kind must be one of: 'nrd', 'r_selector'")
546
+
547
+ details = {
548
+ "projection_explained_ratio": float(explained_ratio),
549
+ "projection_vector": np.asarray(principal_vec, dtype=np.float64),
550
+ "projection_variance": float(var_proj),
551
+ }
552
+ return float(factor), details
553
+
554
+
555
+ def _as_targets_numpy_2d(targets, n_samples: int) -> np.ndarray:
556
+ y = np.asarray(_to_numpy(targets), dtype=np.float64)
557
+ if y.ndim == 1:
558
+ if y.shape[0] != n_samples:
559
+ raise ValueError("targets length must match samples")
560
+ y = y.reshape(-1, 1)
561
+ elif y.ndim == 2:
562
+ if y.shape[0] != n_samples:
563
+ raise ValueError("targets rows must match samples")
564
+ else:
565
+ raise ValueError("targets must be 1D or 2D")
566
+ return y
567
+
568
+
569
+ def _stable_inverse_cov(cov, xp=np, ref_arr=None):
570
+ d = int(cov.shape[0])
571
+ cov_work = xp_asarray(cov, dtype=xp.float64, xp=xp, ref_arr=ref_arr)
572
+ cov_work = 0.5 * (cov_work + cov_work.T)
573
+
574
+ trace = _to_float_scalar(xp.trace(cov_work))
575
+ base = trace / float(max(1, d)) if np.isfinite(trace) else 1.0
576
+ jitter = max(base * 1e-12, 1e-12)
577
+
578
+ for _ in range(8):
579
+ try:
580
+ return xp.linalg.inv(cov_work)
581
+ except Exception:
582
+ cov_work = cov_work + jitter * xp_eye(d, xp.float64, xp, cov_work)
583
+ jitter *= 10.0
584
+
585
+ return xp.linalg.pinv(cov_work)
586
+
587
+
588
+ def _fill_diagonal_zero(arr, xp=np):
589
+ """Set diagonal entries to zero across NumPy/CuPy/Torch backends."""
590
+ if hasattr(xp, "fill_diagonal"):
591
+ xp.fill_diagonal(arr, 0.0)
592
+ return
593
+ if hasattr(arr, "fill_diagonal_"):
594
+ arr.fill_diagonal_(0.0)
595
+ return
596
+ diag_idx = xp_arange(arr.shape[0], xp=xp, ref_arr=arr)
597
+ arr[diag_idx, diag_idx] = 0.0
598
+
599
+
600
+ def _kernel_regression_cv_score(
601
+ *,
602
+ samples_2d,
603
+ targets_2d,
604
+ weights_norm,
605
+ data_cov,
606
+ kernel_name: str,
607
+ factor: float,
608
+ regression_mode: str,
609
+ xp=np,
610
+ ) -> float:
611
+ f = float(factor)
612
+ if (not np.isfinite(f)) or f <= 0.0:
613
+ return float("inf")
614
+
615
+ n, d = int(samples_2d.shape[0]), int(samples_2d.shape[1])
616
+ if n < 3:
617
+ return float("inf")
618
+
619
+ scaled_cov = xp_asarray(data_cov, dtype=xp.float64, xp=xp, ref_arr=samples_2d) * (f ** 2)
620
+ inv_cov = _stable_inverse_cov(scaled_cov, xp=xp, ref_arr=samples_2d)
621
+
622
+ if d == 1:
623
+ x = samples_2d[:, 0]
624
+ diff = x[:, None] - x[None, :]
625
+ quad = (diff * diff) * _to_float_scalar(inv_cov[0, 0])
626
+ else:
627
+ s_proj = samples_2d @ inv_cov
628
+ s_quad = xp.sum(s_proj * samples_2d, axis=1)
629
+ cross = s_proj @ samples_2d.T
630
+ quad = s_quad[:, None] + s_quad[None, :] - 2.0 * cross
631
+ quad = xp_maximum(quad, 0.0, xp)
632
+
633
+ kernels = _kernel_values_from_quad(quad, kernel_name, xp)
634
+ _fill_diagonal_zero(kernels, xp)
635
+
636
+ weighted = kernels * weights_norm[None, :]
637
+ denom = xp.sum(weighted, axis=1)
638
+ tiny = float(np.finfo(np.float64).tiny)
639
+
640
+ valid = denom > tiny
641
+ if not _to_float_scalar(xp.any(valid)):
642
+ return float("inf")
643
+
644
+ numer_nw = weighted @ targets_2d
645
+ pred_nw = xp.where(
646
+ denom[:, None] > tiny,
647
+ numer_nw / xp.where(denom[:, None] > tiny, denom[:, None], 1.0),
648
+ xp.zeros_like(numer_nw),
649
+ )
650
+ pred = pred_nw
651
+
652
+ if regression_mode == "local_linear" and d == 1:
653
+ x = samples_2d[:, 0]
654
+ diff = x[:, None] - x[None, :]
655
+
656
+ s0 = denom
657
+ s1 = xp.sum(weighted * diff, axis=1)
658
+ s2 = xp.sum(weighted * diff * diff, axis=1)
659
+
660
+ t0 = numer_nw
661
+ t1 = (weighted * diff) @ targets_2d
662
+
663
+ det = s0 * s2 - s1 * s1
664
+ det_thresh = tiny * tiny
665
+ use_ll = (s0 > tiny) & (xp.abs(det) > det_thresh)
666
+
667
+ safe_det = xp.where(xp.abs(det) > det_thresh, det, 1.0)
668
+ pred_ll = xp.where(
669
+ use_ll[:, None],
670
+ (xp.where(s2 > 0, s2, 0.0)[:, None] * t0 - s1[:, None] * t1) / safe_det[:, None],
671
+ pred_nw,
672
+ )
673
+ pred = xp.where(use_ll[:, None], pred_ll, pred_nw)
674
+
675
+ err = targets_2d - pred
676
+ mse_i = xp.mean(err * err, axis=1)
677
+
678
+ w_valid = weights_norm * xp_astype(valid, xp.float64, xp)
679
+ wsum = _to_float_scalar(xp.sum(w_valid))
680
+ if (not np.isfinite(wsum)) or wsum <= 0.0:
681
+ return float("inf")
682
+
683
+ score = _to_float_scalar(xp.sum(w_valid * mse_i) / wsum)
684
+ if not np.isfinite(score):
685
+ return float("inf")
686
+ return score
687
+
688
+
689
+ def _as_targets_2d(targets, n_samples: int, xp=np, ref_arr=None):
690
+ y = xp_asarray(targets, dtype=xp.float64, xp=xp, ref_arr=ref_arr)
691
+ if y.ndim == 1:
692
+ if y.shape[0] != n_samples:
693
+ raise ValueError("targets length must match samples")
694
+ y = y.reshape(-1, 1)
695
+ elif y.ndim == 2:
696
+ if y.shape[0] != n_samples:
697
+ raise ValueError("targets rows must match samples")
698
+ else:
699
+ raise ValueError("targets must be 1D or 2D")
700
+ return y
701
+
702
+
703
+ def _normalized_weights(w, xp=np, ref_arr=None):
704
+ w_arr = xp_asarray(w, dtype=xp.float64, xp=xp, ref_arr=ref_arr).reshape(-1)
705
+ w_sum = _to_float_scalar(xp.sum(w_arr))
706
+ if w_sum <= 0.0:
707
+ raise ValueError("weights must sum to a positive value")
708
+ return w_arr / w_sum
709
+
710
+
711
+ def _kernel_regression_cv_factor(
712
+ *,
713
+ samples_2d,
714
+ targets_2d,
715
+ weights_1d,
716
+ data_cov,
717
+ kernel_name: str,
718
+ regression_mode: str,
719
+ n_eff: float,
720
+ n_features: int,
721
+ xp=np,
722
+ ) -> tuple[float, Dict[str, Any]]:
723
+ x = xp_asarray(samples_2d, dtype=xp.float64, xp=xp)
724
+ y = _as_targets_2d(targets_2d, int(x.shape[0]), xp=xp, ref_arr=x)
725
+ w = _normalized_weights(weights_1d, xp=xp, ref_arr=x)
726
+ cov = xp_asarray(data_cov, dtype=xp.float64, xp=xp, ref_arr=x)
727
+
728
+ d = int(n_features)
729
+ if d != int(x.shape[1]):
730
+ raise ValueError("n_features is inconsistent with samples")
731
+
732
+ f0 = _bandwidth_factor("scott", n_eff=float(n_eff), n_features=d)
733
+ lower = max(float(f0) * 0.2, 1e-4)
734
+ upper = max(float(f0) * 5.0, lower * 1.01)
735
+ tol = max(1e-4, lower * 0.02)
736
+
737
+ best_score_box = {"value": float("inf")}
738
+
739
+ def _objective(f: float) -> float:
740
+ score = _kernel_regression_cv_score(
741
+ samples_2d=x,
742
+ targets_2d=y,
743
+ weights_norm=w,
744
+ data_cov=cov,
745
+ kernel_name=kernel_name,
746
+ factor=f,
747
+ regression_mode=regression_mode,
748
+ xp=xp,
749
+ )
750
+ if score < best_score_box["value"]:
751
+ best_score_box["value"] = score
752
+ return score
753
+
754
+ factor = _golden_section_minimize(_objective, lower, upper, tol)
755
+ score = _objective(float(factor))
756
+
757
+ if (not np.isfinite(factor)) or factor <= 0.0:
758
+ raise ValueError("regression CV bandwidth rule produced a non-positive value")
759
+
760
+ details = {
761
+ "cv_objective": "leave_one_out_mse",
762
+ "cv_score": float(score),
763
+ "cv_score_best_seen": float(best_score_box["value"]),
764
+ "cv_search_lower": float(lower),
765
+ "cv_search_upper": float(upper),
766
+ "cv_regression_mode": str(regression_mode),
767
+ }
768
+ return float(factor), details
769
+
770
+
771
+ class _BaseBandwidthSelector:
772
+ def __init__(
773
+ self,
774
+ *,
775
+ estimator: str,
776
+ n_eff: float,
777
+ n_features: int,
778
+ samples_2d,
779
+ weights_1d,
780
+ data_cov,
781
+ xp,
782
+ enable_r_selectors: bool,
783
+ weighted_r_selector_strategy: str,
784
+ multivariate_selector_strategy: str,
785
+ ):
786
+ self.estimator = _normalize_estimator_name(estimator)
787
+ self.n_eff = float(n_eff)
788
+ self.n_features = int(n_features)
789
+ self.samples_2d = samples_2d
790
+ self.weights_1d = weights_1d
791
+ self.data_cov = data_cov
792
+ self.xp = xp
793
+ self.enable_r_selectors = bool(enable_r_selectors)
794
+ self.weighted_r_selector_strategy = _normalize_weighted_strategy(weighted_r_selector_strategy)
795
+ self.multivariate_selector_strategy = _normalize_multivariate_strategy(multivariate_selector_strategy)
796
+
797
+ self.weights_np = _normalized_weights_numpy(
798
+ np.asarray(_to_numpy(weights_1d), dtype=np.float64).reshape(-1)
799
+ )
800
+ self.is_weighted = float(np.max(self.weights_np) - np.min(self.weights_np)) > 1e-12
801
+
802
+ if (not np.isfinite(self.n_eff)) or self.n_eff <= 0.0:
803
+ raise ValueError("n_eff must be a finite positive scalar")
804
+ if self.n_features <= 0:
805
+ raise ValueError("n_features must be a positive integer")
806
+
807
+ def _base_details(self, bandwidth: Union[str, float, int]) -> Dict[str, Any]:
808
+ return {
809
+ "input_bandwidth": bandwidth,
810
+ "n_samples": int(self.weights_np.size),
811
+ "estimator": self.estimator,
812
+ }
813
+
814
+ def _select_special(self, bw_name: str, details: Dict[str, Any]) -> Optional[BandwidthSelectionResult]:
815
+ return None
816
+
817
+ def select(self, bandwidth: Union[str, float, int]) -> BandwidthSelectionResult:
818
+ used_r_selector = False
819
+ selector_dim = self.n_features
820
+ weighted_used = "uniform" if not self.is_weighted else self.weighted_r_selector_strategy
821
+ multi_used = "none"
822
+ method_label = "scalar"
823
+ details = self._base_details(bandwidth)
824
+
825
+ if isinstance(bandwidth, str):
826
+ bw_name = bandwidth.strip().lower()
827
+ method_label = bw_name
828
+
829
+ special = self._select_special(bw_name, details)
830
+ if special is not None:
831
+ return special
832
+
833
+ if bw_name in ("nrd0", "nrd"):
834
+ if self.n_features == 1:
835
+ selector_dim = 1
836
+ factor = _bandwidth_factor_1d_nrd(
837
+ bw_name,
838
+ n_eff=self.n_eff,
839
+ samples_2d=self.samples_2d,
840
+ data_cov=self.data_cov,
841
+ xp=self.xp,
842
+ )
843
+ else:
844
+ multi_used = self.multivariate_selector_strategy
845
+ selector_dim = 1
846
+ factor, proj_details = _multivariate_factor_from_projected_1d(
847
+ bw_name,
848
+ samples_2d=self.samples_2d,
849
+ weights_1d=self.weights_1d,
850
+ data_cov=self.data_cov,
851
+ n_eff=self.n_eff,
852
+ rule_kind="nrd",
853
+ weighted_r_selector_strategy=self.weighted_r_selector_strategy,
854
+ )
855
+ details.update(proj_details)
856
+
857
+ details["rule"] = "nrd"
858
+ return BandwidthSelectionResult(
859
+ factor=float(factor),
860
+ method=method_label,
861
+ n_features=self.n_features,
862
+ n_eff=self.n_eff,
863
+ used_r_selector=False,
864
+ weighted=self.is_weighted,
865
+ weighted_strategy=weighted_used,
866
+ multivariate_strategy=multi_used,
867
+ selector_dimension=selector_dim,
868
+ details=details,
869
+ )
870
+
871
+ if bw_name in ("ucv", "bcv", "sj", "sj-ste", "sj-dpi"):
872
+ if not self.enable_r_selectors:
873
+ raise ValueError("R-style bandwidth selectors are disabled for this estimator")
874
+
875
+ used_r_selector = True
876
+ details["rule"] = "r_selector"
877
+
878
+ if self.n_features == 1:
879
+ selector_dim = 1
880
+ factor = _bandwidth_factor_1d_r_selectors(
881
+ bw_name,
882
+ samples_2d=self.samples_2d,
883
+ weights_1d=self.weights_1d,
884
+ data_cov=self.data_cov,
885
+ weighted_strategy=self.weighted_r_selector_strategy,
886
+ )
887
+ else:
888
+ multi_used = self.multivariate_selector_strategy
889
+ selector_dim = 1
890
+ factor, proj_details = _multivariate_factor_from_projected_1d(
891
+ bw_name,
892
+ samples_2d=self.samples_2d,
893
+ weights_1d=self.weights_1d,
894
+ data_cov=self.data_cov,
895
+ n_eff=self.n_eff,
896
+ rule_kind="r_selector",
897
+ weighted_r_selector_strategy=self.weighted_r_selector_strategy,
898
+ )
899
+ details.update(proj_details)
900
+
901
+ return BandwidthSelectionResult(
902
+ factor=float(factor),
903
+ method=method_label,
904
+ n_features=self.n_features,
905
+ n_eff=self.n_eff,
906
+ used_r_selector=used_r_selector,
907
+ weighted=self.is_weighted,
908
+ weighted_strategy=weighted_used,
909
+ multivariate_strategy=multi_used,
910
+ selector_dimension=selector_dim,
911
+ details=details,
912
+ )
913
+
914
+ factor = _bandwidth_factor(
915
+ bandwidth,
916
+ n_eff=self.n_eff,
917
+ n_features=self.n_features,
918
+ )
919
+ return BandwidthSelectionResult(
920
+ factor=float(factor),
921
+ method=method_label,
922
+ n_features=self.n_features,
923
+ n_eff=self.n_eff,
924
+ used_r_selector=used_r_selector,
925
+ weighted=self.is_weighted,
926
+ weighted_strategy=weighted_used,
927
+ multivariate_strategy=multi_used,
928
+ selector_dimension=selector_dim,
929
+ details=details,
930
+ )
931
+
932
+
933
+ class _KDEBandwidthSelector(_BaseBandwidthSelector):
934
+ pass
935
+
936
+
937
+ class _KernelRegressionBandwidthSelector(_BaseBandwidthSelector):
938
+ def __init__(self, *, targets, regression: str, kernel: str, **kwargs):
939
+ super().__init__(**kwargs)
940
+ self.targets = targets
941
+ self.regression = _normalize_regression_mode(regression)
942
+ self.kernel = str(kernel).strip().lower()
943
+
944
+ def _select_special(self, bw_name: str, details: Dict[str, Any]) -> Optional[BandwidthSelectionResult]:
945
+ if bw_name not in ("cv", "cv_ls", "cv-nw", "cv-ll"):
946
+ return None
947
+
948
+ if bw_name == "cv-ll":
949
+ cv_mode = "local_linear"
950
+ elif bw_name == "cv-nw":
951
+ cv_mode = "nw"
952
+ else:
953
+ cv_mode = self.regression
954
+
955
+ factor, cv_details = _kernel_regression_cv_factor(
956
+ samples_2d=self.samples_2d,
957
+ targets_2d=self.targets,
958
+ weights_1d=self.weights_1d,
959
+ data_cov=self.data_cov,
960
+ kernel_name=self.kernel,
961
+ regression_mode=cv_mode,
962
+ n_eff=self.n_eff,
963
+ n_features=self.n_features,
964
+ xp=self.xp,
965
+ )
966
+
967
+ details["rule"] = "regression_cv"
968
+ details.update(cv_details)
969
+
970
+ return BandwidthSelectionResult(
971
+ factor=float(factor),
972
+ method=bw_name,
973
+ n_features=self.n_features,
974
+ n_eff=self.n_eff,
975
+ used_r_selector=False,
976
+ weighted=self.is_weighted,
977
+ weighted_strategy=(
978
+ "uniform" if not self.is_weighted else self.weighted_r_selector_strategy
979
+ ),
980
+ multivariate_strategy="none",
981
+ selector_dimension=self.n_features,
982
+ details=details,
983
+ )
984
+
985
+
986
+ def select_bandwidth(
987
+ bandwidth: Union[str, float, int],
988
+ *,
989
+ n_eff: float,
990
+ n_features: int,
991
+ samples_2d,
992
+ weights_1d,
993
+ data_cov,
994
+ xp,
995
+ enable_r_selectors: bool = True,
996
+ weighted_r_selector_strategy: str = "quantile_resample",
997
+ multivariate_selector_strategy: str = "projection_pca_1d",
998
+ estimator: str = "kde",
999
+ targets=None,
1000
+ regression: str = "nw",
1001
+ kernel: str = "gaussian",
1002
+ ) -> BandwidthSelectionResult:
1003
+ """Select bandwidth factor and return diagnostic metadata."""
1004
+
1005
+ estimator_name = _normalize_estimator_name(estimator)
1006
+ if estimator_name == "kde":
1007
+ selector = _KDEBandwidthSelector(
1008
+ estimator=estimator_name,
1009
+ n_eff=n_eff,
1010
+ n_features=n_features,
1011
+ samples_2d=samples_2d,
1012
+ weights_1d=weights_1d,
1013
+ data_cov=data_cov,
1014
+ xp=xp,
1015
+ enable_r_selectors=enable_r_selectors,
1016
+ weighted_r_selector_strategy=weighted_r_selector_strategy,
1017
+ multivariate_selector_strategy=multivariate_selector_strategy,
1018
+ )
1019
+ else:
1020
+ selector = _KernelRegressionBandwidthSelector(
1021
+ estimator=estimator_name,
1022
+ n_eff=n_eff,
1023
+ n_features=n_features,
1024
+ samples_2d=samples_2d,
1025
+ weights_1d=weights_1d,
1026
+ data_cov=data_cov,
1027
+ xp=xp,
1028
+ enable_r_selectors=enable_r_selectors,
1029
+ weighted_r_selector_strategy=weighted_r_selector_strategy,
1030
+ multivariate_selector_strategy=multivariate_selector_strategy,
1031
+ targets=targets,
1032
+ regression=regression,
1033
+ kernel=kernel,
1034
+ )
1035
+
1036
+ return selector.select(bandwidth)
1037
+
1038
+
1039
+ def select_bandwidth_factor(
1040
+ bandwidth: Union[str, float, int],
1041
+ *,
1042
+ n_eff: float,
1043
+ n_features: int,
1044
+ samples_2d,
1045
+ weights_1d,
1046
+ data_cov,
1047
+ xp,
1048
+ enable_r_selectors: bool = True,
1049
+ weighted_r_selector_strategy: str = "quantile_resample",
1050
+ multivariate_selector_strategy: str = "projection_pca_1d",
1051
+ estimator: str = "kde",
1052
+ targets=None,
1053
+ regression: str = "nw",
1054
+ kernel: str = "gaussian",
1055
+ ) -> float:
1056
+ """Select bandwidth factor for kernel estimators."""
1057
+ result = select_bandwidth(
1058
+ bandwidth,
1059
+ n_eff=n_eff,
1060
+ n_features=n_features,
1061
+ samples_2d=samples_2d,
1062
+ weights_1d=weights_1d,
1063
+ data_cov=data_cov,
1064
+ xp=xp,
1065
+ enable_r_selectors=enable_r_selectors,
1066
+ weighted_r_selector_strategy=weighted_r_selector_strategy,
1067
+ multivariate_selector_strategy=multivariate_selector_strategy,
1068
+ estimator=estimator,
1069
+ targets=targets,
1070
+ regression=regression,
1071
+ kernel=kernel,
1072
+ )
1073
+ return float(result.factor)
1074
+
1075
+
1076
+ __all__ = [
1077
+ "BandwidthSelectionResult",
1078
+ "_bandwidth_factor",
1079
+ "_bandwidth_factor_1d_nrd",
1080
+ "_bandwidth_factor_1d_r_selectors",
1081
+ "select_bandwidth",
1082
+ "select_bandwidth_factor",
1083
+ ]