statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1003 @@
1
+ """Utility functions for knockoff feature selection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import OrderedDict
6
+ from contextlib import contextmanager
7
+ import hashlib
8
+ import os
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+ import warnings
11
+
12
+ import numpy as np
13
+
14
+ from statgpu.backends import (
15
+ _get_torch_device_str,
16
+ _torch_dev,
17
+ _get_xp,
18
+ _resolve_backend,
19
+ _to_float_scalar,
20
+ _to_numpy,
21
+ )
22
+
23
+ # Re-export for backward compatibility with modules that import from here
24
+ __all__ = [
25
+ "_get_xp",
26
+ "_resolve_backend",
27
+ "_to_numpy",
28
+ "_to_float_scalar",
29
+ ]
30
+
31
+
32
+ _LASSO_DIFF_CACHE_MAXSIZE = int(os.getenv("STATGPU_KNOCKOFF_LASSO_CACHE_SIZE", "32"))
33
+ _LASSO_DIFF_CACHE: "OrderedDict[Tuple[Any, ...], np.ndarray]" = OrderedDict()
34
+
35
+
36
+ def _array_identity_token(x: Any) -> Tuple[Any, ...]:
37
+ if x is None:
38
+ return ("none",)
39
+
40
+ # Try CuPy array
41
+ try:
42
+ import cupy as cp
43
+
44
+ if isinstance(x, cp.ndarray):
45
+ return ("cupy", int(x.data.ptr), tuple(int(v) for v in x.shape), str(x.dtype))
46
+ except Exception:
47
+ pass
48
+
49
+ # Try Torch tensor
50
+ try:
51
+ import torch
52
+
53
+ if isinstance(x, torch.Tensor):
54
+ if x.is_cuda:
55
+ return ("torch_cuda", int(x.data_ptr()), tuple(int(v) for v in x.shape), str(x.dtype))
56
+ else:
57
+ return ("torch_cpu", int(x.data_ptr()), tuple(int(v) for v in x.shape), str(x.dtype))
58
+ except Exception:
59
+ pass
60
+
61
+ # Default to NumPy
62
+ arr = np.asarray(x)
63
+ ptr = int(arr.__array_interface__["data"][0]) if int(arr.size) > 0 else 0
64
+ return ("numpy", ptr, tuple(int(v) for v in arr.shape), str(arr.dtype))
65
+
66
+
67
+ def _int_array_signature(x: Any) -> str:
68
+ arr = np.ascontiguousarray(np.asarray(x, dtype=np.int64).reshape(-1))
69
+ return hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest()
70
+
71
+
72
+ def _lasso_diff_cache_get(cache_key: Optional[Tuple[Any, ...]]) -> Optional[np.ndarray]:
73
+ if cache_key is None or _LASSO_DIFF_CACHE_MAXSIZE <= 0:
74
+ return None
75
+
76
+ cached = _LASSO_DIFF_CACHE.get(cache_key)
77
+ if cached is None:
78
+ return None
79
+
80
+ _LASSO_DIFF_CACHE.move_to_end(cache_key)
81
+ return np.asarray(cached, dtype=np.float64).copy()
82
+
83
+
84
+ def _lasso_diff_cache_put(cache_key: Optional[Tuple[Any, ...]], value: np.ndarray) -> None:
85
+ if cache_key is None or _LASSO_DIFF_CACHE_MAXSIZE <= 0:
86
+ return
87
+
88
+ _LASSO_DIFF_CACHE[cache_key] = np.asarray(value, dtype=np.float64).copy()
89
+ _LASSO_DIFF_CACHE.move_to_end(cache_key)
90
+
91
+ while len(_LASSO_DIFF_CACHE) > int(_LASSO_DIFF_CACHE_MAXSIZE):
92
+ _LASSO_DIFF_CACHE.popitem(last=False)
93
+
94
+
95
+ def _make_lasso_coef_diff_cache_key(
96
+ *,
97
+ X_std,
98
+ X_knock,
99
+ y,
100
+ random_state: Optional[int],
101
+ backend_name: str,
102
+ max_iter_eff: int,
103
+ tol_eff: float,
104
+ cv_folds_eff: int,
105
+ n_alphas_eff: int,
106
+ lasso_cv_impl: str,
107
+ fast_profile_eff: str,
108
+ knockpy_style: bool,
109
+ ) -> Optional[Tuple[Any, ...]]:
110
+ # random_state=None implies a fresh random permutation every call; disable reuse.
111
+ if random_state is None:
112
+ return None
113
+
114
+ return (
115
+ "knockoff_lasso_diff_v1",
116
+ _array_identity_token(X_std),
117
+ _array_identity_token(X_knock),
118
+ _array_identity_token(y),
119
+ int(random_state),
120
+ str(backend_name).lower(),
121
+ int(max_iter_eff),
122
+ float(tol_eff),
123
+ int(cv_folds_eff),
124
+ int(n_alphas_eff),
125
+ str(lasso_cv_impl).lower(),
126
+ str(fast_profile_eff).lower(),
127
+ bool(knockpy_style),
128
+ )
129
+
130
+
131
+ def _normalize_compat_mode(compat_mode: str) -> str:
132
+ key = str(compat_mode).strip().lower()
133
+ if key in ("statgpu", "default"):
134
+ return "statgpu"
135
+ if key in ("knockpy", "compat", "knockpy_compat"):
136
+ return "knockpy"
137
+ raise ValueError("compat_mode must be one of: 'statgpu', 'knockpy'")
138
+
139
+
140
+ def _normalize_lasso_cv_impl(lasso_cv_impl: str) -> str:
141
+ key = str(lasso_cv_impl).strip().lower()
142
+ if key in ("auto", "default"):
143
+ return "auto"
144
+ if key in ("statgpu", "internal"):
145
+ return "statgpu"
146
+ if key in ("sklearn", "knockpy", "knockpy_sklearn"):
147
+ return "sklearn"
148
+ raise ValueError("lasso_cv_impl must be one of: 'auto', 'statgpu', 'sklearn'")
149
+
150
+
151
+ def _normalize_lasso_fast_profile(lasso_fast_profile: str) -> str:
152
+ key = str(lasso_fast_profile).strip().lower()
153
+ if key in ("off", "none", "default"):
154
+ return "off"
155
+ if key in ("auto",):
156
+ return "auto"
157
+ if key in ("moderate", "balanced"):
158
+ return "moderate"
159
+ if key in ("aggressive", "fast"):
160
+ return "aggressive"
161
+ raise ValueError(
162
+ "lasso_fast_profile must be one of: 'off', 'auto', 'moderate', 'aggressive'"
163
+ )
164
+
165
+
166
+ def _resolve_lasso_fast_profile_for_problem(lasso_fast_profile: str, problem_size: int) -> str:
167
+ profile = _normalize_lasso_fast_profile(lasso_fast_profile)
168
+ if profile != "auto":
169
+ return profile
170
+
171
+ if int(problem_size) >= 2_000_000:
172
+ return "moderate"
173
+ return "off"
174
+
175
+
176
+ @contextmanager
177
+ def _temporary_numpy_seed(seed: Optional[int]):
178
+ if seed is None:
179
+ yield
180
+ return
181
+
182
+ state = np.random.get_state()
183
+ np.random.seed(int(seed))
184
+ try:
185
+ yield
186
+ finally:
187
+ np.random.set_state(state)
188
+
189
+
190
+ def _calc_mineig_np(M: np.ndarray) -> float:
191
+ eigvals = np.linalg.eigvalsh(0.5 * (M + M.T))
192
+ return float(np.min(eigvals))
193
+
194
+
195
+ def _shift_until_psd_np(M: np.ndarray, tol: float) -> np.ndarray:
196
+ mineig = _calc_mineig_np(M)
197
+ if mineig < float(tol):
198
+ M = M + (float(tol) - mineig) * np.eye(M.shape[0], dtype=np.float64)
199
+ return 0.5 * (M + M.T)
200
+
201
+
202
+ def _scale_until_psd_np(
203
+ Sigma: np.ndarray,
204
+ S: np.ndarray,
205
+ tol: float = 1e-5,
206
+ num_iter: int = 25,
207
+ ):
208
+ S_shifted = _shift_until_psd_np(S, tol)
209
+
210
+ lower = 0.0
211
+ upper = 1.0
212
+ for _ in range(int(num_iter)):
213
+ gamma = 0.5 * (lower + upper)
214
+ V = 2.0 * Sigma - gamma * S_shifted
215
+ try:
216
+ np.linalg.cholesky(V - float(tol) * np.eye(V.shape[0], dtype=np.float64))
217
+ lower = gamma
218
+ except np.linalg.LinAlgError:
219
+ upper = gamma
220
+
221
+ gamma = float(lower)
222
+ return gamma * S_shifted, gamma
223
+
224
+
225
+ def _estimate_covariance_knockpy_style(
226
+ X: np.ndarray,
227
+ *,
228
+ shrinkage: str = "ledoitwolf",
229
+ tol: float = 1e-4,
230
+ ):
231
+ X_np = np.asarray(X, dtype=np.float64)
232
+
233
+ shrink_key = str(shrinkage).strip().lower()
234
+ if shrink_key in ("none", "mle"):
235
+ shrink_key = "none"
236
+
237
+ Sigma = None
238
+ inv_sigma = None
239
+ estimator_name = shrink_key
240
+
241
+ if shrink_key == "none":
242
+ Sigma = np.cov(X_np.T)
243
+ if _calc_mineig_np(Sigma) < float(tol):
244
+ shrink_key = "ledoitwolf"
245
+ estimator_name = "ledoitwolf_auto"
246
+
247
+ if shrink_key != "none":
248
+ try:
249
+ from sklearn import covariance as sk_cov
250
+ except Exception:
251
+ # Fallback keeps compatibility even when sklearn is unavailable.
252
+ Sigma = np.cov(X_np.T)
253
+ estimator_name = "mle_fallback_no_sklearn"
254
+ else:
255
+ if shrink_key == "ledoitwolf":
256
+ estimator = sk_cov.LedoitWolf()
257
+ elif shrink_key in ("graphicallasso", "glasso"):
258
+ estimator = sk_cov.GraphicalLasso(alpha=0.1)
259
+ else:
260
+ raise ValueError(
261
+ "modelx_shrinkage must be one of: 'ledoitwolf', 'none', 'mle', 'graphicallasso'"
262
+ )
263
+ with warnings.catch_warnings():
264
+ warnings.simplefilter("ignore")
265
+ estimator.fit(X_np)
266
+ Sigma = np.asarray(estimator.covariance_, dtype=np.float64)
267
+ inv_sigma = np.asarray(estimator.precision_, dtype=np.float64)
268
+ estimator_name = shrink_key
269
+
270
+ Sigma = 0.5 * (np.asarray(Sigma, dtype=np.float64) + np.asarray(Sigma, dtype=np.float64).T)
271
+ if inv_sigma is None:
272
+ try:
273
+ inv_sigma = np.linalg.inv(Sigma)
274
+ except np.linalg.LinAlgError:
275
+ ridge = max(1e-8, -_calc_mineig_np(Sigma) + 1e-8)
276
+ Sigma = Sigma + ridge * np.eye(Sigma.shape[0], dtype=np.float64)
277
+ Sigma = 0.5 * (Sigma + Sigma.T)
278
+ inv_sigma = np.linalg.inv(Sigma)
279
+
280
+ return Sigma, np.asarray(inv_sigma, dtype=np.float64), estimator_name
281
+
282
+
283
+ def _compute_smatrix_knockpy_style(
284
+ Sigma: np.ndarray,
285
+ *,
286
+ method: str = "mvr",
287
+ tol: float = 1e-5,
288
+ ):
289
+ Sigma_np = np.asarray(Sigma, dtype=np.float64)
290
+ p = int(Sigma_np.shape[0])
291
+ groups = np.arange(1, p + 1, dtype=np.int64)
292
+
293
+ source = "equicorrelated_fallback"
294
+ S = None
295
+ try:
296
+ from knockpy import smatrix as kp_smatrix
297
+
298
+ S = kp_smatrix.compute_smatrix(
299
+ Sigma=Sigma_np,
300
+ groups=groups,
301
+ method=str(method).strip().lower(),
302
+ )
303
+ source = "knockpy"
304
+ except Exception:
305
+ # Robust fallback if knockpy is not installed.
306
+ min_eig = _calc_mineig_np(Sigma_np)
307
+ s_val = min(2.0 * min_eig, 1.0)
308
+ if s_val <= 1e-12:
309
+ raise ValueError("Failed to construct model-X knockoff S-matrix")
310
+ S = s_val * np.eye(p, dtype=np.float64)
311
+
312
+ S = _shift_until_psd_np(np.asarray(S, dtype=np.float64), tol=float(tol))
313
+ S, gamma = _scale_until_psd_np(Sigma_np, S, tol=float(tol), num_iter=25)
314
+ return S, source, float(gamma)
315
+
316
+
317
+ def _random_permutation_inds(length: int, random_state: Optional[int]):
318
+ rng = np.random.default_rng(random_state)
319
+ inds = rng.permutation(int(length)).astype(np.int64, copy=False)
320
+ rev_inds = np.empty(int(length), dtype=np.int64)
321
+ rev_inds[inds] = np.arange(int(length), dtype=np.int64)
322
+ return inds, rev_inds
323
+
324
+
325
+ def _validate_q(q: float) -> float:
326
+ q_f = float(q)
327
+ if q_f <= 0.0 or q_f >= 1.0:
328
+ raise ValueError("q must be in (0, 1)")
329
+ return q_f
330
+
331
+
332
+ def _normalize_fdr_control(fdr_control: str) -> int:
333
+ key = str(fdr_control).strip().lower()
334
+ if key in ("knockoff_plus", "plus", "knockoff+"):
335
+ return 1
336
+ if key in ("knockoff", "standard"):
337
+ return 0
338
+ raise ValueError("fdr_control must be one of: 'knockoff_plus', 'knockoff'")
339
+
340
+
341
+ def _normalize_knockoff_type(knockoff_type: str) -> str:
342
+ key = str(knockoff_type).strip().lower()
343
+ if key in ("fixed_x", "fixed-x", "fixedx"):
344
+ return "fixed_x"
345
+ if key in ("model_x", "model-x", "modelx"):
346
+ return "model_x"
347
+ raise ValueError("knockoff_type must be one of: 'fixed_x', 'model_x'")
348
+
349
+
350
+ def _standardize_design(X, xp):
351
+ """Standardize design matrix to unit norm (L2 norm = 1 per column).
352
+
353
+ This centers each column to zero mean and scales to unit L2 norm,
354
+ which is the standard normalization for Fixed-X knockoff construction.
355
+
356
+ Note: This differs from R glmnet's internal standardization (unit variance),
357
+ but is the conventional scaling for knockoff methods as it ensures the
358
+ knockoff construction is invariant to feature scaling.
359
+ """
360
+ X = xp.asarray(X, dtype=xp.float64)
361
+ if X.ndim != 2:
362
+ raise ValueError("X must be a 2D array")
363
+
364
+ X_centered = X - xp.mean(X, axis=0, keepdims=True)
365
+ scale = xp.sqrt(xp.sum(X_centered * X_centered, axis=0))
366
+ if bool(xp.any(scale <= 1e-12)):
367
+ raise ValueError("X contains near-constant columns; knockoff construction is unstable")
368
+
369
+ return X_centered / scale
370
+
371
+
372
+ def _standardize_features_unit_variance(X, xp):
373
+ X_arr = xp.asarray(X, dtype=xp.float64)
374
+ if X_arr.ndim != 2:
375
+ raise ValueError("X must be a 2D array")
376
+
377
+ n = int(X_arr.shape[0])
378
+ if n < 2:
379
+ raise ValueError("model-X knockoff requires at least 2 samples")
380
+
381
+ X_centered = X_arr - xp.mean(X_arr, axis=0, keepdims=True)
382
+ if _torch_dev(X_centered) is not None:
383
+ scale = xp.std(X_centered, axis=0, correction=1)
384
+ else:
385
+ scale = xp.std(X_centered, axis=0, ddof=1)
386
+ if bool(xp.any(scale <= 1e-12)):
387
+ raise ValueError("X contains near-constant columns; model-X knockoff is unstable")
388
+
389
+ return X_centered / scale
390
+
391
+
392
+ def _build_fixed_x_knockoffs(X_std, random_state: Optional[int], xp):
393
+ n, p = int(X_std.shape[0]), int(X_std.shape[1])
394
+ if n < 2 * p:
395
+ raise ValueError("fixed-X knockoff requires n_samples >= 2 * n_features")
396
+
397
+ Sigma = X_std.T @ X_std
398
+ Sigma = 0.5 * (Sigma + Sigma.T)
399
+
400
+ eigvals = xp.linalg.eigvalsh(Sigma)
401
+ min_eig = _to_float_scalar(xp.min(eigvals))
402
+ if min_eig <= 1e-10:
403
+ raise ValueError("X'X is near-singular; fixed-X knockoff requires full-rank design")
404
+
405
+ s_val = min(2.0 * min_eig, 1.0)
406
+ if s_val <= 1e-12:
407
+ raise ValueError("Failed to construct a valid knockoff S-matrix")
408
+
409
+ # Create identity matrix on the same device as X_std (important for torch)
410
+ # Handle numpy (no device), cupy (device attribute but different API), and torch
411
+ if xp is np:
412
+ S = s_val * xp.eye(p, dtype=xp.float64)
413
+ elif getattr(xp, '__name__', '') == 'cupy':
414
+ # CuPy: create eye on current device context (same as X_std)
415
+ S = s_val * xp.eye(p, dtype=xp.float64)
416
+ else:
417
+ # Torch: use device keyword
418
+ device = getattr(X_std, 'device', None)
419
+ S = s_val * xp.eye(p, dtype=xp.float64, device=device)
420
+
421
+ # For torch, use torch.linalg.solve which preserves device better
422
+ if xp is np:
423
+ Sigma_inv_S = xp.linalg.solve(Sigma, S)
424
+ elif getattr(xp, '__name__', '') == 'cupy':
425
+ # CuPy
426
+ Sigma_inv_S = xp.linalg.solve(Sigma, S)
427
+ else:
428
+ # Torch: use explicit torch.linalg.solve to ensure device consistency
429
+ import torch
430
+ # Ensure both inputs are on the same device
431
+ torch_device = getattr(X_std, 'device', None)
432
+ Sigma_on_device = Sigma.to(torch_device) if hasattr(Sigma, 'to') else Sigma
433
+ S_on_device = S.to(torch_device) if hasattr(S, 'to') else S
434
+ Sigma_inv_S = torch.linalg.solve(Sigma_on_device, S_on_device)
435
+ # Ensure result is on the correct device
436
+ if torch_device is not None and hasattr(Sigma_inv_S, 'to'):
437
+ Sigma_inv_S = Sigma_inv_S.to(torch_device)
438
+
439
+ c_arg = 2.0 * S - S @ Sigma_inv_S
440
+ c_arg = 0.5 * (c_arg + c_arg.T)
441
+
442
+ c_eigvals, c_eigvecs = xp.linalg.eigh(c_arg)
443
+ c_eigvals = xp.clip(c_eigvals, 0.0, None)
444
+ C = c_eigvecs @ xp.diag(xp.sqrt(c_eigvals)) @ c_eigvecs.T
445
+
446
+ # Generate random matrix A with appropriate backend
447
+ if xp is np:
448
+ rng = np.random.default_rng(random_state)
449
+ A = rng.standard_normal(size=(n, p))
450
+ else:
451
+ # CuPy or Torch
452
+ seed = 0 if random_state is None else int(random_state)
453
+ try:
454
+ # Try CuPy API
455
+ rng = xp.random.RandomState(seed)
456
+ A = rng.standard_normal(size=(n, p), dtype=xp.float64)
457
+ except (AttributeError, TypeError):
458
+ # Torch API: use manual_seed and randn
459
+ import torch
460
+ if xp is torch:
461
+ if hasattr(X_std, "device"):
462
+ torch_device = X_std.device
463
+ else:
464
+ torch_device = torch.device(_get_torch_device_str())
465
+ gen = torch.Generator(device=torch_device)
466
+ gen.manual_seed(seed)
467
+ A = torch.randn(
468
+ n,
469
+ p,
470
+ dtype=torch.float64,
471
+ device=torch_device,
472
+ generator=gen,
473
+ )
474
+ else:
475
+ # Fallback
476
+ rng = xp.random.Generator(xp.random.PCG64(seed))
477
+ A = rng.standard_normal(size=(n, p), dtype=xp.float64)
478
+
479
+ # Q[:, :p] spans col(X), Q[:, p:2p] spans an orthonormal complement basis.
480
+ Q, _ = xp.linalg.qr(xp.concatenate([X_std, A], axis=1), mode="reduced")
481
+ U = Q[:, p : 2 * p]
482
+
483
+ # Create identity matrix on the same device as X_std (important for torch)
484
+ if xp is np:
485
+ eye_matrix = xp.eye(p, dtype=xp.float64)
486
+ elif getattr(xp, '__name__', '') == 'cupy':
487
+ # CuPy: create eye on current device context (same as X_std)
488
+ eye_matrix = xp.eye(p, dtype=xp.float64)
489
+ else:
490
+ # Torch: use device keyword
491
+ device = getattr(X_std, 'device', None)
492
+ eye_matrix = xp.eye(p, dtype=xp.float64, device=device)
493
+
494
+ X_knock = X_std @ (eye_matrix - Sigma_inv_S) + U @ C
495
+ return X_knock
496
+
497
+
498
+ def _build_model_x_knockoffs(
499
+ X_std,
500
+ random_state: Optional[int],
501
+ xp,
502
+ covariance_shrinkage: float = 0.20,
503
+ s_scale: float = 0.999,
504
+ ):
505
+ n, p = int(X_std.shape[0]), int(X_std.shape[1])
506
+
507
+ Sigma = (X_std.T @ X_std) / float(max(1, n - 1))
508
+ Sigma = 0.5 * (Sigma + Sigma.T)
509
+
510
+ shrinkage = float(min(1.0, max(0.0, covariance_shrinkage)))
511
+ if shrinkage > 0.0:
512
+ trace_mean = xp.trace(Sigma) / float(max(1, p))
513
+ # Create identity matrix - numpy/cupy don't need device, torch does
514
+ if xp is np:
515
+ eye_matrix = xp.eye(p, dtype=xp.float64)
516
+ elif getattr(xp, '__name__', '') == 'cupy':
517
+ # CuPy: create eye on current device context (same as X_std)
518
+ eye_matrix = xp.eye(p, dtype=xp.float64)
519
+ else:
520
+ # Torch: use device keyword
521
+ device = getattr(X_std, 'device', None)
522
+ eye_matrix = xp.eye(p, dtype=xp.float64, device=device)
523
+ Sigma = (1.0 - shrinkage) * Sigma + shrinkage * trace_mean * eye_matrix
524
+ Sigma = 0.5 * (Sigma + Sigma.T)
525
+
526
+ eigvals = xp.linalg.eigvalsh(Sigma)
527
+ min_eig = _to_float_scalar(xp.min(eigvals))
528
+
529
+ ridge = 0.0
530
+ if min_eig < 1e-6:
531
+ ridge = float((1e-6 - min_eig) + 1e-8)
532
+ # Create identity matrix - numpy/cupy don't need device, torch does
533
+ if xp is np:
534
+ eye_matrix = xp.eye(p, dtype=xp.float64)
535
+ elif getattr(xp, '__name__', '') == 'cupy':
536
+ # CuPy: create eye on current device context (same as X_std)
537
+ eye_matrix = xp.eye(p, dtype=xp.float64)
538
+ else:
539
+ # Torch: use device keyword
540
+ device = getattr(X_std, 'device', None)
541
+ eye_matrix = xp.eye(p, dtype=xp.float64, device=device)
542
+ Sigma = Sigma + ridge * eye_matrix
543
+ Sigma = 0.5 * (Sigma + Sigma.T)
544
+ eigvals = xp.linalg.eigvalsh(Sigma)
545
+ min_eig = _to_float_scalar(xp.min(eigvals))
546
+
547
+ if min_eig <= 1e-12:
548
+ raise ValueError("Estimated covariance is near-singular; model-X knockoff failed")
549
+
550
+ s_val = min(2.0 * min_eig * float(s_scale), 1.0)
551
+ if s_val <= 1e-12:
552
+ raise ValueError("Failed to construct a valid model-X knockoff S-matrix")
553
+
554
+ # Create identity matrix - numpy/cupy don't need device, torch does
555
+ if xp is np:
556
+ S = s_val * xp.eye(p, dtype=xp.float64)
557
+ elif getattr(xp, '__name__', '') == 'cupy':
558
+ # CuPy: create eye on current device context (same as X_std)
559
+ S = s_val * xp.eye(p, dtype=xp.float64)
560
+ else:
561
+ # Torch: use device keyword
562
+ device = getattr(X_std, 'device', None)
563
+ S = s_val * xp.eye(p, dtype=xp.float64, device=device)
564
+
565
+ # For torch, use explicit torch.linalg.solve to ensure device consistency
566
+ if xp is np:
567
+ Sigma_inv_S = xp.linalg.solve(Sigma, S)
568
+ elif getattr(xp, '__name__', '') == 'cupy':
569
+ # CuPy
570
+ Sigma_inv_S = xp.linalg.solve(Sigma, S)
571
+ else:
572
+ # Torch: use explicit torch.linalg.solve to ensure device consistency
573
+ import torch
574
+ torch_device = getattr(X_std, 'device', None)
575
+ Sigma_on_device = Sigma.to(torch_device) if hasattr(Sigma, 'to') else Sigma
576
+ S_on_device = S.to(torch_device) if hasattr(S, 'to') else S
577
+ Sigma_inv_S = torch.linalg.solve(Sigma_on_device, S_on_device)
578
+ if torch_device is not None and hasattr(Sigma_inv_S, 'to'):
579
+ Sigma_inv_S = Sigma_inv_S.to(torch_device)
580
+
581
+ c_arg = 2.0 * S - S @ Sigma_inv_S
582
+ c_arg = 0.5 * (c_arg + c_arg.T)
583
+ c_eigvals, c_eigvecs = xp.linalg.eigh(c_arg)
584
+ c_eigvals = xp.clip(c_eigvals, 0.0, None)
585
+ C = c_eigvecs @ xp.diag(xp.sqrt(c_eigvals)) @ c_eigvecs.T
586
+
587
+ # Generate random matrix Z with appropriate backend
588
+ if xp is np:
589
+ rng = np.random.default_rng(random_state)
590
+ Z = rng.standard_normal(size=(n, p))
591
+ else:
592
+ # CuPy or Torch
593
+ seed = 0 if random_state is None else int(random_state)
594
+ try:
595
+ # Try CuPy API
596
+ rng = xp.random.RandomState(seed)
597
+ Z = rng.standard_normal(size=(n, p), dtype=xp.float64)
598
+ except (AttributeError, TypeError):
599
+ # Torch API: use manual_seed and randn
600
+ import torch
601
+ if isinstance(xp, type(torch)):
602
+ gen = torch.Generator(device=_get_torch_device_str())
603
+ gen.manual_seed(seed)
604
+ Z = torch.randn(n, p, dtype=torch.float64, device=_get_torch_device_str())
605
+ else:
606
+ # Fallback
607
+ rng = xp.random.Generator(xp.random.PCG64(seed))
608
+ Z = rng.standard_normal(size=(n, p), dtype=xp.float64)
609
+
610
+ X_knock = X_std - X_std @ Sigma_inv_S + Z @ C
611
+ return X_knock, {
612
+ "s_value": float(s_val),
613
+ "ridge": float(ridge),
614
+ "min_eigenvalue": float(min_eig),
615
+ "covariance_shrinkage": float(shrinkage),
616
+ "s_scale": float(s_scale),
617
+ }
618
+
619
+
620
+ def _build_model_x_knockoffs_knockpy_compat(
621
+ X,
622
+ random_state: Optional[int],
623
+ *,
624
+ modelx_shrinkage: str = "ledoitwolf",
625
+ modelx_smatrix_method: str = "mvr",
626
+ sample_tol: float = 1e-5,
627
+ ):
628
+ X_np = np.asarray(X, dtype=np.float64)
629
+ if X_np.ndim != 2:
630
+ raise ValueError("X must be a 2D array")
631
+
632
+ n, p = int(X_np.shape[0]), int(X_np.shape[1])
633
+ if n < 2:
634
+ raise ValueError("model-X knockoff requires at least 2 samples")
635
+
636
+ mu = np.mean(X_np, axis=0)
637
+ Sigma, inv_sigma, cov_estimator = _estimate_covariance_knockpy_style(
638
+ X_np,
639
+ shrinkage=modelx_shrinkage,
640
+ tol=1e-4,
641
+ )
642
+ S, smatrix_source, smatrix_gamma = _compute_smatrix_knockpy_style(
643
+ Sigma,
644
+ method=modelx_smatrix_method,
645
+ tol=float(sample_tol),
646
+ )
647
+
648
+ inv_sigma_S = inv_sigma @ S
649
+ mu_k = X_np - (X_np - mu.reshape(1, -1)) @ inv_sigma_S
650
+ Vk = 2.0 * S - S @ inv_sigma_S
651
+ Vk = _shift_until_psd_np(Vk, tol=float(sample_tol))
652
+
653
+ Lk = np.linalg.cholesky(Vk)
654
+ with _temporary_numpy_seed(random_state):
655
+ Z = np.random.randn(n, p)
656
+ X_knock = Z @ Lk.T + mu_k
657
+
658
+ return np.asarray(X_knock, dtype=np.float64), {
659
+ "s_value": float(np.mean(np.diag(S))),
660
+ "ridge": 0.0,
661
+ "min_eigenvalue": float(_calc_mineig_np(Sigma)),
662
+ "covariance_shrinkage": None,
663
+ "s_scale": float(smatrix_gamma),
664
+ "modelx_shrinkage": str(modelx_shrinkage),
665
+ "modelx_smatrix_method": str(modelx_smatrix_method),
666
+ "modelx_covariance_estimator": str(cov_estimator),
667
+ "modelx_smatrix_source": str(smatrix_source),
668
+ }
669
+
670
+
671
+ def _model_x_draw_seed(random_state: Optional[int], draw_index: int) -> Optional[int]:
672
+ if random_state is None:
673
+ return None
674
+ return int(random_state) + 104729 * int(draw_index)
675
+
676
+
677
+ def _corr_diff_statistics(X_std, X_knock, y, xp):
678
+ y_arr = xp.asarray(y, dtype=xp.float64).reshape(-1)
679
+ if y_arr.shape[0] != X_std.shape[0]:
680
+ raise ValueError("y must have the same number of rows as X")
681
+
682
+ y_centered = y_arr - xp.mean(y_arr)
683
+ score_orig = xp.abs(X_std.T @ y_centered)
684
+ score_knock = xp.abs(X_knock.T @ y_centered)
685
+ return score_orig - score_knock
686
+
687
+
688
+ def _ols_coef_diff_statistics(X_std, X_knock, y, xp, ridge: float = 1e-8):
689
+ y_arr = xp.asarray(y, dtype=xp.float64).reshape(-1)
690
+ if y_arr.shape[0] != X_std.shape[0]:
691
+ raise ValueError("y must have the same number of rows as X")
692
+
693
+ y_centered = y_arr - xp.mean(y_arr)
694
+ p = int(X_std.shape[1])
695
+
696
+ Z = xp.concatenate([X_std, X_knock], axis=1)
697
+ ridge_f = float(max(0.0, ridge))
698
+
699
+ if ridge_f > 0.0:
700
+ # Create identity matrix - numpy/cupy don't need device, torch does
701
+ if xp is np:
702
+ eye_matrix = xp.eye(2 * p, dtype=xp.float64)
703
+ elif getattr(xp, '__name__', '') == 'cupy':
704
+ # CuPy: create eye on current device context (same as Z)
705
+ eye_matrix = xp.eye(2 * p, dtype=xp.float64)
706
+ else:
707
+ # Torch: use device keyword
708
+ device = getattr(Z, 'device', None)
709
+ eye_matrix = xp.eye(2 * p, dtype=xp.float64, device=device)
710
+ gram = Z.T @ Z + ridge_f * eye_matrix
711
+ rhs = Z.T @ y_centered
712
+ try:
713
+ coef = xp.linalg.solve(gram, rhs)
714
+ except Exception:
715
+ coef = xp.linalg.lstsq(Z, y_centered, rcond=None)[0]
716
+ else:
717
+ coef = xp.linalg.lstsq(Z, y_centered, rcond=None)[0]
718
+
719
+ coef_orig = coef[:p]
720
+ coef_knock = coef[p:]
721
+ return xp.abs(coef_orig) - xp.abs(coef_knock)
722
+
723
+
724
+ def _lasso_coef_diff_statistics(
725
+ X_std,
726
+ X_knock,
727
+ y,
728
+ xp,
729
+ random_state: Optional[int] = None,
730
+ backend_name: str = "numpy",
731
+ max_iter: int = 3000,
732
+ tol: float = 1e-4,
733
+ cv_folds: int = 5,
734
+ n_alphas: int = 12,
735
+ lasso_cv_impl: str = "statgpu",
736
+ lasso_fast_profile: str = "off",
737
+ knockpy_style: bool = False,
738
+ ):
739
+ y_arr = xp.asarray(y, dtype=xp.float64).reshape(-1)
740
+ if y_arr.shape[0] != X_std.shape[0]:
741
+ raise ValueError("y must have the same number of rows as X")
742
+
743
+ if bool(knockpy_style):
744
+ y_model = y_arr
745
+ else:
746
+ y_model = y_arr - xp.mean(y_arr)
747
+ p = int(X_std.shape[1])
748
+ problem_size_full = int(X_std.shape[0]) * int(2 * p)
749
+ fast_profile_eff = _resolve_lasso_fast_profile_for_problem(
750
+ lasso_fast_profile,
751
+ problem_size_full,
752
+ )
753
+
754
+ cv_folds_eff = max(2, int(cv_folds))
755
+ n_alphas_eff = max(2, int(n_alphas))
756
+ max_iter_eff = max(500, int(max_iter))
757
+ tol_base = float(tol)
758
+
759
+ if fast_profile_eff == "moderate":
760
+ if problem_size_full >= 1_000_000:
761
+ cv_folds_eff = min(cv_folds_eff, 4)
762
+ n_alphas_eff = min(n_alphas_eff, 14 if bool(knockpy_style) else 12)
763
+ max_iter_eff = min(max_iter_eff, 2800)
764
+ elif fast_profile_eff == "aggressive":
765
+ if problem_size_full >= 2_000_000:
766
+ cv_folds_eff = min(cv_folds_eff, 2)
767
+ n_alphas_eff = min(n_alphas_eff, 6 if bool(knockpy_style) else 5)
768
+ max_iter_eff = min(max_iter_eff, 1600)
769
+ else:
770
+ cv_folds_eff = min(cv_folds_eff, 3)
771
+ n_alphas_eff = min(n_alphas_eff, 8 if bool(knockpy_style) else 7)
772
+ max_iter_eff = min(max_iter_eff, 2200)
773
+
774
+ tol_eff = max(1e-3, tol_base) if bool(knockpy_style) else tol_base
775
+ if fast_profile_eff == "aggressive":
776
+ tol_eff = max(tol_eff, 4e-3 if problem_size_full >= 2_000_000 else 2e-3)
777
+
778
+ lasso_diff_cache_key = _make_lasso_coef_diff_cache_key(
779
+ X_std=X_std,
780
+ X_knock=X_knock,
781
+ y=y_arr,
782
+ random_state=random_state,
783
+ backend_name=backend_name,
784
+ max_iter_eff=int(max_iter_eff),
785
+ tol_eff=float(tol_eff),
786
+ cv_folds_eff=int(cv_folds_eff),
787
+ n_alphas_eff=int(n_alphas_eff),
788
+ lasso_cv_impl=lasso_cv_impl,
789
+ fast_profile_eff=fast_profile_eff,
790
+ knockpy_style=bool(knockpy_style),
791
+ )
792
+ cached_w = _lasso_diff_cache_get(lasso_diff_cache_key)
793
+ if cached_w is not None:
794
+ return xp.asarray(cached_w, dtype=xp.float64)
795
+
796
+ Z = xp.concatenate([X_std, X_knock], axis=1)
797
+
798
+ # Knockpy-style symmetry preservation: permute [X, Xk] jointly, then undo at the end.
799
+ inds, rev_inds = _random_permutation_inds(2 * p, random_state=random_state)
800
+ alphas = np.logspace(-4.0, 4.0, base=10.0, num=int(n_alphas_eff))
801
+
802
+ cv_impl = _normalize_lasso_cv_impl(lasso_cv_impl)
803
+
804
+ # Force statgpu for torch backend since sklearn doesn't support torch tensors
805
+ backend_is_torch = str(backend_name).lower() == "torch"
806
+ if backend_is_torch and cv_impl == "sklearn":
807
+ cv_impl = "statgpu"
808
+
809
+ if cv_impl == "sklearn":
810
+ try:
811
+ from sklearn import linear_model
812
+ except Exception:
813
+ cv_impl = "statgpu"
814
+
815
+ if cv_impl == "sklearn":
816
+ Z_np = _to_numpy(Z).astype(np.float64, copy=False)
817
+ y_np = _to_numpy(y_model).astype(np.float64, copy=False).reshape(-1)
818
+ Z_perm = Z_np[:, inds]
819
+ with warnings.catch_warnings():
820
+ warnings.simplefilter("ignore")
821
+ model = linear_model.LassoCV(
822
+ alphas=alphas,
823
+ cv=int(cv_folds_eff),
824
+ verbose=False,
825
+ max_iter=int(max_iter_eff),
826
+ tol=float(tol_eff),
827
+ ).fit(Z_perm, y_np)
828
+ coef_perm = np.asarray(model.coef_, dtype=np.float64).reshape(-1)
829
+ else:
830
+ from statgpu.linear_model.wrappers._lasso import _fit_lasso_single_alpha_fast, _select_lasso_alpha_cv
831
+
832
+ use_cupy_native = str(backend_name).lower() == "cupy" and _is_cupy_array(Z)
833
+ use_torch_native = str(backend_name).lower() == "torch" and hasattr(Z, 'shape')
834
+ if use_cupy_native:
835
+ import cupy as cp
836
+
837
+ inds_device = cp.asarray(inds, dtype=cp.int64)
838
+ Z_perm = xp.asarray(Z, dtype=xp.float64)[:, inds_device]
839
+ y_fit = xp.asarray(y_model, dtype=xp.float64).reshape(-1)
840
+ elif use_torch_native:
841
+ import torch
842
+ inds_tensor = torch.tensor(inds, dtype=torch.int64, device=Z.device)
843
+ Z_perm = Z[:, inds_tensor]
844
+ y_fit = y_model.reshape(-1)
845
+ else:
846
+ Z_np = _to_numpy(Z).astype(np.float64, copy=False)
847
+ Z_perm = Z_np[:, inds]
848
+ y_fit = _to_numpy(y_model).astype(np.float64, copy=False).reshape(-1)
849
+
850
+ problem_size = int(Z_perm.shape[0]) * int(Z_perm.shape[1])
851
+
852
+ fit_intercept_eff = bool(knockpy_style)
853
+ if random_state is None:
854
+ alpha_cache_key = None
855
+ else:
856
+ alpha_cache_key = (
857
+ "knockoff_lasso_cv_v1",
858
+ _array_identity_token(X_std),
859
+ _array_identity_token(X_knock),
860
+ _array_identity_token(y_arr),
861
+ int(random_state),
862
+ str(backend_name).lower(),
863
+ bool(knockpy_style),
864
+ str(fast_profile_eff).lower(),
865
+ int(cv_folds_eff),
866
+ int(n_alphas_eff),
867
+ int(max_iter_eff),
868
+ float(tol_eff),
869
+ _int_array_signature(inds),
870
+ )
871
+ alpha_select_kwargs = {
872
+ "cv_folds": int(cv_folds_eff),
873
+ "random_state": random_state,
874
+ "fit_intercept": fit_intercept_eff,
875
+ "device": "cuda" if str(backend_name).lower() in ("cupy", "torch") else "cpu",
876
+ "max_iter": int(max_iter_eff),
877
+ "tol": tol_eff,
878
+ "cpu_solver": "coordinate_descent",
879
+ "cache_key": alpha_cache_key,
880
+ }
881
+ if bool(knockpy_style):
882
+ # Match knockpy-oriented branch settings used by the sklearn path as closely as possible.
883
+ alpha_select_kwargs["alphas"] = alphas
884
+ alpha_select_kwargs["method"] = "glmnet"
885
+ # For large designs, reduce full KKT scan frequency to lower CV overhead.
886
+ cd_kkt_check_every_eff = 4 if problem_size >= 1_000_000 else 2
887
+ if fast_profile_eff == "moderate":
888
+ cd_kkt_check_every_eff = max(cd_kkt_check_every_eff, 6)
889
+ elif fast_profile_eff == "aggressive":
890
+ cd_kkt_check_every_eff = max(
891
+ cd_kkt_check_every_eff,
892
+ 12 if problem_size >= 2_000_000 else 8,
893
+ )
894
+ alpha_select_kwargs["cd_kkt_check_every"] = cd_kkt_check_every_eff
895
+ else:
896
+ alpha_select_kwargs["n_alphas"] = int(n_alphas_eff)
897
+
898
+ alpha = _select_lasso_alpha_cv(
899
+ Z_perm,
900
+ y_fit,
901
+ **alpha_select_kwargs,
902
+ )
903
+
904
+ fit_out = _fit_lasso_single_alpha_fast(
905
+ Z_perm,
906
+ y_fit,
907
+ alpha=float(alpha),
908
+ fit_intercept=fit_intercept_eff,
909
+ max_iter=int(max_iter_eff),
910
+ tol=tol_eff,
911
+ device="cuda" if str(backend_name).lower() in ("cupy", "torch") else "cpu",
912
+ stopping="coef_delta",
913
+ cpu_solver="coordinate_descent",
914
+ cd_kkt_check_every=int(alpha_select_kwargs.get("cd_kkt_check_every", 1)),
915
+ )
916
+
917
+ coef_perm = np.asarray(fit_out["coef"], dtype=np.float64).reshape(-1)
918
+ if coef_perm.shape[0] != 2 * p:
919
+ raise RuntimeError("lasso_coef_diff produced unexpected coefficient shape")
920
+
921
+ coef = coef_perm[rev_inds]
922
+
923
+ W_np = np.abs(coef[:p]) - np.abs(coef[p:])
924
+ _lasso_diff_cache_put(lasso_diff_cache_key, W_np)
925
+ return xp.asarray(W_np, dtype=xp.float64)
926
+
927
+
928
+ def _compute_w_statistics(
929
+ X_std,
930
+ X_knock,
931
+ y,
932
+ method: str,
933
+ xp,
934
+ random_state: Optional[int] = None,
935
+ backend_name: str = "numpy",
936
+ lasso_cv_impl: str = "statgpu",
937
+ lasso_fast_profile: str = "off",
938
+ lasso_knockpy_style: bool = False,
939
+ ):
940
+ key = str(method).strip().lower()
941
+ if key == "corr_diff":
942
+ return _corr_diff_statistics(X_std, X_knock, y, xp), "corr_diff"
943
+ if key in ("ols_coef_diff", "ols", "coef_diff"):
944
+ return _ols_coef_diff_statistics(X_std, X_knock, y, xp), "ols_coef_diff"
945
+ if key in ("lasso_coef_diff", "lasso", "lasso_diff"):
946
+ return (
947
+ _lasso_coef_diff_statistics(
948
+ X_std,
949
+ X_knock,
950
+ y,
951
+ xp,
952
+ random_state=random_state,
953
+ backend_name=backend_name,
954
+ lasso_cv_impl=lasso_cv_impl,
955
+ lasso_fast_profile=lasso_fast_profile,
956
+ knockpy_style=lasso_knockpy_style,
957
+ n_alphas=20 if bool(lasso_knockpy_style) else 12,
958
+ ),
959
+ "lasso_coef_diff",
960
+ )
961
+ raise ValueError("method must be one of: 'corr_diff', 'ols_coef_diff', 'lasso_coef_diff'")
962
+
963
+
964
+ def _knockoff_threshold_and_path(W, q: float, offset: int):
965
+ W_np = np.asarray(_to_numpy(W), dtype=np.float64).reshape(-1)
966
+ if W_np.size == 0:
967
+ return float(np.inf), 0.0, []
968
+
969
+ abs_w = np.abs(W_np)
970
+ if not np.any(abs_w > 0):
971
+ return float(np.inf), 0.0, []
972
+
973
+ inds = np.argsort(-abs_w, kind="stable")
974
+ negatives = np.cumsum(W_np[inds] <= 0)
975
+ positives = np.cumsum(W_np[inds] > 0)
976
+ positives[positives == 0] = 1
977
+ hat_fdrs = (negatives + int(offset)) / positives
978
+
979
+ trajectory: List[Dict[str, float]] = []
980
+ for rank, idx in enumerate(inds):
981
+ trajectory.append(
982
+ {
983
+ "rank": int(rank + 1),
984
+ "threshold": float(abs_w[idx]),
985
+ "fdr_hat": float(min(1.0, hat_fdrs[rank])),
986
+ "n_selected": int(positives[rank]),
987
+ }
988
+ )
989
+
990
+ if np.any(hat_fdrs <= float(q)):
991
+ valid = np.where(hat_fdrs <= float(q))[0]
992
+ chosen_rank = int(valid.max())
993
+ chosen_threshold = float(abs_w[inds[chosen_rank]])
994
+ if chosen_threshold == 0.0:
995
+ positive_w = W_np[W_np > 0.0]
996
+ if positive_w.size > 0:
997
+ chosen_threshold = float(np.min(positive_w))
998
+ else:
999
+ chosen_threshold = float(np.inf)
1000
+ chosen_fdr = float(min(1.0, hat_fdrs[chosen_rank]))
1001
+ return chosen_threshold, chosen_fdr, trajectory
1002
+
1003
+ return float(np.inf), 0.0, trajectory