statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,870 @@
1
+ """Fixed-X knockoff feature selection skeleton (CPU/GPU)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import numpy as np
9
+
10
+ from statgpu.feature_selection import _knockoff_utils as _kutils
11
+ from statgpu.feature_selection._knockoff_utils import (
12
+ _build_fixed_x_knockoffs,
13
+ _build_model_x_knockoffs,
14
+ _build_model_x_knockoffs_knockpy_compat,
15
+ _compute_w_statistics,
16
+ _get_xp,
17
+ _knockoff_threshold_and_path,
18
+ _model_x_draw_seed,
19
+ _normalize_compat_mode,
20
+ _normalize_fdr_control,
21
+ _normalize_knockoff_type,
22
+ _normalize_lasso_fast_profile,
23
+ _resolve_backend,
24
+ _standardize_design,
25
+ _standardize_features_unit_variance,
26
+ _to_numpy,
27
+ _validate_q,
28
+ )
29
+
30
+ # Backward compatibility for existing internal profiling scripts.
31
+ _random_permutation_inds = _kutils._random_permutation_inds
32
+
33
+
34
+ @dataclass
35
+ class KnockoffResult:
36
+ """Structured output for knockoff selection."""
37
+
38
+ knockoff_type: str
39
+ selected_features: np.ndarray
40
+ W: np.ndarray
41
+ threshold: float
42
+ q: float
43
+ estimated_fdr: float
44
+ q_trajectory: List[Dict[str, float]]
45
+ method: str
46
+ fdr_control: str
47
+ random_state: Optional[int]
48
+ backend: str
49
+ metadata: Dict[str, Any] = field(default_factory=dict)
50
+
51
+ def to_dict(self) -> Dict[str, Any]:
52
+ return {
53
+ "knockoff_type": self.knockoff_type,
54
+ "selected_features": self.selected_features.tolist(),
55
+ "W": self.W.tolist(),
56
+ "threshold": float(self.threshold),
57
+ "q": float(self.q),
58
+ "estimated_fdr": float(self.estimated_fdr),
59
+ "q_trajectory": list(self.q_trajectory),
60
+ "method": self.method,
61
+ "fdr_control": self.fdr_control,
62
+ "random_state": self.random_state,
63
+ "backend": self.backend,
64
+ "metadata": self.metadata,
65
+ }
66
+
67
+
68
+ # -----------------------------------------------------------------------------
69
+ # Knockpy-compatible interface placeholders
70
+ # -----------------------------------------------------------------------------
71
+ def knockpy_gaussian_mvr_sampler(
72
+ X,
73
+ *,
74
+ mu=None,
75
+ Sigma=None,
76
+ groups=None,
77
+ random_state: Optional[int] = None,
78
+ ):
79
+ """Interface placeholder for knockpy GaussianSampler(method='mvr')."""
80
+ raise NotImplementedError("knockpy_gaussian_mvr_sampler is not yet implemented")
81
+
82
+
83
+ def knockpy_gaussian_sdp_sampler(
84
+ X,
85
+ *,
86
+ mu=None,
87
+ Sigma=None,
88
+ groups=None,
89
+ random_state: Optional[int] = None,
90
+ ):
91
+ """Interface placeholder for knockpy GaussianSampler(method='sdp')."""
92
+ raise NotImplementedError("knockpy_gaussian_sdp_sampler is not yet implemented")
93
+
94
+
95
+ def knockpy_gaussian_maxent_sampler(
96
+ X,
97
+ *,
98
+ mu=None,
99
+ Sigma=None,
100
+ groups=None,
101
+ random_state: Optional[int] = None,
102
+ ):
103
+ """Interface placeholder for knockpy GaussianSampler(method='maxent')."""
104
+ raise NotImplementedError("knockpy_gaussian_maxent_sampler is not yet implemented")
105
+
106
+
107
+ def knockpy_gaussian_equi_sampler(
108
+ X,
109
+ *,
110
+ mu=None,
111
+ Sigma=None,
112
+ groups=None,
113
+ random_state: Optional[int] = None,
114
+ ):
115
+ """Interface placeholder for knockpy GaussianSampler(method='equi')."""
116
+ raise NotImplementedError("knockpy_gaussian_equi_sampler is not yet implemented")
117
+
118
+
119
+ def knockpy_gaussian_ci_sampler(
120
+ X,
121
+ *,
122
+ mu=None,
123
+ Sigma=None,
124
+ groups=None,
125
+ random_state: Optional[int] = None,
126
+ ):
127
+ """Interface placeholder for knockpy GaussianSampler(method='ci')."""
128
+ raise NotImplementedError("knockpy_gaussian_ci_sampler is not yet implemented")
129
+
130
+
131
+ def knockpy_fx_sampler(
132
+ X,
133
+ *,
134
+ y=None,
135
+ groups=None,
136
+ random_state: Optional[int] = None,
137
+ ):
138
+ """Interface placeholder for knockpy FXSampler."""
139
+ raise NotImplementedError("knockpy_fx_sampler is not yet implemented")
140
+
141
+
142
+ def knockpy_metro_sampler(
143
+ X,
144
+ *,
145
+ y=None,
146
+ groups=None,
147
+ random_state: Optional[int] = None,
148
+ ):
149
+ """Interface placeholder for knockpy MetroSampler."""
150
+ raise NotImplementedError("knockpy_metro_sampler is not yet implemented")
151
+
152
+
153
+ def knockpy_artk_sampler(
154
+ X,
155
+ *,
156
+ y=None,
157
+ groups=None,
158
+ random_state: Optional[int] = None,
159
+ ):
160
+ """Interface placeholder for knockpy ARTKSampler."""
161
+ raise NotImplementedError("knockpy_artk_sampler is not yet implemented")
162
+
163
+
164
+ def knockpy_sampler_dispatch(
165
+ sampler: str,
166
+ X,
167
+ *,
168
+ method: Optional[str] = None,
169
+ mu=None,
170
+ Sigma=None,
171
+ y=None,
172
+ groups=None,
173
+ random_state: Optional[int] = None,
174
+ ):
175
+ """Unified dispatcher for knockpy-compatible sampler placeholders."""
176
+ sampler_key = str(sampler).strip().lower().replace("-", "_")
177
+ method_key = None if method is None else str(method).strip().lower().replace("-", "_")
178
+
179
+ gaussian_dispatch = {
180
+ "mvr": knockpy_gaussian_mvr_sampler,
181
+ "sdp": knockpy_gaussian_sdp_sampler,
182
+ "maxent": knockpy_gaussian_maxent_sampler,
183
+ "equi": knockpy_gaussian_equi_sampler,
184
+ "ci": knockpy_gaussian_ci_sampler,
185
+ }
186
+
187
+ if sampler_key == "gaussian":
188
+ method_key = "mvr" if method_key is None else method_key
189
+ fn = gaussian_dispatch.get(method_key)
190
+ if fn is None:
191
+ raise ValueError("For sampler='gaussian', method must be one of: 'mvr', 'sdp', 'maxent', 'equi', 'ci'")
192
+ return fn(
193
+ X,
194
+ mu=mu,
195
+ Sigma=Sigma,
196
+ groups=groups,
197
+ random_state=random_state,
198
+ )
199
+
200
+ gaussian_aliases = {
201
+ "gaussian_mvr": "mvr",
202
+ "gaussian_sdp": "sdp",
203
+ "gaussian_maxent": "maxent",
204
+ "gaussian_equi": "equi",
205
+ "gaussian_ci": "ci",
206
+ }
207
+ if sampler_key in gaussian_aliases:
208
+ fn = gaussian_dispatch[gaussian_aliases[sampler_key]]
209
+ return fn(
210
+ X,
211
+ mu=mu,
212
+ Sigma=Sigma,
213
+ groups=groups,
214
+ random_state=random_state,
215
+ )
216
+
217
+ if sampler_key == "fx":
218
+ return knockpy_fx_sampler(
219
+ X,
220
+ y=y,
221
+ groups=groups,
222
+ random_state=random_state,
223
+ )
224
+
225
+ if sampler_key == "metro":
226
+ return knockpy_metro_sampler(
227
+ X,
228
+ y=y,
229
+ groups=groups,
230
+ random_state=random_state,
231
+ )
232
+
233
+ if sampler_key == "artk":
234
+ return knockpy_artk_sampler(
235
+ X,
236
+ y=y,
237
+ groups=groups,
238
+ random_state=random_state,
239
+ )
240
+
241
+ raise ValueError(
242
+ "sampler must be one of: 'gaussian', 'gaussian_mvr', 'gaussian_sdp', "
243
+ "'gaussian_maxent', 'gaussian_equi', 'gaussian_ci', 'fx', 'metro', 'artk'"
244
+ )
245
+
246
+
247
+ def fixed_x_knockoff_filter(
248
+ X,
249
+ y,
250
+ q: float = 0.1,
251
+ method: str = "corr_diff",
252
+ fdr_control: str = "knockoff_plus",
253
+ random_state: Optional[int] = None,
254
+ backend: str = "auto",
255
+ Xk=None,
256
+ compat_mode: str = "statgpu",
257
+ lasso_cv_impl: str = "auto",
258
+ lasso_fast_profile: str = "off",
259
+ ) -> KnockoffResult:
260
+ """
261
+ Fixed-X knockoff selection skeleton.
262
+
263
+ Parameters
264
+ ----------
265
+ X : array-like of shape (n_samples, n_features)
266
+ Design matrix.
267
+ y : array-like of shape (n_samples,)
268
+ Response vector.
269
+ q : float, default=0.1
270
+ Target FDR level in (0, 1).
271
+ method : {'corr_diff', 'ols_coef_diff', 'lasso_coef_diff'}, default='corr_diff'
272
+ Feature-importance statistic for W construction.
273
+ fdr_control : {'knockoff_plus', 'knockoff'}, default='knockoff_plus'
274
+ Knockoff threshold variant.
275
+ random_state : int, optional
276
+ Random seed for knockoff construction.
277
+ backend : {'auto', 'numpy', 'cupy', 'torch'}, default='auto'
278
+ Compute backend. ``'auto'`` infers from input arrays.
279
+ Use ``'torch'`` for PyTorch GPU acceleration.
280
+
281
+ Returns
282
+ -------
283
+ KnockoffResult
284
+ Selected feature indices and full knockoff diagnostics.
285
+ """
286
+ q_f = _validate_q(q)
287
+ compat = _normalize_compat_mode(compat_mode)
288
+ lasso_impl = str(lasso_cv_impl).strip().lower()
289
+ if lasso_impl == "auto":
290
+ lasso_impl = "sklearn" if compat == "knockpy" else "statgpu"
291
+ lasso_profile = _normalize_lasso_fast_profile(lasso_fast_profile)
292
+
293
+ offset = _normalize_fdr_control(fdr_control)
294
+ backend_name = _resolve_backend(backend, X, y, Xk)
295
+ xp = _get_xp(backend_name)
296
+
297
+ X_arr = xp.asarray(X, dtype=xp.float64)
298
+ y_arr = xp.asarray(y, dtype=xp.float64).reshape(-1)
299
+ if X_arr.ndim != 2:
300
+ raise ValueError("X must be a 2D array")
301
+ if y_arr.shape[0] != X_arr.shape[0]:
302
+ raise ValueError("y must have the same number of rows as X")
303
+
304
+ if Xk is None:
305
+ X_work = _standardize_design(X_arr, xp)
306
+ X_knock = _build_fixed_x_knockoffs(X_work, random_state=random_state, xp=xp)
307
+ xk_source = "generated_fixed_x"
308
+ else:
309
+ X_work = X_arr
310
+ X_knock = xp.asarray(Xk, dtype=xp.float64)
311
+ if X_knock.shape != X_work.shape:
312
+ raise ValueError("Xk must have the same shape as X")
313
+ xk_source = "provided"
314
+
315
+ W_xp, method_n = _compute_w_statistics(
316
+ X_work,
317
+ X_knock,
318
+ y_arr,
319
+ method=method,
320
+ xp=xp,
321
+ random_state=random_state,
322
+ backend_name=backend_name,
323
+ lasso_cv_impl=lasso_impl,
324
+ lasso_fast_profile=lasso_profile,
325
+ lasso_knockpy_style=(compat == "knockpy"),
326
+ )
327
+ threshold, fdr_hat, trajectory = _knockoff_threshold_and_path(W_xp, q=q_f, offset=offset)
328
+
329
+ if np.isfinite(threshold):
330
+ selected_xp = xp.where(W_xp >= threshold)[0]
331
+ else:
332
+ selected_xp = xp.asarray([], dtype=xp.int64)
333
+
334
+ W = _to_numpy(W_xp).astype(np.float64, copy=False)
335
+ selected = _to_numpy(selected_xp).astype(np.int64, copy=False)
336
+
337
+ return KnockoffResult(
338
+ knockoff_type="fixed_x",
339
+ selected_features=selected,
340
+ W=W,
341
+ threshold=float(threshold),
342
+ q=float(q_f),
343
+ estimated_fdr=float(fdr_hat),
344
+ q_trajectory=trajectory,
345
+ method=method_n,
346
+ fdr_control="knockoff_plus" if offset == 1 else "knockoff",
347
+ random_state=random_state,
348
+ backend=backend_name,
349
+ metadata={
350
+ "n_samples": int(X_arr.shape[0]),
351
+ "n_features": int(X_arr.shape[1]),
352
+ "offset": int(offset),
353
+ "compat_mode": compat,
354
+ "lasso_cv_impl": lasso_impl,
355
+ "lasso_fast_profile": lasso_profile,
356
+ "xk_source": xk_source,
357
+ },
358
+ )
359
+
360
+ def model_x_knockoff_filter(
361
+ X,
362
+ y,
363
+ q: float = 0.1,
364
+ method: str = "corr_diff",
365
+ fdr_control: str = "knockoff_plus",
366
+ random_state: Optional[int] = None,
367
+ backend: str = "auto",
368
+ Xk=None,
369
+ compat_mode: str = "statgpu",
370
+ lasso_cv_impl: str = "auto",
371
+ lasso_fast_profile: str = "off",
372
+ modelx_covariance_shrinkage: float = 0.20,
373
+ modelx_s_scale: float = 0.999,
374
+ modelx_draws: Optional[int] = None,
375
+ modelx_shrinkage: str = "ledoitwolf",
376
+ modelx_smatrix_method: str = "mvr",
377
+ knockpy_sampler: Optional[str] = None,
378
+ knockpy_sampler_method: Optional[str] = None,
379
+ ) -> KnockoffResult:
380
+ """
381
+ Model-X knockoff selection (Gaussian second-order approximation).
382
+
383
+ This implementation estimates a Gaussian feature model and builds
384
+ equi-correlated knockoffs from the estimated covariance.
385
+ """
386
+ q_f = _validate_q(q)
387
+ compat = _normalize_compat_mode(compat_mode)
388
+ lasso_impl = str(lasso_cv_impl).strip().lower()
389
+ if lasso_impl == "auto":
390
+ lasso_impl = "sklearn" if compat == "knockpy" else "statgpu"
391
+ lasso_profile = _normalize_lasso_fast_profile(lasso_fast_profile)
392
+
393
+ offset = _normalize_fdr_control(fdr_control)
394
+ backend_name = _resolve_backend(backend, X, y, Xk)
395
+ xp = _get_xp(backend_name)
396
+
397
+ X_arr = xp.asarray(X, dtype=xp.float64)
398
+ y_arr = xp.asarray(y, dtype=xp.float64).reshape(-1)
399
+ if X_arr.ndim != 2:
400
+ raise ValueError("X must be a 2D array")
401
+ if y_arr.shape[0] != X_arr.shape[0]:
402
+ raise ValueError("y must have the same number of rows as X")
403
+
404
+ method_key = str(method).strip().lower()
405
+ default_draws = (
406
+ 5
407
+ if method_key in ("ols_coef_diff", "ols", "coef_diff", "lasso_coef_diff", "lasso", "lasso_diff")
408
+ else 3
409
+ )
410
+
411
+ if compat == "knockpy":
412
+ # Preserve backend-native execution when caller supplies Xk and explicitly
413
+ # asks for statgpu CV implementation under knockpy-compatible lasso settings.
414
+ if Xk is not None and lasso_impl == "statgpu":
415
+ X_knock = xp.asarray(Xk, dtype=xp.float64)
416
+ if X_knock.shape != X_arr.shape:
417
+ raise ValueError("Xk must have the same shape as X")
418
+
419
+ W_xp, method_n = _compute_w_statistics(
420
+ X_arr,
421
+ X_knock,
422
+ y_arr,
423
+ method=method,
424
+ xp=xp,
425
+ random_state=random_state,
426
+ backend_name=backend_name,
427
+ lasso_cv_impl=lasso_impl,
428
+ lasso_fast_profile=lasso_profile,
429
+ lasso_knockpy_style=True,
430
+ )
431
+ threshold, fdr_hat, trajectory = _knockoff_threshold_and_path(W_xp, q=q_f, offset=offset)
432
+
433
+ if np.isfinite(threshold):
434
+ selected_xp = xp.where(W_xp >= threshold)[0]
435
+ else:
436
+ selected_xp = xp.asarray([], dtype=xp.int64)
437
+
438
+ W_np = _to_numpy(W_xp).astype(np.float64, copy=False)
439
+ selected = _to_numpy(selected_xp).astype(np.int64, copy=False)
440
+
441
+ return KnockoffResult(
442
+ knockoff_type="model_x",
443
+ selected_features=selected,
444
+ W=W_np,
445
+ threshold=float(threshold),
446
+ q=float(q_f),
447
+ estimated_fdr=float(fdr_hat),
448
+ q_trajectory=trajectory,
449
+ method=method_n,
450
+ fdr_control="knockoff_plus" if offset == 1 else "knockoff",
451
+ random_state=random_state,
452
+ backend=backend_name,
453
+ metadata={
454
+ "n_samples": int(X_arr.shape[0]),
455
+ "n_features": int(X_arr.shape[1]),
456
+ "offset": int(offset),
457
+ "n_modelx_draws": 1,
458
+ "compat_mode": compat,
459
+ "lasso_cv_impl": lasso_impl,
460
+ "lasso_fast_profile": lasso_profile,
461
+ "modelx_shrinkage": None,
462
+ "modelx_smatrix_method": None,
463
+ "knockpy_sampler": knockpy_sampler,
464
+ "knockpy_sampler_method": knockpy_sampler_method,
465
+ "xk_source": "provided",
466
+ },
467
+ )
468
+
469
+ X_np = np.asarray(_to_numpy(X_arr), dtype=np.float64)
470
+ y_np = np.asarray(_to_numpy(y_arr), dtype=np.float64).reshape(-1)
471
+
472
+ draw_specs = []
473
+ if Xk is not None:
474
+ Xk_np = np.asarray(_to_numpy(Xk), dtype=np.float64)
475
+ if Xk_np.shape != X_np.shape:
476
+ raise ValueError("Xk must have the same shape as X")
477
+ draw_specs.append((Xk_np, random_state, {"xk_source": "provided"}))
478
+ else:
479
+ if knockpy_sampler is not None:
480
+ Xk_draw = knockpy_sampler_dispatch(
481
+ knockpy_sampler,
482
+ X_np,
483
+ method=knockpy_sampler_method,
484
+ y=y_np,
485
+ random_state=random_state,
486
+ )
487
+ if Xk_draw is None:
488
+ raise NotImplementedError(
489
+ "Selected knockpy sampler interface is a placeholder (pass). "
490
+ "Implement the sampler before enabling this route."
491
+ )
492
+ Xk_draw = np.asarray(Xk_draw, dtype=np.float64)
493
+ if Xk_draw.shape != X_np.shape:
494
+ raise ValueError("Dispatched knockoff matrix must have the same shape as X")
495
+ draw_specs.append(
496
+ (
497
+ Xk_draw,
498
+ random_state,
499
+ {
500
+ "xk_source": "generated_model_x_dispatch",
501
+ "knockpy_sampler": str(knockpy_sampler),
502
+ "knockpy_sampler_method": None
503
+ if knockpy_sampler_method is None
504
+ else str(knockpy_sampler_method),
505
+ },
506
+ )
507
+ )
508
+ else:
509
+ n_modelx_draws = max(1, int(default_draws if modelx_draws is None else modelx_draws))
510
+ for draw_idx in range(n_modelx_draws):
511
+ draw_seed = _model_x_draw_seed(random_state, draw_idx)
512
+ Xk_draw, draw_meta = _build_model_x_knockoffs_knockpy_compat(
513
+ X_np,
514
+ random_state=draw_seed,
515
+ modelx_shrinkage=modelx_shrinkage,
516
+ modelx_smatrix_method=modelx_smatrix_method,
517
+ )
518
+ draw_specs.append((Xk_draw, draw_seed, draw_meta))
519
+
520
+ W_acc = None
521
+ method_n = "corr_diff"
522
+ model_meta: Dict[str, Any] = {"xk_source": "generated_model_x"}
523
+ for Xk_draw, draw_seed, draw_meta in draw_specs:
524
+ W_draw, method_n = _compute_w_statistics(
525
+ X_np,
526
+ Xk_draw,
527
+ y_np,
528
+ method=method,
529
+ xp=np,
530
+ random_state=draw_seed,
531
+ backend_name="numpy",
532
+ lasso_cv_impl=lasso_impl,
533
+ lasso_fast_profile=lasso_profile,
534
+ lasso_knockpy_style=True,
535
+ )
536
+ W_acc = W_draw if W_acc is None else (W_acc + W_draw)
537
+ model_meta.update(draw_meta)
538
+
539
+ n_modelx_draws = int(len(draw_specs))
540
+ W_np = np.asarray(W_acc / float(max(1, n_modelx_draws)), dtype=np.float64)
541
+ threshold, fdr_hat, trajectory = _knockoff_threshold_and_path(W_np, q=q_f, offset=offset)
542
+
543
+ if np.isfinite(threshold):
544
+ selected = np.where(W_np >= threshold)[0].astype(np.int64, copy=False)
545
+ else:
546
+ selected = np.asarray([], dtype=np.int64)
547
+
548
+ return KnockoffResult(
549
+ knockoff_type="model_x",
550
+ selected_features=selected,
551
+ W=W_np,
552
+ threshold=float(threshold),
553
+ q=float(q_f),
554
+ estimated_fdr=float(fdr_hat),
555
+ q_trajectory=trajectory,
556
+ method=method_n,
557
+ fdr_control="knockoff_plus" if offset == 1 else "knockoff",
558
+ random_state=random_state,
559
+ backend="numpy",
560
+ metadata={
561
+ "n_samples": int(X_np.shape[0]),
562
+ "n_features": int(X_np.shape[1]),
563
+ "offset": int(offset),
564
+ "n_modelx_draws": int(n_modelx_draws),
565
+ "compat_mode": compat,
566
+ "lasso_cv_impl": lasso_impl,
567
+ "lasso_fast_profile": lasso_profile,
568
+ "modelx_shrinkage": str(modelx_shrinkage),
569
+ "modelx_smatrix_method": str(modelx_smatrix_method),
570
+ "knockpy_sampler": knockpy_sampler,
571
+ "knockpy_sampler_method": knockpy_sampler_method,
572
+ **model_meta,
573
+ },
574
+ )
575
+
576
+ # statgpu default path
577
+ if Xk is not None:
578
+ X_work = X_arr
579
+ X_knock = xp.asarray(Xk, dtype=xp.float64)
580
+ if X_knock.shape != X_work.shape:
581
+ raise ValueError("Xk must have the same shape as X")
582
+
583
+ W_xp, method_n = _compute_w_statistics(
584
+ X_work,
585
+ X_knock,
586
+ y_arr,
587
+ method=method,
588
+ xp=xp,
589
+ random_state=random_state,
590
+ backend_name=backend_name,
591
+ lasso_cv_impl=lasso_impl,
592
+ lasso_fast_profile=lasso_profile,
593
+ lasso_knockpy_style=False,
594
+ )
595
+ n_modelx_draws = 1
596
+ model_meta = {
597
+ "xk_source": "provided",
598
+ "covariance_shrinkage": None,
599
+ "s_scale": None,
600
+ "modelx_shrinkage": None,
601
+ "modelx_smatrix_method": None,
602
+ }
603
+ else:
604
+ X_std = _standardize_features_unit_variance(X_arr, xp)
605
+ n_modelx_draws = max(1, int(default_draws if modelx_draws is None else modelx_draws))
606
+
607
+ W_acc = None
608
+ method_n = "corr_diff"
609
+ model_meta: Dict[str, Any] = {"xk_source": "generated_model_x"}
610
+ for draw_idx in range(n_modelx_draws):
611
+ draw_seed = _model_x_draw_seed(random_state, draw_idx)
612
+ X_knock, _draw_meta = _build_model_x_knockoffs(
613
+ X_std,
614
+ random_state=draw_seed,
615
+ xp=xp,
616
+ covariance_shrinkage=float(modelx_covariance_shrinkage),
617
+ s_scale=float(modelx_s_scale),
618
+ )
619
+ model_meta.update(_draw_meta)
620
+ W_draw, method_n = _compute_w_statistics(
621
+ X_std,
622
+ X_knock,
623
+ y_arr,
624
+ method=method,
625
+ xp=xp,
626
+ random_state=draw_seed,
627
+ backend_name=backend_name,
628
+ lasso_cv_impl=lasso_impl,
629
+ lasso_fast_profile=lasso_profile,
630
+ lasso_knockpy_style=False,
631
+ )
632
+ W_acc = W_draw if W_acc is None else (W_acc + W_draw)
633
+
634
+ W_xp = W_acc / float(n_modelx_draws)
635
+
636
+ threshold, fdr_hat, trajectory = _knockoff_threshold_and_path(W_xp, q=q_f, offset=offset)
637
+
638
+ if np.isfinite(threshold):
639
+ selected_xp = xp.where(W_xp >= threshold)[0]
640
+ else:
641
+ selected_xp = xp.asarray([], dtype=xp.int64)
642
+
643
+ W = _to_numpy(W_xp).astype(np.float64, copy=False)
644
+ selected = _to_numpy(selected_xp).astype(np.int64, copy=False)
645
+
646
+ return KnockoffResult(
647
+ knockoff_type="model_x",
648
+ selected_features=selected,
649
+ W=W,
650
+ threshold=float(threshold),
651
+ q=float(q_f),
652
+ estimated_fdr=float(fdr_hat),
653
+ q_trajectory=trajectory,
654
+ method=method_n,
655
+ fdr_control="knockoff_plus" if offset == 1 else "knockoff",
656
+ random_state=random_state,
657
+ backend=backend_name,
658
+ metadata={
659
+ "n_samples": int(X_arr.shape[0]),
660
+ "n_features": int(X_arr.shape[1]),
661
+ "offset": int(offset),
662
+ "n_modelx_draws": int(n_modelx_draws),
663
+ "compat_mode": compat,
664
+ "lasso_cv_impl": lasso_impl,
665
+ "lasso_fast_profile": lasso_profile,
666
+ "modelx_covariance_shrinkage": float(modelx_covariance_shrinkage),
667
+ "modelx_s_scale": float(modelx_s_scale),
668
+ **model_meta,
669
+ },
670
+ )
671
+
672
+
673
+ def knockoff_filter(
674
+ X,
675
+ y,
676
+ knockoff_type: str = "fixed_x",
677
+ q: float = 0.1,
678
+ method: str = "corr_diff",
679
+ fdr_control: str = "knockoff_plus",
680
+ random_state: Optional[int] = None,
681
+ backend: str = "auto",
682
+ Xk=None,
683
+ compat_mode: str = "statgpu",
684
+ lasso_cv_impl: str = "auto",
685
+ lasso_fast_profile: str = "off",
686
+ modelx_covariance_shrinkage: float = 0.20,
687
+ modelx_s_scale: float = 0.999,
688
+ modelx_draws: Optional[int] = None,
689
+ modelx_shrinkage: str = "ledoitwolf",
690
+ modelx_smatrix_method: str = "mvr",
691
+ knockpy_sampler: Optional[str] = None,
692
+ knockpy_sampler_method: Optional[str] = None,
693
+ ) -> KnockoffResult:
694
+ """Unified knockoff entrypoint for fixed-X and model-X variants."""
695
+ kind = _normalize_knockoff_type(knockoff_type)
696
+ if kind == "fixed_x":
697
+ return fixed_x_knockoff_filter(
698
+ X,
699
+ y,
700
+ q=q,
701
+ method=method,
702
+ fdr_control=fdr_control,
703
+ random_state=random_state,
704
+ backend=backend,
705
+ Xk=Xk,
706
+ compat_mode=compat_mode,
707
+ lasso_cv_impl=lasso_cv_impl,
708
+ lasso_fast_profile=lasso_fast_profile,
709
+ )
710
+
711
+ return model_x_knockoff_filter(
712
+ X,
713
+ y,
714
+ q=q,
715
+ method=method,
716
+ fdr_control=fdr_control,
717
+ random_state=random_state,
718
+ backend=backend,
719
+ Xk=Xk,
720
+ compat_mode=compat_mode,
721
+ lasso_cv_impl=lasso_cv_impl,
722
+ lasso_fast_profile=lasso_fast_profile,
723
+ modelx_covariance_shrinkage=modelx_covariance_shrinkage,
724
+ modelx_s_scale=modelx_s_scale,
725
+ modelx_draws=modelx_draws,
726
+ modelx_shrinkage=modelx_shrinkage,
727
+ modelx_smatrix_method=modelx_smatrix_method,
728
+ knockpy_sampler=knockpy_sampler,
729
+ knockpy_sampler_method=knockpy_sampler_method,
730
+ )
731
+
732
+
733
+ class KnockoffSelector:
734
+ """Sklearn-like wrapper for unified knockoff feature selection."""
735
+
736
+ def __init__(
737
+ self,
738
+ knockoff_type: str = "fixed_x",
739
+ q: float = 0.1,
740
+ method: str = "corr_diff",
741
+ fdr_control: str = "knockoff_plus",
742
+ random_state: Optional[int] = None,
743
+ backend: str = "auto",
744
+ compat_mode: str = "statgpu",
745
+ lasso_cv_impl: str = "auto",
746
+ lasso_fast_profile: str = "off",
747
+ modelx_covariance_shrinkage: float = 0.20,
748
+ modelx_s_scale: float = 0.999,
749
+ modelx_draws: Optional[int] = None,
750
+ modelx_shrinkage: str = "ledoitwolf",
751
+ modelx_smatrix_method: str = "mvr",
752
+ knockpy_sampler: Optional[str] = None,
753
+ knockpy_sampler_method: Optional[str] = None,
754
+ ):
755
+ self.knockoff_type = knockoff_type
756
+ self.q = q
757
+ self.method = method
758
+ self.fdr_control = fdr_control
759
+ self.random_state = random_state
760
+ self.backend = backend
761
+ self.compat_mode = compat_mode
762
+ self.lasso_cv_impl = lasso_cv_impl
763
+ self.lasso_fast_profile = lasso_fast_profile
764
+ self.modelx_covariance_shrinkage = modelx_covariance_shrinkage
765
+ self.modelx_s_scale = modelx_s_scale
766
+ self.modelx_draws = modelx_draws
767
+ self.modelx_shrinkage = modelx_shrinkage
768
+ self.modelx_smatrix_method = modelx_smatrix_method
769
+ self.knockpy_sampler = knockpy_sampler
770
+ self.knockpy_sampler_method = knockpy_sampler_method
771
+
772
+ self.result_: Optional[KnockoffResult] = None
773
+ self.selected_features_: Optional[np.ndarray] = None
774
+
775
+ def fit(self, X, y, Xk=None):
776
+ self.result_ = knockoff_filter(
777
+ X,
778
+ y,
779
+ knockoff_type=self.knockoff_type,
780
+ q=self.q,
781
+ method=self.method,
782
+ fdr_control=self.fdr_control,
783
+ random_state=self.random_state,
784
+ backend=self.backend,
785
+ Xk=Xk,
786
+ compat_mode=self.compat_mode,
787
+ lasso_cv_impl=self.lasso_cv_impl,
788
+ lasso_fast_profile=self.lasso_fast_profile,
789
+ modelx_covariance_shrinkage=self.modelx_covariance_shrinkage,
790
+ modelx_s_scale=self.modelx_s_scale,
791
+ modelx_draws=self.modelx_draws,
792
+ modelx_shrinkage=self.modelx_shrinkage,
793
+ modelx_smatrix_method=self.modelx_smatrix_method,
794
+ knockpy_sampler=self.knockpy_sampler,
795
+ knockpy_sampler_method=self.knockpy_sampler_method,
796
+ )
797
+ self.selected_features_ = self.result_.selected_features
798
+ return self
799
+
800
+ def get_support(self) -> np.ndarray:
801
+ if self.selected_features_ is None:
802
+ raise RuntimeError("Selector has not been fitted yet")
803
+ n_features = int(self.result_.W.shape[0])
804
+ mask = np.zeros(n_features, dtype=bool)
805
+ mask[self.selected_features_] = True
806
+ return mask
807
+
808
+ def transform(self, X):
809
+ if self.selected_features_ is None:
810
+ raise RuntimeError("Selector has not been fitted yet")
811
+ X_arr = np.asarray(X)
812
+ return X_arr[:, self.selected_features_]
813
+
814
+ def fit_transform(self, X, y, Xk=None):
815
+ return self.fit(X, y, Xk=Xk).transform(X)
816
+
817
+
818
+ class FixedXKnockoffSelector:
819
+ """Sklearn-like wrapper for fixed-X knockoff feature selection."""
820
+
821
+ def __init__(
822
+ self,
823
+ q: float = 0.1,
824
+ method: str = "corr_diff",
825
+ fdr_control: str = "knockoff_plus",
826
+ random_state: Optional[int] = None,
827
+ backend: str = "auto",
828
+ compat_mode: str = "statgpu",
829
+ lasso_cv_impl: str = "auto",
830
+ lasso_fast_profile: str = "off",
831
+ ):
832
+ self._selector = KnockoffSelector(
833
+ knockoff_type="fixed_x",
834
+ q=q,
835
+ method=method,
836
+ fdr_control=fdr_control,
837
+ random_state=random_state,
838
+ backend=backend,
839
+ compat_mode=compat_mode,
840
+ lasso_cv_impl=lasso_cv_impl,
841
+ lasso_fast_profile=lasso_fast_profile,
842
+ )
843
+
844
+ self.q = q
845
+ self.method = method
846
+ self.fdr_control = fdr_control
847
+ self.random_state = random_state
848
+ self.backend = backend
849
+ self.compat_mode = compat_mode
850
+ self.lasso_cv_impl = lasso_cv_impl
851
+ self.lasso_fast_profile = lasso_fast_profile
852
+ self.result_: Optional[KnockoffResult] = None
853
+ self.selected_features_: Optional[np.ndarray] = None
854
+
855
+ def fit(self, X, y, Xk=None):
856
+ self._selector.fit(X, y, Xk=Xk)
857
+ self.result_ = self._selector.result_
858
+ self.selected_features_ = self._selector.selected_features_
859
+ return self
860
+
861
+ def get_support(self) -> np.ndarray:
862
+ return self._selector.get_support()
863
+
864
+ def transform(self, X):
865
+ return self._selector.transform(X)
866
+
867
+ def fit_transform(self, X, y, Xk=None):
868
+ return self.fit(X, y, Xk=Xk).transform(X)
869
+
870
+