statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1003 @@
|
|
|
1
|
+
"""Utility functions for knockoff feature selection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
import hashlib
|
|
8
|
+
import os
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from statgpu.backends import (
|
|
15
|
+
_get_torch_device_str,
|
|
16
|
+
_torch_dev,
|
|
17
|
+
_get_xp,
|
|
18
|
+
_resolve_backend,
|
|
19
|
+
_to_float_scalar,
|
|
20
|
+
_to_numpy,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Re-export for backward compatibility with modules that import from here
|
|
24
|
+
__all__ = [
|
|
25
|
+
"_get_xp",
|
|
26
|
+
"_resolve_backend",
|
|
27
|
+
"_to_numpy",
|
|
28
|
+
"_to_float_scalar",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
_LASSO_DIFF_CACHE_MAXSIZE = int(os.getenv("STATGPU_KNOCKOFF_LASSO_CACHE_SIZE", "32"))
|
|
33
|
+
_LASSO_DIFF_CACHE: "OrderedDict[Tuple[Any, ...], np.ndarray]" = OrderedDict()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _array_identity_token(x: Any) -> Tuple[Any, ...]:
|
|
37
|
+
if x is None:
|
|
38
|
+
return ("none",)
|
|
39
|
+
|
|
40
|
+
# Try CuPy array
|
|
41
|
+
try:
|
|
42
|
+
import cupy as cp
|
|
43
|
+
|
|
44
|
+
if isinstance(x, cp.ndarray):
|
|
45
|
+
return ("cupy", int(x.data.ptr), tuple(int(v) for v in x.shape), str(x.dtype))
|
|
46
|
+
except Exception:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
# Try Torch tensor
|
|
50
|
+
try:
|
|
51
|
+
import torch
|
|
52
|
+
|
|
53
|
+
if isinstance(x, torch.Tensor):
|
|
54
|
+
if x.is_cuda:
|
|
55
|
+
return ("torch_cuda", int(x.data_ptr()), tuple(int(v) for v in x.shape), str(x.dtype))
|
|
56
|
+
else:
|
|
57
|
+
return ("torch_cpu", int(x.data_ptr()), tuple(int(v) for v in x.shape), str(x.dtype))
|
|
58
|
+
except Exception:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
# Default to NumPy
|
|
62
|
+
arr = np.asarray(x)
|
|
63
|
+
ptr = int(arr.__array_interface__["data"][0]) if int(arr.size) > 0 else 0
|
|
64
|
+
return ("numpy", ptr, tuple(int(v) for v in arr.shape), str(arr.dtype))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _int_array_signature(x: Any) -> str:
|
|
68
|
+
arr = np.ascontiguousarray(np.asarray(x, dtype=np.int64).reshape(-1))
|
|
69
|
+
return hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _lasso_diff_cache_get(cache_key: Optional[Tuple[Any, ...]]) -> Optional[np.ndarray]:
|
|
73
|
+
if cache_key is None or _LASSO_DIFF_CACHE_MAXSIZE <= 0:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
cached = _LASSO_DIFF_CACHE.get(cache_key)
|
|
77
|
+
if cached is None:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
_LASSO_DIFF_CACHE.move_to_end(cache_key)
|
|
81
|
+
return np.asarray(cached, dtype=np.float64).copy()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _lasso_diff_cache_put(cache_key: Optional[Tuple[Any, ...]], value: np.ndarray) -> None:
|
|
85
|
+
if cache_key is None or _LASSO_DIFF_CACHE_MAXSIZE <= 0:
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
_LASSO_DIFF_CACHE[cache_key] = np.asarray(value, dtype=np.float64).copy()
|
|
89
|
+
_LASSO_DIFF_CACHE.move_to_end(cache_key)
|
|
90
|
+
|
|
91
|
+
while len(_LASSO_DIFF_CACHE) > int(_LASSO_DIFF_CACHE_MAXSIZE):
|
|
92
|
+
_LASSO_DIFF_CACHE.popitem(last=False)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _make_lasso_coef_diff_cache_key(
|
|
96
|
+
*,
|
|
97
|
+
X_std,
|
|
98
|
+
X_knock,
|
|
99
|
+
y,
|
|
100
|
+
random_state: Optional[int],
|
|
101
|
+
backend_name: str,
|
|
102
|
+
max_iter_eff: int,
|
|
103
|
+
tol_eff: float,
|
|
104
|
+
cv_folds_eff: int,
|
|
105
|
+
n_alphas_eff: int,
|
|
106
|
+
lasso_cv_impl: str,
|
|
107
|
+
fast_profile_eff: str,
|
|
108
|
+
knockpy_style: bool,
|
|
109
|
+
) -> Optional[Tuple[Any, ...]]:
|
|
110
|
+
# random_state=None implies a fresh random permutation every call; disable reuse.
|
|
111
|
+
if random_state is None:
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
return (
|
|
115
|
+
"knockoff_lasso_diff_v1",
|
|
116
|
+
_array_identity_token(X_std),
|
|
117
|
+
_array_identity_token(X_knock),
|
|
118
|
+
_array_identity_token(y),
|
|
119
|
+
int(random_state),
|
|
120
|
+
str(backend_name).lower(),
|
|
121
|
+
int(max_iter_eff),
|
|
122
|
+
float(tol_eff),
|
|
123
|
+
int(cv_folds_eff),
|
|
124
|
+
int(n_alphas_eff),
|
|
125
|
+
str(lasso_cv_impl).lower(),
|
|
126
|
+
str(fast_profile_eff).lower(),
|
|
127
|
+
bool(knockpy_style),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _normalize_compat_mode(compat_mode: str) -> str:
|
|
132
|
+
key = str(compat_mode).strip().lower()
|
|
133
|
+
if key in ("statgpu", "default"):
|
|
134
|
+
return "statgpu"
|
|
135
|
+
if key in ("knockpy", "compat", "knockpy_compat"):
|
|
136
|
+
return "knockpy"
|
|
137
|
+
raise ValueError("compat_mode must be one of: 'statgpu', 'knockpy'")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _normalize_lasso_cv_impl(lasso_cv_impl: str) -> str:
|
|
141
|
+
key = str(lasso_cv_impl).strip().lower()
|
|
142
|
+
if key in ("auto", "default"):
|
|
143
|
+
return "auto"
|
|
144
|
+
if key in ("statgpu", "internal"):
|
|
145
|
+
return "statgpu"
|
|
146
|
+
if key in ("sklearn", "knockpy", "knockpy_sklearn"):
|
|
147
|
+
return "sklearn"
|
|
148
|
+
raise ValueError("lasso_cv_impl must be one of: 'auto', 'statgpu', 'sklearn'")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _normalize_lasso_fast_profile(lasso_fast_profile: str) -> str:
|
|
152
|
+
key = str(lasso_fast_profile).strip().lower()
|
|
153
|
+
if key in ("off", "none", "default"):
|
|
154
|
+
return "off"
|
|
155
|
+
if key in ("auto",):
|
|
156
|
+
return "auto"
|
|
157
|
+
if key in ("moderate", "balanced"):
|
|
158
|
+
return "moderate"
|
|
159
|
+
if key in ("aggressive", "fast"):
|
|
160
|
+
return "aggressive"
|
|
161
|
+
raise ValueError(
|
|
162
|
+
"lasso_fast_profile must be one of: 'off', 'auto', 'moderate', 'aggressive'"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _resolve_lasso_fast_profile_for_problem(lasso_fast_profile: str, problem_size: int) -> str:
|
|
167
|
+
profile = _normalize_lasso_fast_profile(lasso_fast_profile)
|
|
168
|
+
if profile != "auto":
|
|
169
|
+
return profile
|
|
170
|
+
|
|
171
|
+
if int(problem_size) >= 2_000_000:
|
|
172
|
+
return "moderate"
|
|
173
|
+
return "off"
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@contextmanager
|
|
177
|
+
def _temporary_numpy_seed(seed: Optional[int]):
|
|
178
|
+
if seed is None:
|
|
179
|
+
yield
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
state = np.random.get_state()
|
|
183
|
+
np.random.seed(int(seed))
|
|
184
|
+
try:
|
|
185
|
+
yield
|
|
186
|
+
finally:
|
|
187
|
+
np.random.set_state(state)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _calc_mineig_np(M: np.ndarray) -> float:
|
|
191
|
+
eigvals = np.linalg.eigvalsh(0.5 * (M + M.T))
|
|
192
|
+
return float(np.min(eigvals))
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _shift_until_psd_np(M: np.ndarray, tol: float) -> np.ndarray:
|
|
196
|
+
mineig = _calc_mineig_np(M)
|
|
197
|
+
if mineig < float(tol):
|
|
198
|
+
M = M + (float(tol) - mineig) * np.eye(M.shape[0], dtype=np.float64)
|
|
199
|
+
return 0.5 * (M + M.T)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _scale_until_psd_np(
|
|
203
|
+
Sigma: np.ndarray,
|
|
204
|
+
S: np.ndarray,
|
|
205
|
+
tol: float = 1e-5,
|
|
206
|
+
num_iter: int = 25,
|
|
207
|
+
):
|
|
208
|
+
S_shifted = _shift_until_psd_np(S, tol)
|
|
209
|
+
|
|
210
|
+
lower = 0.0
|
|
211
|
+
upper = 1.0
|
|
212
|
+
for _ in range(int(num_iter)):
|
|
213
|
+
gamma = 0.5 * (lower + upper)
|
|
214
|
+
V = 2.0 * Sigma - gamma * S_shifted
|
|
215
|
+
try:
|
|
216
|
+
np.linalg.cholesky(V - float(tol) * np.eye(V.shape[0], dtype=np.float64))
|
|
217
|
+
lower = gamma
|
|
218
|
+
except np.linalg.LinAlgError:
|
|
219
|
+
upper = gamma
|
|
220
|
+
|
|
221
|
+
gamma = float(lower)
|
|
222
|
+
return gamma * S_shifted, gamma
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _estimate_covariance_knockpy_style(
|
|
226
|
+
X: np.ndarray,
|
|
227
|
+
*,
|
|
228
|
+
shrinkage: str = "ledoitwolf",
|
|
229
|
+
tol: float = 1e-4,
|
|
230
|
+
):
|
|
231
|
+
X_np = np.asarray(X, dtype=np.float64)
|
|
232
|
+
|
|
233
|
+
shrink_key = str(shrinkage).strip().lower()
|
|
234
|
+
if shrink_key in ("none", "mle"):
|
|
235
|
+
shrink_key = "none"
|
|
236
|
+
|
|
237
|
+
Sigma = None
|
|
238
|
+
inv_sigma = None
|
|
239
|
+
estimator_name = shrink_key
|
|
240
|
+
|
|
241
|
+
if shrink_key == "none":
|
|
242
|
+
Sigma = np.cov(X_np.T)
|
|
243
|
+
if _calc_mineig_np(Sigma) < float(tol):
|
|
244
|
+
shrink_key = "ledoitwolf"
|
|
245
|
+
estimator_name = "ledoitwolf_auto"
|
|
246
|
+
|
|
247
|
+
if shrink_key != "none":
|
|
248
|
+
try:
|
|
249
|
+
from sklearn import covariance as sk_cov
|
|
250
|
+
except Exception:
|
|
251
|
+
# Fallback keeps compatibility even when sklearn is unavailable.
|
|
252
|
+
Sigma = np.cov(X_np.T)
|
|
253
|
+
estimator_name = "mle_fallback_no_sklearn"
|
|
254
|
+
else:
|
|
255
|
+
if shrink_key == "ledoitwolf":
|
|
256
|
+
estimator = sk_cov.LedoitWolf()
|
|
257
|
+
elif shrink_key in ("graphicallasso", "glasso"):
|
|
258
|
+
estimator = sk_cov.GraphicalLasso(alpha=0.1)
|
|
259
|
+
else:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"modelx_shrinkage must be one of: 'ledoitwolf', 'none', 'mle', 'graphicallasso'"
|
|
262
|
+
)
|
|
263
|
+
with warnings.catch_warnings():
|
|
264
|
+
warnings.simplefilter("ignore")
|
|
265
|
+
estimator.fit(X_np)
|
|
266
|
+
Sigma = np.asarray(estimator.covariance_, dtype=np.float64)
|
|
267
|
+
inv_sigma = np.asarray(estimator.precision_, dtype=np.float64)
|
|
268
|
+
estimator_name = shrink_key
|
|
269
|
+
|
|
270
|
+
Sigma = 0.5 * (np.asarray(Sigma, dtype=np.float64) + np.asarray(Sigma, dtype=np.float64).T)
|
|
271
|
+
if inv_sigma is None:
|
|
272
|
+
try:
|
|
273
|
+
inv_sigma = np.linalg.inv(Sigma)
|
|
274
|
+
except np.linalg.LinAlgError:
|
|
275
|
+
ridge = max(1e-8, -_calc_mineig_np(Sigma) + 1e-8)
|
|
276
|
+
Sigma = Sigma + ridge * np.eye(Sigma.shape[0], dtype=np.float64)
|
|
277
|
+
Sigma = 0.5 * (Sigma + Sigma.T)
|
|
278
|
+
inv_sigma = np.linalg.inv(Sigma)
|
|
279
|
+
|
|
280
|
+
return Sigma, np.asarray(inv_sigma, dtype=np.float64), estimator_name
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _compute_smatrix_knockpy_style(
|
|
284
|
+
Sigma: np.ndarray,
|
|
285
|
+
*,
|
|
286
|
+
method: str = "mvr",
|
|
287
|
+
tol: float = 1e-5,
|
|
288
|
+
):
|
|
289
|
+
Sigma_np = np.asarray(Sigma, dtype=np.float64)
|
|
290
|
+
p = int(Sigma_np.shape[0])
|
|
291
|
+
groups = np.arange(1, p + 1, dtype=np.int64)
|
|
292
|
+
|
|
293
|
+
source = "equicorrelated_fallback"
|
|
294
|
+
S = None
|
|
295
|
+
try:
|
|
296
|
+
from knockpy import smatrix as kp_smatrix
|
|
297
|
+
|
|
298
|
+
S = kp_smatrix.compute_smatrix(
|
|
299
|
+
Sigma=Sigma_np,
|
|
300
|
+
groups=groups,
|
|
301
|
+
method=str(method).strip().lower(),
|
|
302
|
+
)
|
|
303
|
+
source = "knockpy"
|
|
304
|
+
except Exception:
|
|
305
|
+
# Robust fallback if knockpy is not installed.
|
|
306
|
+
min_eig = _calc_mineig_np(Sigma_np)
|
|
307
|
+
s_val = min(2.0 * min_eig, 1.0)
|
|
308
|
+
if s_val <= 1e-12:
|
|
309
|
+
raise ValueError("Failed to construct model-X knockoff S-matrix")
|
|
310
|
+
S = s_val * np.eye(p, dtype=np.float64)
|
|
311
|
+
|
|
312
|
+
S = _shift_until_psd_np(np.asarray(S, dtype=np.float64), tol=float(tol))
|
|
313
|
+
S, gamma = _scale_until_psd_np(Sigma_np, S, tol=float(tol), num_iter=25)
|
|
314
|
+
return S, source, float(gamma)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _random_permutation_inds(length: int, random_state: Optional[int]):
|
|
318
|
+
rng = np.random.default_rng(random_state)
|
|
319
|
+
inds = rng.permutation(int(length)).astype(np.int64, copy=False)
|
|
320
|
+
rev_inds = np.empty(int(length), dtype=np.int64)
|
|
321
|
+
rev_inds[inds] = np.arange(int(length), dtype=np.int64)
|
|
322
|
+
return inds, rev_inds
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _validate_q(q: float) -> float:
|
|
326
|
+
q_f = float(q)
|
|
327
|
+
if q_f <= 0.0 or q_f >= 1.0:
|
|
328
|
+
raise ValueError("q must be in (0, 1)")
|
|
329
|
+
return q_f
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _normalize_fdr_control(fdr_control: str) -> int:
|
|
333
|
+
key = str(fdr_control).strip().lower()
|
|
334
|
+
if key in ("knockoff_plus", "plus", "knockoff+"):
|
|
335
|
+
return 1
|
|
336
|
+
if key in ("knockoff", "standard"):
|
|
337
|
+
return 0
|
|
338
|
+
raise ValueError("fdr_control must be one of: 'knockoff_plus', 'knockoff'")
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def _normalize_knockoff_type(knockoff_type: str) -> str:
|
|
342
|
+
key = str(knockoff_type).strip().lower()
|
|
343
|
+
if key in ("fixed_x", "fixed-x", "fixedx"):
|
|
344
|
+
return "fixed_x"
|
|
345
|
+
if key in ("model_x", "model-x", "modelx"):
|
|
346
|
+
return "model_x"
|
|
347
|
+
raise ValueError("knockoff_type must be one of: 'fixed_x', 'model_x'")
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _standardize_design(X, xp):
|
|
351
|
+
"""Standardize design matrix to unit norm (L2 norm = 1 per column).
|
|
352
|
+
|
|
353
|
+
This centers each column to zero mean and scales to unit L2 norm,
|
|
354
|
+
which is the standard normalization for Fixed-X knockoff construction.
|
|
355
|
+
|
|
356
|
+
Note: This differs from R glmnet's internal standardization (unit variance),
|
|
357
|
+
but is the conventional scaling for knockoff methods as it ensures the
|
|
358
|
+
knockoff construction is invariant to feature scaling.
|
|
359
|
+
"""
|
|
360
|
+
X = xp.asarray(X, dtype=xp.float64)
|
|
361
|
+
if X.ndim != 2:
|
|
362
|
+
raise ValueError("X must be a 2D array")
|
|
363
|
+
|
|
364
|
+
X_centered = X - xp.mean(X, axis=0, keepdims=True)
|
|
365
|
+
scale = xp.sqrt(xp.sum(X_centered * X_centered, axis=0))
|
|
366
|
+
if bool(xp.any(scale <= 1e-12)):
|
|
367
|
+
raise ValueError("X contains near-constant columns; knockoff construction is unstable")
|
|
368
|
+
|
|
369
|
+
return X_centered / scale
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def _standardize_features_unit_variance(X, xp):
|
|
373
|
+
X_arr = xp.asarray(X, dtype=xp.float64)
|
|
374
|
+
if X_arr.ndim != 2:
|
|
375
|
+
raise ValueError("X must be a 2D array")
|
|
376
|
+
|
|
377
|
+
n = int(X_arr.shape[0])
|
|
378
|
+
if n < 2:
|
|
379
|
+
raise ValueError("model-X knockoff requires at least 2 samples")
|
|
380
|
+
|
|
381
|
+
X_centered = X_arr - xp.mean(X_arr, axis=0, keepdims=True)
|
|
382
|
+
if _torch_dev(X_centered) is not None:
|
|
383
|
+
scale = xp.std(X_centered, axis=0, correction=1)
|
|
384
|
+
else:
|
|
385
|
+
scale = xp.std(X_centered, axis=0, ddof=1)
|
|
386
|
+
if bool(xp.any(scale <= 1e-12)):
|
|
387
|
+
raise ValueError("X contains near-constant columns; model-X knockoff is unstable")
|
|
388
|
+
|
|
389
|
+
return X_centered / scale
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _build_fixed_x_knockoffs(X_std, random_state: Optional[int], xp):
|
|
393
|
+
n, p = int(X_std.shape[0]), int(X_std.shape[1])
|
|
394
|
+
if n < 2 * p:
|
|
395
|
+
raise ValueError("fixed-X knockoff requires n_samples >= 2 * n_features")
|
|
396
|
+
|
|
397
|
+
Sigma = X_std.T @ X_std
|
|
398
|
+
Sigma = 0.5 * (Sigma + Sigma.T)
|
|
399
|
+
|
|
400
|
+
eigvals = xp.linalg.eigvalsh(Sigma)
|
|
401
|
+
min_eig = _to_float_scalar(xp.min(eigvals))
|
|
402
|
+
if min_eig <= 1e-10:
|
|
403
|
+
raise ValueError("X'X is near-singular; fixed-X knockoff requires full-rank design")
|
|
404
|
+
|
|
405
|
+
s_val = min(2.0 * min_eig, 1.0)
|
|
406
|
+
if s_val <= 1e-12:
|
|
407
|
+
raise ValueError("Failed to construct a valid knockoff S-matrix")
|
|
408
|
+
|
|
409
|
+
# Create identity matrix on the same device as X_std (important for torch)
|
|
410
|
+
# Handle numpy (no device), cupy (device attribute but different API), and torch
|
|
411
|
+
if xp is np:
|
|
412
|
+
S = s_val * xp.eye(p, dtype=xp.float64)
|
|
413
|
+
elif getattr(xp, '__name__', '') == 'cupy':
|
|
414
|
+
# CuPy: create eye on current device context (same as X_std)
|
|
415
|
+
S = s_val * xp.eye(p, dtype=xp.float64)
|
|
416
|
+
else:
|
|
417
|
+
# Torch: use device keyword
|
|
418
|
+
device = getattr(X_std, 'device', None)
|
|
419
|
+
S = s_val * xp.eye(p, dtype=xp.float64, device=device)
|
|
420
|
+
|
|
421
|
+
# For torch, use torch.linalg.solve which preserves device better
|
|
422
|
+
if xp is np:
|
|
423
|
+
Sigma_inv_S = xp.linalg.solve(Sigma, S)
|
|
424
|
+
elif getattr(xp, '__name__', '') == 'cupy':
|
|
425
|
+
# CuPy
|
|
426
|
+
Sigma_inv_S = xp.linalg.solve(Sigma, S)
|
|
427
|
+
else:
|
|
428
|
+
# Torch: use explicit torch.linalg.solve to ensure device consistency
|
|
429
|
+
import torch
|
|
430
|
+
# Ensure both inputs are on the same device
|
|
431
|
+
torch_device = getattr(X_std, 'device', None)
|
|
432
|
+
Sigma_on_device = Sigma.to(torch_device) if hasattr(Sigma, 'to') else Sigma
|
|
433
|
+
S_on_device = S.to(torch_device) if hasattr(S, 'to') else S
|
|
434
|
+
Sigma_inv_S = torch.linalg.solve(Sigma_on_device, S_on_device)
|
|
435
|
+
# Ensure result is on the correct device
|
|
436
|
+
if torch_device is not None and hasattr(Sigma_inv_S, 'to'):
|
|
437
|
+
Sigma_inv_S = Sigma_inv_S.to(torch_device)
|
|
438
|
+
|
|
439
|
+
c_arg = 2.0 * S - S @ Sigma_inv_S
|
|
440
|
+
c_arg = 0.5 * (c_arg + c_arg.T)
|
|
441
|
+
|
|
442
|
+
c_eigvals, c_eigvecs = xp.linalg.eigh(c_arg)
|
|
443
|
+
c_eigvals = xp.clip(c_eigvals, 0.0, None)
|
|
444
|
+
C = c_eigvecs @ xp.diag(xp.sqrt(c_eigvals)) @ c_eigvecs.T
|
|
445
|
+
|
|
446
|
+
# Generate random matrix A with appropriate backend
|
|
447
|
+
if xp is np:
|
|
448
|
+
rng = np.random.default_rng(random_state)
|
|
449
|
+
A = rng.standard_normal(size=(n, p))
|
|
450
|
+
else:
|
|
451
|
+
# CuPy or Torch
|
|
452
|
+
seed = 0 if random_state is None else int(random_state)
|
|
453
|
+
try:
|
|
454
|
+
# Try CuPy API
|
|
455
|
+
rng = xp.random.RandomState(seed)
|
|
456
|
+
A = rng.standard_normal(size=(n, p), dtype=xp.float64)
|
|
457
|
+
except (AttributeError, TypeError):
|
|
458
|
+
# Torch API: use manual_seed and randn
|
|
459
|
+
import torch
|
|
460
|
+
if xp is torch:
|
|
461
|
+
if hasattr(X_std, "device"):
|
|
462
|
+
torch_device = X_std.device
|
|
463
|
+
else:
|
|
464
|
+
torch_device = torch.device(_get_torch_device_str())
|
|
465
|
+
gen = torch.Generator(device=torch_device)
|
|
466
|
+
gen.manual_seed(seed)
|
|
467
|
+
A = torch.randn(
|
|
468
|
+
n,
|
|
469
|
+
p,
|
|
470
|
+
dtype=torch.float64,
|
|
471
|
+
device=torch_device,
|
|
472
|
+
generator=gen,
|
|
473
|
+
)
|
|
474
|
+
else:
|
|
475
|
+
# Fallback
|
|
476
|
+
rng = xp.random.Generator(xp.random.PCG64(seed))
|
|
477
|
+
A = rng.standard_normal(size=(n, p), dtype=xp.float64)
|
|
478
|
+
|
|
479
|
+
# Q[:, :p] spans col(X), Q[:, p:2p] spans an orthonormal complement basis.
|
|
480
|
+
Q, _ = xp.linalg.qr(xp.concatenate([X_std, A], axis=1), mode="reduced")
|
|
481
|
+
U = Q[:, p : 2 * p]
|
|
482
|
+
|
|
483
|
+
# Create identity matrix on the same device as X_std (important for torch)
|
|
484
|
+
if xp is np:
|
|
485
|
+
eye_matrix = xp.eye(p, dtype=xp.float64)
|
|
486
|
+
elif getattr(xp, '__name__', '') == 'cupy':
|
|
487
|
+
# CuPy: create eye on current device context (same as X_std)
|
|
488
|
+
eye_matrix = xp.eye(p, dtype=xp.float64)
|
|
489
|
+
else:
|
|
490
|
+
# Torch: use device keyword
|
|
491
|
+
device = getattr(X_std, 'device', None)
|
|
492
|
+
eye_matrix = xp.eye(p, dtype=xp.float64, device=device)
|
|
493
|
+
|
|
494
|
+
X_knock = X_std @ (eye_matrix - Sigma_inv_S) + U @ C
|
|
495
|
+
return X_knock
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def _build_model_x_knockoffs(
|
|
499
|
+
X_std,
|
|
500
|
+
random_state: Optional[int],
|
|
501
|
+
xp,
|
|
502
|
+
covariance_shrinkage: float = 0.20,
|
|
503
|
+
s_scale: float = 0.999,
|
|
504
|
+
):
|
|
505
|
+
n, p = int(X_std.shape[0]), int(X_std.shape[1])
|
|
506
|
+
|
|
507
|
+
Sigma = (X_std.T @ X_std) / float(max(1, n - 1))
|
|
508
|
+
Sigma = 0.5 * (Sigma + Sigma.T)
|
|
509
|
+
|
|
510
|
+
shrinkage = float(min(1.0, max(0.0, covariance_shrinkage)))
|
|
511
|
+
if shrinkage > 0.0:
|
|
512
|
+
trace_mean = xp.trace(Sigma) / float(max(1, p))
|
|
513
|
+
# Create identity matrix - numpy/cupy don't need device, torch does
|
|
514
|
+
if xp is np:
|
|
515
|
+
eye_matrix = xp.eye(p, dtype=xp.float64)
|
|
516
|
+
elif getattr(xp, '__name__', '') == 'cupy':
|
|
517
|
+
# CuPy: create eye on current device context (same as X_std)
|
|
518
|
+
eye_matrix = xp.eye(p, dtype=xp.float64)
|
|
519
|
+
else:
|
|
520
|
+
# Torch: use device keyword
|
|
521
|
+
device = getattr(X_std, 'device', None)
|
|
522
|
+
eye_matrix = xp.eye(p, dtype=xp.float64, device=device)
|
|
523
|
+
Sigma = (1.0 - shrinkage) * Sigma + shrinkage * trace_mean * eye_matrix
|
|
524
|
+
Sigma = 0.5 * (Sigma + Sigma.T)
|
|
525
|
+
|
|
526
|
+
eigvals = xp.linalg.eigvalsh(Sigma)
|
|
527
|
+
min_eig = _to_float_scalar(xp.min(eigvals))
|
|
528
|
+
|
|
529
|
+
ridge = 0.0
|
|
530
|
+
if min_eig < 1e-6:
|
|
531
|
+
ridge = float((1e-6 - min_eig) + 1e-8)
|
|
532
|
+
# Create identity matrix - numpy/cupy don't need device, torch does
|
|
533
|
+
if xp is np:
|
|
534
|
+
eye_matrix = xp.eye(p, dtype=xp.float64)
|
|
535
|
+
elif getattr(xp, '__name__', '') == 'cupy':
|
|
536
|
+
# CuPy: create eye on current device context (same as X_std)
|
|
537
|
+
eye_matrix = xp.eye(p, dtype=xp.float64)
|
|
538
|
+
else:
|
|
539
|
+
# Torch: use device keyword
|
|
540
|
+
device = getattr(X_std, 'device', None)
|
|
541
|
+
eye_matrix = xp.eye(p, dtype=xp.float64, device=device)
|
|
542
|
+
Sigma = Sigma + ridge * eye_matrix
|
|
543
|
+
Sigma = 0.5 * (Sigma + Sigma.T)
|
|
544
|
+
eigvals = xp.linalg.eigvalsh(Sigma)
|
|
545
|
+
min_eig = _to_float_scalar(xp.min(eigvals))
|
|
546
|
+
|
|
547
|
+
if min_eig <= 1e-12:
|
|
548
|
+
raise ValueError("Estimated covariance is near-singular; model-X knockoff failed")
|
|
549
|
+
|
|
550
|
+
s_val = min(2.0 * min_eig * float(s_scale), 1.0)
|
|
551
|
+
if s_val <= 1e-12:
|
|
552
|
+
raise ValueError("Failed to construct a valid model-X knockoff S-matrix")
|
|
553
|
+
|
|
554
|
+
# Create identity matrix - numpy/cupy don't need device, torch does
|
|
555
|
+
if xp is np:
|
|
556
|
+
S = s_val * xp.eye(p, dtype=xp.float64)
|
|
557
|
+
elif getattr(xp, '__name__', '') == 'cupy':
|
|
558
|
+
# CuPy: create eye on current device context (same as X_std)
|
|
559
|
+
S = s_val * xp.eye(p, dtype=xp.float64)
|
|
560
|
+
else:
|
|
561
|
+
# Torch: use device keyword
|
|
562
|
+
device = getattr(X_std, 'device', None)
|
|
563
|
+
S = s_val * xp.eye(p, dtype=xp.float64, device=device)
|
|
564
|
+
|
|
565
|
+
# For torch, use explicit torch.linalg.solve to ensure device consistency
|
|
566
|
+
if xp is np:
|
|
567
|
+
Sigma_inv_S = xp.linalg.solve(Sigma, S)
|
|
568
|
+
elif getattr(xp, '__name__', '') == 'cupy':
|
|
569
|
+
# CuPy
|
|
570
|
+
Sigma_inv_S = xp.linalg.solve(Sigma, S)
|
|
571
|
+
else:
|
|
572
|
+
# Torch: use explicit torch.linalg.solve to ensure device consistency
|
|
573
|
+
import torch
|
|
574
|
+
torch_device = getattr(X_std, 'device', None)
|
|
575
|
+
Sigma_on_device = Sigma.to(torch_device) if hasattr(Sigma, 'to') else Sigma
|
|
576
|
+
S_on_device = S.to(torch_device) if hasattr(S, 'to') else S
|
|
577
|
+
Sigma_inv_S = torch.linalg.solve(Sigma_on_device, S_on_device)
|
|
578
|
+
if torch_device is not None and hasattr(Sigma_inv_S, 'to'):
|
|
579
|
+
Sigma_inv_S = Sigma_inv_S.to(torch_device)
|
|
580
|
+
|
|
581
|
+
c_arg = 2.0 * S - S @ Sigma_inv_S
|
|
582
|
+
c_arg = 0.5 * (c_arg + c_arg.T)
|
|
583
|
+
c_eigvals, c_eigvecs = xp.linalg.eigh(c_arg)
|
|
584
|
+
c_eigvals = xp.clip(c_eigvals, 0.0, None)
|
|
585
|
+
C = c_eigvecs @ xp.diag(xp.sqrt(c_eigvals)) @ c_eigvecs.T
|
|
586
|
+
|
|
587
|
+
# Generate random matrix Z with appropriate backend
|
|
588
|
+
if xp is np:
|
|
589
|
+
rng = np.random.default_rng(random_state)
|
|
590
|
+
Z = rng.standard_normal(size=(n, p))
|
|
591
|
+
else:
|
|
592
|
+
# CuPy or Torch
|
|
593
|
+
seed = 0 if random_state is None else int(random_state)
|
|
594
|
+
try:
|
|
595
|
+
# Try CuPy API
|
|
596
|
+
rng = xp.random.RandomState(seed)
|
|
597
|
+
Z = rng.standard_normal(size=(n, p), dtype=xp.float64)
|
|
598
|
+
except (AttributeError, TypeError):
|
|
599
|
+
# Torch API: use manual_seed and randn
|
|
600
|
+
import torch
|
|
601
|
+
if isinstance(xp, type(torch)):
|
|
602
|
+
gen = torch.Generator(device=_get_torch_device_str())
|
|
603
|
+
gen.manual_seed(seed)
|
|
604
|
+
Z = torch.randn(n, p, dtype=torch.float64, device=_get_torch_device_str())
|
|
605
|
+
else:
|
|
606
|
+
# Fallback
|
|
607
|
+
rng = xp.random.Generator(xp.random.PCG64(seed))
|
|
608
|
+
Z = rng.standard_normal(size=(n, p), dtype=xp.float64)
|
|
609
|
+
|
|
610
|
+
X_knock = X_std - X_std @ Sigma_inv_S + Z @ C
|
|
611
|
+
return X_knock, {
|
|
612
|
+
"s_value": float(s_val),
|
|
613
|
+
"ridge": float(ridge),
|
|
614
|
+
"min_eigenvalue": float(min_eig),
|
|
615
|
+
"covariance_shrinkage": float(shrinkage),
|
|
616
|
+
"s_scale": float(s_scale),
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def _build_model_x_knockoffs_knockpy_compat(
|
|
621
|
+
X,
|
|
622
|
+
random_state: Optional[int],
|
|
623
|
+
*,
|
|
624
|
+
modelx_shrinkage: str = "ledoitwolf",
|
|
625
|
+
modelx_smatrix_method: str = "mvr",
|
|
626
|
+
sample_tol: float = 1e-5,
|
|
627
|
+
):
|
|
628
|
+
X_np = np.asarray(X, dtype=np.float64)
|
|
629
|
+
if X_np.ndim != 2:
|
|
630
|
+
raise ValueError("X must be a 2D array")
|
|
631
|
+
|
|
632
|
+
n, p = int(X_np.shape[0]), int(X_np.shape[1])
|
|
633
|
+
if n < 2:
|
|
634
|
+
raise ValueError("model-X knockoff requires at least 2 samples")
|
|
635
|
+
|
|
636
|
+
mu = np.mean(X_np, axis=0)
|
|
637
|
+
Sigma, inv_sigma, cov_estimator = _estimate_covariance_knockpy_style(
|
|
638
|
+
X_np,
|
|
639
|
+
shrinkage=modelx_shrinkage,
|
|
640
|
+
tol=1e-4,
|
|
641
|
+
)
|
|
642
|
+
S, smatrix_source, smatrix_gamma = _compute_smatrix_knockpy_style(
|
|
643
|
+
Sigma,
|
|
644
|
+
method=modelx_smatrix_method,
|
|
645
|
+
tol=float(sample_tol),
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
inv_sigma_S = inv_sigma @ S
|
|
649
|
+
mu_k = X_np - (X_np - mu.reshape(1, -1)) @ inv_sigma_S
|
|
650
|
+
Vk = 2.0 * S - S @ inv_sigma_S
|
|
651
|
+
Vk = _shift_until_psd_np(Vk, tol=float(sample_tol))
|
|
652
|
+
|
|
653
|
+
Lk = np.linalg.cholesky(Vk)
|
|
654
|
+
with _temporary_numpy_seed(random_state):
|
|
655
|
+
Z = np.random.randn(n, p)
|
|
656
|
+
X_knock = Z @ Lk.T + mu_k
|
|
657
|
+
|
|
658
|
+
return np.asarray(X_knock, dtype=np.float64), {
|
|
659
|
+
"s_value": float(np.mean(np.diag(S))),
|
|
660
|
+
"ridge": 0.0,
|
|
661
|
+
"min_eigenvalue": float(_calc_mineig_np(Sigma)),
|
|
662
|
+
"covariance_shrinkage": None,
|
|
663
|
+
"s_scale": float(smatrix_gamma),
|
|
664
|
+
"modelx_shrinkage": str(modelx_shrinkage),
|
|
665
|
+
"modelx_smatrix_method": str(modelx_smatrix_method),
|
|
666
|
+
"modelx_covariance_estimator": str(cov_estimator),
|
|
667
|
+
"modelx_smatrix_source": str(smatrix_source),
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def _model_x_draw_seed(random_state: Optional[int], draw_index: int) -> Optional[int]:
|
|
672
|
+
if random_state is None:
|
|
673
|
+
return None
|
|
674
|
+
return int(random_state) + 104729 * int(draw_index)
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def _corr_diff_statistics(X_std, X_knock, y, xp):
|
|
678
|
+
y_arr = xp.asarray(y, dtype=xp.float64).reshape(-1)
|
|
679
|
+
if y_arr.shape[0] != X_std.shape[0]:
|
|
680
|
+
raise ValueError("y must have the same number of rows as X")
|
|
681
|
+
|
|
682
|
+
y_centered = y_arr - xp.mean(y_arr)
|
|
683
|
+
score_orig = xp.abs(X_std.T @ y_centered)
|
|
684
|
+
score_knock = xp.abs(X_knock.T @ y_centered)
|
|
685
|
+
return score_orig - score_knock
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def _ols_coef_diff_statistics(X_std, X_knock, y, xp, ridge: float = 1e-8):
|
|
689
|
+
y_arr = xp.asarray(y, dtype=xp.float64).reshape(-1)
|
|
690
|
+
if y_arr.shape[0] != X_std.shape[0]:
|
|
691
|
+
raise ValueError("y must have the same number of rows as X")
|
|
692
|
+
|
|
693
|
+
y_centered = y_arr - xp.mean(y_arr)
|
|
694
|
+
p = int(X_std.shape[1])
|
|
695
|
+
|
|
696
|
+
Z = xp.concatenate([X_std, X_knock], axis=1)
|
|
697
|
+
ridge_f = float(max(0.0, ridge))
|
|
698
|
+
|
|
699
|
+
if ridge_f > 0.0:
|
|
700
|
+
# Create identity matrix - numpy/cupy don't need device, torch does
|
|
701
|
+
if xp is np:
|
|
702
|
+
eye_matrix = xp.eye(2 * p, dtype=xp.float64)
|
|
703
|
+
elif getattr(xp, '__name__', '') == 'cupy':
|
|
704
|
+
# CuPy: create eye on current device context (same as Z)
|
|
705
|
+
eye_matrix = xp.eye(2 * p, dtype=xp.float64)
|
|
706
|
+
else:
|
|
707
|
+
# Torch: use device keyword
|
|
708
|
+
device = getattr(Z, 'device', None)
|
|
709
|
+
eye_matrix = xp.eye(2 * p, dtype=xp.float64, device=device)
|
|
710
|
+
gram = Z.T @ Z + ridge_f * eye_matrix
|
|
711
|
+
rhs = Z.T @ y_centered
|
|
712
|
+
try:
|
|
713
|
+
coef = xp.linalg.solve(gram, rhs)
|
|
714
|
+
except Exception:
|
|
715
|
+
coef = xp.linalg.lstsq(Z, y_centered, rcond=None)[0]
|
|
716
|
+
else:
|
|
717
|
+
coef = xp.linalg.lstsq(Z, y_centered, rcond=None)[0]
|
|
718
|
+
|
|
719
|
+
coef_orig = coef[:p]
|
|
720
|
+
coef_knock = coef[p:]
|
|
721
|
+
return xp.abs(coef_orig) - xp.abs(coef_knock)
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
def _lasso_coef_diff_statistics(
|
|
725
|
+
X_std,
|
|
726
|
+
X_knock,
|
|
727
|
+
y,
|
|
728
|
+
xp,
|
|
729
|
+
random_state: Optional[int] = None,
|
|
730
|
+
backend_name: str = "numpy",
|
|
731
|
+
max_iter: int = 3000,
|
|
732
|
+
tol: float = 1e-4,
|
|
733
|
+
cv_folds: int = 5,
|
|
734
|
+
n_alphas: int = 12,
|
|
735
|
+
lasso_cv_impl: str = "statgpu",
|
|
736
|
+
lasso_fast_profile: str = "off",
|
|
737
|
+
knockpy_style: bool = False,
|
|
738
|
+
):
|
|
739
|
+
y_arr = xp.asarray(y, dtype=xp.float64).reshape(-1)
|
|
740
|
+
if y_arr.shape[0] != X_std.shape[0]:
|
|
741
|
+
raise ValueError("y must have the same number of rows as X")
|
|
742
|
+
|
|
743
|
+
if bool(knockpy_style):
|
|
744
|
+
y_model = y_arr
|
|
745
|
+
else:
|
|
746
|
+
y_model = y_arr - xp.mean(y_arr)
|
|
747
|
+
p = int(X_std.shape[1])
|
|
748
|
+
problem_size_full = int(X_std.shape[0]) * int(2 * p)
|
|
749
|
+
fast_profile_eff = _resolve_lasso_fast_profile_for_problem(
|
|
750
|
+
lasso_fast_profile,
|
|
751
|
+
problem_size_full,
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
cv_folds_eff = max(2, int(cv_folds))
|
|
755
|
+
n_alphas_eff = max(2, int(n_alphas))
|
|
756
|
+
max_iter_eff = max(500, int(max_iter))
|
|
757
|
+
tol_base = float(tol)
|
|
758
|
+
|
|
759
|
+
if fast_profile_eff == "moderate":
|
|
760
|
+
if problem_size_full >= 1_000_000:
|
|
761
|
+
cv_folds_eff = min(cv_folds_eff, 4)
|
|
762
|
+
n_alphas_eff = min(n_alphas_eff, 14 if bool(knockpy_style) else 12)
|
|
763
|
+
max_iter_eff = min(max_iter_eff, 2800)
|
|
764
|
+
elif fast_profile_eff == "aggressive":
|
|
765
|
+
if problem_size_full >= 2_000_000:
|
|
766
|
+
cv_folds_eff = min(cv_folds_eff, 2)
|
|
767
|
+
n_alphas_eff = min(n_alphas_eff, 6 if bool(knockpy_style) else 5)
|
|
768
|
+
max_iter_eff = min(max_iter_eff, 1600)
|
|
769
|
+
else:
|
|
770
|
+
cv_folds_eff = min(cv_folds_eff, 3)
|
|
771
|
+
n_alphas_eff = min(n_alphas_eff, 8 if bool(knockpy_style) else 7)
|
|
772
|
+
max_iter_eff = min(max_iter_eff, 2200)
|
|
773
|
+
|
|
774
|
+
tol_eff = max(1e-3, tol_base) if bool(knockpy_style) else tol_base
|
|
775
|
+
if fast_profile_eff == "aggressive":
|
|
776
|
+
tol_eff = max(tol_eff, 4e-3 if problem_size_full >= 2_000_000 else 2e-3)
|
|
777
|
+
|
|
778
|
+
lasso_diff_cache_key = _make_lasso_coef_diff_cache_key(
|
|
779
|
+
X_std=X_std,
|
|
780
|
+
X_knock=X_knock,
|
|
781
|
+
y=y_arr,
|
|
782
|
+
random_state=random_state,
|
|
783
|
+
backend_name=backend_name,
|
|
784
|
+
max_iter_eff=int(max_iter_eff),
|
|
785
|
+
tol_eff=float(tol_eff),
|
|
786
|
+
cv_folds_eff=int(cv_folds_eff),
|
|
787
|
+
n_alphas_eff=int(n_alphas_eff),
|
|
788
|
+
lasso_cv_impl=lasso_cv_impl,
|
|
789
|
+
fast_profile_eff=fast_profile_eff,
|
|
790
|
+
knockpy_style=bool(knockpy_style),
|
|
791
|
+
)
|
|
792
|
+
cached_w = _lasso_diff_cache_get(lasso_diff_cache_key)
|
|
793
|
+
if cached_w is not None:
|
|
794
|
+
return xp.asarray(cached_w, dtype=xp.float64)
|
|
795
|
+
|
|
796
|
+
Z = xp.concatenate([X_std, X_knock], axis=1)
|
|
797
|
+
|
|
798
|
+
# Knockpy-style symmetry preservation: permute [X, Xk] jointly, then undo at the end.
|
|
799
|
+
inds, rev_inds = _random_permutation_inds(2 * p, random_state=random_state)
|
|
800
|
+
alphas = np.logspace(-4.0, 4.0, base=10.0, num=int(n_alphas_eff))
|
|
801
|
+
|
|
802
|
+
cv_impl = _normalize_lasso_cv_impl(lasso_cv_impl)
|
|
803
|
+
|
|
804
|
+
# Force statgpu for torch backend since sklearn doesn't support torch tensors
|
|
805
|
+
backend_is_torch = str(backend_name).lower() == "torch"
|
|
806
|
+
if backend_is_torch and cv_impl == "sklearn":
|
|
807
|
+
cv_impl = "statgpu"
|
|
808
|
+
|
|
809
|
+
if cv_impl == "sklearn":
|
|
810
|
+
try:
|
|
811
|
+
from sklearn import linear_model
|
|
812
|
+
except Exception:
|
|
813
|
+
cv_impl = "statgpu"
|
|
814
|
+
|
|
815
|
+
if cv_impl == "sklearn":
|
|
816
|
+
Z_np = _to_numpy(Z).astype(np.float64, copy=False)
|
|
817
|
+
y_np = _to_numpy(y_model).astype(np.float64, copy=False).reshape(-1)
|
|
818
|
+
Z_perm = Z_np[:, inds]
|
|
819
|
+
with warnings.catch_warnings():
|
|
820
|
+
warnings.simplefilter("ignore")
|
|
821
|
+
model = linear_model.LassoCV(
|
|
822
|
+
alphas=alphas,
|
|
823
|
+
cv=int(cv_folds_eff),
|
|
824
|
+
verbose=False,
|
|
825
|
+
max_iter=int(max_iter_eff),
|
|
826
|
+
tol=float(tol_eff),
|
|
827
|
+
).fit(Z_perm, y_np)
|
|
828
|
+
coef_perm = np.asarray(model.coef_, dtype=np.float64).reshape(-1)
|
|
829
|
+
else:
|
|
830
|
+
from statgpu.linear_model.wrappers._lasso import _fit_lasso_single_alpha_fast, _select_lasso_alpha_cv
|
|
831
|
+
|
|
832
|
+
use_cupy_native = str(backend_name).lower() == "cupy" and _is_cupy_array(Z)
|
|
833
|
+
use_torch_native = str(backend_name).lower() == "torch" and hasattr(Z, 'shape')
|
|
834
|
+
if use_cupy_native:
|
|
835
|
+
import cupy as cp
|
|
836
|
+
|
|
837
|
+
inds_device = cp.asarray(inds, dtype=cp.int64)
|
|
838
|
+
Z_perm = xp.asarray(Z, dtype=xp.float64)[:, inds_device]
|
|
839
|
+
y_fit = xp.asarray(y_model, dtype=xp.float64).reshape(-1)
|
|
840
|
+
elif use_torch_native:
|
|
841
|
+
import torch
|
|
842
|
+
inds_tensor = torch.tensor(inds, dtype=torch.int64, device=Z.device)
|
|
843
|
+
Z_perm = Z[:, inds_tensor]
|
|
844
|
+
y_fit = y_model.reshape(-1)
|
|
845
|
+
else:
|
|
846
|
+
Z_np = _to_numpy(Z).astype(np.float64, copy=False)
|
|
847
|
+
Z_perm = Z_np[:, inds]
|
|
848
|
+
y_fit = _to_numpy(y_model).astype(np.float64, copy=False).reshape(-1)
|
|
849
|
+
|
|
850
|
+
problem_size = int(Z_perm.shape[0]) * int(Z_perm.shape[1])
|
|
851
|
+
|
|
852
|
+
fit_intercept_eff = bool(knockpy_style)
|
|
853
|
+
if random_state is None:
|
|
854
|
+
alpha_cache_key = None
|
|
855
|
+
else:
|
|
856
|
+
alpha_cache_key = (
|
|
857
|
+
"knockoff_lasso_cv_v1",
|
|
858
|
+
_array_identity_token(X_std),
|
|
859
|
+
_array_identity_token(X_knock),
|
|
860
|
+
_array_identity_token(y_arr),
|
|
861
|
+
int(random_state),
|
|
862
|
+
str(backend_name).lower(),
|
|
863
|
+
bool(knockpy_style),
|
|
864
|
+
str(fast_profile_eff).lower(),
|
|
865
|
+
int(cv_folds_eff),
|
|
866
|
+
int(n_alphas_eff),
|
|
867
|
+
int(max_iter_eff),
|
|
868
|
+
float(tol_eff),
|
|
869
|
+
_int_array_signature(inds),
|
|
870
|
+
)
|
|
871
|
+
alpha_select_kwargs = {
|
|
872
|
+
"cv_folds": int(cv_folds_eff),
|
|
873
|
+
"random_state": random_state,
|
|
874
|
+
"fit_intercept": fit_intercept_eff,
|
|
875
|
+
"device": "cuda" if str(backend_name).lower() in ("cupy", "torch") else "cpu",
|
|
876
|
+
"max_iter": int(max_iter_eff),
|
|
877
|
+
"tol": tol_eff,
|
|
878
|
+
"cpu_solver": "coordinate_descent",
|
|
879
|
+
"cache_key": alpha_cache_key,
|
|
880
|
+
}
|
|
881
|
+
if bool(knockpy_style):
|
|
882
|
+
# Match knockpy-oriented branch settings used by the sklearn path as closely as possible.
|
|
883
|
+
alpha_select_kwargs["alphas"] = alphas
|
|
884
|
+
alpha_select_kwargs["method"] = "glmnet"
|
|
885
|
+
# For large designs, reduce full KKT scan frequency to lower CV overhead.
|
|
886
|
+
cd_kkt_check_every_eff = 4 if problem_size >= 1_000_000 else 2
|
|
887
|
+
if fast_profile_eff == "moderate":
|
|
888
|
+
cd_kkt_check_every_eff = max(cd_kkt_check_every_eff, 6)
|
|
889
|
+
elif fast_profile_eff == "aggressive":
|
|
890
|
+
cd_kkt_check_every_eff = max(
|
|
891
|
+
cd_kkt_check_every_eff,
|
|
892
|
+
12 if problem_size >= 2_000_000 else 8,
|
|
893
|
+
)
|
|
894
|
+
alpha_select_kwargs["cd_kkt_check_every"] = cd_kkt_check_every_eff
|
|
895
|
+
else:
|
|
896
|
+
alpha_select_kwargs["n_alphas"] = int(n_alphas_eff)
|
|
897
|
+
|
|
898
|
+
alpha = _select_lasso_alpha_cv(
|
|
899
|
+
Z_perm,
|
|
900
|
+
y_fit,
|
|
901
|
+
**alpha_select_kwargs,
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
fit_out = _fit_lasso_single_alpha_fast(
|
|
905
|
+
Z_perm,
|
|
906
|
+
y_fit,
|
|
907
|
+
alpha=float(alpha),
|
|
908
|
+
fit_intercept=fit_intercept_eff,
|
|
909
|
+
max_iter=int(max_iter_eff),
|
|
910
|
+
tol=tol_eff,
|
|
911
|
+
device="cuda" if str(backend_name).lower() in ("cupy", "torch") else "cpu",
|
|
912
|
+
stopping="coef_delta",
|
|
913
|
+
cpu_solver="coordinate_descent",
|
|
914
|
+
cd_kkt_check_every=int(alpha_select_kwargs.get("cd_kkt_check_every", 1)),
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
coef_perm = np.asarray(fit_out["coef"], dtype=np.float64).reshape(-1)
|
|
918
|
+
if coef_perm.shape[0] != 2 * p:
|
|
919
|
+
raise RuntimeError("lasso_coef_diff produced unexpected coefficient shape")
|
|
920
|
+
|
|
921
|
+
coef = coef_perm[rev_inds]
|
|
922
|
+
|
|
923
|
+
W_np = np.abs(coef[:p]) - np.abs(coef[p:])
|
|
924
|
+
_lasso_diff_cache_put(lasso_diff_cache_key, W_np)
|
|
925
|
+
return xp.asarray(W_np, dtype=xp.float64)
|
|
926
|
+
|
|
927
|
+
|
|
928
|
+
def _compute_w_statistics(
|
|
929
|
+
X_std,
|
|
930
|
+
X_knock,
|
|
931
|
+
y,
|
|
932
|
+
method: str,
|
|
933
|
+
xp,
|
|
934
|
+
random_state: Optional[int] = None,
|
|
935
|
+
backend_name: str = "numpy",
|
|
936
|
+
lasso_cv_impl: str = "statgpu",
|
|
937
|
+
lasso_fast_profile: str = "off",
|
|
938
|
+
lasso_knockpy_style: bool = False,
|
|
939
|
+
):
|
|
940
|
+
key = str(method).strip().lower()
|
|
941
|
+
if key == "corr_diff":
|
|
942
|
+
return _corr_diff_statistics(X_std, X_knock, y, xp), "corr_diff"
|
|
943
|
+
if key in ("ols_coef_diff", "ols", "coef_diff"):
|
|
944
|
+
return _ols_coef_diff_statistics(X_std, X_knock, y, xp), "ols_coef_diff"
|
|
945
|
+
if key in ("lasso_coef_diff", "lasso", "lasso_diff"):
|
|
946
|
+
return (
|
|
947
|
+
_lasso_coef_diff_statistics(
|
|
948
|
+
X_std,
|
|
949
|
+
X_knock,
|
|
950
|
+
y,
|
|
951
|
+
xp,
|
|
952
|
+
random_state=random_state,
|
|
953
|
+
backend_name=backend_name,
|
|
954
|
+
lasso_cv_impl=lasso_cv_impl,
|
|
955
|
+
lasso_fast_profile=lasso_fast_profile,
|
|
956
|
+
knockpy_style=lasso_knockpy_style,
|
|
957
|
+
n_alphas=20 if bool(lasso_knockpy_style) else 12,
|
|
958
|
+
),
|
|
959
|
+
"lasso_coef_diff",
|
|
960
|
+
)
|
|
961
|
+
raise ValueError("method must be one of: 'corr_diff', 'ols_coef_diff', 'lasso_coef_diff'")
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def _knockoff_threshold_and_path(W, q: float, offset: int):
|
|
965
|
+
W_np = np.asarray(_to_numpy(W), dtype=np.float64).reshape(-1)
|
|
966
|
+
if W_np.size == 0:
|
|
967
|
+
return float(np.inf), 0.0, []
|
|
968
|
+
|
|
969
|
+
abs_w = np.abs(W_np)
|
|
970
|
+
if not np.any(abs_w > 0):
|
|
971
|
+
return float(np.inf), 0.0, []
|
|
972
|
+
|
|
973
|
+
inds = np.argsort(-abs_w, kind="stable")
|
|
974
|
+
negatives = np.cumsum(W_np[inds] <= 0)
|
|
975
|
+
positives = np.cumsum(W_np[inds] > 0)
|
|
976
|
+
positives[positives == 0] = 1
|
|
977
|
+
hat_fdrs = (negatives + int(offset)) / positives
|
|
978
|
+
|
|
979
|
+
trajectory: List[Dict[str, float]] = []
|
|
980
|
+
for rank, idx in enumerate(inds):
|
|
981
|
+
trajectory.append(
|
|
982
|
+
{
|
|
983
|
+
"rank": int(rank + 1),
|
|
984
|
+
"threshold": float(abs_w[idx]),
|
|
985
|
+
"fdr_hat": float(min(1.0, hat_fdrs[rank])),
|
|
986
|
+
"n_selected": int(positives[rank]),
|
|
987
|
+
}
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
if np.any(hat_fdrs <= float(q)):
|
|
991
|
+
valid = np.where(hat_fdrs <= float(q))[0]
|
|
992
|
+
chosen_rank = int(valid.max())
|
|
993
|
+
chosen_threshold = float(abs_w[inds[chosen_rank]])
|
|
994
|
+
if chosen_threshold == 0.0:
|
|
995
|
+
positive_w = W_np[W_np > 0.0]
|
|
996
|
+
if positive_w.size > 0:
|
|
997
|
+
chosen_threshold = float(np.min(positive_w))
|
|
998
|
+
else:
|
|
999
|
+
chosen_threshold = float(np.inf)
|
|
1000
|
+
chosen_fdr = float(min(1.0, hat_fdrs[chosen_rank]))
|
|
1001
|
+
return chosen_threshold, chosen_fdr, trajectory
|
|
1002
|
+
|
|
1003
|
+
return float(np.inf), 0.0, trajectory
|