statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,4876 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lasso regression with full statistical inference and GPU support.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
import hashlib
|
|
7
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
8
|
+
import os
|
|
9
|
+
import warnings
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import stats
|
|
12
|
+
from scipy.stats import norm as _norm_dist
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from numba import njit
|
|
16
|
+
|
|
17
|
+
_NUMBA_AVAILABLE = True
|
|
18
|
+
except Exception:
|
|
19
|
+
njit = None
|
|
20
|
+
_NUMBA_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
from statgpu._base import BaseEstimator
|
|
23
|
+
from statgpu._config import Device
|
|
24
|
+
from statgpu.linear_model._cv_base import CVEstimatorBase
|
|
25
|
+
from statgpu.backends import get_backend
|
|
26
|
+
from statgpu.inference._distributions_backend import (
|
|
27
|
+
norm,
|
|
28
|
+
t,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
_NUMBA_CD_DISABLED = str(os.getenv("STATGPU_DISABLE_NUMBA_CD", "0")).strip().lower() in (
|
|
33
|
+
"1",
|
|
34
|
+
"true",
|
|
35
|
+
"yes",
|
|
36
|
+
"on",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
_LASSO_CV_ALPHA_CACHE_MAXSIZE = int(os.getenv("STATGPU_LASSO_CV_CACHE_SIZE", "64"))
|
|
40
|
+
_LASSO_CV_ALPHA_CACHE: "OrderedDict[Tuple[Any, ...], Dict[str, Any]]" = OrderedDict()
|
|
41
|
+
_LASSO_DEBIASED_M_CACHE_MAXSIZE = int(os.getenv("STATGPU_LASSO_DEBIASED_M_CACHE_SIZE", "16"))
|
|
42
|
+
_LASSO_DEBIASED_M_CACHE: "OrderedDict[Tuple[Any, ...], np.ndarray]" = OrderedDict()
|
|
43
|
+
_LASSO_DEBIASED_M_GPU_HASH_ROW_CHUNK = 1024
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ============================================================================
|
|
47
|
+
# CuPy Fused Kernels for Lasso - Now implemented as Lasso class methods
|
|
48
|
+
# See Lasso._get_cupy_fused_kernels() for details.
|
|
49
|
+
# ============================================================================
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _debiased_m_cache_get(key):
|
|
53
|
+
val = _LASSO_DEBIASED_M_CACHE.get(key)
|
|
54
|
+
if val is not None:
|
|
55
|
+
_LASSO_DEBIASED_M_CACHE.move_to_end(key)
|
|
56
|
+
return val
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _debiased_m_cache_put(key, value):
|
|
60
|
+
_LASSO_DEBIASED_M_CACHE[key] = value
|
|
61
|
+
_LASSO_DEBIASED_M_CACHE.move_to_end(key)
|
|
62
|
+
while len(_LASSO_DEBIASED_M_CACHE) > _LASSO_DEBIASED_M_CACHE_MAXSIZE:
|
|
63
|
+
_LASSO_DEBIASED_M_CACHE.popitem(last=False)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _debiased_m_key_from_numpy_design(
|
|
67
|
+
X: np.ndarray,
|
|
68
|
+
*,
|
|
69
|
+
n: int,
|
|
70
|
+
p: int,
|
|
71
|
+
lam_nw: float,
|
|
72
|
+
tol: float,
|
|
73
|
+
):
|
|
74
|
+
X_cache = np.asarray(X)
|
|
75
|
+
if not X_cache.flags["C_CONTIGUOUS"]:
|
|
76
|
+
X_cache = np.ascontiguousarray(X_cache)
|
|
77
|
+
h = hashlib.blake2b(digest_size=32)
|
|
78
|
+
h.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
|
|
79
|
+
h.update(str(X_cache.dtype).encode("utf-8"))
|
|
80
|
+
h.update(np.asarray([float(lam_nw), float(tol)], dtype=np.float64).tobytes())
|
|
81
|
+
h.update(X_cache.view(np.uint8).tobytes())
|
|
82
|
+
return h.hexdigest()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _debiased_m_key_from_sample(
|
|
86
|
+
*,
|
|
87
|
+
n: int,
|
|
88
|
+
p: int,
|
|
89
|
+
dtype_name: str,
|
|
90
|
+
sample_block: np.ndarray,
|
|
91
|
+
lam_nw: float,
|
|
92
|
+
tol: float,
|
|
93
|
+
):
|
|
94
|
+
"""Generate cache key for debiased M matrix from a sample block of X.
|
|
95
|
+
|
|
96
|
+
This is used for Torch backend where we don't want to hash the entire matrix.
|
|
97
|
+
"""
|
|
98
|
+
h = hashlib.blake2b(digest_size=32)
|
|
99
|
+
h.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
|
|
100
|
+
h.update(dtype_name.encode("utf-8"))
|
|
101
|
+
h.update(np.asarray([float(lam_nw), float(tol)], dtype=np.float64).tobytes())
|
|
102
|
+
if not sample_block.flags["C_CONTIGUOUS"]:
|
|
103
|
+
sample_block = np.ascontiguousarray(sample_block)
|
|
104
|
+
h.update(sample_block.view(np.uint8).tobytes())
|
|
105
|
+
return h.hexdigest()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class Lasso(BaseEstimator):
|
|
109
|
+
"""
|
|
110
|
+
Lasso regression (L1 regularization) with GPU acceleration
|
|
111
|
+
and full statistical inference.
|
|
112
|
+
|
|
113
|
+
CPU solver supports multiple algorithms (coordinate descent by default, and FISTA when cpu_solver='fista').
|
|
114
|
+
GPU solver supports multiple algorithms via `solver` (e.g. FISTA / ADMM).
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
alpha : float, default=1.0
|
|
119
|
+
Regularization strength. Larger values specify stronger regularization.
|
|
120
|
+
Must be non-negative.
|
|
121
|
+
fit_intercept : bool, default=True
|
|
122
|
+
Whether to calculate the intercept.
|
|
123
|
+
max_iter : int, default=1000
|
|
124
|
+
Maximum number of iterations for coordinate descent.
|
|
125
|
+
tol : float, default=1e-4
|
|
126
|
+
Tolerance for convergence.
|
|
127
|
+
device : str or Device, default='auto'
|
|
128
|
+
Computation device: 'cpu', 'cuda', or 'auto'.
|
|
129
|
+
cpu_solver : str, default='coordinate_descent'
|
|
130
|
+
CPU optimization algorithm: 'coordinate_descent' or 'fista'.
|
|
131
|
+
GPU uses the `solver` parameter instead.
|
|
132
|
+
|
|
133
|
+
Attributes
|
|
134
|
+
----------
|
|
135
|
+
coef_ : ndarray of shape (n_features,)
|
|
136
|
+
Estimated coefficients.
|
|
137
|
+
intercept_ : float
|
|
138
|
+
Independent term.
|
|
139
|
+
n_iter_ : int
|
|
140
|
+
Number of iterations run.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
# Internal cache for CuPy fused kernels (populated on first GPU use)
|
|
144
|
+
_cupy_fused_kernels = None
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
alpha: float = 1.0,
|
|
149
|
+
fit_intercept: bool = True,
|
|
150
|
+
max_iter: int = 1000,
|
|
151
|
+
tol: float = 1e-4,
|
|
152
|
+
stopping: str = "coef_delta",
|
|
153
|
+
inference_method: str = "cpu_ols_inference",
|
|
154
|
+
n_bootstrap: int = 200,
|
|
155
|
+
bootstrap_random_state: Optional[int] = None,
|
|
156
|
+
enable_simultaneous_inference: bool = False,
|
|
157
|
+
simultaneous_method: str = "maxz_bootstrap",
|
|
158
|
+
simultaneous_alpha: float = 0.05,
|
|
159
|
+
simultaneous_n_bootstrap: int = 1000,
|
|
160
|
+
simultaneous_random_state: Optional[int] = None,
|
|
161
|
+
simultaneous_include_intercept: bool = False,
|
|
162
|
+
device: Union[str, Device] = Device.AUTO,
|
|
163
|
+
n_jobs: Optional[int] = None,
|
|
164
|
+
compute_inference: bool = True,
|
|
165
|
+
solver: str = "fista",
|
|
166
|
+
cpu_solver: str = "coordinate_descent",
|
|
167
|
+
lipschitz_L: Optional[float] = None,
|
|
168
|
+
admm_rho: float = 1.0,
|
|
169
|
+
gpu_memory_cleanup: bool = False,
|
|
170
|
+
):
|
|
171
|
+
super().__init__(device=device, n_jobs=n_jobs)
|
|
172
|
+
self.alpha = alpha
|
|
173
|
+
self.fit_intercept = fit_intercept
|
|
174
|
+
self.max_iter = max_iter
|
|
175
|
+
self.tol = tol
|
|
176
|
+
self.stopping = stopping.lower()
|
|
177
|
+
self.inference_method = inference_method.lower()
|
|
178
|
+
# Semantic rename with backwards-compatible aliases.
|
|
179
|
+
# - "naive_ols" previously meant CPU-sided t-distribution inference.
|
|
180
|
+
# - "gpu_naive_ols" previously meant GPU-sided t-distribution inference
|
|
181
|
+
# with minimal residual/design transfers.
|
|
182
|
+
alias_map = {
|
|
183
|
+
"naive_ols": "cpu_ols_inference",
|
|
184
|
+
"gpu_naive_ols": "gpu_ols_inference",
|
|
185
|
+
}
|
|
186
|
+
self.inference_method = alias_map.get(self.inference_method, self.inference_method)
|
|
187
|
+
self.n_bootstrap = int(n_bootstrap)
|
|
188
|
+
self.bootstrap_random_state = bootstrap_random_state
|
|
189
|
+
self.enable_simultaneous_inference = bool(enable_simultaneous_inference)
|
|
190
|
+
self.simultaneous_method = str(simultaneous_method).lower()
|
|
191
|
+
self.simultaneous_alpha = float(simultaneous_alpha)
|
|
192
|
+
self.simultaneous_n_bootstrap = int(simultaneous_n_bootstrap)
|
|
193
|
+
self.simultaneous_random_state = simultaneous_random_state
|
|
194
|
+
self.simultaneous_include_intercept = bool(simultaneous_include_intercept)
|
|
195
|
+
self.compute_inference = compute_inference
|
|
196
|
+
self.solver = solver.lower()
|
|
197
|
+
self.cpu_solver = cpu_solver.lower()
|
|
198
|
+
self.lipschitz_L = lipschitz_L
|
|
199
|
+
self.admm_rho = admm_rho
|
|
200
|
+
self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
|
|
201
|
+
self.coef_ = None
|
|
202
|
+
self.intercept_ = None
|
|
203
|
+
self.n_iter_ = 0
|
|
204
|
+
|
|
205
|
+
# Internal storage for inference
|
|
206
|
+
self._X_design = None
|
|
207
|
+
self._y = None
|
|
208
|
+
self._resid = None
|
|
209
|
+
self._scale = None
|
|
210
|
+
self._nobs = None
|
|
211
|
+
self._df_resid = None
|
|
212
|
+
self._params = None
|
|
213
|
+
self._bse = None
|
|
214
|
+
self._tvalues = None
|
|
215
|
+
self._pvalues = None
|
|
216
|
+
self._conf_int = None
|
|
217
|
+
self._conf_int_simultaneous = None
|
|
218
|
+
self._simultaneous_enabled = False
|
|
219
|
+
self._simultaneous_method = None
|
|
220
|
+
self._simultaneous_alpha = None
|
|
221
|
+
self._simultaneous_n_bootstrap = None
|
|
222
|
+
self._simultaneous_critical_value = None
|
|
223
|
+
self._simultaneous_target_mask = None
|
|
224
|
+
self._debiased_M_cpu = None
|
|
225
|
+
self._inference_cautions = []
|
|
226
|
+
|
|
227
|
+
def fit(self, X, y, sample_weight=None):
|
|
228
|
+
"""Fit Lasso regression model using coordinate descent."""
|
|
229
|
+
self._validate_simultaneous_config()
|
|
230
|
+
self._reset_simultaneous_outputs()
|
|
231
|
+
device = self._get_compute_device()
|
|
232
|
+
|
|
233
|
+
# Get backend - support explicit torch backend selection
|
|
234
|
+
backend = self._get_backend(backend="auto")
|
|
235
|
+
backend_name = backend.name
|
|
236
|
+
|
|
237
|
+
if device == Device.CPU and self.inference_method == "gpu_ols_inference":
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"inference_method='gpu_ols_inference' requires device='cuda' or "
|
|
240
|
+
"device='torch'. Use inference_method='cpu_ols_inference' on CPU."
|
|
241
|
+
)
|
|
242
|
+
if device in (Device.CUDA, Device.TORCH) and self.inference_method == "cpu_ols_inference":
|
|
243
|
+
self.inference_method = "gpu_ols_inference"
|
|
244
|
+
if device == Device.CPU:
|
|
245
|
+
self._y = np.asarray(y)
|
|
246
|
+
else:
|
|
247
|
+
# GPU path: avoid host copies unless CPU-side inference needs y.
|
|
248
|
+
if (not self.compute_inference) or self.inference_method in (
|
|
249
|
+
"gpu_ols_inference",
|
|
250
|
+
"debiased",
|
|
251
|
+
):
|
|
252
|
+
self._y = None
|
|
253
|
+
else:
|
|
254
|
+
# y may already be a CuPy array; use safe conversion.
|
|
255
|
+
self._y = self._to_numpy(y)
|
|
256
|
+
|
|
257
|
+
if (
|
|
258
|
+
self.compute_inference
|
|
259
|
+
and device in (Device.CUDA, Device.TORCH)
|
|
260
|
+
and self.inference_method not in ("gpu_ols_inference", "debiased")
|
|
261
|
+
):
|
|
262
|
+
raise NotImplementedError(
|
|
263
|
+
f"Lasso inference_method='{self.inference_method}' is not implemented "
|
|
264
|
+
f"for device='{device.value}' without CPU fallback."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
X_arr = self._to_array(X, backend=backend_name)
|
|
268
|
+
y_arr = self._to_array(y, backend=backend_name)
|
|
269
|
+
|
|
270
|
+
# Route to appropriate backend
|
|
271
|
+
if backend_name == "torch":
|
|
272
|
+
self._fit_torch(X_arr, y_arr, sample_weight)
|
|
273
|
+
elif device == Device.CUDA:
|
|
274
|
+
self._fit_gpu(X_arr, y_arr, sample_weight)
|
|
275
|
+
else:
|
|
276
|
+
self._fit_cpu(X_arr, y_arr, sample_weight)
|
|
277
|
+
|
|
278
|
+
_skip_post_fit = {"gpu_ols_inference"}
|
|
279
|
+
if device == Device.CUDA and self.inference_method == "debiased":
|
|
280
|
+
_skip_post_fit.add("debiased")
|
|
281
|
+
if backend_name == "torch" and self.inference_method == "debiased":
|
|
282
|
+
_skip_post_fit.add("debiased")
|
|
283
|
+
if self.compute_inference and self.inference_method not in _skip_post_fit:
|
|
284
|
+
self._compute_inference()
|
|
285
|
+
if self.enable_simultaneous_inference:
|
|
286
|
+
self._compute_simultaneous_inference()
|
|
287
|
+
self._inference_cautions = self._build_inference_cautions()
|
|
288
|
+
for msg in self._inference_cautions:
|
|
289
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
290
|
+
self._fitted = True
|
|
291
|
+
return self
|
|
292
|
+
|
|
293
|
+
def _validate_simultaneous_config(self):
|
|
294
|
+
if not self.enable_simultaneous_inference:
|
|
295
|
+
return
|
|
296
|
+
if not self.compute_inference:
|
|
297
|
+
raise ValueError(
|
|
298
|
+
"enable_simultaneous_inference=True requires compute_inference=True."
|
|
299
|
+
)
|
|
300
|
+
if self.inference_method != "debiased":
|
|
301
|
+
raise ValueError(
|
|
302
|
+
"enable_simultaneous_inference=True currently requires "
|
|
303
|
+
"inference_method='debiased'."
|
|
304
|
+
)
|
|
305
|
+
if self.simultaneous_method != "maxz_bootstrap":
|
|
306
|
+
raise ValueError(
|
|
307
|
+
"simultaneous_method must be 'maxz_bootstrap'."
|
|
308
|
+
)
|
|
309
|
+
if not (0.0 < self.simultaneous_alpha < 1.0):
|
|
310
|
+
raise ValueError("simultaneous_alpha must be in (0, 1).")
|
|
311
|
+
if self.simultaneous_n_bootstrap <= 0:
|
|
312
|
+
raise ValueError("simultaneous_n_bootstrap must be a positive integer.")
|
|
313
|
+
|
|
314
|
+
def _reset_simultaneous_outputs(self):
|
|
315
|
+
self._conf_int_simultaneous = None
|
|
316
|
+
self._simultaneous_enabled = False
|
|
317
|
+
self._simultaneous_method = None
|
|
318
|
+
self._simultaneous_alpha = None
|
|
319
|
+
self._simultaneous_n_bootstrap = None
|
|
320
|
+
self._simultaneous_critical_value = None
|
|
321
|
+
self._simultaneous_target_mask = None
|
|
322
|
+
self._debiased_M_cpu = None
|
|
323
|
+
|
|
324
|
+
def _build_inference_cautions(self):
|
|
325
|
+
cautions = []
|
|
326
|
+
if not self.compute_inference:
|
|
327
|
+
return cautions
|
|
328
|
+
|
|
329
|
+
if self.inference_method in ("cpu_ols_inference", "gpu_ols_inference"):
|
|
330
|
+
cautions.append(
|
|
331
|
+
"Lasso OLS-style post-selection intervals are heuristic and do not "
|
|
332
|
+
"provide valid selective-inference confidence coverage."
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
if self.inference_method == "debiased":
|
|
336
|
+
cautions.append(
|
|
337
|
+
"Debiased Lasso currently reports per-coefficient (marginal) confidence "
|
|
338
|
+
"intervals only; joint/multiple-testing coverage is not guaranteed."
|
|
339
|
+
)
|
|
340
|
+
if self._simultaneous_enabled:
|
|
341
|
+
target_txt = (
|
|
342
|
+
"including intercept"
|
|
343
|
+
if (self.fit_intercept and self.simultaneous_include_intercept)
|
|
344
|
+
else "excluding intercept"
|
|
345
|
+
)
|
|
346
|
+
cautions.append(
|
|
347
|
+
"Simultaneous inference enabled via maxz_bootstrap with joint coverage "
|
|
348
|
+
f"target set {target_txt}."
|
|
349
|
+
)
|
|
350
|
+
if self.fit_intercept and self.simultaneous_include_intercept:
|
|
351
|
+
cautions.append(
|
|
352
|
+
"Intercept is included using the same max-|Z| critical value "
|
|
353
|
+
"calibrated on feature coefficients."
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return cautions
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def _get_cupy_fused_kernels():
|
|
360
|
+
"""
|
|
361
|
+
Get cached CuPy fused kernels for Lasso FISTA solver.
|
|
362
|
+
|
|
363
|
+
Fused kernels combine multiple elementwise operations into a single
|
|
364
|
+
kernel launch, reducing GPU kernel launch overhead. This is especially
|
|
365
|
+
beneficial for small-to-medium data sizes (n < 2000, p < 100).
|
|
366
|
+
|
|
367
|
+
Returns
|
|
368
|
+
-------
|
|
369
|
+
dict or None
|
|
370
|
+
Dictionary of fused kernels, or None if CuPy is not available.
|
|
371
|
+
"""
|
|
372
|
+
# Check cache first (class-level cache shared across all instances)
|
|
373
|
+
if Lasso._cupy_fused_kernels is not None:
|
|
374
|
+
return Lasso._cupy_fused_kernels
|
|
375
|
+
|
|
376
|
+
try:
|
|
377
|
+
import cupy as cp
|
|
378
|
+
except ImportError:
|
|
379
|
+
return None
|
|
380
|
+
|
|
381
|
+
# Fused soft thresholding: sign(x) * max(|x| - gamma, 0)
|
|
382
|
+
@cp.fuse()
|
|
383
|
+
def _soft_threshold_fused(x, gamma):
|
|
384
|
+
"""Fused soft thresholding operator."""
|
|
385
|
+
abs_x = abs(x)
|
|
386
|
+
return (x > 0) * (abs_x > gamma) * (abs_x - gamma) - (x < 0) * (abs_x > gamma) * (abs_x - gamma)
|
|
387
|
+
|
|
388
|
+
# Fused FISTA momentum update: coef + beta * (coef - coef_old)
|
|
389
|
+
@cp.fuse()
|
|
390
|
+
def _fista_momentum_fused(coef, coef_old, beta):
|
|
391
|
+
"""Fused FISTA momentum update."""
|
|
392
|
+
return coef + beta * (coef - coef_old)
|
|
393
|
+
|
|
394
|
+
# Fused KKT violation check: max(|grad| - alpha, 0)
|
|
395
|
+
@cp.fuse()
|
|
396
|
+
def _kkt_violation_fused(grad, alpha):
|
|
397
|
+
"""Fused KKT violation computation."""
|
|
398
|
+
abs_grad = abs(grad)
|
|
399
|
+
diff = abs_grad - alpha
|
|
400
|
+
return (diff > 0) * diff
|
|
401
|
+
|
|
402
|
+
# Custom ElementwiseKernel for soft thresholding
|
|
403
|
+
SOFT_THRESHOLD_KERNEL = cp.ElementwiseKernel(
|
|
404
|
+
'float64 x, float64 gamma',
|
|
405
|
+
'float64 y',
|
|
406
|
+
'''
|
|
407
|
+
double abs_x = abs(x);
|
|
408
|
+
if (abs_x > gamma) {
|
|
409
|
+
y = (x > 0 ? 1.0 : -1.0) * (abs_x - gamma);
|
|
410
|
+
} else {
|
|
411
|
+
y = 0.0;
|
|
412
|
+
}
|
|
413
|
+
''',
|
|
414
|
+
'lasso_soft_threshold'
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# Custom ElementwiseKernel for absolute delta (convergence check)
|
|
418
|
+
ABS_DELTA_KERNEL = cp.ElementwiseKernel(
|
|
419
|
+
'float64 a, float64 b',
|
|
420
|
+
'float64 y',
|
|
421
|
+
'''
|
|
422
|
+
double diff = a - b;
|
|
423
|
+
y = (diff > 0 ? diff : -diff);
|
|
424
|
+
''',
|
|
425
|
+
'lasso_abs_delta'
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Cache and return
|
|
429
|
+
Lasso._cupy_fused_kernels = {
|
|
430
|
+
'soft_threshold': _soft_threshold_fused,
|
|
431
|
+
'fista_momentum': _fista_momentum_fused,
|
|
432
|
+
'kkt_violation': _kkt_violation_fused,
|
|
433
|
+
'elementwise_kernel': SOFT_THRESHOLD_KERNEL,
|
|
434
|
+
'abs_delta_kernel': ABS_DELTA_KERNEL,
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
return Lasso._cupy_fused_kernels
|
|
438
|
+
|
|
439
|
+
def _soft_threshold(self, x, gamma):
|
|
440
|
+
"""Soft thresholding operator: S(x, gamma) = sign(x) * max(|x| - gamma, 0)."""
|
|
441
|
+
return np.sign(x) * np.maximum(np.abs(x) - gamma, 0)
|
|
442
|
+
|
|
443
|
+
def _fit_cpu(self, X, y, sample_weight=None):
|
|
444
|
+
"""Fit using CPU (coordinate descent or FISTA)."""
|
|
445
|
+
X = np.asarray(X)
|
|
446
|
+
y = np.asarray(y)
|
|
447
|
+
|
|
448
|
+
n_samples, n_features = X.shape
|
|
449
|
+
self._nobs = n_samples
|
|
450
|
+
|
|
451
|
+
if sample_weight is not None:
|
|
452
|
+
sample_weight = np.asarray(sample_weight)
|
|
453
|
+
sqrt_sw = np.sqrt(sample_weight)
|
|
454
|
+
X = X * sqrt_sw[:, np.newaxis]
|
|
455
|
+
y = y * sqrt_sw
|
|
456
|
+
|
|
457
|
+
if self.fit_intercept:
|
|
458
|
+
X_mean = np.mean(X, axis=0)
|
|
459
|
+
y_mean = np.mean(y)
|
|
460
|
+
X_centered = X - X_mean
|
|
461
|
+
y_centered = y - y_mean
|
|
462
|
+
else:
|
|
463
|
+
X_centered = X
|
|
464
|
+
y_mean = 0.0
|
|
465
|
+
y_centered = y
|
|
466
|
+
|
|
467
|
+
if y.ndim == 1:
|
|
468
|
+
y_centered = y_centered.reshape(-1, 1)
|
|
469
|
+
|
|
470
|
+
Xty = X_centered.T @ y_centered.flatten()
|
|
471
|
+
XtX = X_centered.T @ X_centered
|
|
472
|
+
|
|
473
|
+
coef = np.zeros(n_features)
|
|
474
|
+
|
|
475
|
+
if self.cpu_solver in ("fista",):
|
|
476
|
+
# Proximal gradient / FISTA for L1-regularized least squares:
|
|
477
|
+
# minimize (1/(2n)) * ||y - Xw||^2 + alpha * ||w||_1
|
|
478
|
+
# Uses the same stopping criterion as coordinate descent in this codebase:
|
|
479
|
+
# sum(abs(coef - coef_old)) < tol
|
|
480
|
+
|
|
481
|
+
if self.lipschitz_L is not None:
|
|
482
|
+
L = float(self.lipschitz_L)
|
|
483
|
+
else:
|
|
484
|
+
L_frob = float(np.sum(X_centered**2) / n_samples)
|
|
485
|
+
try:
|
|
486
|
+
eigvals = np.linalg.eigvalsh(XtX)
|
|
487
|
+
L = float(eigvals[-1] / n_samples)
|
|
488
|
+
except Exception:
|
|
489
|
+
L = L_frob
|
|
490
|
+
|
|
491
|
+
if L <= 0:
|
|
492
|
+
coef = np.zeros(n_features)
|
|
493
|
+
self.n_iter_ = 0
|
|
494
|
+
else:
|
|
495
|
+
step = 1.0 / L
|
|
496
|
+
thresh = self.alpha * step
|
|
497
|
+
|
|
498
|
+
# FISTA variables
|
|
499
|
+
y_k = coef.copy()
|
|
500
|
+
t_k = 1.0
|
|
501
|
+
|
|
502
|
+
for iteration in range(self.max_iter):
|
|
503
|
+
coef_old = coef.copy()
|
|
504
|
+
|
|
505
|
+
# grad = (XtX @ y_k - Xty) / n
|
|
506
|
+
grad = (XtX @ y_k - Xty) / n_samples
|
|
507
|
+
|
|
508
|
+
coef = self._soft_threshold(y_k - step * grad, thresh)
|
|
509
|
+
|
|
510
|
+
# Momentum update
|
|
511
|
+
t_new = (1.0 + np.sqrt(1.0 + 4.0 * (t_k**2))) / 2.0
|
|
512
|
+
beta = (t_k - 1.0) / t_new
|
|
513
|
+
y_k = coef + beta * (coef - coef_old)
|
|
514
|
+
t_k = t_new
|
|
515
|
+
|
|
516
|
+
if self.stopping == "kkt":
|
|
517
|
+
# KKT violation for Lasso:
|
|
518
|
+
# grad_sse = (XtX @ w - Xty) / n
|
|
519
|
+
# optimality: |grad_sse_j| <= alpha when w_j == 0
|
|
520
|
+
# violation measure: max_j max(|grad_sse_j| - alpha, 0)
|
|
521
|
+
grad_sse = (XtX @ coef - Xty) / n_samples
|
|
522
|
+
violation = np.max(np.maximum(np.abs(grad_sse) - self.alpha, 0.0))
|
|
523
|
+
if violation < self.tol:
|
|
524
|
+
self.n_iter_ = iteration + 1
|
|
525
|
+
break
|
|
526
|
+
else:
|
|
527
|
+
# Legacy stopping: coefficient delta
|
|
528
|
+
if np.sum(np.abs(coef - coef_old)) < self.tol:
|
|
529
|
+
self.n_iter_ = iteration + 1
|
|
530
|
+
break
|
|
531
|
+
else:
|
|
532
|
+
self.n_iter_ = self.max_iter
|
|
533
|
+
|
|
534
|
+
else:
|
|
535
|
+
# Coordinate descent (legacy CPU path)
|
|
536
|
+
# Precompute squared norms for each feature
|
|
537
|
+
X_sq_norms = np.diag(XtX)
|
|
538
|
+
|
|
539
|
+
for iteration in range(self.max_iter):
|
|
540
|
+
coef_old = coef.copy()
|
|
541
|
+
|
|
542
|
+
for j in range(n_features):
|
|
543
|
+
# Compute partial residual
|
|
544
|
+
rho_j = Xty[j] - np.dot(XtX[j, :], coef) + XtX[j, j] * coef[j]
|
|
545
|
+
|
|
546
|
+
# Update coefficient with soft thresholding
|
|
547
|
+
if X_sq_norms[j] > 1e-10:
|
|
548
|
+
coef[j] = self._soft_threshold(rho_j, self.alpha * n_samples) / X_sq_norms[j]
|
|
549
|
+
else:
|
|
550
|
+
coef[j] = 0.0
|
|
551
|
+
|
|
552
|
+
# Check convergence
|
|
553
|
+
if self.stopping == "kkt":
|
|
554
|
+
grad_sse = (XtX @ coef - Xty) / n_samples
|
|
555
|
+
violation = np.max(np.maximum(np.abs(grad_sse) - self.alpha, 0.0))
|
|
556
|
+
if violation < self.tol:
|
|
557
|
+
self.n_iter_ = iteration + 1
|
|
558
|
+
break
|
|
559
|
+
else:
|
|
560
|
+
if np.sum(np.abs(coef - coef_old)) < self.tol:
|
|
561
|
+
self.n_iter_ = iteration + 1
|
|
562
|
+
break
|
|
563
|
+
else:
|
|
564
|
+
self.n_iter_ = self.max_iter
|
|
565
|
+
|
|
566
|
+
# Compute intercept
|
|
567
|
+
if self.fit_intercept:
|
|
568
|
+
self.intercept_ = float(y_mean - X_mean @ coef)
|
|
569
|
+
self.coef_ = coef
|
|
570
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
571
|
+
else:
|
|
572
|
+
self.intercept_ = 0.0
|
|
573
|
+
self.coef_ = coef
|
|
574
|
+
self._params = coef.copy()
|
|
575
|
+
self._df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
|
|
576
|
+
if self.compute_inference:
|
|
577
|
+
if self.fit_intercept:
|
|
578
|
+
self._X_design = np.column_stack(
|
|
579
|
+
[np.ones(n_samples, dtype=X.dtype), X]
|
|
580
|
+
)
|
|
581
|
+
else:
|
|
582
|
+
self._X_design = X.copy()
|
|
583
|
+
|
|
584
|
+
y_pred = self._X_design @ self._params
|
|
585
|
+
self._resid = self._y - y_pred
|
|
586
|
+
|
|
587
|
+
if self._df_resid > 0:
|
|
588
|
+
self._scale = np.sum(self._resid ** 2) / self._df_resid
|
|
589
|
+
else:
|
|
590
|
+
self._scale = np.nan
|
|
591
|
+
else:
|
|
592
|
+
self._X_design = None
|
|
593
|
+
self._resid = None
|
|
594
|
+
self._scale = np.nan
|
|
595
|
+
|
|
596
|
+
def _soft_threshold_cupy(self, x, gamma):
|
|
597
|
+
"""Soft thresholding operator for CuPy arrays.
|
|
598
|
+
|
|
599
|
+
Uses fused kernel when available for improved performance on
|
|
600
|
+
small-to-medium data sizes.
|
|
601
|
+
"""
|
|
602
|
+
import cupy as cp
|
|
603
|
+
|
|
604
|
+
# Try to use fused kernel for better performance
|
|
605
|
+
fused = self._get_cupy_fused_kernels()
|
|
606
|
+
if fused is not None:
|
|
607
|
+
# Use ElementwiseKernel for best performance
|
|
608
|
+
return fused['elementwise_kernel'](x, gamma)
|
|
609
|
+
|
|
610
|
+
# Fallback to standard implementation
|
|
611
|
+
return cp.sign(x) * cp.maximum(cp.abs(x) - gamma, 0)
|
|
612
|
+
|
|
613
|
+
def _cleanup_cuda_memory(self):
|
|
614
|
+
"""
|
|
615
|
+
Best-effort CUDA memory pool cleanup.
|
|
616
|
+
|
|
617
|
+
CuPy caches freed blocks in its memory pool for speed. Enable
|
|
618
|
+
`gpu_memory_cleanup=True` to return cached blocks after fit when
|
|
619
|
+
VRAM pressure is more important than repeated-fit throughput.
|
|
620
|
+
"""
|
|
621
|
+
if not self.gpu_memory_cleanup:
|
|
622
|
+
return
|
|
623
|
+
try:
|
|
624
|
+
import cupy as cp
|
|
625
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
626
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
627
|
+
except Exception:
|
|
628
|
+
pass
|
|
629
|
+
|
|
630
|
+
def _fit_gpu(self, X, y, sample_weight=None):
|
|
631
|
+
"""Fit using GPU solver."""
|
|
632
|
+
import cupy as cp
|
|
633
|
+
from statgpu.backends._gpu_inference_cupy import compute_r2_gpu
|
|
634
|
+
|
|
635
|
+
if self.solver not in ("fista", "admm"):
|
|
636
|
+
raise ValueError("solver must be one of: 'fista', 'admm'")
|
|
637
|
+
|
|
638
|
+
if self.solver == "admm":
|
|
639
|
+
return self._fit_gpu_admm(X, y, sample_weight=sample_weight)
|
|
640
|
+
|
|
641
|
+
# Default: FISTA
|
|
642
|
+
|
|
643
|
+
n_samples, n_features = X.shape
|
|
644
|
+
self._nobs = n_samples
|
|
645
|
+
|
|
646
|
+
# Ensure CuPy arrays
|
|
647
|
+
X = cp.asarray(X)
|
|
648
|
+
y = cp.asarray(y)
|
|
649
|
+
|
|
650
|
+
if sample_weight is not None:
|
|
651
|
+
sample_weight = cp.asarray(sample_weight)
|
|
652
|
+
sqrt_sw = cp.sqrt(sample_weight)
|
|
653
|
+
X = X * sqrt_sw[:, cp.newaxis]
|
|
654
|
+
y = y * sqrt_sw
|
|
655
|
+
|
|
656
|
+
# Ensure vector y on GPU
|
|
657
|
+
y = y.reshape(-1)
|
|
658
|
+
|
|
659
|
+
# Center X/y when fitting intercept to match sklearn Lasso convention.
|
|
660
|
+
if self.fit_intercept:
|
|
661
|
+
X_mean = cp.mean(X, axis=0)
|
|
662
|
+
y_mean = cp.mean(y)
|
|
663
|
+
X_centered = X - X_mean
|
|
664
|
+
y_centered = y - y_mean
|
|
665
|
+
else:
|
|
666
|
+
X_centered = X
|
|
667
|
+
y_mean = cp.array(0.0, dtype=X.dtype)
|
|
668
|
+
y_centered = y
|
|
669
|
+
|
|
670
|
+
# Precompute XtX / Xty for FISTA gradient: grad(w) = (XtX @ w - Xty) / n
|
|
671
|
+
XtX = X_centered.T @ X_centered
|
|
672
|
+
Xty = X_centered.T @ y_centered
|
|
673
|
+
|
|
674
|
+
# Lipschitz constant L for grad(w): L = lambda_max(XtX) / n
|
|
675
|
+
# If user provides lipschitz_L, trust it (should be safe for convergence).
|
|
676
|
+
if self.lipschitz_L is not None:
|
|
677
|
+
L = cp.array(float(self.lipschitz_L), dtype=X.dtype)
|
|
678
|
+
else:
|
|
679
|
+
L_frob = cp.sum(X_centered ** 2) / n_samples
|
|
680
|
+
try:
|
|
681
|
+
eigvals = cp.linalg.eigvalsh(XtX)
|
|
682
|
+
L = eigvals[-1] / n_samples
|
|
683
|
+
except Exception:
|
|
684
|
+
L = L_frob
|
|
685
|
+
|
|
686
|
+
if L <= 0:
|
|
687
|
+
# Degenerate case: return all-zero coefficients
|
|
688
|
+
coef = cp.zeros(n_features, dtype=X.dtype)
|
|
689
|
+
self.n_iter_ = 0
|
|
690
|
+
else:
|
|
691
|
+
step = 1.0 / L
|
|
692
|
+
thresh = self.alpha * step
|
|
693
|
+
|
|
694
|
+
# FISTA variables
|
|
695
|
+
coef = cp.zeros(n_features, dtype=X.dtype) # w_k
|
|
696
|
+
y_k = coef.copy() # y_k
|
|
697
|
+
t_k = cp.array(1.0, dtype=X.dtype)
|
|
698
|
+
|
|
699
|
+
# Get fused kernels for optimized FISTA iterations
|
|
700
|
+
fused = self._get_cupy_fused_kernels()
|
|
701
|
+
|
|
702
|
+
for iteration in range(self.max_iter):
|
|
703
|
+
coef_old = coef
|
|
704
|
+
|
|
705
|
+
# Gradient at y_k: (1/n) XtX @ y_k - (1/n) Xty
|
|
706
|
+
grad = (XtX @ y_k - Xty) / n_samples
|
|
707
|
+
|
|
708
|
+
# Prox step for L1
|
|
709
|
+
coef = self._soft_threshold_cupy(y_k - step * grad, thresh)
|
|
710
|
+
|
|
711
|
+
# Momentum update (use fused kernel when available)
|
|
712
|
+
t_new = (1 + cp.sqrt(1 + 4 * (t_k ** 2))) / 2
|
|
713
|
+
beta = (t_k - 1) / t_new
|
|
714
|
+
if fused is not None:
|
|
715
|
+
y_k = fused['fista_momentum'](coef, coef_old, beta)
|
|
716
|
+
else:
|
|
717
|
+
y_k = coef + beta * (coef - coef_old)
|
|
718
|
+
t_k = t_new
|
|
719
|
+
|
|
720
|
+
# Convergence test
|
|
721
|
+
if self.stopping == "kkt":
|
|
722
|
+
grad_sse = (XtX @ coef - Xty) / n_samples
|
|
723
|
+
# Use fused KKT violation check when available
|
|
724
|
+
if fused is not None:
|
|
725
|
+
violation = cp.max(fused['kkt_violation'](grad_sse, self.alpha))
|
|
726
|
+
else:
|
|
727
|
+
violation = cp.max(cp.maximum(cp.abs(grad_sse) - self.alpha, 0.0))
|
|
728
|
+
if violation < self.tol:
|
|
729
|
+
self.n_iter_ = iteration + 1
|
|
730
|
+
break
|
|
731
|
+
else:
|
|
732
|
+
# Legacy stopping: coefficient delta (fast but not guaranteed objective optimality)
|
|
733
|
+
# Use fused delta kernel when available
|
|
734
|
+
if fused is not None and 'abs_delta_kernel' in fused:
|
|
735
|
+
delta = cp.sum(fused['abs_delta_kernel'](coef, coef_old))
|
|
736
|
+
else:
|
|
737
|
+
delta = cp.sum(cp.abs(coef - coef_old))
|
|
738
|
+
if delta < self.tol:
|
|
739
|
+
self.n_iter_ = iteration + 1
|
|
740
|
+
break
|
|
741
|
+
else:
|
|
742
|
+
self.n_iter_ = self.max_iter
|
|
743
|
+
|
|
744
|
+
# Build full coefficients and (optionally) residuals for inference/R^2
|
|
745
|
+
if self.fit_intercept:
|
|
746
|
+
intercept_gpu = y_mean - X_mean @ coef
|
|
747
|
+
coef_full = cp.concatenate([intercept_gpu.reshape(1), coef])
|
|
748
|
+
else:
|
|
749
|
+
coef_full = coef
|
|
750
|
+
|
|
751
|
+
# Always transfer coefficients; remaining transfers depend on compute_inference.
|
|
752
|
+
coef_full_np = coef_full.get()
|
|
753
|
+
|
|
754
|
+
if self.fit_intercept:
|
|
755
|
+
self.intercept_ = float(coef_full_np[0])
|
|
756
|
+
self.coef_ = coef_full_np[1:]
|
|
757
|
+
self._params = coef_full_np
|
|
758
|
+
else:
|
|
759
|
+
self.intercept_ = 0.0
|
|
760
|
+
self.coef_ = coef_full_np
|
|
761
|
+
self._params = coef_full_np
|
|
762
|
+
|
|
763
|
+
df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
|
|
764
|
+
self._df_resid = df_resid
|
|
765
|
+
|
|
766
|
+
# Inference/diagnostics require residuals and design matrix.
|
|
767
|
+
if self.compute_inference:
|
|
768
|
+
# Only build the design matrix when we need residuals/inference.
|
|
769
|
+
if self.fit_intercept:
|
|
770
|
+
X_design = cp.concatenate(
|
|
771
|
+
[cp.ones((n_samples, 1), dtype=X.dtype), X], axis=1
|
|
772
|
+
)
|
|
773
|
+
else:
|
|
774
|
+
X_design = X
|
|
775
|
+
|
|
776
|
+
y_pred = X_design @ coef_full
|
|
777
|
+
resid = y - y_pred
|
|
778
|
+
|
|
779
|
+
if df_resid > 0:
|
|
780
|
+
scale = cp.sum(resid ** 2) / df_resid
|
|
781
|
+
self._scale = float(scale.get()) if not cp.isnan(scale) else np.nan
|
|
782
|
+
else:
|
|
783
|
+
self._scale = np.nan
|
|
784
|
+
scale = cp.nan
|
|
785
|
+
|
|
786
|
+
if self.inference_method == "gpu_ols_inference":
|
|
787
|
+
# Compute inference fully on GPU, then transfer only small vectors.
|
|
788
|
+
XtX = X_design.T @ X_design
|
|
789
|
+
try:
|
|
790
|
+
XtX_inv = cp.linalg.inv(XtX)
|
|
791
|
+
except Exception:
|
|
792
|
+
XtX_inv = cp.linalg.pinv(XtX)
|
|
793
|
+
|
|
794
|
+
bse_gpu = cp.sqrt(scale * cp.diag(XtX_inv))
|
|
795
|
+
|
|
796
|
+
# Inference vectors on GPU to avoid scipy/cpu cdf/ppf.
|
|
797
|
+
params_gpu = coef_full # includes intercept when fit_intercept=True
|
|
798
|
+
tvalues_gpu = params_gpu / (bse_gpu + 1e-30)
|
|
799
|
+
# Two-tailed p-values from the Student-t survival function should
|
|
800
|
+
# already lie in [0, 1]. We still clamp at 1.0 as a defensive
|
|
801
|
+
# safeguard against tiny floating-point overshoots on GPU/backends.
|
|
802
|
+
pvalues_gpu = cp.minimum(1.0, 2.0 * t.sf(cp.abs(tvalues_gpu), df=df_resid))
|
|
803
|
+
|
|
804
|
+
alpha = 0.05 # two-tailed for 95% CI
|
|
805
|
+
t_crit_gpu = t.ppf(1.0 - alpha / 2.0, df=df_resid)
|
|
806
|
+
margin_gpu = t_crit_gpu * bse_gpu
|
|
807
|
+
conf_int_gpu = cp.stack([params_gpu - margin_gpu, params_gpu + margin_gpu], axis=1)
|
|
808
|
+
|
|
809
|
+
# Transfer only the small inference vectors back to CPU.
|
|
810
|
+
self._bse = cp.asnumpy(bse_gpu)
|
|
811
|
+
self._tvalues = cp.asnumpy(tvalues_gpu)
|
|
812
|
+
self._pvalues = cp.asnumpy(pvalues_gpu)
|
|
813
|
+
self._conf_int = cp.asnumpy(conf_int_gpu)
|
|
814
|
+
|
|
815
|
+
# R^2 / keep diagnostics consistent without transferring residuals.
|
|
816
|
+
y_mean_gpu = cp.mean(y)
|
|
817
|
+
ss_tot = cp.sum((y - y_mean_gpu) ** 2)
|
|
818
|
+
ss_res = cp.sum(resid ** 2)
|
|
819
|
+
self._rsquared_gpu = float(cp.asnumpy(1 - ss_res / ss_tot)) if ss_tot > 0 else 0.0
|
|
820
|
+
|
|
821
|
+
self._resid = None
|
|
822
|
+
self._X_design = None
|
|
823
|
+
elif self.inference_method == "debiased":
|
|
824
|
+
self._compute_inference_debiased_gpu(X, y, coef)
|
|
825
|
+
|
|
826
|
+
y_mean_gpu = cp.mean(y)
|
|
827
|
+
ss_tot = cp.sum((y - y_mean_gpu) ** 2)
|
|
828
|
+
ss_res = cp.sum(resid ** 2)
|
|
829
|
+
self._rsquared_gpu = float(cp.asnumpy(1 - ss_res / ss_tot)) if ss_tot > 0 else 0.0
|
|
830
|
+
|
|
831
|
+
self._resid = None
|
|
832
|
+
self._X_design = None
|
|
833
|
+
else:
|
|
834
|
+
# Default: transfer residuals and design to CPU.
|
|
835
|
+
self._resid = resid.get()
|
|
836
|
+
self._X_design = X_design.get()
|
|
837
|
+
|
|
838
|
+
else:
|
|
839
|
+
# Strict GPU mode: avoid large residual/host design transfers.
|
|
840
|
+
self._scale = np.nan
|
|
841
|
+
self._resid = None
|
|
842
|
+
self._X_design = None
|
|
843
|
+
# R^2 is optional; keep behavior as None when no residuals are available.
|
|
844
|
+
self._rsquared_gpu = None
|
|
845
|
+
|
|
846
|
+
# Drop large temporaries early (before optional pool cleanup).
|
|
847
|
+
try:
|
|
848
|
+
del X_design
|
|
849
|
+
except Exception:
|
|
850
|
+
pass
|
|
851
|
+
try:
|
|
852
|
+
del resid
|
|
853
|
+
except Exception:
|
|
854
|
+
pass
|
|
855
|
+
try:
|
|
856
|
+
del XtX
|
|
857
|
+
except Exception:
|
|
858
|
+
pass
|
|
859
|
+
try:
|
|
860
|
+
del Xty
|
|
861
|
+
except Exception:
|
|
862
|
+
pass
|
|
863
|
+
try:
|
|
864
|
+
del X_centered
|
|
865
|
+
except Exception:
|
|
866
|
+
pass
|
|
867
|
+
try:
|
|
868
|
+
del y_centered
|
|
869
|
+
except Exception:
|
|
870
|
+
pass
|
|
871
|
+
try:
|
|
872
|
+
del y_pred
|
|
873
|
+
except Exception:
|
|
874
|
+
pass
|
|
875
|
+
try:
|
|
876
|
+
del coef_full
|
|
877
|
+
except Exception:
|
|
878
|
+
pass
|
|
879
|
+
self._cleanup_cuda_memory()
|
|
880
|
+
|
|
881
|
+
def _cleanup_torch_memory(self):
|
|
882
|
+
"""Best-effort Torch CUDA memory cleanup."""
|
|
883
|
+
if not self.gpu_memory_cleanup:
|
|
884
|
+
return
|
|
885
|
+
try:
|
|
886
|
+
import torch
|
|
887
|
+
if torch.cuda.is_available():
|
|
888
|
+
torch.cuda.empty_cache()
|
|
889
|
+
torch.cuda.synchronize()
|
|
890
|
+
except Exception:
|
|
891
|
+
pass
|
|
892
|
+
|
|
893
|
+
def _matrix_fingerprint_torch(self, X: "torch.Tensor") -> str:
|
|
894
|
+
"""Generate a fingerprint key for caching debiased M matrix (Torch version)."""
|
|
895
|
+
import torch
|
|
896
|
+
n, p = X.shape
|
|
897
|
+
r = min(24, n)
|
|
898
|
+
c = min(24, p)
|
|
899
|
+
sample = X[:r, :c].contiguous()
|
|
900
|
+
h = hashlib.sha1()
|
|
901
|
+
h.update(str((n, p, str(X.dtype))).encode("utf-8"))
|
|
902
|
+
h.update(sample.cpu().numpy().tobytes())
|
|
903
|
+
return h.hexdigest()
|
|
904
|
+
|
|
905
|
+
def _solve_lasso_path_torch_fista_multi_fold_from_gram(
|
|
906
|
+
self,
|
|
907
|
+
XtX_batch,
|
|
908
|
+
Xty_batch,
|
|
909
|
+
*,
|
|
910
|
+
n_samples_vec,
|
|
911
|
+
alphas_desc,
|
|
912
|
+
max_iter,
|
|
913
|
+
tol,
|
|
914
|
+
stopping,
|
|
915
|
+
lipschitz_L=None,
|
|
916
|
+
check_every=8,
|
|
917
|
+
):
|
|
918
|
+
"""Solve descending-alpha Lasso paths for all folds together on Torch GPU."""
|
|
919
|
+
import torch
|
|
920
|
+
|
|
921
|
+
n_folds = int(XtX_batch.shape[0])
|
|
922
|
+
n_features = int(XtX_batch.shape[1])
|
|
923
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
924
|
+
dtype = XtX_batch.dtype
|
|
925
|
+
device = XtX_batch.device
|
|
926
|
+
|
|
927
|
+
coefs = torch.zeros((n_folds, n_features, n_alphas), dtype=dtype, device=device)
|
|
928
|
+
yk = coefs.clone()
|
|
929
|
+
tk = torch.ones((n_folds, n_alphas), dtype=dtype, device=device)
|
|
930
|
+
n_iters = torch.zeros((n_folds, n_alphas), dtype=torch.int32, device=device)
|
|
931
|
+
|
|
932
|
+
n_vec = torch.as_tensor(n_samples_vec, dtype=dtype, device=device).reshape(-1)
|
|
933
|
+
if n_vec.size != n_folds:
|
|
934
|
+
raise ValueError("n_samples_vec must have one entry per fold")
|
|
935
|
+
|
|
936
|
+
if lipschitz_L is not None:
|
|
937
|
+
L = torch.full((n_folds,), float(lipschitz_L), dtype=dtype, device=device)
|
|
938
|
+
else:
|
|
939
|
+
try:
|
|
940
|
+
eigvals = torch.linalg.eigvalsh(XtX_batch)
|
|
941
|
+
L = eigvals[:, -1] / n_vec
|
|
942
|
+
except Exception:
|
|
943
|
+
row_sum_bound = torch.max(torch.sum(torch.abs(XtX_batch), dim=2), dim=1)[0] / n_vec
|
|
944
|
+
L = torch.maximum(row_sum_bound, torch.tensor(1e-12, dtype=dtype, device=device))
|
|
945
|
+
|
|
946
|
+
step = 1.0 / L.reshape(n_folds, 1, 1)
|
|
947
|
+
alpha_gpu = torch.as_tensor(np.asarray(alphas_desc, dtype=np.float64), dtype=dtype, device=device).reshape(1, 1, n_alphas)
|
|
948
|
+
thresholds = alpha_gpu * step
|
|
949
|
+
|
|
950
|
+
Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
|
|
951
|
+
n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
|
|
952
|
+
stopping_name = str(stopping).lower()
|
|
953
|
+
check_every = max(1, int(check_every))
|
|
954
|
+
|
|
955
|
+
active_gpu = torch.ones((n_folds, n_alphas), dtype=torch.bool, device=device)
|
|
956
|
+
active_count = int(n_folds * n_alphas)
|
|
957
|
+
|
|
958
|
+
for iteration in range(int(max_iter)):
|
|
959
|
+
if active_count == 0:
|
|
960
|
+
break
|
|
961
|
+
|
|
962
|
+
active_expanded = active_gpu[:, None, :]
|
|
963
|
+
|
|
964
|
+
coef_old = coefs.clone()
|
|
965
|
+
grad = (torch.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
|
|
966
|
+
coef_candidate = torch.sign(yk - step * grad) * torch.maximum(torch.abs(yk - step * grad) - thresholds, torch.tensor(0.0, dtype=dtype, device=device))
|
|
967
|
+
coefs = torch.where(active_expanded, coef_candidate, coefs)
|
|
968
|
+
|
|
969
|
+
t_old = tk
|
|
970
|
+
t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
971
|
+
beta = (t_old - 1.0) / t_new
|
|
972
|
+
y_candidate = coefs + beta[:, None, :] * (coefs - coef_old)
|
|
973
|
+
yk = torch.where(active_expanded, y_candidate, yk)
|
|
974
|
+
tk = torch.where(active_gpu, t_new, tk)
|
|
975
|
+
|
|
976
|
+
active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
|
|
977
|
+
check_every_eff = max(check_every, 1)
|
|
978
|
+
should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
|
|
979
|
+
if not should_check:
|
|
980
|
+
continue
|
|
981
|
+
|
|
982
|
+
if stopping_name == "kkt":
|
|
983
|
+
grad_sse = (torch.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
|
|
984
|
+
violation = torch.max(torch.maximum(torch.abs(grad_sse) - alpha_gpu, torch.tensor(0.0, dtype=dtype, device=device)), dim=1)[0]
|
|
985
|
+
converged_local_gpu = violation < float(tol)
|
|
986
|
+
else:
|
|
987
|
+
delta = torch.sum(torch.abs(coefs - coef_old), dim=1)
|
|
988
|
+
converged_local_gpu = delta < float(tol)
|
|
989
|
+
|
|
990
|
+
newly_done_gpu = active_gpu & converged_local_gpu
|
|
991
|
+
done_count = int(torch.count_nonzero(newly_done_gpu).item())
|
|
992
|
+
if done_count == 0:
|
|
993
|
+
continue
|
|
994
|
+
|
|
995
|
+
n_iters[newly_done_gpu] = int(iteration) + 1
|
|
996
|
+
yk = torch.where(newly_done_gpu[:, None, :], coefs, yk)
|
|
997
|
+
active_gpu = active_gpu & (~converged_local_gpu)
|
|
998
|
+
active_count -= done_count
|
|
999
|
+
|
|
1000
|
+
return coefs.transpose(1, 2), n_iters.cpu().numpy()
|
|
1001
|
+
|
|
1002
|
+
def _compute_inference_debiased_torch(self, X_torch, y_torch, coef_torch):
|
|
1003
|
+
"""Torch GPU path for debiased Lasso inference.
|
|
1004
|
+
|
|
1005
|
+
Parameters
|
|
1006
|
+
----------
|
|
1007
|
+
X_torch : torch.Tensor, shape (n, p)
|
|
1008
|
+
Raw feature matrix on Torch GPU (no intercept column).
|
|
1009
|
+
y_torch : torch.Tensor, shape (n,)
|
|
1010
|
+
Response on Torch GPU.
|
|
1011
|
+
coef_torch : torch.Tensor, shape (p,)
|
|
1012
|
+
Lasso coefficients on Torch GPU (no intercept).
|
|
1013
|
+
"""
|
|
1014
|
+
import torch
|
|
1015
|
+
from statgpu.inference._distributions_backend import norm
|
|
1016
|
+
|
|
1017
|
+
n, p = X_torch.shape
|
|
1018
|
+
dtype = torch.float64
|
|
1019
|
+
device = X_torch.device
|
|
1020
|
+
|
|
1021
|
+
# Ensure correct dtype
|
|
1022
|
+
if X_torch.dtype != dtype:
|
|
1023
|
+
X_torch = X_torch.to(dtype)
|
|
1024
|
+
if y_torch.dtype != dtype:
|
|
1025
|
+
y_torch = y_torch.to(dtype)
|
|
1026
|
+
if coef_torch.dtype != dtype:
|
|
1027
|
+
coef_torch = coef_torch.to(dtype)
|
|
1028
|
+
|
|
1029
|
+
# Compute Sigma_hat = X'X / n
|
|
1030
|
+
Sigma_hat = X_torch.T @ X_torch / n
|
|
1031
|
+
|
|
1032
|
+
# Compute Lasso residuals
|
|
1033
|
+
resid_lasso = y_torch - X_torch @ coef_torch
|
|
1034
|
+
if self.fit_intercept:
|
|
1035
|
+
resid_lasso = resid_lasso - torch.mean(y_torch) + torch.mean(X_torch, dim=0) @ coef_torch
|
|
1036
|
+
|
|
1037
|
+
# Estimate noise variance sigma^2
|
|
1038
|
+
s_hat = torch.sum(torch.abs(coef_torch) > 0).to(dtype)
|
|
1039
|
+
denom = torch.maximum(torch.tensor(1.0, dtype=dtype, device=device), torch.tensor(float(n), dtype=dtype, device=device) - s_hat)
|
|
1040
|
+
sigma2 = torch.sum(resid_lasso ** 2) / denom
|
|
1041
|
+
|
|
1042
|
+
# Node-wise Lasso for M matrix estimation
|
|
1043
|
+
lam_nw = float(np.sqrt(2.0 * np.log(max(p, 2)) / n))
|
|
1044
|
+
alpha_nw = np.asarray([lam_nw], dtype=np.float64)
|
|
1045
|
+
tiny = 1e-30
|
|
1046
|
+
zero = 0.0
|
|
1047
|
+
one = 1.0
|
|
1048
|
+
|
|
1049
|
+
# Caching for M matrix
|
|
1050
|
+
X_sample = X_torch[: min(24, n), : min(24, p)].cpu().numpy()
|
|
1051
|
+
m_cache_key = _debiased_m_key_from_sample(
|
|
1052
|
+
n=n,
|
|
1053
|
+
p=p,
|
|
1054
|
+
dtype_name=str(dtype),
|
|
1055
|
+
sample_block=X_sample,
|
|
1056
|
+
lam_nw=lam_nw,
|
|
1057
|
+
tol=float(self.tol),
|
|
1058
|
+
)
|
|
1059
|
+
M_cached = _debiased_m_cache_get(m_cache_key)
|
|
1060
|
+
|
|
1061
|
+
if M_cached is not None:
|
|
1062
|
+
M = torch.from_numpy(M_cached).to(dtype).to(device)
|
|
1063
|
+
else:
|
|
1064
|
+
M = torch.zeros((p, p), dtype=dtype, device=device)
|
|
1065
|
+
XtX_full = X_torch.T @ X_torch
|
|
1066
|
+
Sigma_diag = torch.diag(Sigma_hat)
|
|
1067
|
+
|
|
1068
|
+
# Batch node-wise problems for efficiency
|
|
1069
|
+
try:
|
|
1070
|
+
# Estimate available GPU memory for batching
|
|
1071
|
+
if torch.cuda.is_available():
|
|
1072
|
+
free_mem = torch.cuda.mem_get_info(device)[0]
|
|
1073
|
+
bytes_per_fold = max(8, (p - 1) * (p - 1) * 8 * 2)
|
|
1074
|
+
chunk_size = int(max(4, min(64, free_mem // max(bytes_per_fold, 1))))
|
|
1075
|
+
else:
|
|
1076
|
+
chunk_size = 16
|
|
1077
|
+
except Exception:
|
|
1078
|
+
chunk_size = 16
|
|
1079
|
+
chunk_size = max(4, min(int(p), chunk_size))
|
|
1080
|
+
|
|
1081
|
+
for j0 in range(0, p, chunk_size):
|
|
1082
|
+
j1 = min(p, j0 + chunk_size)
|
|
1083
|
+
bsz = j1 - j0
|
|
1084
|
+
j_batch = torch.arange(j0, j1, dtype=torch.int32, device=device)
|
|
1085
|
+
|
|
1086
|
+
# Build "all except j" column index matrix
|
|
1087
|
+
base = torch.arange(p - 1, dtype=torch.int32, device=device).reshape(1, -1)
|
|
1088
|
+
cols_batch = base + (base >= j_batch.reshape(-1, 1))
|
|
1089
|
+
|
|
1090
|
+
# Gather batched Gram/Xty blocks
|
|
1091
|
+
XtX_batch = XtX_full[
|
|
1092
|
+
cols_batch[:, :, None],
|
|
1093
|
+
cols_batch[:, None, :],
|
|
1094
|
+
]
|
|
1095
|
+
Xty_batch = XtX_full[cols_batch, j_batch.reshape(-1, 1)].reshape(bsz, p - 1)
|
|
1096
|
+
|
|
1097
|
+
# Solve node-wise Lasso problems
|
|
1098
|
+
coefs_batch_desc, _ = self._solve_lasso_path_torch_fista_multi_fold_from_gram(
|
|
1099
|
+
XtX_batch,
|
|
1100
|
+
Xty_batch,
|
|
1101
|
+
n_samples_vec=np.full((bsz,), float(n), dtype=np.float64),
|
|
1102
|
+
alphas_desc=alpha_nw,
|
|
1103
|
+
max_iter=500,
|
|
1104
|
+
tol=1e-5,
|
|
1105
|
+
stopping="coef_delta",
|
|
1106
|
+
lipschitz_L=None,
|
|
1107
|
+
check_every=8,
|
|
1108
|
+
)
|
|
1109
|
+
gamma_batch = torch.from_numpy(np.asarray(coefs_batch_desc[:, 0, :], dtype=np.float64)).to(dtype).to(device)
|
|
1110
|
+
|
|
1111
|
+
# C_j = Sigma_jj - Sigma_{j,-j} gamma_j
|
|
1112
|
+
sigma_j_cols = Sigma_hat[j_batch[:, None], cols_batch]
|
|
1113
|
+
C_batch = Sigma_diag[j_batch] - torch.sum(sigma_j_cols * gamma_batch, dim=1)
|
|
1114
|
+
|
|
1115
|
+
small_c = torch.abs(C_batch) < tiny
|
|
1116
|
+
inv_c = torch.where(small_c, torch.tensor(zero, dtype=dtype, device=device), torch.tensor(one, dtype=dtype, device=device) / C_batch)
|
|
1117
|
+
M[j_batch, j_batch] = torch.where(small_c, torch.tensor(one, dtype=dtype, device=device), inv_c)
|
|
1118
|
+
M[j_batch[:, None], cols_batch] = -gamma_batch * inv_c.reshape(-1, 1)
|
|
1119
|
+
|
|
1120
|
+
# Cleanup
|
|
1121
|
+
del XtX_batch
|
|
1122
|
+
del Xty_batch
|
|
1123
|
+
del coefs_batch_desc
|
|
1124
|
+
del gamma_batch
|
|
1125
|
+
del sigma_j_cols
|
|
1126
|
+
|
|
1127
|
+
_debiased_m_cache_put(m_cache_key, M.cpu().numpy())
|
|
1128
|
+
|
|
1129
|
+
# Compute full residual
|
|
1130
|
+
if self.fit_intercept:
|
|
1131
|
+
y_pred = X_torch @ coef_torch + torch.tensor(self.intercept_, dtype=dtype, device=device)
|
|
1132
|
+
else:
|
|
1133
|
+
y_pred = X_torch @ coef_torch
|
|
1134
|
+
resid_full = y_torch - y_pred
|
|
1135
|
+
|
|
1136
|
+
# Debiased estimate: theta_db = coef + M @ X' @ resid / n
|
|
1137
|
+
theta_db = coef_torch + (M @ X_torch.T @ resid_full) / n
|
|
1138
|
+
|
|
1139
|
+
# Variance estimation: V = M @ Sigma_hat @ M'
|
|
1140
|
+
V = M @ Sigma_hat @ M.T
|
|
1141
|
+
se = torch.sqrt(sigma2 * torch.diag(V) / n)
|
|
1142
|
+
|
|
1143
|
+
# z-statistics and p-values
|
|
1144
|
+
z_stats = theta_db / (se + 1e-30)
|
|
1145
|
+
pvalues = torch.minimum(torch.tensor(1.0, dtype=dtype, device=device), 2.0 * norm.sf(torch.abs(z_stats)))
|
|
1146
|
+
|
|
1147
|
+
# Confidence intervals
|
|
1148
|
+
alpha_ci = 0.05
|
|
1149
|
+
z_crit = norm.ppf(1.0 - alpha_ci / 2.0)
|
|
1150
|
+
ci = torch.stack([theta_db - z_crit * se, theta_db + z_crit * se], dim=1)
|
|
1151
|
+
|
|
1152
|
+
# Handle intercept
|
|
1153
|
+
if self.fit_intercept:
|
|
1154
|
+
X_full = torch.cat([torch.ones((n, 1), dtype=dtype, device=device), X_torch], dim=1)
|
|
1155
|
+
XtX_full = X_full.T @ X_full
|
|
1156
|
+
try:
|
|
1157
|
+
XtX_inv = torch.linalg.inv(XtX_full)
|
|
1158
|
+
except Exception:
|
|
1159
|
+
XtX_inv = torch.linalg.pinv(XtX_full)
|
|
1160
|
+
se_intercept = torch.sqrt(sigma2 * XtX_inv[0, 0])
|
|
1161
|
+
intercept_torch = torch.tensor(self.intercept_, dtype=dtype, device=device)
|
|
1162
|
+
z_intercept = intercept_torch / (se_intercept + 1e-30)
|
|
1163
|
+
p_intercept = torch.minimum(torch.tensor(1.0, dtype=dtype, device=device), 2.0 * norm.sf(torch.abs(z_intercept).reshape(1)))
|
|
1164
|
+
ci_intercept = torch.stack([
|
|
1165
|
+
intercept_torch - z_crit * se_intercept,
|
|
1166
|
+
intercept_torch + z_crit * se_intercept,
|
|
1167
|
+
]).reshape(1, 2)
|
|
1168
|
+
|
|
1169
|
+
bse_torch = torch.cat([se_intercept.reshape(1), se])
|
|
1170
|
+
tvalues_torch = torch.cat([z_intercept.reshape(1), z_stats])
|
|
1171
|
+
pvalues_torch = torch.cat([p_intercept.reshape(1), pvalues])
|
|
1172
|
+
conf_int_torch = torch.cat([ci_intercept, ci], dim=0)
|
|
1173
|
+
params_torch = torch.cat([intercept_torch.reshape(1), theta_db])
|
|
1174
|
+
else:
|
|
1175
|
+
bse_torch = se
|
|
1176
|
+
tvalues_torch = z_stats
|
|
1177
|
+
pvalues_torch = pvalues
|
|
1178
|
+
conf_int_torch = ci
|
|
1179
|
+
params_torch = theta_db
|
|
1180
|
+
|
|
1181
|
+
# Transfer to CPU
|
|
1182
|
+
self._bse = bse_torch.cpu().numpy()
|
|
1183
|
+
self._tvalues = tvalues_torch.cpu().numpy()
|
|
1184
|
+
self._pvalues = pvalues_torch.cpu().numpy()
|
|
1185
|
+
self._conf_int = conf_int_torch.cpu().numpy()
|
|
1186
|
+
self._params = params_torch.cpu().numpy()
|
|
1187
|
+
|
|
1188
|
+
# Store M matrix for simultaneous inference
|
|
1189
|
+
self._debiased_M_cpu = M.cpu().numpy()
|
|
1190
|
+
|
|
1191
|
+
# Simultaneous inference (max-|Z| bootstrap)
|
|
1192
|
+
if self.enable_simultaneous_inference:
|
|
1193
|
+
self._compute_simultaneous_inference_torch(
|
|
1194
|
+
params_torch, bse_torch, se, M, X_torch, resid_full, n
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
def _compute_simultaneous_inference_torch(
|
|
1198
|
+
self, params_torch, bse_torch, se_feat_torch, M_torch, X_torch, resid_full_torch, n
|
|
1199
|
+
):
|
|
1200
|
+
"""Torch GPU implementation of simultaneous inference via max-|Z| bootstrap."""
|
|
1201
|
+
import torch
|
|
1202
|
+
|
|
1203
|
+
# Get target indices
|
|
1204
|
+
param_target_idx_np = self._get_simultaneous_target_indices(int(params_torch.shape[0]))
|
|
1205
|
+
param_target_idx_torch = torch.as_tensor(param_target_idx_np, dtype=torch.int32, device=params_torch.device)
|
|
1206
|
+
|
|
1207
|
+
if param_target_idx_torch.size == 0:
|
|
1208
|
+
raise RuntimeError("No coefficients selected for simultaneous inference target set.")
|
|
1209
|
+
|
|
1210
|
+
feature_offset = 1 if self.fit_intercept else 0
|
|
1211
|
+
feature_target_torch = param_target_idx_torch - feature_offset
|
|
1212
|
+
feature_target_torch = feature_target_torch[feature_target_torch >= 0]
|
|
1213
|
+
|
|
1214
|
+
if feature_target_torch.size == 0:
|
|
1215
|
+
raise RuntimeError("No feature coefficients selected for simultaneous inference target set.")
|
|
1216
|
+
|
|
1217
|
+
se_target_torch = torch.index_select(se_feat_torch, 0, feature_target_torch)
|
|
1218
|
+
M_target = torch.index_select(M_torch, 0, feature_target_torch)
|
|
1219
|
+
|
|
1220
|
+
B = int(self.simultaneous_n_bootstrap)
|
|
1221
|
+
if self.simultaneous_random_state is not None:
|
|
1222
|
+
torch.manual_seed(self.simultaneous_random_state)
|
|
1223
|
+
|
|
1224
|
+
# Bootstrap in chunks to manage memory
|
|
1225
|
+
try:
|
|
1226
|
+
# Try one-shot computation
|
|
1227
|
+
xi = torch.randn((B, n), dtype=torch.float64, device=X_torch.device)
|
|
1228
|
+
weighted = xi * resid_full_torch.reshape(1, -1)
|
|
1229
|
+
score_target = (weighted @ X_torch) @ M_target.T / float(max(n, 1))
|
|
1230
|
+
z_star_target = score_target / (se_target_torch.reshape(1, -1) + 1e-30)
|
|
1231
|
+
max_stats_torch = torch.max(torch.abs(z_star_target), dim=1)[0]
|
|
1232
|
+
except Exception:
|
|
1233
|
+
# Fallback to chunked computation
|
|
1234
|
+
max_stats_torch = torch.empty((B,), dtype=torch.float64, device=X_torch.device)
|
|
1235
|
+
chunk = min(B, 64)
|
|
1236
|
+
filled = 0
|
|
1237
|
+
while filled < B:
|
|
1238
|
+
bsz = min(chunk, B - filled)
|
|
1239
|
+
xi = torch.randn((bsz, n), dtype=torch.float64, device=X_torch.device)
|
|
1240
|
+
weighted = xi * resid_full_torch.reshape(1, -1)
|
|
1241
|
+
score_target = (weighted @ X_torch) @ M_target.T / float(max(n, 1))
|
|
1242
|
+
z_star_target = score_target / (se_target_torch.reshape(1, -1) + 1e-30)
|
|
1243
|
+
max_stats_torch[filled : filled + bsz] = torch.max(torch.abs(z_star_target), dim=1)[0]
|
|
1244
|
+
filled += bsz
|
|
1245
|
+
|
|
1246
|
+
# Compute critical value
|
|
1247
|
+
critical_torch = torch.quantile(max_stats_torch, 1.0 - float(self.simultaneous_alpha))
|
|
1248
|
+
|
|
1249
|
+
# Build simultaneous confidence intervals
|
|
1250
|
+
conf_sim_torch = conf_int_torch.clone()
|
|
1251
|
+
lower_torch = torch.index_select(params_torch, 0, param_target_idx_torch) - critical_torch * torch.index_select(bse_torch, 0, param_target_idx_torch)
|
|
1252
|
+
upper_torch = torch.index_select(params_torch, 0, param_target_idx_torch) + critical_torch * torch.index_select(bse_torch, 0, param_target_idx_torch)
|
|
1253
|
+
conf_sim_torch[param_target_idx_torch, 0] = lower_torch
|
|
1254
|
+
conf_sim_torch[param_target_idx_torch, 1] = upper_torch
|
|
1255
|
+
|
|
1256
|
+
# Store results
|
|
1257
|
+
target_mask = np.zeros(int(params_torch.shape[0]), dtype=bool)
|
|
1258
|
+
target_mask[param_target_idx_np] = True
|
|
1259
|
+
self._conf_int_simultaneous = conf_sim_torch.cpu().numpy()
|
|
1260
|
+
self._simultaneous_enabled = True
|
|
1261
|
+
self._simultaneous_method = self.simultaneous_method
|
|
1262
|
+
self._simultaneous_alpha = float(self.simultaneous_alpha)
|
|
1263
|
+
self._simultaneous_n_bootstrap = B
|
|
1264
|
+
self._simultaneous_critical_value = float(critical_torch.cpu().numpy())
|
|
1265
|
+
self._simultaneous_target_mask = target_mask
|
|
1266
|
+
|
|
1267
|
+
def _soft_threshold_torch(self, x, gamma):
|
|
1268
|
+
"""Soft thresholding operator for Torch tensors."""
|
|
1269
|
+
import torch
|
|
1270
|
+
return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, torch.tensor(0.0, dtype=x.dtype, device=x.device))
|
|
1271
|
+
|
|
1272
|
+
def _fit_torch(self, X, y, sample_weight=None):
|
|
1273
|
+
"""Fit using Torch GPU with FISTA solver."""
|
|
1274
|
+
import torch
|
|
1275
|
+
from statgpu.backends._gpu_inference_torch import compute_r2_torch
|
|
1276
|
+
|
|
1277
|
+
if self.solver not in ("fista", "admm"):
|
|
1278
|
+
raise ValueError("Torch backend currently only supports 'fista' solver")
|
|
1279
|
+
|
|
1280
|
+
# For now, only FISTA is implemented for Torch backend
|
|
1281
|
+
if self.solver == "admm":
|
|
1282
|
+
raise NotImplementedError("ADMM solver not yet implemented for Torch backend")
|
|
1283
|
+
|
|
1284
|
+
n_samples, n_features = X.shape
|
|
1285
|
+
self._nobs = n_samples
|
|
1286
|
+
|
|
1287
|
+
# Ensure Torch tensors on GPU
|
|
1288
|
+
if not isinstance(X, torch.Tensor):
|
|
1289
|
+
X = torch.from_numpy(X).to('cuda')
|
|
1290
|
+
if not isinstance(y, torch.Tensor):
|
|
1291
|
+
y = torch.from_numpy(y).to('cuda')
|
|
1292
|
+
if y.dtype != torch.float64:
|
|
1293
|
+
y = y.to(torch.float64)
|
|
1294
|
+
if X.dtype != torch.float64:
|
|
1295
|
+
X = X.to(torch.float64)
|
|
1296
|
+
|
|
1297
|
+
if sample_weight is not None:
|
|
1298
|
+
if not isinstance(sample_weight, torch.Tensor):
|
|
1299
|
+
sample_weight = torch.from_numpy(sample_weight).to('cuda')
|
|
1300
|
+
sqrt_sw = torch.sqrt(sample_weight)
|
|
1301
|
+
X = X * sqrt_sw[:, None]
|
|
1302
|
+
y = y * sqrt_sw
|
|
1303
|
+
|
|
1304
|
+
# Ensure vector y on GPU
|
|
1305
|
+
y = y.reshape(-1)
|
|
1306
|
+
|
|
1307
|
+
# Center for intercept
|
|
1308
|
+
if self.fit_intercept:
|
|
1309
|
+
X_mean = torch.mean(X, dim=0)
|
|
1310
|
+
y_mean = torch.mean(y)
|
|
1311
|
+
X_centered = X - X_mean
|
|
1312
|
+
y_centered = y - y_mean
|
|
1313
|
+
else:
|
|
1314
|
+
X_centered = X
|
|
1315
|
+
y_mean = torch.tensor(0.0, dtype=X.dtype, device=X.device)
|
|
1316
|
+
y_centered = y
|
|
1317
|
+
|
|
1318
|
+
# Precompute XtX / Xty for FISTA gradient
|
|
1319
|
+
XtX = X_centered.T @ X_centered
|
|
1320
|
+
Xty = X_centered.T @ y_centered
|
|
1321
|
+
|
|
1322
|
+
# Lipschitz constant L
|
|
1323
|
+
if self.lipschitz_L is not None:
|
|
1324
|
+
L = torch.tensor(float(self.lipschitz_L), dtype=X.dtype, device=X.device)
|
|
1325
|
+
else:
|
|
1326
|
+
L_frob = torch.sum(X_centered ** 2) / n_samples
|
|
1327
|
+
try:
|
|
1328
|
+
eigvals = torch.linalg.eigvalsh(XtX)
|
|
1329
|
+
L = eigvals[-1] / n_samples
|
|
1330
|
+
except Exception:
|
|
1331
|
+
L = L_frob
|
|
1332
|
+
|
|
1333
|
+
if L <= 0:
|
|
1334
|
+
coef = torch.zeros(n_features, dtype=X.dtype, device=X.device)
|
|
1335
|
+
self.n_iter_ = 0
|
|
1336
|
+
else:
|
|
1337
|
+
step = 1.0 / L
|
|
1338
|
+
thresh = self.alpha * step
|
|
1339
|
+
|
|
1340
|
+
# FISTA variables
|
|
1341
|
+
coef = torch.zeros(n_features, dtype=X.dtype, device=X.device)
|
|
1342
|
+
y_k = coef.clone()
|
|
1343
|
+
t_k = torch.tensor(1.0, dtype=X.dtype, device=X.device)
|
|
1344
|
+
|
|
1345
|
+
for iteration in range(self.max_iter):
|
|
1346
|
+
coef_old = coef.clone()
|
|
1347
|
+
|
|
1348
|
+
# Gradient at y_k
|
|
1349
|
+
grad = (XtX @ y_k - Xty) / n_samples
|
|
1350
|
+
|
|
1351
|
+
# Prox step for L1
|
|
1352
|
+
coef = self._soft_threshold_torch(y_k - step * grad, thresh)
|
|
1353
|
+
|
|
1354
|
+
# Momentum update
|
|
1355
|
+
t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_k ** 2))) / 2.0
|
|
1356
|
+
beta = (t_k - 1.0) / t_new
|
|
1357
|
+
y_k = coef + beta * (coef - coef_old)
|
|
1358
|
+
t_k = t_new
|
|
1359
|
+
|
|
1360
|
+
# Convergence test
|
|
1361
|
+
if self.stopping == "kkt":
|
|
1362
|
+
grad_sse = (XtX @ coef - Xty) / n_samples
|
|
1363
|
+
violation = torch.max(torch.maximum(torch.abs(grad_sse) - self.alpha, torch.tensor(0.0, dtype=X.dtype, device=X.device)))
|
|
1364
|
+
if violation < self.tol:
|
|
1365
|
+
self.n_iter_ = iteration + 1
|
|
1366
|
+
break
|
|
1367
|
+
else:
|
|
1368
|
+
if torch.sum(torch.abs(coef - coef_old)) < self.tol:
|
|
1369
|
+
self.n_iter_ = iteration + 1
|
|
1370
|
+
break
|
|
1371
|
+
else:
|
|
1372
|
+
self.n_iter_ = self.max_iter
|
|
1373
|
+
|
|
1374
|
+
# Build full coefficients
|
|
1375
|
+
if self.fit_intercept:
|
|
1376
|
+
intercept_torch = y_mean - X_mean @ coef
|
|
1377
|
+
coef_full = torch.cat([intercept_torch.reshape(1), coef])
|
|
1378
|
+
else:
|
|
1379
|
+
coef_full = coef
|
|
1380
|
+
|
|
1381
|
+
# Transfer coefficients to CPU
|
|
1382
|
+
coef_full_np = coef_full.cpu().numpy()
|
|
1383
|
+
|
|
1384
|
+
if self.fit_intercept:
|
|
1385
|
+
self.intercept_ = float(coef_full_np[0])
|
|
1386
|
+
self.coef_ = coef_full_np[1:]
|
|
1387
|
+
self._params = coef_full_np
|
|
1388
|
+
else:
|
|
1389
|
+
self.intercept_ = 0.0
|
|
1390
|
+
self.coef_ = coef_full_np
|
|
1391
|
+
self._params = coef_full_np
|
|
1392
|
+
|
|
1393
|
+
df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
|
|
1394
|
+
self._df_resid = df_resid
|
|
1395
|
+
|
|
1396
|
+
# Inference/diagnostics
|
|
1397
|
+
if self.compute_inference:
|
|
1398
|
+
if self.fit_intercept:
|
|
1399
|
+
X_design = torch.cat([torch.ones((n_samples, 1), dtype=X.dtype, device=X.device), X], dim=1)
|
|
1400
|
+
else:
|
|
1401
|
+
X_design = X
|
|
1402
|
+
|
|
1403
|
+
y_pred = X_design @ coef_full
|
|
1404
|
+
resid = y - y_pred
|
|
1405
|
+
|
|
1406
|
+
if df_resid > 0:
|
|
1407
|
+
scale = torch.sum(resid ** 2) / df_resid
|
|
1408
|
+
self._scale = float(scale.cpu().numpy()) if not torch.isnan(scale) else np.nan
|
|
1409
|
+
else:
|
|
1410
|
+
self._scale = np.nan
|
|
1411
|
+
scale = torch.tensor(np.nan, dtype=X.dtype, device=X.device)
|
|
1412
|
+
|
|
1413
|
+
if self.inference_method == "gpu_ols_inference":
|
|
1414
|
+
# Compute inference fully on GPU
|
|
1415
|
+
XtX_inf = X_design.T @ X_design
|
|
1416
|
+
try:
|
|
1417
|
+
XtX_inv = torch.linalg.inv(XtX_inf)
|
|
1418
|
+
except Exception:
|
|
1419
|
+
XtX_inv = torch.linalg.pinv(XtX_inf)
|
|
1420
|
+
|
|
1421
|
+
bse_gpu = torch.sqrt(scale * torch.diag(XtX_inv))
|
|
1422
|
+
params_gpu = coef_full
|
|
1423
|
+
tvalues_gpu = params_gpu / (bse_gpu + 1e-30)
|
|
1424
|
+
|
|
1425
|
+
from statgpu.inference._distributions_backend import get_distribution
|
|
1426
|
+
t_dist = get_distribution("t", backend="torch", device=str(X.device))
|
|
1427
|
+
pvalues_gpu = torch.minimum(torch.tensor(1.0, device=X.device), 2.0 * t_dist.sf(torch.abs(tvalues_gpu), df=df_resid))
|
|
1428
|
+
|
|
1429
|
+
alpha = 0.05
|
|
1430
|
+
t_crit_gpu = t_dist.ppf(1.0 - alpha / 2.0, df=df_resid)
|
|
1431
|
+
margin_gpu = t_crit_gpu * bse_gpu
|
|
1432
|
+
conf_int_gpu = torch.stack([params_gpu - margin_gpu, params_gpu + margin_gpu], dim=1)
|
|
1433
|
+
|
|
1434
|
+
# Transfer to CPU
|
|
1435
|
+
self._bse = bse_gpu.cpu().numpy()
|
|
1436
|
+
self._tvalues = tvalues_gpu.cpu().numpy()
|
|
1437
|
+
self._pvalues = pvalues_gpu.cpu().numpy()
|
|
1438
|
+
self._conf_int = conf_int_gpu.cpu().numpy()
|
|
1439
|
+
|
|
1440
|
+
# R^2
|
|
1441
|
+
y_mean_gpu = torch.mean(y)
|
|
1442
|
+
ss_tot = torch.sum((y - y_mean_gpu) ** 2)
|
|
1443
|
+
ss_res = torch.sum(resid ** 2)
|
|
1444
|
+
self._rsquared_gpu = float((1 - ss_res / ss_tot).cpu().numpy()) if ss_tot > 0 else 0.0
|
|
1445
|
+
|
|
1446
|
+
self._resid = None
|
|
1447
|
+
self._X_design = None
|
|
1448
|
+
elif self.inference_method == "debiased":
|
|
1449
|
+
# Debiased Lasso inference on Torch GPU
|
|
1450
|
+
self._compute_inference_debiased_torch(X, y, coef)
|
|
1451
|
+
|
|
1452
|
+
# R^2 computation
|
|
1453
|
+
y_mean_gpu = torch.mean(y)
|
|
1454
|
+
ss_tot = torch.sum((y - y_mean_gpu) ** 2)
|
|
1455
|
+
ss_res = torch.sum(resid ** 2)
|
|
1456
|
+
self._rsquared_gpu = float((1 - ss_res / ss_tot).cpu().numpy()) if ss_tot > 0 else 0.0
|
|
1457
|
+
|
|
1458
|
+
self._resid = None
|
|
1459
|
+
self._X_design = None
|
|
1460
|
+
else:
|
|
1461
|
+
raise NotImplementedError(
|
|
1462
|
+
f"Lasso inference_method='{self.inference_method}' is not implemented "
|
|
1463
|
+
"for Torch without CPU fallback."
|
|
1464
|
+
)
|
|
1465
|
+
else:
|
|
1466
|
+
self._scale = np.nan
|
|
1467
|
+
self._resid = None
|
|
1468
|
+
self._X_design = None
|
|
1469
|
+
self._rsquared_gpu = None
|
|
1470
|
+
|
|
1471
|
+
# Cleanup
|
|
1472
|
+
try:
|
|
1473
|
+
del X_design
|
|
1474
|
+
except Exception:
|
|
1475
|
+
pass
|
|
1476
|
+
try:
|
|
1477
|
+
del resid
|
|
1478
|
+
except Exception:
|
|
1479
|
+
pass
|
|
1480
|
+
try:
|
|
1481
|
+
del XtX
|
|
1482
|
+
except Exception:
|
|
1483
|
+
pass
|
|
1484
|
+
try:
|
|
1485
|
+
del Xty
|
|
1486
|
+
except Exception:
|
|
1487
|
+
pass
|
|
1488
|
+
try:
|
|
1489
|
+
del X_centered
|
|
1490
|
+
except Exception:
|
|
1491
|
+
pass
|
|
1492
|
+
try:
|
|
1493
|
+
del y_centered
|
|
1494
|
+
except Exception:
|
|
1495
|
+
pass
|
|
1496
|
+
try:
|
|
1497
|
+
del y_pred
|
|
1498
|
+
except Exception:
|
|
1499
|
+
pass
|
|
1500
|
+
try:
|
|
1501
|
+
del coef_full
|
|
1502
|
+
except Exception:
|
|
1503
|
+
pass
|
|
1504
|
+
self._cleanup_torch_memory()
|
|
1505
|
+
|
|
1506
|
+
def _fit_gpu_admm(self, X, y, sample_weight=None):
|
|
1507
|
+
"""Fit using GPU with ADMM solver.
|
|
1508
|
+
|
|
1509
|
+
Objective matches sklearn:
|
|
1510
|
+
(1/(2n)) * ||y - Xw||^2 + alpha * ||w||_1
|
|
1511
|
+
"""
|
|
1512
|
+
import cupy as cp
|
|
1513
|
+
import cupyx.scipy.linalg as cpx_linalg
|
|
1514
|
+
|
|
1515
|
+
n_samples, n_features = X.shape
|
|
1516
|
+
self._nobs = n_samples
|
|
1517
|
+
|
|
1518
|
+
# Ensure CuPy arrays
|
|
1519
|
+
X = cp.asarray(X)
|
|
1520
|
+
y = cp.asarray(y)
|
|
1521
|
+
|
|
1522
|
+
if sample_weight is not None:
|
|
1523
|
+
sample_weight = cp.asarray(sample_weight)
|
|
1524
|
+
sqrt_sw = cp.sqrt(sample_weight)
|
|
1525
|
+
X = X * sqrt_sw[:, cp.newaxis]
|
|
1526
|
+
y = y * sqrt_sw
|
|
1527
|
+
|
|
1528
|
+
# Ensure vector y on GPU
|
|
1529
|
+
y = y.reshape(-1)
|
|
1530
|
+
|
|
1531
|
+
# Center for intercept
|
|
1532
|
+
if self.fit_intercept:
|
|
1533
|
+
X_mean = cp.mean(X, axis=0)
|
|
1534
|
+
y_mean = cp.mean(y)
|
|
1535
|
+
X_centered = X - X_mean
|
|
1536
|
+
y_centered = y - y_mean
|
|
1537
|
+
else:
|
|
1538
|
+
X_centered = X
|
|
1539
|
+
y_mean = cp.array(0.0, dtype=X.dtype)
|
|
1540
|
+
y_centered = y
|
|
1541
|
+
|
|
1542
|
+
# ADMM variables for constraint w=z
|
|
1543
|
+
coef = cp.zeros(n_features, dtype=X.dtype) # w
|
|
1544
|
+
z = cp.zeros(n_features, dtype=X.dtype) # z
|
|
1545
|
+
u = cp.zeros(n_features, dtype=X.dtype) # scaled dual
|
|
1546
|
+
|
|
1547
|
+
# Precompute XtX and Xty
|
|
1548
|
+
XtX = X_centered.T @ X_centered
|
|
1549
|
+
Xty = X_centered.T @ y_centered
|
|
1550
|
+
|
|
1551
|
+
# w-update solves:
|
|
1552
|
+
# (XtX + rho*n*I) w = Xty + rho*n * (z - u)
|
|
1553
|
+
rho = float(self.admm_rho)
|
|
1554
|
+
if rho <= 0:
|
|
1555
|
+
raise ValueError("admm_rho must be > 0")
|
|
1556
|
+
|
|
1557
|
+
lhs = XtX + (rho * n_samples) * cp.eye(n_features, dtype=X.dtype)
|
|
1558
|
+
|
|
1559
|
+
# Pre-factorize once
|
|
1560
|
+
Lmat = cp.linalg.cholesky(lhs)
|
|
1561
|
+
|
|
1562
|
+
def solve_w(rhs):
|
|
1563
|
+
# Solve Lmat @ (Lmat.T @ w) = rhs
|
|
1564
|
+
tmp = cpx_linalg.solve_triangular(Lmat, rhs, lower=True)
|
|
1565
|
+
return cpx_linalg.solve_triangular(Lmat.T, tmp, lower=False)
|
|
1566
|
+
|
|
1567
|
+
thresh = self.alpha / rho
|
|
1568
|
+
|
|
1569
|
+
for iteration in range(self.max_iter):
|
|
1570
|
+
coef_old = coef
|
|
1571
|
+
|
|
1572
|
+
rhs = Xty + (rho * n_samples) * (z - u)
|
|
1573
|
+
coef = solve_w(rhs)
|
|
1574
|
+
|
|
1575
|
+
# z-update (prox of l1)
|
|
1576
|
+
z_old = z
|
|
1577
|
+
z = self._soft_threshold_cupy(coef + u, thresh)
|
|
1578
|
+
|
|
1579
|
+
# dual update
|
|
1580
|
+
u = u + (coef - z)
|
|
1581
|
+
|
|
1582
|
+
# Convergence test
|
|
1583
|
+
if self.stopping == "kkt":
|
|
1584
|
+
grad_sse = (XtX @ coef - Xty) / n_samples
|
|
1585
|
+
violation = cp.max(cp.maximum(cp.abs(grad_sse) - self.alpha, 0.0))
|
|
1586
|
+
if violation < self.tol:
|
|
1587
|
+
self.n_iter_ = iteration + 1
|
|
1588
|
+
break
|
|
1589
|
+
else:
|
|
1590
|
+
# Legacy stopping: coefficient delta
|
|
1591
|
+
if cp.sum(cp.abs(coef - coef_old)) < self.tol:
|
|
1592
|
+
self.n_iter_ = iteration + 1
|
|
1593
|
+
break
|
|
1594
|
+
z = z # keep for clarity
|
|
1595
|
+
else:
|
|
1596
|
+
self.n_iter_ = self.max_iter
|
|
1597
|
+
|
|
1598
|
+
# Build full coefficients and (optionally) residuals for inference/R^2
|
|
1599
|
+
if self.fit_intercept:
|
|
1600
|
+
intercept_gpu = y_mean - X_mean @ coef
|
|
1601
|
+
coef_full = cp.concatenate([intercept_gpu.reshape(1), coef])
|
|
1602
|
+
X_design = cp.concatenate([cp.ones((n_samples, 1), dtype=X.dtype), X], axis=1)
|
|
1603
|
+
else:
|
|
1604
|
+
coef_full = coef
|
|
1605
|
+
X_design = X
|
|
1606
|
+
|
|
1607
|
+
coef_full_np = coef_full.get()
|
|
1608
|
+
if self.fit_intercept:
|
|
1609
|
+
self.intercept_ = float(coef_full_np[0])
|
|
1610
|
+
self.coef_ = coef_full_np[1:]
|
|
1611
|
+
self._params = coef_full_np
|
|
1612
|
+
else:
|
|
1613
|
+
self.intercept_ = 0.0
|
|
1614
|
+
self.coef_ = coef_full_np
|
|
1615
|
+
self._params = coef_full_np
|
|
1616
|
+
|
|
1617
|
+
df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
|
|
1618
|
+
self._df_resid = df_resid
|
|
1619
|
+
|
|
1620
|
+
if self.compute_inference:
|
|
1621
|
+
y_pred = X_design @ coef_full
|
|
1622
|
+
resid = y - y_pred
|
|
1623
|
+
if df_resid > 0:
|
|
1624
|
+
scale = cp.sum(resid ** 2) / df_resid
|
|
1625
|
+
self._scale = float(scale.get()) if not cp.isnan(scale) else np.nan
|
|
1626
|
+
else:
|
|
1627
|
+
self._scale = np.nan
|
|
1628
|
+
scale = cp.nan
|
|
1629
|
+
|
|
1630
|
+
if self.inference_method == "gpu_ols_inference":
|
|
1631
|
+
# Keep the inference path on GPU and transfer only small vectors.
|
|
1632
|
+
XtX_inf = X_design.T @ X_design
|
|
1633
|
+
try:
|
|
1634
|
+
XtX_inv = cp.linalg.inv(XtX_inf)
|
|
1635
|
+
except Exception:
|
|
1636
|
+
XtX_inv = cp.linalg.pinv(XtX_inf)
|
|
1637
|
+
|
|
1638
|
+
bse_gpu = cp.sqrt(scale * cp.diag(XtX_inv))
|
|
1639
|
+
params_gpu = coef_full
|
|
1640
|
+
tvalues_gpu = params_gpu / (bse_gpu + 1e-30)
|
|
1641
|
+
pvalues_gpu = cp.minimum(1.0, 2.0 * t.sf(cp.abs(tvalues_gpu), df=df_resid))
|
|
1642
|
+
|
|
1643
|
+
alpha = 0.05
|
|
1644
|
+
t_crit_gpu = t.ppf(1.0 - alpha / 2.0, df=df_resid)
|
|
1645
|
+
margin_gpu = t_crit_gpu * bse_gpu
|
|
1646
|
+
conf_int_gpu = cp.stack([params_gpu - margin_gpu, params_gpu + margin_gpu], axis=1)
|
|
1647
|
+
|
|
1648
|
+
self._bse = cp.asnumpy(bse_gpu)
|
|
1649
|
+
self._tvalues = cp.asnumpy(tvalues_gpu)
|
|
1650
|
+
self._pvalues = cp.asnumpy(pvalues_gpu)
|
|
1651
|
+
self._conf_int = cp.asnumpy(conf_int_gpu)
|
|
1652
|
+
|
|
1653
|
+
y_mean_gpu = cp.mean(y)
|
|
1654
|
+
ss_tot = cp.sum((y - y_mean_gpu) ** 2)
|
|
1655
|
+
ss_res = cp.sum(resid ** 2)
|
|
1656
|
+
self._rsquared_gpu = float(cp.asnumpy(1 - ss_res / ss_tot)) if ss_tot > 0 else 0.0
|
|
1657
|
+
|
|
1658
|
+
self._resid = None
|
|
1659
|
+
self._X_design = None
|
|
1660
|
+
elif self.inference_method == "debiased":
|
|
1661
|
+
self._compute_inference_debiased_gpu(X, y, coef)
|
|
1662
|
+
|
|
1663
|
+
y_mean_gpu = cp.mean(y)
|
|
1664
|
+
ss_tot = cp.sum((y - y_mean_gpu) ** 2)
|
|
1665
|
+
ss_res = cp.sum(resid ** 2)
|
|
1666
|
+
self._rsquared_gpu = float(cp.asnumpy(1 - ss_res / ss_tot)) if ss_tot > 0 else 0.0
|
|
1667
|
+
|
|
1668
|
+
self._resid = None
|
|
1669
|
+
self._X_design = None
|
|
1670
|
+
else:
|
|
1671
|
+
raise NotImplementedError(
|
|
1672
|
+
f"Lasso inference_method='{self.inference_method}' is not implemented "
|
|
1673
|
+
"for CuPy without CPU fallback."
|
|
1674
|
+
)
|
|
1675
|
+
else:
|
|
1676
|
+
self._scale = np.nan
|
|
1677
|
+
self._resid = None
|
|
1678
|
+
self._X_design = None
|
|
1679
|
+
self._rsquared_gpu = None
|
|
1680
|
+
|
|
1681
|
+
# Drop large temporaries early (before optional pool cleanup).
|
|
1682
|
+
try:
|
|
1683
|
+
del X_design
|
|
1684
|
+
except Exception:
|
|
1685
|
+
pass
|
|
1686
|
+
try:
|
|
1687
|
+
del resid
|
|
1688
|
+
except Exception:
|
|
1689
|
+
pass
|
|
1690
|
+
try:
|
|
1691
|
+
del XtX
|
|
1692
|
+
except Exception:
|
|
1693
|
+
pass
|
|
1694
|
+
try:
|
|
1695
|
+
del Xty
|
|
1696
|
+
except Exception:
|
|
1697
|
+
pass
|
|
1698
|
+
try:
|
|
1699
|
+
del X_centered
|
|
1700
|
+
except Exception:
|
|
1701
|
+
pass
|
|
1702
|
+
try:
|
|
1703
|
+
del y_centered
|
|
1704
|
+
except Exception:
|
|
1705
|
+
pass
|
|
1706
|
+
try:
|
|
1707
|
+
del y_pred
|
|
1708
|
+
except Exception:
|
|
1709
|
+
pass
|
|
1710
|
+
try:
|
|
1711
|
+
del coef_full
|
|
1712
|
+
except Exception:
|
|
1713
|
+
pass
|
|
1714
|
+
try:
|
|
1715
|
+
del lhs
|
|
1716
|
+
except Exception:
|
|
1717
|
+
pass
|
|
1718
|
+
try:
|
|
1719
|
+
del Lmat
|
|
1720
|
+
except Exception:
|
|
1721
|
+
pass
|
|
1722
|
+
self._cleanup_cuda_memory()
|
|
1723
|
+
|
|
1724
|
+
def _compute_inference(self):
|
|
1725
|
+
"""Compute standard errors, t-stats, p-values."""
|
|
1726
|
+
if self.inference_method == "bootstrap":
|
|
1727
|
+
return self._compute_inference_bootstrap()
|
|
1728
|
+
if self.inference_method == "debiased":
|
|
1729
|
+
return self._compute_inference_debiased()
|
|
1730
|
+
if self.inference_method == "gpu_ols_inference":
|
|
1731
|
+
# Inference already computed on GPU in _fit_gpu().
|
|
1732
|
+
return
|
|
1733
|
+
if self._X_design is None or self._scale is None or np.isnan(self._scale):
|
|
1734
|
+
return
|
|
1735
|
+
|
|
1736
|
+
X = self._X_design
|
|
1737
|
+
|
|
1738
|
+
try:
|
|
1739
|
+
XtX_inv = np.linalg.inv(X.T @ X)
|
|
1740
|
+
except np.linalg.LinAlgError:
|
|
1741
|
+
XtX_inv = np.linalg.pinv(X.T @ X)
|
|
1742
|
+
|
|
1743
|
+
self._bse = np.sqrt(self._scale * np.diag(XtX_inv))
|
|
1744
|
+
self._tvalues = self._params / self._bse
|
|
1745
|
+
self._pvalues = 2 * (1 - stats.t.cdf(np.abs(self._tvalues), self._df_resid))
|
|
1746
|
+
|
|
1747
|
+
alpha = 0.05
|
|
1748
|
+
t_crit = stats.t.ppf(1 - alpha/2, self._df_resid)
|
|
1749
|
+
self._conf_int = np.column_stack([
|
|
1750
|
+
self._params - t_crit * self._bse,
|
|
1751
|
+
self._params + t_crit * self._bse
|
|
1752
|
+
])
|
|
1753
|
+
|
|
1754
|
+
def _compute_inference_bootstrap(self) -> None:
|
|
1755
|
+
"""
|
|
1756
|
+
Bootstrap inference for Lasso via residual resampling.
|
|
1757
|
+
|
|
1758
|
+
Notes
|
|
1759
|
+
-----
|
|
1760
|
+
This is more robust than the naive OLS-based inference, but it is still
|
|
1761
|
+
not full "post-selection inference" for Lasso.
|
|
1762
|
+
"""
|
|
1763
|
+
if self._X_design is None or self._resid is None or self._y is None:
|
|
1764
|
+
return
|
|
1765
|
+
|
|
1766
|
+
if self.n_bootstrap <= 0:
|
|
1767
|
+
return
|
|
1768
|
+
|
|
1769
|
+
rng = np.random.default_rng(self.bootstrap_random_state)
|
|
1770
|
+
X = self._X_design
|
|
1771
|
+
y = self._y
|
|
1772
|
+
y_pred = y - self._resid
|
|
1773
|
+
resid = self._resid
|
|
1774
|
+
|
|
1775
|
+
params_dim = self._params.shape[0]
|
|
1776
|
+
boot_params = np.zeros((self.n_bootstrap, params_dim), dtype=float)
|
|
1777
|
+
|
|
1778
|
+
# Precompute Lipschitz constant if needed for CPU FISTA.
|
|
1779
|
+
lipschitz_L = self.lipschitz_L
|
|
1780
|
+
if self.cpu_solver == "fista" and lipschitz_L is None:
|
|
1781
|
+
# L = lambda_max(Xc^T Xc) / n for centered design
|
|
1782
|
+
X_nopen = X[:, 1:] if self.fit_intercept else X
|
|
1783
|
+
X_centered = X_nopen - X_nopen.mean(axis=0, keepdims=True)
|
|
1784
|
+
XtX = X_centered.T @ X_centered
|
|
1785
|
+
eigvals = np.linalg.eigvalsh(XtX)
|
|
1786
|
+
lipschitz_L = float(eigvals[-1] / X_nopen.shape[0])
|
|
1787
|
+
|
|
1788
|
+
for b in range(self.n_bootstrap):
|
|
1789
|
+
eps_star = rng.choice(resid, size=resid.shape[0], replace=True)
|
|
1790
|
+
y_star = y_pred + eps_star
|
|
1791
|
+
|
|
1792
|
+
refit = Lasso(
|
|
1793
|
+
alpha=self.alpha,
|
|
1794
|
+
fit_intercept=self.fit_intercept,
|
|
1795
|
+
max_iter=self.max_iter,
|
|
1796
|
+
tol=self.tol,
|
|
1797
|
+
stopping=self.stopping,
|
|
1798
|
+
inference_method="cpu_ols_inference",
|
|
1799
|
+
n_bootstrap=0,
|
|
1800
|
+
bootstrap_random_state=None,
|
|
1801
|
+
device="cpu",
|
|
1802
|
+
compute_inference=False,
|
|
1803
|
+
solver=self.solver,
|
|
1804
|
+
cpu_solver=self.cpu_solver,
|
|
1805
|
+
lipschitz_L=lipschitz_L,
|
|
1806
|
+
admm_rho=self.admm_rho,
|
|
1807
|
+
)
|
|
1808
|
+
|
|
1809
|
+
# Refit expects raw X (without intercept column).
|
|
1810
|
+
if self.fit_intercept:
|
|
1811
|
+
X_refit = X[:, 1:]
|
|
1812
|
+
else:
|
|
1813
|
+
X_refit = X
|
|
1814
|
+
|
|
1815
|
+
refit.fit(X_refit, y_star)
|
|
1816
|
+
boot_params[b, :] = refit._params
|
|
1817
|
+
|
|
1818
|
+
# Standard errors and bootstrap-based p-values/CI.
|
|
1819
|
+
self._bse = np.std(boot_params, axis=0, ddof=1)
|
|
1820
|
+
self._params = np.asarray(self._params, dtype=float)
|
|
1821
|
+
|
|
1822
|
+
# Two-sided p-values using sign-change probability.
|
|
1823
|
+
pvalues = np.zeros(params_dim, dtype=float)
|
|
1824
|
+
for i in range(params_dim):
|
|
1825
|
+
coef_b = boot_params[:, i]
|
|
1826
|
+
p_lower = np.mean(coef_b <= 0.0)
|
|
1827
|
+
p_upper = np.mean(coef_b >= 0.0)
|
|
1828
|
+
p = 2.0 * min(p_lower, p_upper)
|
|
1829
|
+
pvalues[i] = min(p, 1.0)
|
|
1830
|
+
self._pvalues = pvalues
|
|
1831
|
+
|
|
1832
|
+
# Percentile confidence intervals.
|
|
1833
|
+
lower_q = (0.05 / 2.0) * 1.0
|
|
1834
|
+
upper_q = 1.0 - (0.05 / 2.0) * 1.0
|
|
1835
|
+
self._conf_int = np.column_stack([
|
|
1836
|
+
np.quantile(boot_params, lower_q, axis=0),
|
|
1837
|
+
np.quantile(boot_params, upper_q, axis=0),
|
|
1838
|
+
])
|
|
1839
|
+
|
|
1840
|
+
# t-stats (approx) from bootstrap SE.
|
|
1841
|
+
self._tvalues = self._params / (self._bse + 1e-30)
|
|
1842
|
+
|
|
1843
|
+
def _compute_inference_debiased(self) -> None:
|
|
1844
|
+
"""Debiased Lasso inference (Javanmard-Montanari / Zhang-Zhang).
|
|
1845
|
+
|
|
1846
|
+
Constructs the decorrelation matrix M via node-wise Lasso,
|
|
1847
|
+
then computes the debiased estimator, standard errors,
|
|
1848
|
+
z-statistics, p-values, and per-coefficient (marginal)
|
|
1849
|
+
confidence intervals.
|
|
1850
|
+
"""
|
|
1851
|
+
if self._X_design is None or self._resid is None:
|
|
1852
|
+
return
|
|
1853
|
+
|
|
1854
|
+
if self.fit_intercept:
|
|
1855
|
+
X = self._X_design[:, 1:]
|
|
1856
|
+
else:
|
|
1857
|
+
X = self._X_design
|
|
1858
|
+
|
|
1859
|
+
n, p = X.shape
|
|
1860
|
+
coef = self.coef_.copy()
|
|
1861
|
+
|
|
1862
|
+
Sigma_hat = X.T @ X / n
|
|
1863
|
+
resid_lasso = self._resid
|
|
1864
|
+
|
|
1865
|
+
# --- noise variance: sigma^2 = RSS / (n - s_hat) ---
|
|
1866
|
+
s_hat = int(np.sum(np.abs(coef) > 0))
|
|
1867
|
+
denom = max(n - s_hat, 1)
|
|
1868
|
+
sigma2 = np.sum(resid_lasso ** 2) / denom
|
|
1869
|
+
|
|
1870
|
+
# --- node-wise Lasso to build M (p x p), with cross-fit cache ---
|
|
1871
|
+
lam_nw = np.sqrt(2.0 * np.log(max(p, 2)) / n)
|
|
1872
|
+
m_cache_key = _debiased_m_key_from_numpy_design(
|
|
1873
|
+
X,
|
|
1874
|
+
n=n,
|
|
1875
|
+
p=p,
|
|
1876
|
+
lam_nw=lam_nw,
|
|
1877
|
+
tol=float(self.tol),
|
|
1878
|
+
)
|
|
1879
|
+
M_cached = _debiased_m_cache_get(m_cache_key)
|
|
1880
|
+
if M_cached is not None:
|
|
1881
|
+
M = np.asarray(M_cached, dtype=X.dtype)
|
|
1882
|
+
else:
|
|
1883
|
+
M = np.zeros((p, p), dtype=X.dtype)
|
|
1884
|
+
for j in range(p):
|
|
1885
|
+
cols = np.concatenate([np.arange(0, j), np.arange(j + 1, p)])
|
|
1886
|
+
X_minus_j = X[:, cols]
|
|
1887
|
+
x_j = X[:, j]
|
|
1888
|
+
|
|
1889
|
+
nw = Lasso(
|
|
1890
|
+
alpha=lam_nw,
|
|
1891
|
+
fit_intercept=False,
|
|
1892
|
+
max_iter=500,
|
|
1893
|
+
tol=1e-5,
|
|
1894
|
+
device="cpu",
|
|
1895
|
+
cpu_solver="fista",
|
|
1896
|
+
compute_inference=False,
|
|
1897
|
+
)
|
|
1898
|
+
nw.fit(X_minus_j, x_j)
|
|
1899
|
+
gamma_j = nw.coef_
|
|
1900
|
+
|
|
1901
|
+
z_j = x_j - X_minus_j @ gamma_j
|
|
1902
|
+
C_j = z_j @ x_j / n
|
|
1903
|
+
|
|
1904
|
+
if abs(C_j) < 1e-30:
|
|
1905
|
+
M[j, j] = 1.0
|
|
1906
|
+
continue
|
|
1907
|
+
|
|
1908
|
+
M[j, j] = 1.0 / C_j
|
|
1909
|
+
M[j, cols] = -gamma_j / C_j
|
|
1910
|
+
_debiased_m_cache_put(m_cache_key, np.asarray(M, dtype=np.float64))
|
|
1911
|
+
|
|
1912
|
+
# --- debiased estimates ---
|
|
1913
|
+
theta_db = coef + (M @ X.T @ resid_lasso) / n
|
|
1914
|
+
self._debiased_M_cpu = M
|
|
1915
|
+
|
|
1916
|
+
# --- covariance and standard errors ---
|
|
1917
|
+
V = M @ Sigma_hat @ M.T
|
|
1918
|
+
se = np.sqrt(sigma2 * np.diag(V) / n)
|
|
1919
|
+
|
|
1920
|
+
z_stats = theta_db / (se + 1e-30)
|
|
1921
|
+
pvalues = 2.0 * (1.0 - _norm_dist.cdf(np.abs(z_stats)))
|
|
1922
|
+
|
|
1923
|
+
alpha_ci = 0.05
|
|
1924
|
+
z_crit = _norm_dist.ppf(1.0 - alpha_ci / 2.0)
|
|
1925
|
+
ci = np.column_stack([theta_db - z_crit * se, theta_db + z_crit * se])
|
|
1926
|
+
|
|
1927
|
+
if self.fit_intercept:
|
|
1928
|
+
# Intercept SE via OLS formula: sigma * sqrt([1/n + xbar' (X'X)^-1 xbar])
|
|
1929
|
+
X_full = self._X_design
|
|
1930
|
+
try:
|
|
1931
|
+
XtX_inv = np.linalg.inv(X_full.T @ X_full)
|
|
1932
|
+
except np.linalg.LinAlgError:
|
|
1933
|
+
XtX_inv = np.linalg.pinv(X_full.T @ X_full)
|
|
1934
|
+
se_intercept = np.sqrt(sigma2 * XtX_inv[0, 0])
|
|
1935
|
+
z_intercept = self.intercept_ / (se_intercept + 1e-30)
|
|
1936
|
+
p_intercept = 2.0 * (1.0 - _norm_dist.cdf(np.abs(z_intercept)))
|
|
1937
|
+
ci_intercept = np.array([
|
|
1938
|
+
self.intercept_ - z_crit * se_intercept,
|
|
1939
|
+
self.intercept_ + z_crit * se_intercept,
|
|
1940
|
+
])
|
|
1941
|
+
|
|
1942
|
+
self._bse = np.concatenate([[se_intercept], se])
|
|
1943
|
+
self._tvalues = np.concatenate([[z_intercept], z_stats])
|
|
1944
|
+
self._pvalues = np.concatenate([[p_intercept], pvalues])
|
|
1945
|
+
self._conf_int = np.vstack([ci_intercept[np.newaxis, :], ci])
|
|
1946
|
+
self._params = np.concatenate([[self.intercept_], theta_db])
|
|
1947
|
+
else:
|
|
1948
|
+
self._bse = se
|
|
1949
|
+
self._tvalues = z_stats
|
|
1950
|
+
self._pvalues = pvalues
|
|
1951
|
+
self._conf_int = ci
|
|
1952
|
+
self._params = theta_db
|
|
1953
|
+
|
|
1954
|
+
def _compute_inference_debiased_gpu(self, X_gpu, y_gpu, coef_gpu):
|
|
1955
|
+
"""GPU path for debiased Lasso inference.
|
|
1956
|
+
|
|
1957
|
+
Parameters
|
|
1958
|
+
----------
|
|
1959
|
+
X_gpu : cupy.ndarray, shape (n, p)
|
|
1960
|
+
Raw feature matrix on GPU (no intercept column).
|
|
1961
|
+
y_gpu : cupy.ndarray, shape (n,)
|
|
1962
|
+
Response on GPU.
|
|
1963
|
+
coef_gpu : cupy.ndarray, shape (p,)
|
|
1964
|
+
Lasso coefficients on GPU (no intercept).
|
|
1965
|
+
"""
|
|
1966
|
+
import cupy as cp
|
|
1967
|
+
|
|
1968
|
+
n, p = X_gpu.shape
|
|
1969
|
+
Sigma_hat = X_gpu.T @ X_gpu / n
|
|
1970
|
+
|
|
1971
|
+
resid_lasso = y_gpu - X_gpu @ coef_gpu
|
|
1972
|
+
if self.fit_intercept:
|
|
1973
|
+
resid_lasso = resid_lasso - cp.mean(y_gpu) + cp.mean(X_gpu, axis=0) @ coef_gpu
|
|
1974
|
+
|
|
1975
|
+
s_hat_gpu = cp.sum(cp.abs(coef_gpu) > 0).astype(cp.float64)
|
|
1976
|
+
denom_gpu = cp.maximum(1.0, float(n) - s_hat_gpu)
|
|
1977
|
+
sigma2_gpu = cp.asarray(cp.sum(resid_lasso ** 2) / denom_gpu, dtype=cp.float64)
|
|
1978
|
+
|
|
1979
|
+
lam_nw = float(np.sqrt(2.0 * np.log(max(p, 2)) / n))
|
|
1980
|
+
alpha_nw = np.asarray([lam_nw], dtype=np.float64)
|
|
1981
|
+
tiny = X_gpu.dtype.type(1e-30)
|
|
1982
|
+
zero = X_gpu.dtype.type(0.0)
|
|
1983
|
+
one = X_gpu.dtype.type(1.0)
|
|
1984
|
+
|
|
1985
|
+
# Keep node-wise Lasso solves on GPU to avoid per-feature host round-trips.
|
|
1986
|
+
x_hasher = hashlib.blake2b(digest_size=32)
|
|
1987
|
+
x_hasher.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
|
|
1988
|
+
x_hasher.update(str(X_gpu.dtype).encode("utf-8"))
|
|
1989
|
+
x_hasher.update(np.asarray([float(lam_nw), float(self.tol)], dtype=np.float64).tobytes())
|
|
1990
|
+
row_chunk = max(1, min(int(n), _LASSO_DEBIASED_M_GPU_HASH_ROW_CHUNK))
|
|
1991
|
+
for start in range(0, int(n), row_chunk):
|
|
1992
|
+
stop = min(int(n), start + row_chunk)
|
|
1993
|
+
x_chunk = cp.asnumpy(X_gpu[start:stop])
|
|
1994
|
+
x_hasher.update(x_chunk.tobytes())
|
|
1995
|
+
m_cache_key = x_hasher.hexdigest()
|
|
1996
|
+
M_cached = _debiased_m_cache_get(m_cache_key)
|
|
1997
|
+
if M_cached is not None:
|
|
1998
|
+
M = cp.asarray(M_cached, dtype=X_gpu.dtype)
|
|
1999
|
+
else:
|
|
2000
|
+
M = cp.zeros((p, p), dtype=X_gpu.dtype)
|
|
2001
|
+
# Reuse full Gram to avoid repeated X_minus_j.T @ X_minus_j products.
|
|
2002
|
+
XtX_full = X_gpu.T @ X_gpu
|
|
2003
|
+
Sigma_diag = cp.diag(Sigma_hat)
|
|
2004
|
+
n_samp_vec_dtype = np.float64
|
|
2005
|
+
|
|
2006
|
+
# Batch node-wise problems so GPU can process many j's together.
|
|
2007
|
+
try:
|
|
2008
|
+
free_mem, _ = cp.cuda.Device().mem_info
|
|
2009
|
+
bytes_per_fold = int(max(8, (p - 1) * (p - 1) * 8 * 2))
|
|
2010
|
+
chunk_size = int(max(4, min(64, free_mem // max(bytes_per_fold, 1))))
|
|
2011
|
+
except Exception:
|
|
2012
|
+
chunk_size = 16
|
|
2013
|
+
chunk_size = max(4, min(int(p), chunk_size))
|
|
2014
|
+
|
|
2015
|
+
for j0 in range(0, p, chunk_size):
|
|
2016
|
+
j1 = min(p, j0 + chunk_size)
|
|
2017
|
+
bsz = j1 - j0
|
|
2018
|
+
j_batch = cp.arange(j0, j1, dtype=cp.int32)
|
|
2019
|
+
if int(j_batch.size) == 0:
|
|
2020
|
+
continue
|
|
2021
|
+
|
|
2022
|
+
# Build per-j "all except j" column index matrix of shape (bsz, p-1).
|
|
2023
|
+
base = cp.arange(p - 1, dtype=cp.int32).reshape(1, -1)
|
|
2024
|
+
cols_batch = base + (base >= j_batch.reshape(-1, 1))
|
|
2025
|
+
|
|
2026
|
+
# Gather batched Gram/Xty blocks.
|
|
2027
|
+
XtX_batch = XtX_full[
|
|
2028
|
+
cols_batch[:, :, cp.newaxis],
|
|
2029
|
+
cols_batch[:, cp.newaxis, :],
|
|
2030
|
+
]
|
|
2031
|
+
Xty_batch = XtX_full[cols_batch, j_batch.reshape(-1, 1)].reshape(bsz, p - 1)
|
|
2032
|
+
|
|
2033
|
+
coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram(
|
|
2034
|
+
XtX_batch,
|
|
2035
|
+
Xty_batch,
|
|
2036
|
+
n_samples_vec=np.full((bsz,), float(n), dtype=n_samp_vec_dtype),
|
|
2037
|
+
alphas_desc=alpha_nw,
|
|
2038
|
+
max_iter=500,
|
|
2039
|
+
tol=1e-5,
|
|
2040
|
+
stopping="coef_delta",
|
|
2041
|
+
lipschitz_L=None,
|
|
2042
|
+
check_every=8,
|
|
2043
|
+
)
|
|
2044
|
+
gamma_batch = cp.asarray(coefs_batch_desc[:, 0, :], dtype=X_gpu.dtype)
|
|
2045
|
+
|
|
2046
|
+
# C_j = Sigma_jj - Sigma_{j,-j} gamma_j
|
|
2047
|
+
sigma_j_cols = Sigma_hat[j_batch[:, cp.newaxis], cols_batch]
|
|
2048
|
+
C_batch = Sigma_diag[j_batch] - cp.sum(sigma_j_cols * gamma_batch, axis=1)
|
|
2049
|
+
|
|
2050
|
+
small_c = cp.abs(C_batch) < tiny
|
|
2051
|
+
inv_c = cp.where(small_c, zero, one / C_batch)
|
|
2052
|
+
M[j_batch, j_batch] = cp.where(small_c, one, inv_c)
|
|
2053
|
+
M[j_batch[:, cp.newaxis], cols_batch] = -gamma_batch * inv_c.reshape(-1, 1)
|
|
2054
|
+
|
|
2055
|
+
del XtX_batch
|
|
2056
|
+
del Xty_batch
|
|
2057
|
+
del coefs_batch_desc
|
|
2058
|
+
del gamma_batch
|
|
2059
|
+
del sigma_j_cols
|
|
2060
|
+
_debiased_m_cache_put(m_cache_key, cp.asnumpy(M))
|
|
2061
|
+
|
|
2062
|
+
# Recompute full residual from the original fit
|
|
2063
|
+
if self.fit_intercept:
|
|
2064
|
+
y_pred = X_gpu @ coef_gpu + cp.asarray(self.intercept_, dtype=X_gpu.dtype)
|
|
2065
|
+
else:
|
|
2066
|
+
y_pred = X_gpu @ coef_gpu
|
|
2067
|
+
resid_full = y_gpu - y_pred
|
|
2068
|
+
|
|
2069
|
+
theta_db = coef_gpu + (M @ X_gpu.T @ resid_full) / n
|
|
2070
|
+
|
|
2071
|
+
V = M @ Sigma_hat @ M.T
|
|
2072
|
+
se = cp.sqrt(sigma2_gpu * cp.diag(V) / n)
|
|
2073
|
+
|
|
2074
|
+
z_stats = theta_db / (se + 1e-30)
|
|
2075
|
+
pvalues = cp.minimum(1.0, 2.0 * norm.sf(cp.abs(z_stats)))
|
|
2076
|
+
|
|
2077
|
+
alpha_ci = 0.05
|
|
2078
|
+
z_crit = norm.ppf(1.0 - alpha_ci / 2.0)
|
|
2079
|
+
ci = cp.stack([theta_db - z_crit * se, theta_db + z_crit * se], axis=1)
|
|
2080
|
+
|
|
2081
|
+
if self.fit_intercept:
|
|
2082
|
+
X_full = cp.concatenate(
|
|
2083
|
+
[cp.ones((n, 1), dtype=X_gpu.dtype), X_gpu], axis=1
|
|
2084
|
+
)
|
|
2085
|
+
XtX_full = X_full.T @ X_full
|
|
2086
|
+
try:
|
|
2087
|
+
XtX_inv = cp.linalg.inv(XtX_full)
|
|
2088
|
+
except Exception:
|
|
2089
|
+
XtX_inv = cp.linalg.pinv(XtX_full)
|
|
2090
|
+
se_intercept = cp.sqrt(sigma2_gpu * XtX_inv[0, 0])
|
|
2091
|
+
intercept_gpu = cp.asarray(self.intercept_, dtype=cp.float64)
|
|
2092
|
+
z_intercept = intercept_gpu / (se_intercept + 1e-30)
|
|
2093
|
+
p_intercept = cp.minimum(1.0, 2.0 * norm.sf(cp.abs(z_intercept).reshape(1)))
|
|
2094
|
+
ci_intercept = cp.stack([
|
|
2095
|
+
intercept_gpu - z_crit * se_intercept,
|
|
2096
|
+
intercept_gpu + z_crit * se_intercept,
|
|
2097
|
+
]).reshape(1, 2)
|
|
2098
|
+
|
|
2099
|
+
bse_gpu = cp.concatenate([se_intercept.reshape(1), se])
|
|
2100
|
+
tvalues_gpu = cp.concatenate([z_intercept.reshape(1), z_stats])
|
|
2101
|
+
pvalues_gpu = cp.concatenate([p_intercept.reshape(1), pvalues])
|
|
2102
|
+
conf_int_gpu = cp.concatenate([ci_intercept, ci], axis=0)
|
|
2103
|
+
params_gpu = cp.concatenate([intercept_gpu.reshape(1), theta_db])
|
|
2104
|
+
else:
|
|
2105
|
+
bse_gpu = se
|
|
2106
|
+
tvalues_gpu = z_stats
|
|
2107
|
+
pvalues_gpu = pvalues
|
|
2108
|
+
conf_int_gpu = ci
|
|
2109
|
+
params_gpu = theta_db
|
|
2110
|
+
|
|
2111
|
+
if self.enable_simultaneous_inference:
|
|
2112
|
+
# GPU-native simultaneous CI via max-|Z| multiplier bootstrap.
|
|
2113
|
+
param_target_idx_np = self._get_simultaneous_target_indices(
|
|
2114
|
+
int(params_gpu.shape[0])
|
|
2115
|
+
)
|
|
2116
|
+
param_target_idx_gpu = cp.asarray(param_target_idx_np, dtype=cp.int32)
|
|
2117
|
+
if param_target_idx_gpu.size == 0:
|
|
2118
|
+
raise RuntimeError(
|
|
2119
|
+
"No coefficients selected for simultaneous inference target set."
|
|
2120
|
+
)
|
|
2121
|
+
|
|
2122
|
+
feature_offset = 1 if self.fit_intercept else 0
|
|
2123
|
+
feature_target_gpu = param_target_idx_gpu - feature_offset
|
|
2124
|
+
feature_target_gpu = feature_target_gpu[feature_target_gpu >= 0]
|
|
2125
|
+
if feature_target_gpu.size == 0:
|
|
2126
|
+
raise RuntimeError(
|
|
2127
|
+
"No feature coefficients selected for simultaneous inference target set."
|
|
2128
|
+
)
|
|
2129
|
+
|
|
2130
|
+
se_feat_gpu = se
|
|
2131
|
+
B = int(self.simultaneous_n_bootstrap)
|
|
2132
|
+
rng = cp.random.RandomState(self.simultaneous_random_state)
|
|
2133
|
+
se_target_gpu = cp.take(se_feat_gpu, feature_target_gpu)
|
|
2134
|
+
M_target = cp.take(M, feature_target_gpu, axis=0)
|
|
2135
|
+
# Run bootstrap in one shot when memory allows to reduce kernel-launch overhead.
|
|
2136
|
+
try:
|
|
2137
|
+
xi = rng.standard_normal(size=(B, n)).astype(cp.float64, copy=False)
|
|
2138
|
+
weighted = xi * resid_full.reshape(1, -1)
|
|
2139
|
+
score_target = (weighted @ X_gpu) @ M_target.T / float(max(n, 1))
|
|
2140
|
+
z_star_target = score_target / (se_target_gpu.reshape(1, -1) + 1e-30)
|
|
2141
|
+
max_stats_gpu = cp.max(cp.abs(z_star_target), axis=1)
|
|
2142
|
+
except Exception:
|
|
2143
|
+
free_mem, _ = cp.cuda.Device().mem_info
|
|
2144
|
+
bytes_per_row = max(8 * (3 * n + 2 * p + 64), 8)
|
|
2145
|
+
est_chunk = int(max(64, min(4096, free_mem // bytes_per_row)))
|
|
2146
|
+
chunk = min(B, max(64, est_chunk))
|
|
2147
|
+
max_stats_gpu = cp.empty((B,), dtype=cp.float64)
|
|
2148
|
+
filled = 0
|
|
2149
|
+
while filled < B:
|
|
2150
|
+
bsz = min(chunk, B - filled)
|
|
2151
|
+
xi = rng.standard_normal(size=(bsz, n)).astype(cp.float64, copy=False)
|
|
2152
|
+
weighted = xi * resid_full.reshape(1, -1)
|
|
2153
|
+
score_target = (weighted @ X_gpu) @ M_target.T / float(max(n, 1))
|
|
2154
|
+
z_star_target = score_target / (se_target_gpu.reshape(1, -1) + 1e-30)
|
|
2155
|
+
max_stats_gpu[filled : filled + bsz] = cp.max(
|
|
2156
|
+
cp.abs(z_star_target), axis=1
|
|
2157
|
+
)
|
|
2158
|
+
filled += bsz
|
|
2159
|
+
|
|
2160
|
+
critical_gpu = cp.quantile(
|
|
2161
|
+
max_stats_gpu, 1.0 - float(self.simultaneous_alpha)
|
|
2162
|
+
)
|
|
2163
|
+
conf_sim_gpu = cp.array(conf_int_gpu, copy=True)
|
|
2164
|
+
lower_gpu = cp.take(params_gpu, param_target_idx_gpu) - critical_gpu * cp.take(
|
|
2165
|
+
bse_gpu, param_target_idx_gpu
|
|
2166
|
+
)
|
|
2167
|
+
upper_gpu = cp.take(params_gpu, param_target_idx_gpu) + critical_gpu * cp.take(
|
|
2168
|
+
bse_gpu, param_target_idx_gpu
|
|
2169
|
+
)
|
|
2170
|
+
conf_sim_gpu[param_target_idx_gpu, 0] = lower_gpu
|
|
2171
|
+
conf_sim_gpu[param_target_idx_gpu, 1] = upper_gpu
|
|
2172
|
+
|
|
2173
|
+
target_mask = np.zeros(int(params_gpu.shape[0]), dtype=bool)
|
|
2174
|
+
target_mask[param_target_idx_np] = True
|
|
2175
|
+
self._conf_int_simultaneous = cp.asnumpy(conf_sim_gpu)
|
|
2176
|
+
self._simultaneous_enabled = True
|
|
2177
|
+
self._simultaneous_method = self.simultaneous_method
|
|
2178
|
+
self._simultaneous_alpha = float(self.simultaneous_alpha)
|
|
2179
|
+
self._simultaneous_n_bootstrap = B
|
|
2180
|
+
self._simultaneous_critical_value = float(cp.asnumpy(critical_gpu))
|
|
2181
|
+
self._simultaneous_target_mask = target_mask
|
|
2182
|
+
|
|
2183
|
+
self._bse = cp.asnumpy(bse_gpu)
|
|
2184
|
+
self._tvalues = cp.asnumpy(tvalues_gpu)
|
|
2185
|
+
self._pvalues = cp.asnumpy(pvalues_gpu)
|
|
2186
|
+
self._conf_int = cp.asnumpy(conf_int_gpu)
|
|
2187
|
+
self._params = cp.asnumpy(params_gpu)
|
|
2188
|
+
|
|
2189
|
+
def _get_simultaneous_target_indices(self, n_params: int):
|
|
2190
|
+
if self.fit_intercept and (not self.simultaneous_include_intercept):
|
|
2191
|
+
return np.arange(1, n_params, dtype=int)
|
|
2192
|
+
return np.arange(n_params, dtype=int)
|
|
2193
|
+
|
|
2194
|
+
def _compute_simultaneous_inference(self):
|
|
2195
|
+
if not self.enable_simultaneous_inference:
|
|
2196
|
+
return
|
|
2197
|
+
if self._simultaneous_enabled and self._conf_int_simultaneous is not None:
|
|
2198
|
+
return
|
|
2199
|
+
if self.inference_method != "debiased":
|
|
2200
|
+
return
|
|
2201
|
+
if self._params is None or self._bse is None or self._conf_int is None:
|
|
2202
|
+
return
|
|
2203
|
+
if self._X_design is None or self._resid is None:
|
|
2204
|
+
raise RuntimeError(
|
|
2205
|
+
"Simultaneous debiased inference requires accessible design/residual "
|
|
2206
|
+
"state; re-fit with compute_inference=True."
|
|
2207
|
+
)
|
|
2208
|
+
self._compute_simultaneous_ci_maxz_bootstrap()
|
|
2209
|
+
|
|
2210
|
+
def compute_debiased_inference(self):
|
|
2211
|
+
"""Explicitly recompute debiased inference for a fitted model."""
|
|
2212
|
+
self._check_is_fitted()
|
|
2213
|
+
if self.inference_method != "debiased":
|
|
2214
|
+
raise ValueError("compute_debiased_inference requires inference_method='debiased'.")
|
|
2215
|
+
self._compute_inference()
|
|
2216
|
+
return self
|
|
2217
|
+
|
|
2218
|
+
def compute_debiased_inference_(self):
|
|
2219
|
+
"""Deprecated alias for :meth:`compute_debiased_inference`."""
|
|
2220
|
+
warnings.warn(
|
|
2221
|
+
"compute_debiased_inference_ is deprecated and will be removed in a future "
|
|
2222
|
+
"release; use compute_debiased_inference instead.",
|
|
2223
|
+
DeprecationWarning,
|
|
2224
|
+
stacklevel=2,
|
|
2225
|
+
)
|
|
2226
|
+
return self.compute_debiased_inference()
|
|
2227
|
+
|
|
2228
|
+
def compute_simultaneous_inference(self):
|
|
2229
|
+
"""Explicitly (re)compute simultaneous inference for a fitted model."""
|
|
2230
|
+
self._check_is_fitted()
|
|
2231
|
+
if not self.enable_simultaneous_inference:
|
|
2232
|
+
raise ValueError(
|
|
2233
|
+
"compute_simultaneous_inference requires enable_simultaneous_inference=True."
|
|
2234
|
+
)
|
|
2235
|
+
self._compute_simultaneous_inference()
|
|
2236
|
+
return self
|
|
2237
|
+
|
|
2238
|
+
def compute_simultaneous_inference_(self):
|
|
2239
|
+
"""Deprecated alias for :meth:`compute_simultaneous_inference`."""
|
|
2240
|
+
warnings.warn(
|
|
2241
|
+
"compute_simultaneous_inference_ is deprecated and will be removed in a "
|
|
2242
|
+
"future release; use compute_simultaneous_inference instead.",
|
|
2243
|
+
DeprecationWarning,
|
|
2244
|
+
stacklevel=2,
|
|
2245
|
+
)
|
|
2246
|
+
return self.compute_simultaneous_inference()
|
|
2247
|
+
|
|
2248
|
+
def _compute_simultaneous_ci_maxz_bootstrap(self):
|
|
2249
|
+
"""Compute simultaneous CIs using max-|Z| multiplier bootstrap."""
|
|
2250
|
+
# Feature-only design used by debiased estimator.
|
|
2251
|
+
if self.fit_intercept:
|
|
2252
|
+
X = np.asarray(self._X_design[:, 1:], dtype=float)
|
|
2253
|
+
else:
|
|
2254
|
+
X = np.asarray(self._X_design, dtype=float)
|
|
2255
|
+
resid = np.asarray(self._resid, dtype=float).reshape(-1)
|
|
2256
|
+
n, p = X.shape
|
|
2257
|
+
if p == 0:
|
|
2258
|
+
raise RuntimeError("Simultaneous inference requires at least one feature.")
|
|
2259
|
+
|
|
2260
|
+
# Reuse M from debiased inference when available to avoid duplicate node-wise solves.
|
|
2261
|
+
M = self._debiased_M_cpu
|
|
2262
|
+
if M is None or M.shape != (p, p):
|
|
2263
|
+
lam_nw = np.sqrt(2.0 * np.log(max(p, 2)) / max(n, 1))
|
|
2264
|
+
M = np.zeros((p, p), dtype=float)
|
|
2265
|
+
for j in range(p):
|
|
2266
|
+
cols = np.concatenate([np.arange(0, j), np.arange(j + 1, p)])
|
|
2267
|
+
X_minus_j = X[:, cols]
|
|
2268
|
+
x_j = X[:, j]
|
|
2269
|
+
nw = Lasso(
|
|
2270
|
+
alpha=lam_nw,
|
|
2271
|
+
fit_intercept=False,
|
|
2272
|
+
max_iter=500,
|
|
2273
|
+
tol=1e-5,
|
|
2274
|
+
device="cpu",
|
|
2275
|
+
cpu_solver="fista",
|
|
2276
|
+
compute_inference=False,
|
|
2277
|
+
)
|
|
2278
|
+
nw.fit(X_minus_j, x_j)
|
|
2279
|
+
gamma_j = nw.coef_
|
|
2280
|
+
z_j = x_j - X_minus_j @ gamma_j
|
|
2281
|
+
c_j = float(z_j @ x_j / max(n, 1))
|
|
2282
|
+
if abs(c_j) < 1e-30:
|
|
2283
|
+
M[j, j] = 1.0
|
|
2284
|
+
continue
|
|
2285
|
+
M[j, j] = 1.0 / c_j
|
|
2286
|
+
M[j, cols] = -gamma_j / c_j
|
|
2287
|
+
self._debiased_M_cpu = M
|
|
2288
|
+
|
|
2289
|
+
# Bootstrap the studentized process max_j |Z*_j|.
|
|
2290
|
+
param_target_idx = self._get_simultaneous_target_indices(len(self._params))
|
|
2291
|
+
feature_target_idx = param_target_idx - (1 if self.fit_intercept else 0)
|
|
2292
|
+
feature_target_idx = feature_target_idx[feature_target_idx >= 0]
|
|
2293
|
+
if feature_target_idx.size == 0:
|
|
2294
|
+
raise RuntimeError(
|
|
2295
|
+
"No feature coefficients selected for simultaneous inference target set."
|
|
2296
|
+
)
|
|
2297
|
+
|
|
2298
|
+
se_feat = np.asarray(self._bse[(1 if self.fit_intercept else 0):], dtype=float)
|
|
2299
|
+
eps = resid
|
|
2300
|
+
rng = np.random.default_rng(self.simultaneous_random_state)
|
|
2301
|
+
B = int(self.simultaneous_n_bootstrap)
|
|
2302
|
+
chunk = min(256, B)
|
|
2303
|
+
max_stats = np.empty(B, dtype=float)
|
|
2304
|
+
filled = 0
|
|
2305
|
+
while filled < B:
|
|
2306
|
+
bsz = min(chunk, B - filled)
|
|
2307
|
+
xi = rng.standard_normal(size=(bsz, n))
|
|
2308
|
+
weighted = xi * eps.reshape(1, -1)
|
|
2309
|
+
score = (weighted @ X) @ M.T / float(max(n, 1))
|
|
2310
|
+
z_star = score / (se_feat.reshape(1, -1) + 1e-30)
|
|
2311
|
+
max_stats[filled:filled + bsz] = np.max(
|
|
2312
|
+
np.abs(z_star[:, feature_target_idx]), axis=1
|
|
2313
|
+
)
|
|
2314
|
+
filled += bsz
|
|
2315
|
+
|
|
2316
|
+
critical = float(np.quantile(max_stats, 1.0 - self.simultaneous_alpha))
|
|
2317
|
+
params = np.asarray(self._params, dtype=float)
|
|
2318
|
+
bse = np.asarray(self._bse, dtype=float)
|
|
2319
|
+
conf_sim = np.array(self._conf_int, copy=True, dtype=float)
|
|
2320
|
+
conf_sim[param_target_idx, 0] = params[param_target_idx] - critical * bse[param_target_idx]
|
|
2321
|
+
conf_sim[param_target_idx, 1] = params[param_target_idx] + critical * bse[param_target_idx]
|
|
2322
|
+
|
|
2323
|
+
mask = np.zeros(len(params), dtype=bool)
|
|
2324
|
+
mask[param_target_idx] = True
|
|
2325
|
+
self._conf_int_simultaneous = conf_sim
|
|
2326
|
+
self._simultaneous_enabled = True
|
|
2327
|
+
self._simultaneous_method = self.simultaneous_method
|
|
2328
|
+
self._simultaneous_alpha = float(self.simultaneous_alpha)
|
|
2329
|
+
self._simultaneous_n_bootstrap = B
|
|
2330
|
+
self._simultaneous_critical_value = critical
|
|
2331
|
+
self._simultaneous_target_mask = mask
|
|
2332
|
+
|
|
2333
|
+
@property
|
|
2334
|
+
def rsquared(self):
|
|
2335
|
+
"""R-squared."""
|
|
2336
|
+
if self._resid is None:
|
|
2337
|
+
# In compute_inference=False GPU mode we may avoid transferring residuals.
|
|
2338
|
+
if hasattr(self, "_rsquared_gpu") and self._rsquared_gpu is not None:
|
|
2339
|
+
return self._rsquared_gpu
|
|
2340
|
+
return None
|
|
2341
|
+
if self._y is None or self._resid is None:
|
|
2342
|
+
return None
|
|
2343
|
+
y_mean = np.mean(self._y)
|
|
2344
|
+
ss_tot = np.sum((self._y - y_mean) ** 2)
|
|
2345
|
+
ss_res = np.sum(self._resid ** 2)
|
|
2346
|
+
return 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
|
|
2347
|
+
|
|
2348
|
+
@property
|
|
2349
|
+
def rsquared_adj(self):
|
|
2350
|
+
"""Adjusted R-squared."""
|
|
2351
|
+
if self._nobs is None:
|
|
2352
|
+
return None
|
|
2353
|
+
r2 = self.rsquared
|
|
2354
|
+
if r2 is None:
|
|
2355
|
+
return None
|
|
2356
|
+
k = len(self.coef_)
|
|
2357
|
+
return 1 - (1 - r2) * (self._nobs - 1) / self._df_resid
|
|
2358
|
+
|
|
2359
|
+
@property
|
|
2360
|
+
def fvalue(self):
|
|
2361
|
+
"""F-statistic."""
|
|
2362
|
+
if self._y is not None and self._resid is not None:
|
|
2363
|
+
y_mean = np.mean(self._y)
|
|
2364
|
+
ss_tot = np.sum((self._y - y_mean) ** 2)
|
|
2365
|
+
ss_res = np.sum(self._resid ** 2)
|
|
2366
|
+
ss_reg = ss_tot - ss_res
|
|
2367
|
+
k = len(self.coef_)
|
|
2368
|
+
if k == 0 or ss_res <= 0:
|
|
2369
|
+
return np.inf
|
|
2370
|
+
return (ss_reg / k) / (ss_res / self._df_resid)
|
|
2371
|
+
|
|
2372
|
+
# GPU inference mode may skip transferring residual vectors to host.
|
|
2373
|
+
r2 = self.rsquared
|
|
2374
|
+
if r2 is None:
|
|
2375
|
+
return None
|
|
2376
|
+
k = len(self.coef_)
|
|
2377
|
+
if k <= 0 or self._df_resid is None or self._df_resid <= 0:
|
|
2378
|
+
return None
|
|
2379
|
+
if r2 >= 1.0:
|
|
2380
|
+
return np.inf
|
|
2381
|
+
return (r2 / k) / ((1.0 - r2) / self._df_resid)
|
|
2382
|
+
|
|
2383
|
+
@property
|
|
2384
|
+
def f_pvalue(self):
|
|
2385
|
+
"""p-value for F-statistic."""
|
|
2386
|
+
k = len(self.coef_)
|
|
2387
|
+
if k <= 0 or self._df_resid is None or self._df_resid <= 0:
|
|
2388
|
+
return None
|
|
2389
|
+
fv = self.fvalue
|
|
2390
|
+
if fv is None:
|
|
2391
|
+
return None
|
|
2392
|
+
if fv == np.inf:
|
|
2393
|
+
# An infinite F-statistic corresponds to a perfect-fit / zero-residual
|
|
2394
|
+
# case, so the upper-tail probability tends to 0.
|
|
2395
|
+
return 0.0
|
|
2396
|
+
if fv == np.inf:
|
|
2397
|
+
return 0.0
|
|
2398
|
+
pval = 1.0 - stats.f.cdf(fv, k, self._df_resid)
|
|
2399
|
+
if not np.isfinite(pval):
|
|
2400
|
+
return None
|
|
2401
|
+
return float(np.clip(pval, 0.0, 1.0))
|
|
2402
|
+
|
|
2403
|
+
@property
|
|
2404
|
+
def aic(self):
|
|
2405
|
+
"""Akaike Information Criterion."""
|
|
2406
|
+
if self._nobs is None or np.isnan(self._scale):
|
|
2407
|
+
return None
|
|
2408
|
+
return -2 * self.llf + 2 * len(self._params)
|
|
2409
|
+
|
|
2410
|
+
@property
|
|
2411
|
+
def bic(self):
|
|
2412
|
+
"""Bayesian Information Criterion."""
|
|
2413
|
+
if self._nobs is None or np.isnan(self._scale):
|
|
2414
|
+
return None
|
|
2415
|
+
n = self._nobs
|
|
2416
|
+
k = len(self._params)
|
|
2417
|
+
return -2 * self.llf + k * np.log(n)
|
|
2418
|
+
|
|
2419
|
+
@property
|
|
2420
|
+
def llf(self):
|
|
2421
|
+
"""Log-likelihood."""
|
|
2422
|
+
if self._nobs is None:
|
|
2423
|
+
return None
|
|
2424
|
+
n = self._nobs
|
|
2425
|
+
if self._resid is not None:
|
|
2426
|
+
sigma2_mle = np.sum(self._resid ** 2) / n
|
|
2427
|
+
else:
|
|
2428
|
+
if self._scale is None or np.isnan(self._scale):
|
|
2429
|
+
return None
|
|
2430
|
+
if self._df_resid is None or self._df_resid <= 0:
|
|
2431
|
+
return None
|
|
2432
|
+
sigma2_mle = (self._scale * self._df_resid) / n
|
|
2433
|
+
if sigma2_mle <= 0:
|
|
2434
|
+
return None
|
|
2435
|
+
return -n/2 * np.log(2 * np.pi * sigma2_mle) - n/2
|
|
2436
|
+
|
|
2437
|
+
def summary(self):
|
|
2438
|
+
"""Print summary table."""
|
|
2439
|
+
if not self._fitted:
|
|
2440
|
+
raise RuntimeError("Model has not been fitted yet.")
|
|
2441
|
+
|
|
2442
|
+
if self._bse is None or self._pvalues is None or self._conf_int is None:
|
|
2443
|
+
raise RuntimeError(
|
|
2444
|
+
"compute_inference=False: inference statistics are not available. "
|
|
2445
|
+
"Re-fit with compute_inference=True (default) to use summary()."
|
|
2446
|
+
)
|
|
2447
|
+
|
|
2448
|
+
if self.fit_intercept:
|
|
2449
|
+
feature_names = ['(Intercept)'] + [f'x{i+1}' for i in range(len(self.coef_))]
|
|
2450
|
+
else:
|
|
2451
|
+
feature_names = [f'x{i+1}' for i in range(len(self.coef_))]
|
|
2452
|
+
|
|
2453
|
+
is_debiased = self.inference_method == "debiased"
|
|
2454
|
+
title = "Debiased Lasso Results" if is_debiased else "Lasso Regression Results"
|
|
2455
|
+
stat_label = "z" if is_debiased else "t"
|
|
2456
|
+
pval_label = "P>|z|" if is_debiased else "P>|t|"
|
|
2457
|
+
|
|
2458
|
+
def _fmt_stat(value, fmt_spec: str) -> str:
|
|
2459
|
+
if value is None:
|
|
2460
|
+
return f"{'nan':>15}"
|
|
2461
|
+
try:
|
|
2462
|
+
value_f = float(value)
|
|
2463
|
+
except Exception:
|
|
2464
|
+
return f"{'nan':>15}"
|
|
2465
|
+
if np.isnan(value_f):
|
|
2466
|
+
return f"{'nan':>15}"
|
|
2467
|
+
if np.isposinf(value_f):
|
|
2468
|
+
return f"{'inf':>15}"
|
|
2469
|
+
if np.isneginf(value_f):
|
|
2470
|
+
return f"{'-inf':>15}"
|
|
2471
|
+
return format(value_f, fmt_spec)
|
|
2472
|
+
|
|
2473
|
+
print("=" * 80)
|
|
2474
|
+
if self._inference_cautions:
|
|
2475
|
+
print("Notes:")
|
|
2476
|
+
for note in self._inference_cautions:
|
|
2477
|
+
print(f"- {note}")
|
|
2478
|
+
print("=" * 80)
|
|
2479
|
+
print(f" {title}")
|
|
2480
|
+
print(f" (alpha = {self.alpha:.4f})")
|
|
2481
|
+
print("=" * 80)
|
|
2482
|
+
print(f"No. Observations: {self._nobs:>15}")
|
|
2483
|
+
print(f"Degrees of Freedom: {self._df_resid:>15}")
|
|
2484
|
+
print(f"Iterations: {self.n_iter_:>15}")
|
|
2485
|
+
print(f"R-squared: {_fmt_stat(self.rsquared, '>15.4f')}")
|
|
2486
|
+
print(f"Adj. R-squared: {_fmt_stat(self.rsquared_adj, '>15.4f')}")
|
|
2487
|
+
print(f"F-statistic: {_fmt_stat(self.fvalue, '>15.4f')}")
|
|
2488
|
+
print(f"Prob (F-statistic): {_fmt_stat(self.f_pvalue, '>15.4e')}")
|
|
2489
|
+
print(f"Log-Likelihood: {_fmt_stat(self.llf, '>15.4f')}")
|
|
2490
|
+
print(f"AIC: {_fmt_stat(self.aic, '>15.4f')}")
|
|
2491
|
+
print(f"BIC: {_fmt_stat(self.bic, '>15.4f')}")
|
|
2492
|
+
print("-" * 80)
|
|
2493
|
+
print(f"{'':<15} {'coef':>12} {'std err':>12} {stat_label:>10} {pval_label:>10} {'[0.025':>12} {'0.975]':>12}")
|
|
2494
|
+
print("-" * 80)
|
|
2495
|
+
|
|
2496
|
+
for i, name in enumerate(feature_names):
|
|
2497
|
+
print(f"{name:<15} {self._params[i]:>12.4f} {self._bse[i]:>12.4f} "
|
|
2498
|
+
f"{self._tvalues[i]:>10.3f} {self._pvalues[i]:>10.4f} "
|
|
2499
|
+
f"{self._conf_int[i, 0]:>12.4f} {self._conf_int[i, 1]:>12.4f}")
|
|
2500
|
+
|
|
2501
|
+
if self._simultaneous_enabled and self._conf_int_simultaneous is not None:
|
|
2502
|
+
target_txt = (
|
|
2503
|
+
"include_intercept=True"
|
|
2504
|
+
if (self.fit_intercept and self.simultaneous_include_intercept)
|
|
2505
|
+
else "include_intercept=False"
|
|
2506
|
+
)
|
|
2507
|
+
print("-" * 80)
|
|
2508
|
+
print("Simultaneous inference")
|
|
2509
|
+
print(f"method: {self._simultaneous_method}")
|
|
2510
|
+
print(f"alpha: {self._simultaneous_alpha:.6f}")
|
|
2511
|
+
print(f"n_bootstrap: {self._simultaneous_n_bootstrap}")
|
|
2512
|
+
print(f"critical value (max|Z|): {self._simultaneous_critical_value:.6f}")
|
|
2513
|
+
print(f"target set: {target_txt}")
|
|
2514
|
+
|
|
2515
|
+
print("=" * 80)
|
|
2516
|
+
|
|
2517
|
+
def predict(self, X):
|
|
2518
|
+
"""Predict using the Lasso model."""
|
|
2519
|
+
self._check_is_fitted()
|
|
2520
|
+
device = self._get_compute_device()
|
|
2521
|
+
if device == Device.CUDA:
|
|
2522
|
+
import cupy as cp
|
|
2523
|
+
|
|
2524
|
+
X_gpu = cp.asarray(self._to_array(X, Device.CUDA))
|
|
2525
|
+
coef_gpu = cp.asarray(self.coef_)
|
|
2526
|
+
intercept_gpu = cp.asarray(self.intercept_, dtype=coef_gpu.dtype)
|
|
2527
|
+
return X_gpu @ coef_gpu + intercept_gpu
|
|
2528
|
+
if device == Device.TORCH:
|
|
2529
|
+
import torch
|
|
2530
|
+
|
|
2531
|
+
X_torch = self._to_array(X, Device.TORCH, backend="torch").to(torch.float64)
|
|
2532
|
+
coef_torch = torch.as_tensor(self.coef_, dtype=X_torch.dtype, device=X_torch.device)
|
|
2533
|
+
intercept_torch = torch.as_tensor(
|
|
2534
|
+
self.intercept_, dtype=X_torch.dtype, device=X_torch.device
|
|
2535
|
+
)
|
|
2536
|
+
return X_torch @ coef_torch + intercept_torch
|
|
2537
|
+
X = self._to_array(X, Device.CPU)
|
|
2538
|
+
X = np.asarray(X)
|
|
2539
|
+
return X @ self.coef_ + self.intercept_
|
|
2540
|
+
|
|
2541
|
+
def score(self, X, y):
|
|
2542
|
+
"""Return R^2 score."""
|
|
2543
|
+
y_pred = self.predict(X)
|
|
2544
|
+
device = self._get_compute_device()
|
|
2545
|
+
if device == Device.CUDA:
|
|
2546
|
+
import cupy as cp
|
|
2547
|
+
|
|
2548
|
+
yb = cp.asarray(self._to_array(y, Device.CUDA))
|
|
2549
|
+
ss_res = cp.sum((yb - y_pred) ** 2)
|
|
2550
|
+
ss_tot = cp.sum((yb - cp.mean(yb)) ** 2)
|
|
2551
|
+
return float((1 - ss_res / ss_tot).item()) if float(ss_tot.item()) > 0 else 0.0
|
|
2552
|
+
if device == Device.TORCH:
|
|
2553
|
+
import torch
|
|
2554
|
+
|
|
2555
|
+
yb = self._to_array(y, Device.TORCH, backend="torch").to(y_pred.dtype)
|
|
2556
|
+
ss_res = torch.sum((yb - y_pred) ** 2)
|
|
2557
|
+
ss_tot = torch.sum((yb - torch.mean(yb)) ** 2)
|
|
2558
|
+
return float((1 - ss_res / ss_tot).item()) if float(ss_tot.item()) > 0 else 0.0
|
|
2559
|
+
y_pred = np.asarray(y_pred)
|
|
2560
|
+
y = self._to_numpy(y)
|
|
2561
|
+
ss_res = np.sum((y - y_pred) ** 2)
|
|
2562
|
+
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
|
2563
|
+
return 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
|
|
2564
|
+
|
|
2565
|
+
|
|
2566
|
+
def _lasso_alpha_heuristic(y_centered: np.ndarray, n_features: int) -> float:
|
|
2567
|
+
n_samples = int(y_centered.shape[0])
|
|
2568
|
+
if n_samples > 1:
|
|
2569
|
+
sigma_hat = float(np.std(y_centered, ddof=1))
|
|
2570
|
+
else:
|
|
2571
|
+
sigma_hat = float(np.std(y_centered))
|
|
2572
|
+
sigma_hat = max(sigma_hat, 1e-8)
|
|
2573
|
+
penalty_scale = np.sqrt(2.0 * np.log(max(2, int(n_features))) / max(1, n_samples))
|
|
2574
|
+
return float(sigma_hat * penalty_scale)
|
|
2575
|
+
|
|
2576
|
+
|
|
2577
|
+
def _default_lasso_alpha_grid(
|
|
2578
|
+
X: np.ndarray,
|
|
2579
|
+
y: np.ndarray,
|
|
2580
|
+
n_alphas: int = 12,
|
|
2581
|
+
alpha_min_ratio: float = 1e-3,
|
|
2582
|
+
) -> np.ndarray:
|
|
2583
|
+
n_samples = int(X.shape[0])
|
|
2584
|
+
corr = np.abs(X.T @ y) / float(max(1, n_samples))
|
|
2585
|
+
alpha_max = float(np.max(corr)) if corr.size else 1.0
|
|
2586
|
+
alpha_max = max(alpha_max, _lasso_alpha_heuristic(y, n_features=int(X.shape[1])))
|
|
2587
|
+
alpha_max = max(alpha_max, 1e-6)
|
|
2588
|
+
|
|
2589
|
+
if int(n_alphas) <= 1:
|
|
2590
|
+
return np.asarray([alpha_max], dtype=np.float64)
|
|
2591
|
+
|
|
2592
|
+
alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
|
|
2593
|
+
return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
|
|
2594
|
+
|
|
2595
|
+
|
|
2596
|
+
def _default_lasso_alpha_grid_backend(
|
|
2597
|
+
X,
|
|
2598
|
+
y,
|
|
2599
|
+
backend,
|
|
2600
|
+
n_alphas: int = 12,
|
|
2601
|
+
alpha_min_ratio: float = 1e-3,
|
|
2602
|
+
) -> np.ndarray:
|
|
2603
|
+
"""Generate default alpha grid for Lasso using backend abstraction."""
|
|
2604
|
+
X_arr = backend.asarray(X, dtype=backend.float64)
|
|
2605
|
+
y_arr = backend.asarray(y, dtype=backend.float64).reshape(-1)
|
|
2606
|
+
|
|
2607
|
+
n_samples = int(X_arr.shape[0])
|
|
2608
|
+
corr = backend.abs(X_arr.T @ y_arr) / float(max(1, n_samples))
|
|
2609
|
+
# Use shape to check size - works for both numpy and torch
|
|
2610
|
+
corr_size = int(corr.shape[0]) if hasattr(corr, 'shape') else len(corr)
|
|
2611
|
+
alpha_max = float(backend.to_numpy(backend.max(corr))) if corr_size > 0 else 1.0
|
|
2612
|
+
|
|
2613
|
+
if n_samples > 1:
|
|
2614
|
+
y_std = backend.sqrt(backend.mean((y_arr - backend.mean(y_arr)) ** 2))
|
|
2615
|
+
sigma_hat = float(backend.to_numpy(y_std))
|
|
2616
|
+
else:
|
|
2617
|
+
sigma_hat = 0.0
|
|
2618
|
+
|
|
2619
|
+
sigma_hat = max(sigma_hat, 1e-8)
|
|
2620
|
+
penalty_scale = np.sqrt(2.0 * np.log(max(2, int(X_arr.shape[1]))) / max(1, n_samples))
|
|
2621
|
+
alpha_max = max(alpha_max, float(sigma_hat * penalty_scale), 1e-6)
|
|
2622
|
+
|
|
2623
|
+
if int(n_alphas) <= 1:
|
|
2624
|
+
return np.asarray([alpha_max], dtype=np.float64)
|
|
2625
|
+
|
|
2626
|
+
alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
|
|
2627
|
+
return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
|
|
2628
|
+
|
|
2629
|
+
|
|
2630
|
+
def _default_lasso_alpha_grid_cupy(
|
|
2631
|
+
X,
|
|
2632
|
+
y,
|
|
2633
|
+
n_alphas: int = 12,
|
|
2634
|
+
alpha_min_ratio: float = 1e-3,
|
|
2635
|
+
) -> np.ndarray:
|
|
2636
|
+
import cupy as cp
|
|
2637
|
+
|
|
2638
|
+
X_cp = cp.asarray(X, dtype=cp.float64)
|
|
2639
|
+
y_cp = cp.asarray(y, dtype=cp.float64).reshape(-1)
|
|
2640
|
+
|
|
2641
|
+
n_samples = int(X_cp.shape[0])
|
|
2642
|
+
corr = cp.abs(X_cp.T @ y_cp) / float(max(1, n_samples))
|
|
2643
|
+
alpha_max = float(cp.max(corr).item()) if int(corr.size) > 0 else 1.0
|
|
2644
|
+
|
|
2645
|
+
if n_samples > 1:
|
|
2646
|
+
sigma_hat = float(cp.std(y_cp, ddof=1).item())
|
|
2647
|
+
else:
|
|
2648
|
+
sigma_hat = float(cp.std(y_cp).item())
|
|
2649
|
+
|
|
2650
|
+
sigma_hat = max(sigma_hat, 1e-8)
|
|
2651
|
+
penalty_scale = np.sqrt(2.0 * np.log(max(2, int(X_cp.shape[1]))) / max(1, n_samples))
|
|
2652
|
+
alpha_max = max(alpha_max, float(sigma_hat * penalty_scale), 1e-6)
|
|
2653
|
+
|
|
2654
|
+
if int(n_alphas) <= 1:
|
|
2655
|
+
return np.asarray([alpha_max], dtype=np.float64)
|
|
2656
|
+
|
|
2657
|
+
alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
|
|
2658
|
+
return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
|
|
2659
|
+
|
|
2660
|
+
|
|
2661
|
+
def _kfold_indices(n_samples: int, n_splits: int, random_state: Optional[int]):
|
|
2662
|
+
n = int(n_samples)
|
|
2663
|
+
k = max(2, min(int(n_splits), n))
|
|
2664
|
+
|
|
2665
|
+
rng = np.random.default_rng(random_state)
|
|
2666
|
+
indices = rng.permutation(n)
|
|
2667
|
+
|
|
2668
|
+
fold_sizes = np.full(k, n // k, dtype=np.int64)
|
|
2669
|
+
fold_sizes[: n % k] += 1
|
|
2670
|
+
|
|
2671
|
+
folds = []
|
|
2672
|
+
current = 0
|
|
2673
|
+
for fold_size in fold_sizes:
|
|
2674
|
+
start, stop = current, current + int(fold_size)
|
|
2675
|
+
val_idx = indices[start:stop]
|
|
2676
|
+
train_idx = np.concatenate([indices[:start], indices[stop:]])
|
|
2677
|
+
current = stop
|
|
2678
|
+
if train_idx.size == 0 or val_idx.size == 0:
|
|
2679
|
+
continue
|
|
2680
|
+
folds.append((train_idx, val_idx))
|
|
2681
|
+
|
|
2682
|
+
if len(folds) == 0:
|
|
2683
|
+
all_idx = np.arange(n, dtype=np.int64)
|
|
2684
|
+
return [(all_idx, all_idx)]
|
|
2685
|
+
|
|
2686
|
+
return folds
|
|
2687
|
+
|
|
2688
|
+
|
|
2689
|
+
def _normalize_cv_splits(cv_splits, n_samples: int):
|
|
2690
|
+
if cv_splits is None:
|
|
2691
|
+
return None
|
|
2692
|
+
|
|
2693
|
+
n = int(n_samples)
|
|
2694
|
+
folds = []
|
|
2695
|
+
|
|
2696
|
+
for split in cv_splits:
|
|
2697
|
+
if not isinstance(split, (tuple, list)) or len(split) != 2:
|
|
2698
|
+
raise ValueError("Each cv_splits entry must be a (train_idx, val_idx) pair")
|
|
2699
|
+
|
|
2700
|
+
train_idx = np.asarray(split[0], dtype=np.int64).reshape(-1)
|
|
2701
|
+
val_idx = np.asarray(split[1], dtype=np.int64).reshape(-1)
|
|
2702
|
+
|
|
2703
|
+
if train_idx.size == 0 or val_idx.size == 0:
|
|
2704
|
+
continue
|
|
2705
|
+
|
|
2706
|
+
if (
|
|
2707
|
+
bool(np.any(train_idx < 0))
|
|
2708
|
+
or bool(np.any(train_idx >= n))
|
|
2709
|
+
or bool(np.any(val_idx < 0))
|
|
2710
|
+
or bool(np.any(val_idx >= n))
|
|
2711
|
+
):
|
|
2712
|
+
raise ValueError("cv_splits indices are out of range")
|
|
2713
|
+
|
|
2714
|
+
folds.append((train_idx, val_idx))
|
|
2715
|
+
|
|
2716
|
+
if len(folds) == 0:
|
|
2717
|
+
raise ValueError("cv_splits must contain at least one non-empty split")
|
|
2718
|
+
|
|
2719
|
+
return folds
|
|
2720
|
+
|
|
2721
|
+
|
|
2722
|
+
def _folds_are_complements(folds, n_samples: int) -> bool:
|
|
2723
|
+
"""Return True when each fold uses train as the exact complement of validation."""
|
|
2724
|
+
n = int(n_samples)
|
|
2725
|
+
for train_idx, val_idx in folds:
|
|
2726
|
+
train_arr = np.asarray(train_idx, dtype=np.int64).reshape(-1)
|
|
2727
|
+
val_arr = np.asarray(val_idx, dtype=np.int64).reshape(-1)
|
|
2728
|
+
|
|
2729
|
+
if int(train_arr.size + val_arr.size) != n:
|
|
2730
|
+
return False
|
|
2731
|
+
|
|
2732
|
+
mask = np.zeros((n,), dtype=np.int8)
|
|
2733
|
+
mask[train_arr] = 1
|
|
2734
|
+
if bool(np.any(mask[val_arr] != 0)):
|
|
2735
|
+
return False
|
|
2736
|
+
mask[val_arr] = 1
|
|
2737
|
+
if bool(np.any(mask == 0)):
|
|
2738
|
+
return False
|
|
2739
|
+
|
|
2740
|
+
return True
|
|
2741
|
+
|
|
2742
|
+
|
|
2743
|
+
def _array_identity_token(x: Any) -> Tuple[Any, ...]:
|
|
2744
|
+
if x is None:
|
|
2745
|
+
return ("none",)
|
|
2746
|
+
|
|
2747
|
+
try:
|
|
2748
|
+
import cupy as cp
|
|
2749
|
+
|
|
2750
|
+
if isinstance(x, cp.ndarray):
|
|
2751
|
+
return ("cupy", int(x.data.ptr), tuple(int(v) for v in x.shape), str(x.dtype))
|
|
2752
|
+
except Exception:
|
|
2753
|
+
pass
|
|
2754
|
+
|
|
2755
|
+
# Check for Torch tensors
|
|
2756
|
+
try:
|
|
2757
|
+
import torch
|
|
2758
|
+
|
|
2759
|
+
if isinstance(x, torch.Tensor):
|
|
2760
|
+
# For GPU tensors, use the data pointer; for CPU, use storage pointer
|
|
2761
|
+
if x.is_cuda:
|
|
2762
|
+
ptr = int(x.data_ptr())
|
|
2763
|
+
else:
|
|
2764
|
+
# CPU tensor - use underlying storage pointer
|
|
2765
|
+
ptr = int(x.untyped_storage().data_ptr()) if hasattr(x, 'untyped_storage') else id(x)
|
|
2766
|
+
return ("torch", ptr, tuple(int(v) for v in x.shape), str(x.dtype))
|
|
2767
|
+
except Exception:
|
|
2768
|
+
pass
|
|
2769
|
+
|
|
2770
|
+
arr = np.asarray(x)
|
|
2771
|
+
ptr = int(arr.__array_interface__["data"][0]) if int(arr.size) > 0 else 0
|
|
2772
|
+
return ("numpy", ptr, tuple(int(v) for v in arr.shape), str(arr.dtype))
|
|
2773
|
+
|
|
2774
|
+
|
|
2775
|
+
def _alphas_signature(alphas: np.ndarray) -> str:
|
|
2776
|
+
arr = np.ascontiguousarray(np.asarray(alphas, dtype=np.float64).reshape(-1))
|
|
2777
|
+
return hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest()
|
|
2778
|
+
|
|
2779
|
+
|
|
2780
|
+
def _folds_signature(folds) -> str:
|
|
2781
|
+
hasher = hashlib.blake2b(digest_size=16)
|
|
2782
|
+
for train_idx, val_idx in folds:
|
|
2783
|
+
train_arr = np.ascontiguousarray(np.asarray(train_idx, dtype=np.int64).reshape(-1))
|
|
2784
|
+
val_arr = np.ascontiguousarray(np.asarray(val_idx, dtype=np.int64).reshape(-1))
|
|
2785
|
+
hasher.update(train_arr.tobytes())
|
|
2786
|
+
hasher.update(b"|")
|
|
2787
|
+
hasher.update(val_arr.tobytes())
|
|
2788
|
+
hasher.update(b";")
|
|
2789
|
+
return hasher.hexdigest()
|
|
2790
|
+
|
|
2791
|
+
|
|
2792
|
+
def _make_lasso_cv_auto_cache_key(
|
|
2793
|
+
*,
|
|
2794
|
+
X,
|
|
2795
|
+
y,
|
|
2796
|
+
sample_weight,
|
|
2797
|
+
alpha_grid: np.ndarray,
|
|
2798
|
+
folds,
|
|
2799
|
+
fit_intercept: bool,
|
|
2800
|
+
use_gpu: bool,
|
|
2801
|
+
max_iter: int,
|
|
2802
|
+
tol: float,
|
|
2803
|
+
cpu_solver: str,
|
|
2804
|
+
cv_method: str,
|
|
2805
|
+
cd_kkt_check_every: Optional[int],
|
|
2806
|
+
gpu_cv_mixed_precision: bool,
|
|
2807
|
+
) -> Tuple[Any, ...]:
|
|
2808
|
+
return (
|
|
2809
|
+
"lasso_cv_auto_v1",
|
|
2810
|
+
_array_identity_token(X),
|
|
2811
|
+
_array_identity_token(y),
|
|
2812
|
+
_array_identity_token(sample_weight),
|
|
2813
|
+
_alphas_signature(alpha_grid),
|
|
2814
|
+
_folds_signature(folds),
|
|
2815
|
+
bool(fit_intercept),
|
|
2816
|
+
bool(use_gpu),
|
|
2817
|
+
int(max_iter),
|
|
2818
|
+
float(tol),
|
|
2819
|
+
str(cpu_solver).lower(),
|
|
2820
|
+
str(cv_method).lower(),
|
|
2821
|
+
None if cd_kkt_check_every is None else int(cd_kkt_check_every),
|
|
2822
|
+
bool(gpu_cv_mixed_precision),
|
|
2823
|
+
)
|
|
2824
|
+
|
|
2825
|
+
|
|
2826
|
+
def _clone_lasso_cv_cache_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
2827
|
+
return {
|
|
2828
|
+
"alpha": float(payload["alpha"]),
|
|
2829
|
+
"alphas": np.asarray(payload["alphas"], dtype=np.float64).copy(),
|
|
2830
|
+
"mse_path": np.asarray(payload["mse_path"], dtype=np.float64).copy(),
|
|
2831
|
+
"mean_mse": np.asarray(payload["mean_mse"], dtype=np.float64).copy(),
|
|
2832
|
+
}
|
|
2833
|
+
|
|
2834
|
+
|
|
2835
|
+
def _lasso_cv_cache_get(cache_key: Optional[Tuple[Any, ...]]) -> Optional[Dict[str, Any]]:
|
|
2836
|
+
if cache_key is None or _LASSO_CV_ALPHA_CACHE_MAXSIZE <= 0:
|
|
2837
|
+
return None
|
|
2838
|
+
|
|
2839
|
+
cached = _LASSO_CV_ALPHA_CACHE.get(cache_key)
|
|
2840
|
+
if cached is None:
|
|
2841
|
+
return None
|
|
2842
|
+
|
|
2843
|
+
_LASSO_CV_ALPHA_CACHE.move_to_end(cache_key)
|
|
2844
|
+
return _clone_lasso_cv_cache_payload(cached)
|
|
2845
|
+
|
|
2846
|
+
|
|
2847
|
+
def _lasso_cv_cache_put(cache_key: Optional[Tuple[Any, ...]], payload: Dict[str, Any]) -> None:
|
|
2848
|
+
if cache_key is None or _LASSO_CV_ALPHA_CACHE_MAXSIZE <= 0:
|
|
2849
|
+
return
|
|
2850
|
+
|
|
2851
|
+
_LASSO_CV_ALPHA_CACHE[cache_key] = _clone_lasso_cv_cache_payload(payload)
|
|
2852
|
+
_LASSO_CV_ALPHA_CACHE.move_to_end(cache_key)
|
|
2853
|
+
|
|
2854
|
+
while len(_LASSO_CV_ALPHA_CACHE) > int(_LASSO_CV_ALPHA_CACHE_MAXSIZE):
|
|
2855
|
+
_LASSO_CV_ALPHA_CACHE.popitem(last=False)
|
|
2856
|
+
|
|
2857
|
+
|
|
2858
|
+
def _adaptive_gpu_check_every(
|
|
2859
|
+
*,
|
|
2860
|
+
base_check_every: int,
|
|
2861
|
+
iteration: int,
|
|
2862
|
+
max_iter: int,
|
|
2863
|
+
active_ratio: float,
|
|
2864
|
+
) -> int:
|
|
2865
|
+
"""Adaptive cadence for expensive global convergence checks on GPU."""
|
|
2866
|
+
base = max(1, int(base_check_every))
|
|
2867
|
+
ratio = float(max(0.0, min(1.0, active_ratio)))
|
|
2868
|
+
|
|
2869
|
+
if ratio >= 0.75:
|
|
2870
|
+
interval = max(base, 16)
|
|
2871
|
+
elif ratio >= 0.40:
|
|
2872
|
+
interval = max(base, 12)
|
|
2873
|
+
elif ratio >= 0.15:
|
|
2874
|
+
interval = max(4, base)
|
|
2875
|
+
else:
|
|
2876
|
+
interval = max(2, base // 2)
|
|
2877
|
+
|
|
2878
|
+
progress = float(iteration + 1) / float(max(1, int(max_iter)))
|
|
2879
|
+
if progress >= 0.90:
|
|
2880
|
+
interval = min(interval, 2)
|
|
2881
|
+
elif progress >= 0.75:
|
|
2882
|
+
interval = min(interval, 4)
|
|
2883
|
+
|
|
2884
|
+
return max(1, int(interval))
|
|
2885
|
+
|
|
2886
|
+
|
|
2887
|
+
def _soft_threshold_numpy(x: np.ndarray, gamma: float) -> np.ndarray:
|
|
2888
|
+
gamma_arr = np.asarray(gamma, dtype=np.float64)
|
|
2889
|
+
return np.sign(x) * np.maximum(np.abs(x) - gamma_arr, 0.0)
|
|
2890
|
+
|
|
2891
|
+
|
|
2892
|
+
def _soft_threshold_scalar(x: float, gamma: float) -> float:
|
|
2893
|
+
ax = abs(float(x))
|
|
2894
|
+
g = float(gamma)
|
|
2895
|
+
if ax <= g:
|
|
2896
|
+
return 0.0
|
|
2897
|
+
return float(np.sign(x) * (ax - g))
|
|
2898
|
+
|
|
2899
|
+
|
|
2900
|
+
if _NUMBA_AVAILABLE:
|
|
2901
|
+
|
|
2902
|
+
@njit(cache=True)
|
|
2903
|
+
def _soft_threshold_scalar_numba(x: float, gamma: float) -> float:
|
|
2904
|
+
ax = abs(x)
|
|
2905
|
+
if ax <= gamma:
|
|
2906
|
+
return 0.0
|
|
2907
|
+
if x >= 0.0:
|
|
2908
|
+
return ax - gamma
|
|
2909
|
+
return -(ax - gamma)
|
|
2910
|
+
|
|
2911
|
+
|
|
2912
|
+
@njit(cache=True)
|
|
2913
|
+
def _solve_lasso_path_cpu_cd_numba_impl(
|
|
2914
|
+
XtX: np.ndarray,
|
|
2915
|
+
Xty: np.ndarray,
|
|
2916
|
+
n_samples: int,
|
|
2917
|
+
alphas_desc: np.ndarray,
|
|
2918
|
+
max_iter: int,
|
|
2919
|
+
tol: float,
|
|
2920
|
+
stopping_is_kkt: bool,
|
|
2921
|
+
cd_kkt_check_every: int,
|
|
2922
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
2923
|
+
n_features = XtX.shape[0]
|
|
2924
|
+
n_alphas = alphas_desc.shape[0]
|
|
2925
|
+
|
|
2926
|
+
coefs_path = np.zeros((n_alphas, n_features), dtype=np.float64)
|
|
2927
|
+
n_iters = np.zeros((n_alphas,), dtype=np.int32)
|
|
2928
|
+
|
|
2929
|
+
coef = np.zeros((n_features,), dtype=np.float64)
|
|
2930
|
+
grad = -Xty.copy()
|
|
2931
|
+
|
|
2932
|
+
X_sq_norms = np.empty((n_features,), dtype=np.float64)
|
|
2933
|
+
for j in range(n_features):
|
|
2934
|
+
X_sq_norms[j] = XtX[j, j]
|
|
2935
|
+
|
|
2936
|
+
n_samp = float(max(1, n_samples))
|
|
2937
|
+
alpha_scaled_desc = np.empty((n_alphas,), dtype=np.float64)
|
|
2938
|
+
for idx in range(n_alphas):
|
|
2939
|
+
alpha_scaled_desc[idx] = alphas_desc[idx] * n_samp
|
|
2940
|
+
|
|
2941
|
+
active_mask = np.zeros((n_features,), dtype=np.bool_)
|
|
2942
|
+
check_every = max(1, int(cd_kkt_check_every))
|
|
2943
|
+
|
|
2944
|
+
for alpha_idx in range(n_alphas):
|
|
2945
|
+
alpha = float(alphas_desc[alpha_idx])
|
|
2946
|
+
alpha_scaled = float(alpha_scaled_desc[alpha_idx])
|
|
2947
|
+
if alpha_idx > 0:
|
|
2948
|
+
prev_alpha_scaled = float(alpha_scaled_desc[alpha_idx - 1])
|
|
2949
|
+
else:
|
|
2950
|
+
prev_alpha_scaled = alpha_scaled
|
|
2951
|
+
|
|
2952
|
+
strong_thresh = 2.0 * alpha_scaled - prev_alpha_scaled
|
|
2953
|
+
if strong_thresh < 0.0:
|
|
2954
|
+
strong_thresh = 0.0
|
|
2955
|
+
|
|
2956
|
+
any_active = False
|
|
2957
|
+
max_abs_xty = -1.0
|
|
2958
|
+
max_abs_xty_idx = 0
|
|
2959
|
+
for j in range(n_features):
|
|
2960
|
+
abs_xty = abs(Xty[j])
|
|
2961
|
+
if abs_xty >= strong_thresh:
|
|
2962
|
+
active_mask[j] = True
|
|
2963
|
+
any_active = True
|
|
2964
|
+
if abs_xty > max_abs_xty:
|
|
2965
|
+
max_abs_xty = abs_xty
|
|
2966
|
+
max_abs_xty_idx = j
|
|
2967
|
+
|
|
2968
|
+
if not any_active:
|
|
2969
|
+
active_mask[max_abs_xty_idx] = True
|
|
2970
|
+
|
|
2971
|
+
converged = False
|
|
2972
|
+
|
|
2973
|
+
for iteration in range(int(max_iter)):
|
|
2974
|
+
coef_delta_l1 = 0.0
|
|
2975
|
+
|
|
2976
|
+
for j in range(n_features):
|
|
2977
|
+
if not active_mask[j]:
|
|
2978
|
+
continue
|
|
2979
|
+
|
|
2980
|
+
denom = float(X_sq_norms[j])
|
|
2981
|
+
old_val = float(coef[j])
|
|
2982
|
+
|
|
2983
|
+
if denom > 1e-10:
|
|
2984
|
+
rho_j = -float(grad[j]) + denom * old_val
|
|
2985
|
+
new_val = _soft_threshold_scalar_numba(rho_j, alpha_scaled) / denom
|
|
2986
|
+
else:
|
|
2987
|
+
new_val = 0.0
|
|
2988
|
+
|
|
2989
|
+
delta = new_val - old_val
|
|
2990
|
+
if delta != 0.0:
|
|
2991
|
+
coef[j] = new_val
|
|
2992
|
+
coef_delta_l1 += abs(delta)
|
|
2993
|
+
for row_idx in range(n_features):
|
|
2994
|
+
grad[row_idx] += XtX[row_idx, j] * delta
|
|
2995
|
+
|
|
2996
|
+
should_kkt_scan = (
|
|
2997
|
+
((iteration + 1) % check_every == 0)
|
|
2998
|
+
or (coef_delta_l1 < float(tol))
|
|
2999
|
+
or (iteration + 1 == int(max_iter))
|
|
3000
|
+
)
|
|
3001
|
+
|
|
3002
|
+
violation = 0.0
|
|
3003
|
+
has_inactive_violation = False
|
|
3004
|
+
|
|
3005
|
+
if should_kkt_scan:
|
|
3006
|
+
for j in range(n_features):
|
|
3007
|
+
v = abs(grad[j] / n_samp) - alpha
|
|
3008
|
+
if v < 0.0:
|
|
3009
|
+
v = 0.0
|
|
3010
|
+
if v > violation:
|
|
3011
|
+
violation = v
|
|
3012
|
+
if v > float(tol) and (not active_mask[j]):
|
|
3013
|
+
active_mask[j] = True
|
|
3014
|
+
has_inactive_violation = True
|
|
3015
|
+
|
|
3016
|
+
if stopping_is_kkt:
|
|
3017
|
+
if should_kkt_scan and violation < float(tol):
|
|
3018
|
+
n_iters[alpha_idx] = int(iteration) + 1
|
|
3019
|
+
converged = True
|
|
3020
|
+
break
|
|
3021
|
+
else:
|
|
3022
|
+
if coef_delta_l1 < float(tol) and (not has_inactive_violation):
|
|
3023
|
+
n_iters[alpha_idx] = int(iteration) + 1
|
|
3024
|
+
converged = True
|
|
3025
|
+
break
|
|
3026
|
+
|
|
3027
|
+
if not converged:
|
|
3028
|
+
n_iters[alpha_idx] = int(max_iter)
|
|
3029
|
+
|
|
3030
|
+
for j in range(n_features):
|
|
3031
|
+
coefs_path[alpha_idx, j] = coef[j]
|
|
3032
|
+
if abs(coef[j]) > 0.0:
|
|
3033
|
+
active_mask[j] = True
|
|
3034
|
+
|
|
3035
|
+
return coefs_path, n_iters
|
|
3036
|
+
|
|
3037
|
+
|
|
3038
|
+
def _solve_lasso_path_cpu_cd_numba(
|
|
3039
|
+
XtX: np.ndarray,
|
|
3040
|
+
Xty: np.ndarray,
|
|
3041
|
+
*,
|
|
3042
|
+
n_samples: int,
|
|
3043
|
+
alphas_desc: np.ndarray,
|
|
3044
|
+
max_iter: int,
|
|
3045
|
+
tol: float,
|
|
3046
|
+
stopping: str,
|
|
3047
|
+
cd_kkt_check_every: int,
|
|
3048
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
3049
|
+
XtX_c = np.ascontiguousarray(XtX, dtype=np.float64)
|
|
3050
|
+
Xty_c = np.ascontiguousarray(Xty, dtype=np.float64)
|
|
3051
|
+
alphas_c = np.ascontiguousarray(np.asarray(alphas_desc, dtype=np.float64))
|
|
3052
|
+
stopping_is_kkt = str(stopping).lower() == "kkt"
|
|
3053
|
+
return _solve_lasso_path_cpu_cd_numba_impl(
|
|
3054
|
+
XtX_c,
|
|
3055
|
+
Xty_c,
|
|
3056
|
+
int(n_samples),
|
|
3057
|
+
alphas_c,
|
|
3058
|
+
int(max_iter),
|
|
3059
|
+
float(tol),
|
|
3060
|
+
bool(stopping_is_kkt),
|
|
3061
|
+
int(cd_kkt_check_every),
|
|
3062
|
+
)
|
|
3063
|
+
|
|
3064
|
+
|
|
3065
|
+
def _normalize_lassocv_method(method: str) -> str:
|
|
3066
|
+
"""Normalize CV optimization profile name."""
|
|
3067
|
+
key = str(method).strip().lower()
|
|
3068
|
+
alias_map = {
|
|
3069
|
+
"default": "standard",
|
|
3070
|
+
"classic": "standard",
|
|
3071
|
+
"glmnet_cv": "glmnet",
|
|
3072
|
+
"glmnet.cv": "glmnet",
|
|
3073
|
+
}
|
|
3074
|
+
key = alias_map.get(key, key)
|
|
3075
|
+
if key not in ("standard", "glmnet"):
|
|
3076
|
+
raise ValueError("method must be one of: 'standard', 'glmnet'")
|
|
3077
|
+
return key
|
|
3078
|
+
|
|
3079
|
+
|
|
3080
|
+
def _normalize_cd_kkt_check_every(cd_kkt_check_every: Optional[int]) -> Optional[int]:
|
|
3081
|
+
"""Validate optional coordinate-descent global KKT scan cadence."""
|
|
3082
|
+
if cd_kkt_check_every is None:
|
|
3083
|
+
return None
|
|
3084
|
+
value = int(cd_kkt_check_every)
|
|
3085
|
+
if value <= 0:
|
|
3086
|
+
raise ValueError("cd_kkt_check_every must be a positive integer or None")
|
|
3087
|
+
return value
|
|
3088
|
+
|
|
3089
|
+
|
|
3090
|
+
def _solve_lasso_path_cpu_fista_batched_from_gram(
|
|
3091
|
+
XtX: np.ndarray,
|
|
3092
|
+
Xty: np.ndarray,
|
|
3093
|
+
*,
|
|
3094
|
+
n_samples: int,
|
|
3095
|
+
alphas_desc: np.ndarray,
|
|
3096
|
+
max_iter: int,
|
|
3097
|
+
tol: float,
|
|
3098
|
+
stopping: str,
|
|
3099
|
+
lipschitz_L: Optional[float] = None,
|
|
3100
|
+
check_every: int = 2,
|
|
3101
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
3102
|
+
"""Solve descending-alpha Lasso path with a batched CPU FISTA update."""
|
|
3103
|
+
n_features = int(XtX.shape[0])
|
|
3104
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
3105
|
+
|
|
3106
|
+
coefs = np.zeros((n_features, n_alphas), dtype=np.float64)
|
|
3107
|
+
yk = coefs.copy()
|
|
3108
|
+
tk = np.ones((n_alphas,), dtype=np.float64)
|
|
3109
|
+
n_iters = np.zeros((n_alphas,), dtype=np.int32)
|
|
3110
|
+
|
|
3111
|
+
if lipschitz_L is not None:
|
|
3112
|
+
L = float(lipschitz_L)
|
|
3113
|
+
else:
|
|
3114
|
+
try:
|
|
3115
|
+
eigvals = np.linalg.eigvalsh(XtX)
|
|
3116
|
+
L = float(eigvals[-1] / float(max(1, n_samples)))
|
|
3117
|
+
except Exception:
|
|
3118
|
+
row_sum_bound = float(np.max(np.sum(np.abs(XtX), axis=1)) / float(max(1, n_samples)))
|
|
3119
|
+
L = max(row_sum_bound, 1e-12)
|
|
3120
|
+
|
|
3121
|
+
if L <= 0.0:
|
|
3122
|
+
return coefs.T, n_iters
|
|
3123
|
+
|
|
3124
|
+
n_samp = float(max(1, n_samples))
|
|
3125
|
+
step = 1.0 / L
|
|
3126
|
+
alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
|
|
3127
|
+
thresholds = alphas_desc * step
|
|
3128
|
+
stopping_name = str(stopping).lower()
|
|
3129
|
+
check_every = max(1, int(check_every))
|
|
3130
|
+
|
|
3131
|
+
active = np.arange(n_alphas, dtype=np.int64)
|
|
3132
|
+
|
|
3133
|
+
for iteration in range(int(max_iter)):
|
|
3134
|
+
if active.size == 0:
|
|
3135
|
+
break
|
|
3136
|
+
|
|
3137
|
+
y_active = yk[:, active]
|
|
3138
|
+
coef_old = coefs[:, active]
|
|
3139
|
+
|
|
3140
|
+
grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
|
|
3141
|
+
thresh = thresholds[active].reshape(1, -1)
|
|
3142
|
+
coef_new = _soft_threshold_numpy(y_active - step * grad, thresh)
|
|
3143
|
+
|
|
3144
|
+
t_old = tk[active]
|
|
3145
|
+
t_new = (1.0 + np.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
3146
|
+
beta = (t_old - 1.0) / t_new
|
|
3147
|
+
y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
|
|
3148
|
+
|
|
3149
|
+
coefs[:, active] = coef_new
|
|
3150
|
+
yk[:, active] = y_new
|
|
3151
|
+
tk[active] = t_new
|
|
3152
|
+
|
|
3153
|
+
should_check = ((iteration + 1) % check_every == 0) or (iteration + 1 == int(max_iter))
|
|
3154
|
+
if not should_check:
|
|
3155
|
+
continue
|
|
3156
|
+
|
|
3157
|
+
if stopping_name == "kkt":
|
|
3158
|
+
grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
|
|
3159
|
+
viol = np.max(
|
|
3160
|
+
np.maximum(
|
|
3161
|
+
np.abs(grad_sse) - alphas_desc[active].reshape(1, -1),
|
|
3162
|
+
0.0,
|
|
3163
|
+
),
|
|
3164
|
+
axis=0,
|
|
3165
|
+
)
|
|
3166
|
+
converged_local = viol < float(tol)
|
|
3167
|
+
else:
|
|
3168
|
+
delta = np.sum(np.abs(coef_new - coef_old), axis=0)
|
|
3169
|
+
converged_local = delta < float(tol)
|
|
3170
|
+
|
|
3171
|
+
if not np.any(converged_local):
|
|
3172
|
+
continue
|
|
3173
|
+
|
|
3174
|
+
done = active[converged_local]
|
|
3175
|
+
n_iters[done] = int(iteration) + 1
|
|
3176
|
+
yk[:, done] = coefs[:, done]
|
|
3177
|
+
active = active[~converged_local]
|
|
3178
|
+
|
|
3179
|
+
if active.size > 0:
|
|
3180
|
+
n_iters[active] = int(max_iter)
|
|
3181
|
+
|
|
3182
|
+
return coefs.T, n_iters
|
|
3183
|
+
|
|
3184
|
+
|
|
3185
|
+
def _solve_lasso_path_gpu_fista_batched_from_gram(
|
|
3186
|
+
XtX,
|
|
3187
|
+
Xty,
|
|
3188
|
+
*,
|
|
3189
|
+
n_samples: int,
|
|
3190
|
+
alphas_desc: np.ndarray,
|
|
3191
|
+
max_iter: int,
|
|
3192
|
+
tol: float,
|
|
3193
|
+
stopping: str,
|
|
3194
|
+
lipschitz_L: Optional[float] = None,
|
|
3195
|
+
check_every: int = 8,
|
|
3196
|
+
):
|
|
3197
|
+
"""Solve descending-alpha Lasso path with a batched GPU FISTA update."""
|
|
3198
|
+
import cupy as cp
|
|
3199
|
+
|
|
3200
|
+
n_features = int(XtX.shape[0])
|
|
3201
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
3202
|
+
|
|
3203
|
+
coefs = cp.zeros((n_features, n_alphas), dtype=XtX.dtype)
|
|
3204
|
+
yk = coefs.copy()
|
|
3205
|
+
tk = cp.ones((n_alphas,), dtype=XtX.dtype)
|
|
3206
|
+
n_iters_gpu = cp.zeros((n_alphas,), dtype=cp.int32)
|
|
3207
|
+
|
|
3208
|
+
if lipschitz_L is not None:
|
|
3209
|
+
L = cp.array(float(lipschitz_L), dtype=XtX.dtype)
|
|
3210
|
+
else:
|
|
3211
|
+
try:
|
|
3212
|
+
eigvals = cp.linalg.eigvalsh(XtX)
|
|
3213
|
+
L = eigvals[-1] / float(max(1, n_samples))
|
|
3214
|
+
except Exception:
|
|
3215
|
+
row_sum_bound = cp.max(cp.sum(cp.abs(XtX), axis=1)) / float(max(1, n_samples))
|
|
3216
|
+
L = cp.maximum(row_sum_bound, cp.asarray(1e-12, dtype=XtX.dtype))
|
|
3217
|
+
|
|
3218
|
+
L_scalar = float(cp.asnumpy(L))
|
|
3219
|
+
if L_scalar <= 0.0:
|
|
3220
|
+
return coefs.T, np.zeros((n_alphas,), dtype=np.int32)
|
|
3221
|
+
|
|
3222
|
+
n_samp = float(max(1, n_samples))
|
|
3223
|
+
step = 1.0 / L
|
|
3224
|
+
alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
|
|
3225
|
+
alpha_gpu = cp.asarray(alphas_desc, dtype=XtX.dtype)
|
|
3226
|
+
thresholds = alpha_gpu * step
|
|
3227
|
+
stopping_name = str(stopping).lower()
|
|
3228
|
+
check_every = max(1, int(check_every))
|
|
3229
|
+
|
|
3230
|
+
active_gpu = cp.arange(n_alphas, dtype=cp.int32)
|
|
3231
|
+
|
|
3232
|
+
for iteration in range(int(max_iter)):
|
|
3233
|
+
if int(active_gpu.size) == 0:
|
|
3234
|
+
break
|
|
3235
|
+
|
|
3236
|
+
y_active = yk[:, active_gpu]
|
|
3237
|
+
coef_old = coefs[:, active_gpu]
|
|
3238
|
+
|
|
3239
|
+
grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
|
|
3240
|
+
thresh = thresholds[active_gpu].reshape(1, -1)
|
|
3241
|
+
coef_new = cp.sign(y_active - step * grad) * cp.maximum(cp.abs(y_active - step * grad) - thresh, 0.0)
|
|
3242
|
+
|
|
3243
|
+
t_old = tk[active_gpu]
|
|
3244
|
+
t_new = (1.0 + cp.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
3245
|
+
beta = (t_old - 1.0) / t_new
|
|
3246
|
+
y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
|
|
3247
|
+
|
|
3248
|
+
coefs[:, active_gpu] = coef_new
|
|
3249
|
+
yk[:, active_gpu] = y_new
|
|
3250
|
+
tk[active_gpu] = t_new
|
|
3251
|
+
|
|
3252
|
+
active_ratio = float(int(active_gpu.size)) / float(max(1, n_alphas))
|
|
3253
|
+
check_every_eff = _adaptive_gpu_check_every(
|
|
3254
|
+
base_check_every=check_every,
|
|
3255
|
+
iteration=iteration,
|
|
3256
|
+
max_iter=int(max_iter),
|
|
3257
|
+
active_ratio=active_ratio,
|
|
3258
|
+
)
|
|
3259
|
+
should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
|
|
3260
|
+
if not should_check:
|
|
3261
|
+
continue
|
|
3262
|
+
|
|
3263
|
+
if stopping_name == "kkt":
|
|
3264
|
+
grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
|
|
3265
|
+
viol = cp.max(
|
|
3266
|
+
cp.maximum(
|
|
3267
|
+
cp.abs(grad_sse) - alpha_gpu[active_gpu].reshape(1, -1),
|
|
3268
|
+
0.0,
|
|
3269
|
+
),
|
|
3270
|
+
axis=0,
|
|
3271
|
+
)
|
|
3272
|
+
converged_local_gpu = viol < float(tol)
|
|
3273
|
+
else:
|
|
3274
|
+
delta = cp.sum(cp.abs(coef_new - coef_old), axis=0)
|
|
3275
|
+
converged_local_gpu = delta < float(tol)
|
|
3276
|
+
|
|
3277
|
+
done_gpu = active_gpu[converged_local_gpu]
|
|
3278
|
+
if int(done_gpu.size) == 0:
|
|
3279
|
+
continue
|
|
3280
|
+
|
|
3281
|
+
n_iters_gpu[done_gpu] = int(iteration) + 1
|
|
3282
|
+
yk[:, done_gpu] = coefs[:, done_gpu]
|
|
3283
|
+
active_gpu = active_gpu[~converged_local_gpu]
|
|
3284
|
+
|
|
3285
|
+
if int(active_gpu.size) > 0:
|
|
3286
|
+
n_iters_gpu[active_gpu] = int(max_iter)
|
|
3287
|
+
|
|
3288
|
+
return coefs.T, cp.asnumpy(n_iters_gpu)
|
|
3289
|
+
|
|
3290
|
+
|
|
3291
|
+
def _solve_lasso_path_gpu_fista_multi_fold_from_gram(
|
|
3292
|
+
XtX_batch,
|
|
3293
|
+
Xty_batch,
|
|
3294
|
+
*,
|
|
3295
|
+
n_samples_vec,
|
|
3296
|
+
alphas_desc,
|
|
3297
|
+
max_iter: int,
|
|
3298
|
+
tol: float,
|
|
3299
|
+
stopping: str,
|
|
3300
|
+
lipschitz_L: Optional[float] = None,
|
|
3301
|
+
check_every: int = 8,
|
|
3302
|
+
):
|
|
3303
|
+
"""Solve descending-alpha Lasso paths for all folds together on GPU.
|
|
3304
|
+
|
|
3305
|
+
Note: Fused kernel optimization is disabled for multi-fold solver due to
|
|
3306
|
+
dtype complexity. The single-fold Lasso solver uses fused kernels.
|
|
3307
|
+
"""
|
|
3308
|
+
import cupy as cp
|
|
3309
|
+
|
|
3310
|
+
n_folds = int(XtX_batch.shape[0])
|
|
3311
|
+
n_features = int(XtX_batch.shape[1])
|
|
3312
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
3313
|
+
|
|
3314
|
+
coefs = cp.zeros((n_folds, n_features, n_alphas), dtype=XtX_batch.dtype)
|
|
3315
|
+
yk = coefs.copy()
|
|
3316
|
+
tk = cp.ones((n_folds, n_alphas), dtype=XtX_batch.dtype)
|
|
3317
|
+
n_iters_gpu = cp.zeros((n_folds, n_alphas), dtype=cp.int32)
|
|
3318
|
+
|
|
3319
|
+
# Convert n_samples_vec to numpy using .get() if it's a CuPy array
|
|
3320
|
+
if hasattr(n_samples_vec, 'get'):
|
|
3321
|
+
n_vec_cpu = n_samples_vec.get().astype(np.float64).reshape(-1)
|
|
3322
|
+
else:
|
|
3323
|
+
n_vec_cpu = np.asarray(n_samples_vec, dtype=np.float64).reshape(-1)
|
|
3324
|
+
if n_vec_cpu.size != n_folds:
|
|
3325
|
+
raise ValueError("n_samples_vec must have one entry per fold")
|
|
3326
|
+
n_vec = cp.asarray(n_vec_cpu, dtype=XtX_batch.dtype)
|
|
3327
|
+
|
|
3328
|
+
if lipschitz_L is not None:
|
|
3329
|
+
L = cp.full((n_folds,), float(lipschitz_L), dtype=XtX_batch.dtype)
|
|
3330
|
+
else:
|
|
3331
|
+
try:
|
|
3332
|
+
eigvals = cp.linalg.eigvalsh(XtX_batch)
|
|
3333
|
+
L = eigvals[:, -1] / n_vec
|
|
3334
|
+
except Exception:
|
|
3335
|
+
row_sum_bound = cp.max(cp.sum(cp.abs(XtX_batch), axis=2), axis=1) / n_vec
|
|
3336
|
+
L = cp.maximum(row_sum_bound, cp.asarray(1e-12, dtype=XtX_batch.dtype))
|
|
3337
|
+
|
|
3338
|
+
step = 1.0 / L.reshape(n_folds, 1, 1)
|
|
3339
|
+
# Convert alphas_desc to numpy using .get() if it's a CuPy array
|
|
3340
|
+
if hasattr(alphas_desc, 'get'):
|
|
3341
|
+
alphas_cpu = alphas_desc.get().astype(np.float64)
|
|
3342
|
+
else:
|
|
3343
|
+
alphas_cpu = np.asarray(alphas_desc, dtype=np.float64)
|
|
3344
|
+
alpha_gpu = cp.asarray(alphas_cpu, dtype=XtX_batch.dtype).reshape(1, 1, n_alphas)
|
|
3345
|
+
thresholds = alpha_gpu * step
|
|
3346
|
+
|
|
3347
|
+
Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
|
|
3348
|
+
n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
|
|
3349
|
+
stopping_name = str(stopping).lower()
|
|
3350
|
+
check_every = max(1, int(check_every))
|
|
3351
|
+
|
|
3352
|
+
active_gpu = cp.ones((n_folds, n_alphas), dtype=cp.bool_)
|
|
3353
|
+
active_count = int(n_folds * n_alphas)
|
|
3354
|
+
|
|
3355
|
+
# Note: Fused kernels disabled for multi-fold solver due to dtype complexity
|
|
3356
|
+
# The single-fold Lasso._fit_gpu uses fused kernels
|
|
3357
|
+
use_fused = False
|
|
3358
|
+
fused = None
|
|
3359
|
+
|
|
3360
|
+
for iteration in range(int(max_iter)):
|
|
3361
|
+
if active_count == 0:
|
|
3362
|
+
break
|
|
3363
|
+
|
|
3364
|
+
active_expanded = active_gpu[:, cp.newaxis, :]
|
|
3365
|
+
|
|
3366
|
+
coef_old = coefs.copy()
|
|
3367
|
+
grad = (cp.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
|
|
3368
|
+
|
|
3369
|
+
# Proximal step: soft thresholding
|
|
3370
|
+
yk_step = yk - step * grad
|
|
3371
|
+
coef_candidate = cp.sign(yk_step) * cp.maximum(cp.abs(yk_step) - thresholds, 0.0)
|
|
3372
|
+
coefs = cp.where(active_expanded, coef_candidate, coefs)
|
|
3373
|
+
|
|
3374
|
+
t_old = tk
|
|
3375
|
+
t_new = (1.0 + cp.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
3376
|
+
beta = (t_old - 1.0) / t_new
|
|
3377
|
+
y_candidate = coefs + beta[:, cp.newaxis, :] * (coefs - coef_old)
|
|
3378
|
+
yk = cp.where(active_expanded, y_candidate, yk)
|
|
3379
|
+
tk = cp.where(active_gpu, t_new, tk)
|
|
3380
|
+
|
|
3381
|
+
active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
|
|
3382
|
+
check_every_eff = _adaptive_gpu_check_every(
|
|
3383
|
+
base_check_every=check_every,
|
|
3384
|
+
iteration=iteration,
|
|
3385
|
+
max_iter=int(max_iter),
|
|
3386
|
+
active_ratio=active_ratio,
|
|
3387
|
+
)
|
|
3388
|
+
should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
|
|
3389
|
+
if not should_check:
|
|
3390
|
+
continue
|
|
3391
|
+
|
|
3392
|
+
if stopping_name == "kkt":
|
|
3393
|
+
grad_sse = (cp.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
|
|
3394
|
+
violation = cp.max(cp.maximum(cp.abs(grad_sse) - alpha_gpu, 0.0), axis=1)
|
|
3395
|
+
converged_local_gpu = violation < float(tol)
|
|
3396
|
+
else:
|
|
3397
|
+
delta = cp.sum(cp.abs(coefs - coef_old), axis=1)
|
|
3398
|
+
converged_local_gpu = delta < float(tol)
|
|
3399
|
+
|
|
3400
|
+
newly_done_gpu = active_gpu & converged_local_gpu
|
|
3401
|
+
done_count = int(cp.count_nonzero(newly_done_gpu).item())
|
|
3402
|
+
if done_count == 0:
|
|
3403
|
+
continue
|
|
3404
|
+
|
|
3405
|
+
n_iters_gpu[newly_done_gpu] = int(iteration) + 1
|
|
3406
|
+
yk = cp.where(newly_done_gpu[:, cp.newaxis, :], coefs, yk)
|
|
3407
|
+
active_gpu = active_gpu & (~converged_local_gpu)
|
|
3408
|
+
active_count -= done_count
|
|
3409
|
+
|
|
3410
|
+
n_iters_gpu[active_gpu] = int(max_iter)
|
|
3411
|
+
|
|
3412
|
+
return cp.transpose(coefs, (0, 2, 1)), cp.asnumpy(n_iters_gpu)
|
|
3413
|
+
|
|
3414
|
+
|
|
3415
|
+
def _solve_lasso_path_cpu_from_gram(
|
|
3416
|
+
XtX: np.ndarray,
|
|
3417
|
+
Xty: np.ndarray,
|
|
3418
|
+
*,
|
|
3419
|
+
n_samples: int,
|
|
3420
|
+
alphas_desc: np.ndarray,
|
|
3421
|
+
max_iter: int,
|
|
3422
|
+
tol: float,
|
|
3423
|
+
stopping: str,
|
|
3424
|
+
cpu_solver: str,
|
|
3425
|
+
lipschitz_L: Optional[float] = None,
|
|
3426
|
+
cd_kkt_check_every: int = 1,
|
|
3427
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
3428
|
+
"""Solve a descending-alpha Lasso path on CPU using one precomputed Gram matrix."""
|
|
3429
|
+
n_features = int(XtX.shape[0])
|
|
3430
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
3431
|
+
|
|
3432
|
+
coefs_path = np.zeros((n_alphas, n_features), dtype=np.float64)
|
|
3433
|
+
n_iters = np.zeros(n_alphas, dtype=np.int32)
|
|
3434
|
+
|
|
3435
|
+
coef = np.zeros(n_features, dtype=np.float64)
|
|
3436
|
+
stopping_name = str(stopping).lower()
|
|
3437
|
+
solver_name = str(cpu_solver).lower()
|
|
3438
|
+
|
|
3439
|
+
if solver_name == "fista":
|
|
3440
|
+
return _solve_lasso_path_cpu_fista_batched_from_gram(
|
|
3441
|
+
XtX,
|
|
3442
|
+
Xty,
|
|
3443
|
+
n_samples=n_samples,
|
|
3444
|
+
alphas_desc=alphas_desc,
|
|
3445
|
+
max_iter=max_iter,
|
|
3446
|
+
tol=tol,
|
|
3447
|
+
stopping=stopping,
|
|
3448
|
+
lipschitz_L=lipschitz_L,
|
|
3449
|
+
check_every=2,
|
|
3450
|
+
)
|
|
3451
|
+
|
|
3452
|
+
global _NUMBA_CD_DISABLED
|
|
3453
|
+
use_numba_cd = (
|
|
3454
|
+
_NUMBA_AVAILABLE
|
|
3455
|
+
and (not _NUMBA_CD_DISABLED)
|
|
3456
|
+
and solver_name == "coordinate_descent"
|
|
3457
|
+
)
|
|
3458
|
+
|
|
3459
|
+
if use_numba_cd:
|
|
3460
|
+
try:
|
|
3461
|
+
return _solve_lasso_path_cpu_cd_numba(
|
|
3462
|
+
XtX,
|
|
3463
|
+
Xty,
|
|
3464
|
+
n_samples=n_samples,
|
|
3465
|
+
alphas_desc=alphas_desc,
|
|
3466
|
+
max_iter=max_iter,
|
|
3467
|
+
tol=tol,
|
|
3468
|
+
stopping=stopping,
|
|
3469
|
+
cd_kkt_check_every=cd_kkt_check_every,
|
|
3470
|
+
)
|
|
3471
|
+
except Exception:
|
|
3472
|
+
_NUMBA_CD_DISABLED = True
|
|
3473
|
+
|
|
3474
|
+
# Coordinate descent with incremental gradient updates.
|
|
3475
|
+
X_sq_norms = np.diag(XtX).astype(np.float64, copy=False)
|
|
3476
|
+
grad = XtX @ coef - Xty
|
|
3477
|
+
alpha_scaled_desc = np.asarray(alphas_desc, dtype=np.float64) * float(max(1, n_samples))
|
|
3478
|
+
active_mask = np.zeros((n_features,), dtype=bool)
|
|
3479
|
+
cd_kkt_check_every = max(1, int(cd_kkt_check_every))
|
|
3480
|
+
|
|
3481
|
+
for alpha_idx, alpha in enumerate(alphas_desc):
|
|
3482
|
+
alpha_scaled = float(alpha_scaled_desc[alpha_idx])
|
|
3483
|
+
prev_alpha_scaled = float(alpha_scaled_desc[alpha_idx - 1]) if alpha_idx > 0 else alpha_scaled
|
|
3484
|
+
|
|
3485
|
+
# Strong rule screening: expand active set before cyclic updates.
|
|
3486
|
+
strong_thresh = max(0.0, 2.0 * alpha_scaled - prev_alpha_scaled)
|
|
3487
|
+
active_mask |= np.abs(Xty) >= strong_thresh
|
|
3488
|
+
if not bool(np.any(active_mask)):
|
|
3489
|
+
active_mask[int(np.argmax(np.abs(Xty)))] = True
|
|
3490
|
+
|
|
3491
|
+
converged = False
|
|
3492
|
+
|
|
3493
|
+
for iteration in range(int(max_iter)):
|
|
3494
|
+
coef_delta_l1 = 0.0
|
|
3495
|
+
|
|
3496
|
+
active_idx = np.flatnonzero(active_mask)
|
|
3497
|
+
for j in active_idx:
|
|
3498
|
+
denom = float(X_sq_norms[j])
|
|
3499
|
+
old_val = float(coef[j])
|
|
3500
|
+
|
|
3501
|
+
if denom > 1e-10:
|
|
3502
|
+
rho_j = -float(grad[j]) + denom * old_val
|
|
3503
|
+
new_val = _soft_threshold_scalar(rho_j, alpha_scaled) / denom
|
|
3504
|
+
else:
|
|
3505
|
+
new_val = 0.0
|
|
3506
|
+
|
|
3507
|
+
delta = new_val - old_val
|
|
3508
|
+
if abs(delta) > 0.0:
|
|
3509
|
+
coef[j] = new_val
|
|
3510
|
+
grad += XtX[:, j] * delta
|
|
3511
|
+
coef_delta_l1 += abs(delta)
|
|
3512
|
+
|
|
3513
|
+
# glmnet-style optimization can skip full inactive KKT scans on every pass,
|
|
3514
|
+
# then force a check when updates become small.
|
|
3515
|
+
should_kkt_scan = (
|
|
3516
|
+
((iteration + 1) % cd_kkt_check_every == 0)
|
|
3517
|
+
or (coef_delta_l1 < float(tol))
|
|
3518
|
+
or (iteration + 1 == int(max_iter))
|
|
3519
|
+
)
|
|
3520
|
+
violation = float("inf")
|
|
3521
|
+
inactive_violation_idx = np.empty((0,), dtype=np.int64)
|
|
3522
|
+
|
|
3523
|
+
if should_kkt_scan:
|
|
3524
|
+
violation_vec = np.maximum(
|
|
3525
|
+
np.abs(grad / float(max(1, n_samples))) - float(alpha),
|
|
3526
|
+
0.0,
|
|
3527
|
+
)
|
|
3528
|
+
inactive_violation_idx = np.where((violation_vec > float(tol)) & (~active_mask))[0]
|
|
3529
|
+
if inactive_violation_idx.size > 0:
|
|
3530
|
+
active_mask[inactive_violation_idx] = True
|
|
3531
|
+
violation = float(np.max(violation_vec))
|
|
3532
|
+
|
|
3533
|
+
if stopping_name == "kkt":
|
|
3534
|
+
if should_kkt_scan and violation < float(tol):
|
|
3535
|
+
n_iters[alpha_idx] = iteration + 1
|
|
3536
|
+
converged = True
|
|
3537
|
+
break
|
|
3538
|
+
else:
|
|
3539
|
+
if coef_delta_l1 < float(tol) and inactive_violation_idx.size == 0:
|
|
3540
|
+
n_iters[alpha_idx] = iteration + 1
|
|
3541
|
+
converged = True
|
|
3542
|
+
break
|
|
3543
|
+
|
|
3544
|
+
if not converged:
|
|
3545
|
+
n_iters[alpha_idx] = int(max_iter)
|
|
3546
|
+
|
|
3547
|
+
coefs_path[alpha_idx, :] = coef
|
|
3548
|
+
active_mask |= np.abs(coef) > 0.0
|
|
3549
|
+
|
|
3550
|
+
return coefs_path, n_iters
|
|
3551
|
+
|
|
3552
|
+
|
|
3553
|
+
def _solve_lasso_path_gpu_from_gram(
|
|
3554
|
+
XtX,
|
|
3555
|
+
Xty,
|
|
3556
|
+
*,
|
|
3557
|
+
n_samples: int,
|
|
3558
|
+
alphas_desc: np.ndarray,
|
|
3559
|
+
max_iter: int,
|
|
3560
|
+
tol: float,
|
|
3561
|
+
stopping: str,
|
|
3562
|
+
lipschitz_L: Optional[float] = None,
|
|
3563
|
+
check_every: int = 8,
|
|
3564
|
+
):
|
|
3565
|
+
"""Solve a descending-alpha Lasso path on GPU using one precomputed Gram matrix."""
|
|
3566
|
+
return _solve_lasso_path_gpu_fista_batched_from_gram(
|
|
3567
|
+
XtX,
|
|
3568
|
+
Xty,
|
|
3569
|
+
n_samples=n_samples,
|
|
3570
|
+
alphas_desc=alphas_desc,
|
|
3571
|
+
max_iter=max_iter,
|
|
3572
|
+
tol=tol,
|
|
3573
|
+
stopping=stopping,
|
|
3574
|
+
lipschitz_L=lipschitz_L,
|
|
3575
|
+
check_every=check_every,
|
|
3576
|
+
)
|
|
3577
|
+
|
|
3578
|
+
|
|
3579
|
+
def _batch_mse_numpy(
|
|
3580
|
+
X_val: np.ndarray,
|
|
3581
|
+
y_val: np.ndarray,
|
|
3582
|
+
coefs_path: np.ndarray,
|
|
3583
|
+
intercepts_path: np.ndarray,
|
|
3584
|
+
sample_weight_val: Optional[np.ndarray],
|
|
3585
|
+
) -> np.ndarray:
|
|
3586
|
+
preds = X_val @ coefs_path.T + intercepts_path.reshape(1, -1)
|
|
3587
|
+
sq_err = (y_val.reshape(-1, 1) - preds) ** 2
|
|
3588
|
+
|
|
3589
|
+
if sample_weight_val is None:
|
|
3590
|
+
return np.mean(sq_err, axis=0)
|
|
3591
|
+
|
|
3592
|
+
denom = float(np.sum(sample_weight_val))
|
|
3593
|
+
if denom <= 0.0:
|
|
3594
|
+
return np.mean(sq_err, axis=0)
|
|
3595
|
+
|
|
3596
|
+
return np.sum(sample_weight_val.reshape(-1, 1) * sq_err, axis=0) / denom
|
|
3597
|
+
|
|
3598
|
+
|
|
3599
|
+
def _batch_mse(
|
|
3600
|
+
X_val,
|
|
3601
|
+
y_val,
|
|
3602
|
+
coefs_path,
|
|
3603
|
+
intercepts_path,
|
|
3604
|
+
backend,
|
|
3605
|
+
sample_weight_val,
|
|
3606
|
+
) -> np.ndarray:
|
|
3607
|
+
"""
|
|
3608
|
+
Compute MSE for multiple coefficient vectors.
|
|
3609
|
+
|
|
3610
|
+
Parameters
|
|
3611
|
+
----------
|
|
3612
|
+
X_val : array-like
|
|
3613
|
+
Validation design matrix.
|
|
3614
|
+
y_val : array-like
|
|
3615
|
+
Validation response.
|
|
3616
|
+
coefs_path : array-like
|
|
3617
|
+
Coefficient matrix (n_alphas, n_features).
|
|
3618
|
+
intercepts_path : array-like
|
|
3619
|
+
Intercept vector (n_alphas,).
|
|
3620
|
+
backend : BackendBase
|
|
3621
|
+
Backend instance (CuPyBackend or TorchBackend).
|
|
3622
|
+
sample_weight_val : array-like or None
|
|
3623
|
+
Sample weights.
|
|
3624
|
+
|
|
3625
|
+
Returns
|
|
3626
|
+
-------
|
|
3627
|
+
mse : ndarray
|
|
3628
|
+
MSE for each alpha.
|
|
3629
|
+
"""
|
|
3630
|
+
preds = X_val @ coefs_path.T + intercepts_path.reshape(1, -1)
|
|
3631
|
+
sq_err = (y_val.reshape(-1, 1) - preds) ** 2
|
|
3632
|
+
|
|
3633
|
+
if sample_weight_val is None:
|
|
3634
|
+
mse = backend.mean(sq_err, axis=0)
|
|
3635
|
+
else:
|
|
3636
|
+
denom = backend.sum(sample_weight_val)
|
|
3637
|
+
if float(backend.to_numpy(denom)) <= 0.0:
|
|
3638
|
+
mse = backend.mean(sq_err, axis=0)
|
|
3639
|
+
else:
|
|
3640
|
+
mse = backend.sum(sample_weight_val.reshape(-1, 1) * sq_err, axis=0) / denom
|
|
3641
|
+
|
|
3642
|
+
return backend.to_numpy(mse)
|
|
3643
|
+
|
|
3644
|
+
|
|
3645
|
+
def _soft_threshold_torch(x, gamma):
|
|
3646
|
+
"""Soft thresholding operator for Torch tensors."""
|
|
3647
|
+
import torch
|
|
3648
|
+
return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, torch.tensor(0.0, dtype=x.dtype, device=x.device))
|
|
3649
|
+
|
|
3650
|
+
|
|
3651
|
+
def _fit_lasso_single_alpha_fast(
|
|
3652
|
+
X,
|
|
3653
|
+
y,
|
|
3654
|
+
*,
|
|
3655
|
+
alpha: float,
|
|
3656
|
+
fit_intercept: bool,
|
|
3657
|
+
max_iter: int,
|
|
3658
|
+
tol: float,
|
|
3659
|
+
stopping: str,
|
|
3660
|
+
device: str,
|
|
3661
|
+
cpu_solver: str,
|
|
3662
|
+
cd_kkt_check_every: int = 1,
|
|
3663
|
+
sample_weight=None,
|
|
3664
|
+
) -> Dict[str, object]:
|
|
3665
|
+
"""Fast single-alpha Lasso fit using optimized Gram-based path solvers."""
|
|
3666
|
+
device_name = str(device).lower()
|
|
3667
|
+
alpha_vec = np.asarray([float(alpha)], dtype=np.float64)
|
|
3668
|
+
|
|
3669
|
+
# Check if inputs are torch tensors on GPU
|
|
3670
|
+
is_torch_gpu = False
|
|
3671
|
+
try:
|
|
3672
|
+
import torch
|
|
3673
|
+
is_torch_gpu = device_name == Device.CUDA.value and isinstance(X, torch.Tensor)
|
|
3674
|
+
except Exception:
|
|
3675
|
+
pass
|
|
3676
|
+
|
|
3677
|
+
if device_name == Device.CUDA.value and not is_torch_gpu:
|
|
3678
|
+
# CuPy GPU path
|
|
3679
|
+
import cupy as cp
|
|
3680
|
+
|
|
3681
|
+
X_arr = cp.asarray(X)
|
|
3682
|
+
y_arr = cp.asarray(y).reshape(-1)
|
|
3683
|
+
|
|
3684
|
+
if sample_weight is not None:
|
|
3685
|
+
sw = cp.asarray(sample_weight)
|
|
3686
|
+
sqrt_sw = cp.sqrt(sw)
|
|
3687
|
+
X_arr = X_arr * sqrt_sw[:, cp.newaxis]
|
|
3688
|
+
y_arr = y_arr * sqrt_sw
|
|
3689
|
+
|
|
3690
|
+
if bool(fit_intercept):
|
|
3691
|
+
X_mean = cp.mean(X_arr, axis=0)
|
|
3692
|
+
y_mean = cp.mean(y_arr)
|
|
3693
|
+
X_centered = X_arr - X_mean
|
|
3694
|
+
y_centered = y_arr - y_mean
|
|
3695
|
+
else:
|
|
3696
|
+
X_mean = cp.zeros((X_arr.shape[1],), dtype=X_arr.dtype)
|
|
3697
|
+
y_mean = cp.array(0.0, dtype=X_arr.dtype)
|
|
3698
|
+
X_centered = X_arr
|
|
3699
|
+
y_centered = y_arr
|
|
3700
|
+
|
|
3701
|
+
XtX = X_centered.T @ X_centered
|
|
3702
|
+
Xty = X_centered.T @ y_centered
|
|
3703
|
+
|
|
3704
|
+
coefs_desc, n_iters = _solve_lasso_path_gpu_from_gram(
|
|
3705
|
+
XtX,
|
|
3706
|
+
Xty,
|
|
3707
|
+
n_samples=int(X_arr.shape[0]),
|
|
3708
|
+
alphas_desc=alpha_vec,
|
|
3709
|
+
max_iter=int(max_iter),
|
|
3710
|
+
tol=float(tol),
|
|
3711
|
+
stopping=str(stopping),
|
|
3712
|
+
lipschitz_L=None,
|
|
3713
|
+
check_every=8,
|
|
3714
|
+
)
|
|
3715
|
+
|
|
3716
|
+
coef_gpu = coefs_desc[0]
|
|
3717
|
+
if bool(fit_intercept):
|
|
3718
|
+
intercept_gpu = y_mean - X_mean @ coef_gpu
|
|
3719
|
+
intercept = float(cp.asnumpy(intercept_gpu))
|
|
3720
|
+
else:
|
|
3721
|
+
intercept = 0.0
|
|
3722
|
+
|
|
3723
|
+
coef = np.asarray(cp.asnumpy(coef_gpu), dtype=np.float64)
|
|
3724
|
+
return {
|
|
3725
|
+
"coef": coef,
|
|
3726
|
+
"intercept": float(intercept),
|
|
3727
|
+
"n_iter": int(n_iters[0]),
|
|
3728
|
+
"n_samples": int(X_arr.shape[0]),
|
|
3729
|
+
"n_features": int(X_arr.shape[1]),
|
|
3730
|
+
}
|
|
3731
|
+
|
|
3732
|
+
elif is_torch_gpu:
|
|
3733
|
+
# Torch GPU path - use FISTA solver directly on GPU tensors
|
|
3734
|
+
import torch
|
|
3735
|
+
|
|
3736
|
+
X_arr = X
|
|
3737
|
+
y_arr = y.reshape(-1) if isinstance(y, torch.Tensor) else torch.as_tensor(
|
|
3738
|
+
y, dtype=X_arr.dtype, device=X_arr.device
|
|
3739
|
+
).reshape(-1)
|
|
3740
|
+
|
|
3741
|
+
if sample_weight is not None:
|
|
3742
|
+
sw = sample_weight if isinstance(sample_weight, torch.Tensor) else torch.as_tensor(
|
|
3743
|
+
sample_weight, dtype=X_arr.dtype, device=X_arr.device
|
|
3744
|
+
)
|
|
3745
|
+
sqrt_sw = torch.sqrt(sw)
|
|
3746
|
+
X_arr = X_arr * sqrt_sw[:, None]
|
|
3747
|
+
y_arr = y_arr * sqrt_sw
|
|
3748
|
+
|
|
3749
|
+
if bool(fit_intercept):
|
|
3750
|
+
X_mean = torch.mean(X_arr, dim=0)
|
|
3751
|
+
y_mean = torch.mean(y_arr)
|
|
3752
|
+
X_centered = X_arr - X_mean
|
|
3753
|
+
y_centered = y_arr - y_mean
|
|
3754
|
+
else:
|
|
3755
|
+
X_mean = torch.zeros((X_arr.shape[1],), dtype=X_arr.dtype, device=X_arr.device)
|
|
3756
|
+
y_mean = torch.tensor(0.0, dtype=X_arr.dtype, device=X_arr.device)
|
|
3757
|
+
X_centered = X_arr
|
|
3758
|
+
y_centered = y_arr
|
|
3759
|
+
|
|
3760
|
+
n_samples = int(X_arr.shape[0])
|
|
3761
|
+
n_features = int(X_arr.shape[1])
|
|
3762
|
+
|
|
3763
|
+
# Precompute Gram matrix and X'y for FISTA gradient
|
|
3764
|
+
XtX = X_centered.T @ X_centered
|
|
3765
|
+
Xty = X_centered.T @ y_centered
|
|
3766
|
+
|
|
3767
|
+
# Compute Lipschitz constant L = max eigenvalue of XtX / n
|
|
3768
|
+
try:
|
|
3769
|
+
eigvals = torch.linalg.eigvalsh(XtX)
|
|
3770
|
+
L = eigvals[-1] / n_samples
|
|
3771
|
+
except Exception:
|
|
3772
|
+
L = torch.sum(X_centered ** 2) / n_samples
|
|
3773
|
+
L = torch.clamp(L, min=1e-10)
|
|
3774
|
+
|
|
3775
|
+
step = 1.0 / L
|
|
3776
|
+
thresh = float(alpha) * step
|
|
3777
|
+
|
|
3778
|
+
# FISTA initialization
|
|
3779
|
+
coef = torch.zeros(n_features, dtype=X_arr.dtype, device=X_arr.device)
|
|
3780
|
+
z = coef.clone()
|
|
3781
|
+
t = torch.tensor(1.0, dtype=X_arr.dtype, device=X_arr.device)
|
|
3782
|
+
|
|
3783
|
+
# FISTA iterations
|
|
3784
|
+
for iteration in range(int(max_iter)):
|
|
3785
|
+
coef_old = coef.clone()
|
|
3786
|
+
|
|
3787
|
+
# Gradient step at z
|
|
3788
|
+
grad = (XtX @ z - Xty) / n_samples
|
|
3789
|
+
coef = _soft_threshold_torch(z - step * grad, thresh)
|
|
3790
|
+
|
|
3791
|
+
# Momentum update
|
|
3792
|
+
t_new = (1.0 + torch.sqrt(1.0 + 4.0 * t ** 2)) / 2.0
|
|
3793
|
+
z = coef + ((t - 1.0) / t_new) * (coef - coef_old)
|
|
3794
|
+
t = t_new
|
|
3795
|
+
|
|
3796
|
+
# Convergence check
|
|
3797
|
+
if str(stopping).lower() == "kkt":
|
|
3798
|
+
grad_sse = (XtX @ coef - Xty) / n_samples
|
|
3799
|
+
violation = torch.max(torch.maximum(torch.abs(grad_sse) - float(alpha), torch.tensor(0.0, dtype=X_arr.dtype, device=X_arr.device)))
|
|
3800
|
+
if violation < float(tol):
|
|
3801
|
+
break
|
|
3802
|
+
else:
|
|
3803
|
+
if torch.sum(torch.abs(coef - coef_old)) < float(tol):
|
|
3804
|
+
break
|
|
3805
|
+
|
|
3806
|
+
# Build coefficients
|
|
3807
|
+
if bool(fit_intercept):
|
|
3808
|
+
intercept_torch = y_mean - X_mean @ coef
|
|
3809
|
+
intercept = float(intercept_torch.item())
|
|
3810
|
+
else:
|
|
3811
|
+
intercept = 0.0
|
|
3812
|
+
|
|
3813
|
+
coef_np = np.asarray(coef.detach().cpu().numpy(), dtype=np.float64)
|
|
3814
|
+
return {
|
|
3815
|
+
"coef": coef_np,
|
|
3816
|
+
"intercept": float(intercept),
|
|
3817
|
+
"n_iter": int(iteration + 1),
|
|
3818
|
+
"n_samples": n_samples,
|
|
3819
|
+
"n_features": n_features,
|
|
3820
|
+
}
|
|
3821
|
+
|
|
3822
|
+
X_arr = np.asarray(X)
|
|
3823
|
+
y_arr = np.asarray(y).reshape(-1)
|
|
3824
|
+
|
|
3825
|
+
if sample_weight is not None:
|
|
3826
|
+
sw = np.asarray(sample_weight)
|
|
3827
|
+
sqrt_sw = np.sqrt(sw)
|
|
3828
|
+
X_arr = X_arr * sqrt_sw[:, np.newaxis]
|
|
3829
|
+
y_arr = y_arr * sqrt_sw
|
|
3830
|
+
|
|
3831
|
+
if bool(fit_intercept):
|
|
3832
|
+
X_mean = np.mean(X_arr, axis=0)
|
|
3833
|
+
y_mean = float(np.mean(y_arr))
|
|
3834
|
+
X_centered = X_arr - X_mean
|
|
3835
|
+
y_centered = y_arr - y_mean
|
|
3836
|
+
else:
|
|
3837
|
+
X_mean = np.zeros((X_arr.shape[1],), dtype=np.float64)
|
|
3838
|
+
y_mean = 0.0
|
|
3839
|
+
X_centered = X_arr
|
|
3840
|
+
y_centered = y_arr
|
|
3841
|
+
|
|
3842
|
+
XtX = X_centered.T @ X_centered
|
|
3843
|
+
Xty = X_centered.T @ y_centered
|
|
3844
|
+
|
|
3845
|
+
coefs_desc, n_iters = _solve_lasso_path_cpu_from_gram(
|
|
3846
|
+
XtX,
|
|
3847
|
+
Xty,
|
|
3848
|
+
n_samples=int(X_arr.shape[0]),
|
|
3849
|
+
alphas_desc=alpha_vec,
|
|
3850
|
+
max_iter=int(max_iter),
|
|
3851
|
+
tol=float(tol),
|
|
3852
|
+
stopping=str(stopping),
|
|
3853
|
+
cpu_solver=str(cpu_solver),
|
|
3854
|
+
lipschitz_L=None,
|
|
3855
|
+
cd_kkt_check_every=int(cd_kkt_check_every),
|
|
3856
|
+
)
|
|
3857
|
+
|
|
3858
|
+
coef = np.asarray(coefs_desc[0], dtype=np.float64)
|
|
3859
|
+
if bool(fit_intercept):
|
|
3860
|
+
intercept = float(y_mean - X_mean @ coef)
|
|
3861
|
+
else:
|
|
3862
|
+
intercept = 0.0
|
|
3863
|
+
|
|
3864
|
+
return {
|
|
3865
|
+
"coef": coef,
|
|
3866
|
+
"intercept": float(intercept),
|
|
3867
|
+
"n_iter": int(n_iters[0]),
|
|
3868
|
+
"n_samples": int(X_arr.shape[0]),
|
|
3869
|
+
"n_features": int(X_arr.shape[1]),
|
|
3870
|
+
}
|
|
3871
|
+
|
|
3872
|
+
|
|
3873
|
+
def _select_lasso_alpha_cv(
|
|
3874
|
+
X,
|
|
3875
|
+
y,
|
|
3876
|
+
*,
|
|
3877
|
+
alphas=None,
|
|
3878
|
+
n_alphas: int = 12,
|
|
3879
|
+
alpha_min_ratio: float = 1e-3,
|
|
3880
|
+
cv_folds: int = 5,
|
|
3881
|
+
cv_splits=None,
|
|
3882
|
+
random_state: Optional[int] = None,
|
|
3883
|
+
sample_weight=None,
|
|
3884
|
+
fit_intercept: bool = False,
|
|
3885
|
+
device: Union[str, Device] = Device.CPU,
|
|
3886
|
+
max_iter: int = 3000,
|
|
3887
|
+
tol: float = 1e-4,
|
|
3888
|
+
cpu_solver: str = "coordinate_descent",
|
|
3889
|
+
method: str = "standard",
|
|
3890
|
+
cd_kkt_check_every: Optional[int] = None,
|
|
3891
|
+
gpu_cv_mixed_precision: bool = True,
|
|
3892
|
+
return_details: bool = False,
|
|
3893
|
+
cache_key: Optional[Tuple[Any, ...]] = None,
|
|
3894
|
+
):
|
|
3895
|
+
"""
|
|
3896
|
+
Select alpha via K-fold CV using statgpu's own Lasso implementation.
|
|
3897
|
+
|
|
3898
|
+
Notes
|
|
3899
|
+
-----
|
|
3900
|
+
- Does not depend on sklearn.
|
|
3901
|
+
- Supports GPU path by setting ``device='cuda'``.
|
|
3902
|
+
"""
|
|
3903
|
+
device_name = str(device).lower()
|
|
3904
|
+
use_gpu = device_name == Device.CUDA.value
|
|
3905
|
+
gpu_requested = use_gpu
|
|
3906
|
+
|
|
3907
|
+
gpu_input_cupy = False
|
|
3908
|
+
gpu_input_torch = False
|
|
3909
|
+
if use_gpu:
|
|
3910
|
+
# Check if inputs are already on GPU (CuPy or Torch)
|
|
3911
|
+
try:
|
|
3912
|
+
import cupy as cp
|
|
3913
|
+
gpu_input_cupy = isinstance(X, cp.ndarray) and isinstance(y, cp.ndarray)
|
|
3914
|
+
if sample_weight is not None and not isinstance(sample_weight, cp.ndarray):
|
|
3915
|
+
gpu_input_cupy = False
|
|
3916
|
+
except Exception:
|
|
3917
|
+
pass
|
|
3918
|
+
|
|
3919
|
+
# Also check for torch tensors
|
|
3920
|
+
if not gpu_input_cupy:
|
|
3921
|
+
try:
|
|
3922
|
+
import torch
|
|
3923
|
+
gpu_input_torch = isinstance(X, torch.Tensor) and isinstance(y, torch.Tensor)
|
|
3924
|
+
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
|
|
3925
|
+
gpu_input_torch = False
|
|
3926
|
+
except Exception:
|
|
3927
|
+
pass
|
|
3928
|
+
|
|
3929
|
+
X_np = None
|
|
3930
|
+
y_np = None
|
|
3931
|
+
sample_weight_np = None
|
|
3932
|
+
|
|
3933
|
+
if gpu_input_cupy or gpu_input_torch:
|
|
3934
|
+
# GPU inputs - get backend for validation
|
|
3935
|
+
backend = get_backend(backend='auto', device='cuda')
|
|
3936
|
+
if len(tuple(X.shape)) != 2:
|
|
3937
|
+
raise ValueError("X must be a 2D array")
|
|
3938
|
+
n_samples = int(X.shape[0])
|
|
3939
|
+
y_check = backend.asarray(y).reshape(-1)
|
|
3940
|
+
if int(y_check.shape[0]) != n_samples:
|
|
3941
|
+
raise ValueError("y must have the same number of rows as X")
|
|
3942
|
+
if sample_weight is not None:
|
|
3943
|
+
sw_check = backend.asarray(sample_weight).reshape(-1)
|
|
3944
|
+
if int(sw_check.shape[0]) != n_samples:
|
|
3945
|
+
raise ValueError("sample_weight must have the same number of rows as X")
|
|
3946
|
+
else:
|
|
3947
|
+
X_np = np.asarray(X, dtype=np.float64)
|
|
3948
|
+
y_np = np.asarray(y, dtype=np.float64).reshape(-1)
|
|
3949
|
+
if sample_weight is not None:
|
|
3950
|
+
sample_weight_np = np.asarray(sample_weight, dtype=np.float64).reshape(-1)
|
|
3951
|
+
if X_np.ndim != 2:
|
|
3952
|
+
raise ValueError("X must be a 2D array")
|
|
3953
|
+
if y_np.shape[0] != X_np.shape[0]:
|
|
3954
|
+
raise ValueError("y must have the same number of rows as X")
|
|
3955
|
+
if sample_weight_np is not None and sample_weight_np.shape[0] != X_np.shape[0]:
|
|
3956
|
+
raise ValueError("sample_weight must have the same number of rows as X")
|
|
3957
|
+
n_samples = int(X_np.shape[0])
|
|
3958
|
+
|
|
3959
|
+
cv_method = _normalize_lassocv_method(method)
|
|
3960
|
+
requested_cd_kkt_check_every = _normalize_cd_kkt_check_every(cd_kkt_check_every)
|
|
3961
|
+
|
|
3962
|
+
if alphas is None:
|
|
3963
|
+
if gpu_input_cupy or gpu_input_torch:
|
|
3964
|
+
# Get backend based on input type
|
|
3965
|
+
if gpu_input_torch:
|
|
3966
|
+
backend = get_backend(backend='torch', device='cuda')
|
|
3967
|
+
else:
|
|
3968
|
+
backend = get_backend(backend='cupy', device='cuda')
|
|
3969
|
+
alpha_grid = _default_lasso_alpha_grid_backend(
|
|
3970
|
+
X,
|
|
3971
|
+
y,
|
|
3972
|
+
backend,
|
|
3973
|
+
n_alphas=n_alphas,
|
|
3974
|
+
alpha_min_ratio=alpha_min_ratio,
|
|
3975
|
+
)
|
|
3976
|
+
else:
|
|
3977
|
+
alpha_grid = _default_lasso_alpha_grid(
|
|
3978
|
+
X_np,
|
|
3979
|
+
y_np,
|
|
3980
|
+
n_alphas=n_alphas,
|
|
3981
|
+
alpha_min_ratio=alpha_min_ratio,
|
|
3982
|
+
)
|
|
3983
|
+
else:
|
|
3984
|
+
alpha_grid = np.asarray(alphas, dtype=np.float64).reshape(-1)
|
|
3985
|
+
alpha_grid = alpha_grid[np.isfinite(alpha_grid)]
|
|
3986
|
+
alpha_grid = alpha_grid[alpha_grid > 0.0]
|
|
3987
|
+
if alpha_grid.size == 0:
|
|
3988
|
+
if gpu_input_cupy or gpu_input_torch:
|
|
3989
|
+
# Get backend based on input type
|
|
3990
|
+
if gpu_input_torch:
|
|
3991
|
+
backend = get_backend(backend='torch', device='cuda')
|
|
3992
|
+
else:
|
|
3993
|
+
backend = get_backend(backend='cupy', device='cuda')
|
|
3994
|
+
alpha_grid = _default_lasso_alpha_grid_backend(
|
|
3995
|
+
X,
|
|
3996
|
+
y,
|
|
3997
|
+
backend,
|
|
3998
|
+
n_alphas=n_alphas,
|
|
3999
|
+
alpha_min_ratio=alpha_min_ratio,
|
|
4000
|
+
)
|
|
4001
|
+
else:
|
|
4002
|
+
alpha_grid = _default_lasso_alpha_grid(
|
|
4003
|
+
X_np,
|
|
4004
|
+
y_np,
|
|
4005
|
+
n_alphas=n_alphas,
|
|
4006
|
+
alpha_min_ratio=alpha_min_ratio,
|
|
4007
|
+
)
|
|
4008
|
+
|
|
4009
|
+
user_folds = _normalize_cv_splits(cv_splits, n_samples=n_samples)
|
|
4010
|
+
effective_n_folds = int(len(user_folds)) if user_folds is not None else int(cv_folds)
|
|
4011
|
+
|
|
4012
|
+
if int(n_samples) < 4 or int(alpha_grid.size) == 1 or int(effective_n_folds) < 2:
|
|
4013
|
+
alpha0 = float(alpha_grid[0])
|
|
4014
|
+
if not return_details:
|
|
4015
|
+
return alpha0
|
|
4016
|
+
return {
|
|
4017
|
+
"alpha": alpha0,
|
|
4018
|
+
"alphas": alpha_grid.astype(np.float64, copy=False),
|
|
4019
|
+
"mse_path": np.full((int(alpha_grid.size), 1), np.nan, dtype=np.float64),
|
|
4020
|
+
"mean_mse": np.full(int(alpha_grid.size), np.nan, dtype=np.float64),
|
|
4021
|
+
}
|
|
4022
|
+
|
|
4023
|
+
if user_folds is not None:
|
|
4024
|
+
folds = user_folds
|
|
4025
|
+
else:
|
|
4026
|
+
folds = _kfold_indices(
|
|
4027
|
+
n_samples=int(n_samples),
|
|
4028
|
+
n_splits=int(cv_folds),
|
|
4029
|
+
random_state=random_state,
|
|
4030
|
+
)
|
|
4031
|
+
|
|
4032
|
+
folds_are_complements = _folds_are_complements(folds, n_samples=int(n_samples))
|
|
4033
|
+
|
|
4034
|
+
alpha_grid = alpha_grid.astype(np.float64, copy=False)
|
|
4035
|
+
n_alpha = int(alpha_grid.size)
|
|
4036
|
+
n_folds = int(len(folds))
|
|
4037
|
+
|
|
4038
|
+
cache_key_eff = cache_key
|
|
4039
|
+
if cache_key_eff is None and _LASSO_CV_ALPHA_CACHE_MAXSIZE > 0:
|
|
4040
|
+
cache_key_eff = _make_lasso_cv_auto_cache_key(
|
|
4041
|
+
X=X,
|
|
4042
|
+
y=y,
|
|
4043
|
+
sample_weight=sample_weight,
|
|
4044
|
+
alpha_grid=alpha_grid,
|
|
4045
|
+
folds=folds,
|
|
4046
|
+
fit_intercept=bool(fit_intercept),
|
|
4047
|
+
use_gpu=bool(use_gpu),
|
|
4048
|
+
max_iter=int(max_iter),
|
|
4049
|
+
tol=float(tol),
|
|
4050
|
+
cpu_solver=str(cpu_solver),
|
|
4051
|
+
cv_method=str(cv_method),
|
|
4052
|
+
cd_kkt_check_every=requested_cd_kkt_check_every,
|
|
4053
|
+
gpu_cv_mixed_precision=bool(gpu_cv_mixed_precision),
|
|
4054
|
+
)
|
|
4055
|
+
|
|
4056
|
+
cached_details = _lasso_cv_cache_get(cache_key_eff)
|
|
4057
|
+
if cached_details is not None:
|
|
4058
|
+
if return_details:
|
|
4059
|
+
return cached_details
|
|
4060
|
+
return float(cached_details["alpha"])
|
|
4061
|
+
|
|
4062
|
+
# Evaluate alpha path in descending order for warm-start efficiency.
|
|
4063
|
+
alpha_order_desc = np.argsort(-alpha_grid)
|
|
4064
|
+
alpha_desc = alpha_grid[alpha_order_desc]
|
|
4065
|
+
|
|
4066
|
+
mse_path = np.full((n_alpha, n_folds), np.nan, dtype=np.float64)
|
|
4067
|
+
|
|
4068
|
+
best_alpha = float(alpha_grid[0])
|
|
4069
|
+
best_mse = float("inf")
|
|
4070
|
+
|
|
4071
|
+
if use_gpu:
|
|
4072
|
+
try:
|
|
4073
|
+
# Get backend based on input type - prefer Torch backend for Torch tensors
|
|
4074
|
+
if gpu_input_torch:
|
|
4075
|
+
backend = get_backend(backend='torch', device='cuda')
|
|
4076
|
+
elif gpu_input_cupy:
|
|
4077
|
+
backend = get_backend(backend='cupy', device='cuda')
|
|
4078
|
+
else:
|
|
4079
|
+
backend = get_backend(backend='auto', device='cuda')
|
|
4080
|
+
xp = backend.xp
|
|
4081
|
+
|
|
4082
|
+
cv_dtype = backend.float32 if bool(gpu_cv_mixed_precision) else backend.float64
|
|
4083
|
+
|
|
4084
|
+
# Convert inputs to backend arrays
|
|
4085
|
+
if gpu_input_cupy or gpu_input_torch:
|
|
4086
|
+
# Already on GPU (CuPy or Torch)
|
|
4087
|
+
X_full = backend.asarray(X, dtype=cv_dtype)
|
|
4088
|
+
y_full = backend.asarray(y, dtype=cv_dtype).reshape(-1)
|
|
4089
|
+
if sample_weight is not None:
|
|
4090
|
+
sw_full = backend.asarray(sample_weight, dtype=cv_dtype).reshape(-1)
|
|
4091
|
+
else:
|
|
4092
|
+
sw_full = None
|
|
4093
|
+
else:
|
|
4094
|
+
# Convert from numpy
|
|
4095
|
+
X_full = backend.asarray(X_np, dtype=cv_dtype)
|
|
4096
|
+
y_full = backend.asarray(y_np, dtype=cv_dtype)
|
|
4097
|
+
if sample_weight_np is not None:
|
|
4098
|
+
sw_full = backend.asarray(sample_weight_np, dtype=cv_dtype)
|
|
4099
|
+
else:
|
|
4100
|
+
sw_full = None
|
|
4101
|
+
|
|
4102
|
+
XtX_folds = []
|
|
4103
|
+
Xty_folds = []
|
|
4104
|
+
n_train_folds = []
|
|
4105
|
+
X_mean_folds = []
|
|
4106
|
+
y_mean_folds = []
|
|
4107
|
+
fold_eval_payload = []
|
|
4108
|
+
|
|
4109
|
+
fast_fold_stats = (sw_full is None) and bool(folds_are_complements)
|
|
4110
|
+
if fast_fold_stats:
|
|
4111
|
+
n_total = int(X_full.shape[0])
|
|
4112
|
+
XtX_full = X_full.T @ X_full
|
|
4113
|
+
Xty_full = X_full.T @ y_full
|
|
4114
|
+
if bool(fit_intercept):
|
|
4115
|
+
X_sum_full = backend.sum(X_full, axis=0)
|
|
4116
|
+
y_sum_full = backend.sum(y_full)
|
|
4117
|
+
else:
|
|
4118
|
+
X_sum_full = None
|
|
4119
|
+
y_sum_full = None
|
|
4120
|
+
|
|
4121
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
4122
|
+
train_idx_gpu = backend.asarray(train_idx)
|
|
4123
|
+
val_idx_gpu = backend.asarray(val_idx)
|
|
4124
|
+
|
|
4125
|
+
X_val = X_full[val_idx_gpu]
|
|
4126
|
+
y_val = y_full[val_idx_gpu]
|
|
4127
|
+
sw_val = None if sw_full is None else sw_full[val_idx_gpu]
|
|
4128
|
+
|
|
4129
|
+
if fast_fold_stats:
|
|
4130
|
+
n_val = int(val_idx_gpu.shape[0])
|
|
4131
|
+
n_train = int(n_total - n_val)
|
|
4132
|
+
|
|
4133
|
+
XtX_val = X_val.T @ X_val
|
|
4134
|
+
Xty_val = X_val.T @ y_val
|
|
4135
|
+
XtX_raw = XtX_full - XtX_val
|
|
4136
|
+
Xty_raw = Xty_full - Xty_val
|
|
4137
|
+
|
|
4138
|
+
if bool(fit_intercept):
|
|
4139
|
+
X_sum_val = backend.sum(X_val, axis=0)
|
|
4140
|
+
y_sum_val = backend.sum(y_val)
|
|
4141
|
+
X_sum_train = X_sum_full - X_sum_val
|
|
4142
|
+
y_sum_train = y_sum_full - y_sum_val
|
|
4143
|
+
|
|
4144
|
+
inv_n = backend.asarray(1.0 / float(max(1, n_train)), dtype=X_full.dtype)
|
|
4145
|
+
X_mean = X_sum_train * inv_n
|
|
4146
|
+
y_mean = y_sum_train * inv_n
|
|
4147
|
+
XtX = XtX_raw - backend.outer(X_sum_train, X_sum_train) * inv_n
|
|
4148
|
+
Xty = Xty_raw - X_sum_train * y_mean
|
|
4149
|
+
else:
|
|
4150
|
+
X_mean = backend.zeros((X_full.shape[1],), dtype=X_full.dtype)
|
|
4151
|
+
y_mean = backend.array(0.0, dtype=X_full.dtype)
|
|
4152
|
+
XtX = XtX_raw
|
|
4153
|
+
Xty = Xty_raw
|
|
4154
|
+
else:
|
|
4155
|
+
X_train = X_full[train_idx_gpu]
|
|
4156
|
+
y_train = y_full[train_idx_gpu]
|
|
4157
|
+
sw_train = None if sw_full is None else sw_full[train_idx_gpu]
|
|
4158
|
+
|
|
4159
|
+
if sw_train is not None:
|
|
4160
|
+
sqrt_sw = backend.sqrt(sw_train)
|
|
4161
|
+
X_train = X_train * sqrt_sw[:, None]
|
|
4162
|
+
y_train = y_train * sqrt_sw
|
|
4163
|
+
|
|
4164
|
+
if bool(fit_intercept):
|
|
4165
|
+
X_mean = backend.mean(X_train, axis=0)
|
|
4166
|
+
y_mean = backend.mean(y_train)
|
|
4167
|
+
X_centered = X_train - X_mean
|
|
4168
|
+
y_centered = y_train - y_mean
|
|
4169
|
+
else:
|
|
4170
|
+
X_mean = backend.zeros((X_train.shape[1],), dtype=X_train.dtype)
|
|
4171
|
+
y_mean = backend.array(0.0, dtype=X_train.dtype)
|
|
4172
|
+
X_centered = X_train
|
|
4173
|
+
y_centered = y_train
|
|
4174
|
+
|
|
4175
|
+
XtX = X_centered.T @ X_centered
|
|
4176
|
+
Xty = X_centered.T @ y_centered
|
|
4177
|
+
n_train = int(X_train.shape[0])
|
|
4178
|
+
|
|
4179
|
+
XtX_folds.append(XtX)
|
|
4180
|
+
Xty_folds.append(Xty)
|
|
4181
|
+
n_train_folds.append(int(n_train))
|
|
4182
|
+
X_mean_folds.append(X_mean)
|
|
4183
|
+
y_mean_folds.append(y_mean)
|
|
4184
|
+
fold_eval_payload.append((X_val, y_val, sw_val))
|
|
4185
|
+
|
|
4186
|
+
XtX_batch = backend.stack(XtX_folds, axis=0)
|
|
4187
|
+
Xty_batch = backend.stack(Xty_folds, axis=0)
|
|
4188
|
+
|
|
4189
|
+
# Use native Torch FISTA solver for Torch backend
|
|
4190
|
+
if hasattr(xp, '__name__') and 'torch' in xp.__name__.lower():
|
|
4191
|
+
import torch
|
|
4192
|
+
n_samples_vec_torch = torch.tensor(np.asarray(n_train_folds, dtype=np.int32), device=XtX_batch.device, dtype=XtX_batch.dtype)
|
|
4193
|
+
|
|
4194
|
+
coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
|
|
4195
|
+
XtX_batch,
|
|
4196
|
+
Xty_batch,
|
|
4197
|
+
n_samples_vec=n_samples_vec_torch,
|
|
4198
|
+
alphas_desc=alpha_desc,
|
|
4199
|
+
max_iter=int(max_iter),
|
|
4200
|
+
tol=float(tol),
|
|
4201
|
+
stopping="coef_delta",
|
|
4202
|
+
lipschitz_L=None,
|
|
4203
|
+
check_every=8,
|
|
4204
|
+
)
|
|
4205
|
+
|
|
4206
|
+
# Convert results back to numpy for evaluation
|
|
4207
|
+
for fold_idx in range(int(len(folds))):
|
|
4208
|
+
coefs_desc_np = coefs_batch_desc[fold_idx] # already numpy from the solver
|
|
4209
|
+
|
|
4210
|
+
if bool(fit_intercept):
|
|
4211
|
+
y_mean_val = float(y_mean_folds[fold_idx])
|
|
4212
|
+
X_mean_val = X_mean_folds[fold_idx]
|
|
4213
|
+
intercepts_desc = y_mean_val - X_mean_val @ coefs_desc_np.T
|
|
4214
|
+
intercepts_desc_gpu = backend.asarray(intercepts_desc)
|
|
4215
|
+
coefs_desc_gpu = backend.asarray(coefs_desc_np)
|
|
4216
|
+
else:
|
|
4217
|
+
intercepts_desc_gpu = backend.zeros((coefs_desc_np.shape[0],), dtype=coefs_desc_np.dtype)
|
|
4218
|
+
coefs_desc_gpu = backend.asarray(coefs_desc_np)
|
|
4219
|
+
|
|
4220
|
+
X_val, y_val, sw_val = fold_eval_payload[fold_idx]
|
|
4221
|
+
mse_desc = _batch_mse(X_val, y_val, coefs_desc_gpu, intercepts_desc_gpu, backend, sw_val)
|
|
4222
|
+
|
|
4223
|
+
mse_path[alpha_order_desc, fold_idx] = mse_desc
|
|
4224
|
+
else:
|
|
4225
|
+
# CuPy backend - use existing solver directly
|
|
4226
|
+
import cupy as cp
|
|
4227
|
+
n_samples_vec_cp = cp.asarray(np.asarray(n_train_folds, dtype=np.int32))
|
|
4228
|
+
|
|
4229
|
+
coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram(
|
|
4230
|
+
XtX_batch,
|
|
4231
|
+
Xty_batch,
|
|
4232
|
+
n_samples_vec=n_samples_vec_cp,
|
|
4233
|
+
alphas_desc=alpha_desc,
|
|
4234
|
+
max_iter=int(max_iter),
|
|
4235
|
+
tol=float(tol),
|
|
4236
|
+
stopping="coef_delta",
|
|
4237
|
+
lipschitz_L=None,
|
|
4238
|
+
check_every=8,
|
|
4239
|
+
)
|
|
4240
|
+
|
|
4241
|
+
for fold_idx in range(int(len(folds))):
|
|
4242
|
+
coefs_desc = coefs_batch_desc[fold_idx]
|
|
4243
|
+
|
|
4244
|
+
if bool(fit_intercept):
|
|
4245
|
+
intercepts_desc = y_mean_folds[fold_idx] - X_mean_folds[fold_idx] @ coefs_desc.T
|
|
4246
|
+
else:
|
|
4247
|
+
intercepts_desc = backend.zeros((coefs_desc.shape[0],), dtype=coefs_desc.dtype)
|
|
4248
|
+
|
|
4249
|
+
X_val, y_val, sw_val = fold_eval_payload[fold_idx]
|
|
4250
|
+
mse_desc = _batch_mse(X_val, y_val, coefs_desc, intercepts_desc, backend, sw_val)
|
|
4251
|
+
|
|
4252
|
+
mse_path[alpha_order_desc, fold_idx] = mse_desc
|
|
4253
|
+
|
|
4254
|
+
except Exception as exc:
|
|
4255
|
+
raise RuntimeError(
|
|
4256
|
+
"GPU path failed in _select_lasso_alpha_cv with device='cuda'; "
|
|
4257
|
+
"CPU fallback is disabled for strict CUDA execution."
|
|
4258
|
+
) from exc
|
|
4259
|
+
|
|
4260
|
+
if not use_gpu:
|
|
4261
|
+
if gpu_requested:
|
|
4262
|
+
raise RuntimeError(
|
|
4263
|
+
"device='cuda' requested but GPU path was not executed; "
|
|
4264
|
+
"CPU fallback is disabled for strict CUDA execution."
|
|
4265
|
+
)
|
|
4266
|
+
cpu_solver_name = str(cpu_solver).lower()
|
|
4267
|
+
|
|
4268
|
+
if cv_method == "glmnet":
|
|
4269
|
+
# glmnet-like CV profile: coordinate-descent path with periodic full KKT scans.
|
|
4270
|
+
cpu_solver_name = "coordinate_descent"
|
|
4271
|
+
|
|
4272
|
+
if requested_cd_kkt_check_every is None:
|
|
4273
|
+
cd_kkt_check_every_effective = 4 if cv_method == "glmnet" else 1
|
|
4274
|
+
else:
|
|
4275
|
+
cd_kkt_check_every_effective = int(requested_cd_kkt_check_every)
|
|
4276
|
+
|
|
4277
|
+
fast_fold_stats = (sample_weight_np is None) and bool(folds_are_complements)
|
|
4278
|
+
if fast_fold_stats:
|
|
4279
|
+
n_total = int(X_np.shape[0])
|
|
4280
|
+
XtX_full = X_np.T @ X_np
|
|
4281
|
+
Xty_full = X_np.T @ y_np
|
|
4282
|
+
if bool(fit_intercept):
|
|
4283
|
+
X_sum_full = np.sum(X_np, axis=0)
|
|
4284
|
+
y_sum_full = float(np.sum(y_np))
|
|
4285
|
+
else:
|
|
4286
|
+
X_sum_full = None
|
|
4287
|
+
y_sum_full = None
|
|
4288
|
+
|
|
4289
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
4290
|
+
X_val = X_np[val_idx]
|
|
4291
|
+
y_val = y_np[val_idx]
|
|
4292
|
+
sw_val = None if sample_weight_np is None else sample_weight_np[val_idx]
|
|
4293
|
+
|
|
4294
|
+
if fast_fold_stats:
|
|
4295
|
+
n_val = int(np.asarray(val_idx, dtype=np.int64).reshape(-1).size)
|
|
4296
|
+
n_train = int(n_total - n_val)
|
|
4297
|
+
|
|
4298
|
+
XtX_val = X_val.T @ X_val
|
|
4299
|
+
Xty_val = X_val.T @ y_val
|
|
4300
|
+
XtX_raw = XtX_full - XtX_val
|
|
4301
|
+
Xty_raw = Xty_full - Xty_val
|
|
4302
|
+
|
|
4303
|
+
if bool(fit_intercept):
|
|
4304
|
+
X_sum_val = np.sum(X_val, axis=0)
|
|
4305
|
+
y_sum_val = float(np.sum(y_val))
|
|
4306
|
+
X_sum_train = X_sum_full - X_sum_val
|
|
4307
|
+
y_sum_train = y_sum_full - y_sum_val
|
|
4308
|
+
|
|
4309
|
+
inv_n = 1.0 / float(max(1, n_train))
|
|
4310
|
+
X_mean = X_sum_train * inv_n
|
|
4311
|
+
y_mean = y_sum_train * inv_n
|
|
4312
|
+
XtX = XtX_raw - np.outer(X_sum_train, X_sum_train) * inv_n
|
|
4313
|
+
Xty = Xty_raw - X_sum_train * y_mean
|
|
4314
|
+
else:
|
|
4315
|
+
X_mean = np.zeros((X_np.shape[1],), dtype=np.float64)
|
|
4316
|
+
y_mean = 0.0
|
|
4317
|
+
XtX = XtX_raw
|
|
4318
|
+
Xty = Xty_raw
|
|
4319
|
+
else:
|
|
4320
|
+
X_train = X_np[train_idx]
|
|
4321
|
+
y_train = y_np[train_idx]
|
|
4322
|
+
sw_train = None if sample_weight_np is None else sample_weight_np[train_idx]
|
|
4323
|
+
|
|
4324
|
+
if sw_train is not None:
|
|
4325
|
+
sqrt_sw = np.sqrt(sw_train)
|
|
4326
|
+
X_train = X_train * sqrt_sw[:, np.newaxis]
|
|
4327
|
+
y_train = y_train * sqrt_sw
|
|
4328
|
+
|
|
4329
|
+
if bool(fit_intercept):
|
|
4330
|
+
X_mean = np.mean(X_train, axis=0)
|
|
4331
|
+
y_mean = float(np.mean(y_train))
|
|
4332
|
+
X_centered = X_train - X_mean
|
|
4333
|
+
y_centered = y_train - y_mean
|
|
4334
|
+
else:
|
|
4335
|
+
X_mean = np.zeros((X_train.shape[1],), dtype=np.float64)
|
|
4336
|
+
y_mean = 0.0
|
|
4337
|
+
X_centered = X_train
|
|
4338
|
+
y_centered = y_train
|
|
4339
|
+
|
|
4340
|
+
XtX = X_centered.T @ X_centered
|
|
4341
|
+
Xty = X_centered.T @ y_centered
|
|
4342
|
+
n_train = int(X_train.shape[0])
|
|
4343
|
+
|
|
4344
|
+
coefs_desc, _ = _solve_lasso_path_cpu_from_gram(
|
|
4345
|
+
XtX,
|
|
4346
|
+
Xty,
|
|
4347
|
+
n_samples=int(n_train),
|
|
4348
|
+
alphas_desc=alpha_desc,
|
|
4349
|
+
max_iter=int(max_iter),
|
|
4350
|
+
tol=float(tol),
|
|
4351
|
+
stopping="coef_delta",
|
|
4352
|
+
cpu_solver=cpu_solver_name,
|
|
4353
|
+
lipschitz_L=None,
|
|
4354
|
+
cd_kkt_check_every=cd_kkt_check_every_effective,
|
|
4355
|
+
)
|
|
4356
|
+
|
|
4357
|
+
if bool(fit_intercept):
|
|
4358
|
+
intercepts_desc = y_mean - X_mean @ coefs_desc.T
|
|
4359
|
+
else:
|
|
4360
|
+
intercepts_desc = np.zeros((coefs_desc.shape[0],), dtype=np.float64)
|
|
4361
|
+
|
|
4362
|
+
mse_desc = _batch_mse_numpy(
|
|
4363
|
+
X_val,
|
|
4364
|
+
y_val,
|
|
4365
|
+
coefs_desc,
|
|
4366
|
+
intercepts_desc,
|
|
4367
|
+
sw_val,
|
|
4368
|
+
)
|
|
4369
|
+
|
|
4370
|
+
mse_path[alpha_order_desc, fold_idx] = np.asarray(mse_desc, dtype=np.float64)
|
|
4371
|
+
|
|
4372
|
+
for alpha_idx, alpha in enumerate(alpha_grid):
|
|
4373
|
+
alpha_f = float(alpha)
|
|
4374
|
+
valid = np.isfinite(mse_path[alpha_idx])
|
|
4375
|
+
if not bool(np.any(valid)):
|
|
4376
|
+
continue
|
|
4377
|
+
|
|
4378
|
+
mean_mse = float(np.mean(mse_path[alpha_idx, valid]))
|
|
4379
|
+
if mean_mse < best_mse:
|
|
4380
|
+
best_mse = mean_mse
|
|
4381
|
+
best_alpha = alpha_f
|
|
4382
|
+
|
|
4383
|
+
mean_mse_vec = np.full(int(alpha_grid.size), np.nan, dtype=np.float64)
|
|
4384
|
+
for alpha_idx in range(int(alpha_grid.size)):
|
|
4385
|
+
valid = np.isfinite(mse_path[alpha_idx])
|
|
4386
|
+
if bool(np.any(valid)):
|
|
4387
|
+
mean_mse_vec[alpha_idx] = float(np.mean(mse_path[alpha_idx, valid]))
|
|
4388
|
+
|
|
4389
|
+
details = {
|
|
4390
|
+
"alpha": float(best_alpha),
|
|
4391
|
+
"alphas": alpha_grid.astype(np.float64, copy=False),
|
|
4392
|
+
"mse_path": mse_path,
|
|
4393
|
+
"mean_mse": mean_mse_vec,
|
|
4394
|
+
}
|
|
4395
|
+
|
|
4396
|
+
_lasso_cv_cache_put(cache_key_eff, details)
|
|
4397
|
+
|
|
4398
|
+
if return_details:
|
|
4399
|
+
return details
|
|
4400
|
+
|
|
4401
|
+
return float(details["alpha"])
|
|
4402
|
+
|
|
4403
|
+
|
|
4404
|
+
class LassoCV(CVEstimatorBase):
|
|
4405
|
+
"""
|
|
4406
|
+
Cross-validated Lasso built on top of statgpu's own ``Lasso`` implementation.
|
|
4407
|
+
|
|
4408
|
+
This class mirrors the common sklearn-style usage pattern while keeping
|
|
4409
|
+
backend/device behavior consistent with statgpu models.
|
|
4410
|
+
"""
|
|
4411
|
+
|
|
4412
|
+
def __init__(
|
|
4413
|
+
self,
|
|
4414
|
+
alphas=None,
|
|
4415
|
+
n_alphas: int = 12,
|
|
4416
|
+
alpha_min_ratio: float = 1e-3,
|
|
4417
|
+
cv: int = 5,
|
|
4418
|
+
cv_splits=None,
|
|
4419
|
+
fit_intercept: bool = True,
|
|
4420
|
+
max_iter: int = 3000,
|
|
4421
|
+
tol: float = 1e-4,
|
|
4422
|
+
stopping: str = "coef_delta",
|
|
4423
|
+
inference_method: str = "cpu_ols_inference",
|
|
4424
|
+
n_bootstrap: int = 200,
|
|
4425
|
+
bootstrap_random_state: Optional[int] = None,
|
|
4426
|
+
device: Union[str, Device] = Device.AUTO,
|
|
4427
|
+
n_jobs: Optional[int] = None,
|
|
4428
|
+
compute_inference: bool = True,
|
|
4429
|
+
solver: str = "fista",
|
|
4430
|
+
cpu_solver: str = "coordinate_descent",
|
|
4431
|
+
method: str = "standard",
|
|
4432
|
+
cd_kkt_check_every: Optional[int] = None,
|
|
4433
|
+
lipschitz_L: Optional[float] = None,
|
|
4434
|
+
admm_rho: float = 1.0,
|
|
4435
|
+
gpu_memory_cleanup: bool = False,
|
|
4436
|
+
gpu_cv_mixed_precision: bool = True,
|
|
4437
|
+
random_state: Optional[int] = None,
|
|
4438
|
+
):
|
|
4439
|
+
super().__init__(
|
|
4440
|
+
cv=cv,
|
|
4441
|
+
random_state=random_state,
|
|
4442
|
+
device=device,
|
|
4443
|
+
n_jobs=n_jobs,
|
|
4444
|
+
)
|
|
4445
|
+
self.alphas = alphas
|
|
4446
|
+
self.n_alphas = int(n_alphas)
|
|
4447
|
+
self.alpha_min_ratio = float(alpha_min_ratio)
|
|
4448
|
+
self.cv = int(cv)
|
|
4449
|
+
self.cv_splits = cv_splits
|
|
4450
|
+
self.fit_intercept = bool(fit_intercept)
|
|
4451
|
+
self.max_iter = int(max_iter)
|
|
4452
|
+
self.tol = float(tol)
|
|
4453
|
+
self.stopping = str(stopping)
|
|
4454
|
+
self.inference_method = str(inference_method)
|
|
4455
|
+
self.n_bootstrap = int(n_bootstrap)
|
|
4456
|
+
self.bootstrap_random_state = bootstrap_random_state
|
|
4457
|
+
self.compute_inference = bool(compute_inference)
|
|
4458
|
+
self.solver = str(solver)
|
|
4459
|
+
self.cpu_solver = str(cpu_solver)
|
|
4460
|
+
self.method = _normalize_lassocv_method(method)
|
|
4461
|
+
self.cd_kkt_check_every = _normalize_cd_kkt_check_every(cd_kkt_check_every)
|
|
4462
|
+
self.lipschitz_L = lipschitz_L
|
|
4463
|
+
self.admm_rho = float(admm_rho)
|
|
4464
|
+
self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
|
|
4465
|
+
self.gpu_cv_mixed_precision = bool(gpu_cv_mixed_precision)
|
|
4466
|
+
self.random_state = random_state
|
|
4467
|
+
|
|
4468
|
+
self.alpha_ = None
|
|
4469
|
+
self.alphas_ = None
|
|
4470
|
+
self.mse_path_ = None
|
|
4471
|
+
self.mean_mse_ = None
|
|
4472
|
+
self.best_score_ = None
|
|
4473
|
+
self.coef_ = None
|
|
4474
|
+
self.intercept_ = None
|
|
4475
|
+
self.n_iter_ = None
|
|
4476
|
+
self.estimator_ = None
|
|
4477
|
+
|
|
4478
|
+
def fit(self, X, y, sample_weight=None):
|
|
4479
|
+
device_name = self._get_compute_device().value
|
|
4480
|
+
effective_cpu_solver = (
|
|
4481
|
+
"coordinate_descent" if str(self.method).lower() == "glmnet" else str(self.cpu_solver)
|
|
4482
|
+
)
|
|
4483
|
+
|
|
4484
|
+
details = _select_lasso_alpha_cv(
|
|
4485
|
+
X,
|
|
4486
|
+
y,
|
|
4487
|
+
alphas=self.alphas,
|
|
4488
|
+
n_alphas=self.n_alphas,
|
|
4489
|
+
alpha_min_ratio=self.alpha_min_ratio,
|
|
4490
|
+
cv_folds=self.cv,
|
|
4491
|
+
cv_splits=self.cv_splits,
|
|
4492
|
+
random_state=self.random_state,
|
|
4493
|
+
sample_weight=sample_weight,
|
|
4494
|
+
fit_intercept=self.fit_intercept,
|
|
4495
|
+
device=device_name,
|
|
4496
|
+
max_iter=self.max_iter,
|
|
4497
|
+
tol=self.tol,
|
|
4498
|
+
cpu_solver=effective_cpu_solver,
|
|
4499
|
+
method=self.method,
|
|
4500
|
+
cd_kkt_check_every=self.cd_kkt_check_every,
|
|
4501
|
+
gpu_cv_mixed_precision=self.gpu_cv_mixed_precision,
|
|
4502
|
+
return_details=True,
|
|
4503
|
+
)
|
|
4504
|
+
|
|
4505
|
+
effective_cd_kkt_check_every = self.cd_kkt_check_every
|
|
4506
|
+
if effective_cd_kkt_check_every is None:
|
|
4507
|
+
effective_cd_kkt_check_every = 4 if str(self.method).lower() == "glmnet" else 1
|
|
4508
|
+
|
|
4509
|
+
self.alpha_ = float(details["alpha"])
|
|
4510
|
+
self.alphas_ = np.asarray(details["alphas"], dtype=np.float64)
|
|
4511
|
+
self.mse_path_ = np.asarray(details["mse_path"], dtype=np.float64)
|
|
4512
|
+
self.mean_mse_ = np.asarray(details["mean_mse"], dtype=np.float64)
|
|
4513
|
+
|
|
4514
|
+
if np.any(np.isfinite(self.mean_mse_)):
|
|
4515
|
+
self.best_score_ = float(np.nanmin(self.mean_mse_))
|
|
4516
|
+
else:
|
|
4517
|
+
self.best_score_ = np.nan
|
|
4518
|
+
|
|
4519
|
+
estimator = Lasso(
|
|
4520
|
+
alpha=self.alpha_,
|
|
4521
|
+
fit_intercept=self.fit_intercept,
|
|
4522
|
+
max_iter=self.max_iter,
|
|
4523
|
+
tol=self.tol,
|
|
4524
|
+
stopping=self.stopping,
|
|
4525
|
+
inference_method=self.inference_method,
|
|
4526
|
+
n_bootstrap=self.n_bootstrap,
|
|
4527
|
+
bootstrap_random_state=self.bootstrap_random_state,
|
|
4528
|
+
device=self.device,
|
|
4529
|
+
n_jobs=self.n_jobs,
|
|
4530
|
+
compute_inference=self.compute_inference,
|
|
4531
|
+
solver=self.solver,
|
|
4532
|
+
cpu_solver=effective_cpu_solver,
|
|
4533
|
+
lipschitz_L=self.lipschitz_L,
|
|
4534
|
+
admm_rho=self.admm_rho,
|
|
4535
|
+
gpu_memory_cleanup=self.gpu_memory_cleanup,
|
|
4536
|
+
)
|
|
4537
|
+
|
|
4538
|
+
fast_refit_enabled = (
|
|
4539
|
+
(not bool(self.compute_inference))
|
|
4540
|
+
and str(self.solver).lower() == "fista"
|
|
4541
|
+
and str(self.stopping).lower() in ("coef_delta", "kkt")
|
|
4542
|
+
)
|
|
4543
|
+
|
|
4544
|
+
if fast_refit_enabled:
|
|
4545
|
+
fast = _fit_lasso_single_alpha_fast(
|
|
4546
|
+
X,
|
|
4547
|
+
y,
|
|
4548
|
+
alpha=float(self.alpha_),
|
|
4549
|
+
fit_intercept=bool(self.fit_intercept),
|
|
4550
|
+
max_iter=int(self.max_iter),
|
|
4551
|
+
tol=float(self.tol),
|
|
4552
|
+
stopping=str(self.stopping),
|
|
4553
|
+
device=str(device_name),
|
|
4554
|
+
cpu_solver=str(effective_cpu_solver),
|
|
4555
|
+
cd_kkt_check_every=int(effective_cd_kkt_check_every),
|
|
4556
|
+
sample_weight=sample_weight,
|
|
4557
|
+
)
|
|
4558
|
+
|
|
4559
|
+
estimator.coef_ = np.asarray(fast["coef"], dtype=np.float64)
|
|
4560
|
+
estimator.intercept_ = float(fast["intercept"])
|
|
4561
|
+
estimator.n_iter_ = int(fast["n_iter"])
|
|
4562
|
+
estimator._nobs = int(fast["n_samples"])
|
|
4563
|
+
estimator._df_resid = int(fast["n_samples"]) - (
|
|
4564
|
+
int(fast["n_features"]) + (1 if bool(self.fit_intercept) else 0)
|
|
4565
|
+
)
|
|
4566
|
+
|
|
4567
|
+
if bool(self.fit_intercept):
|
|
4568
|
+
estimator._params = np.concatenate(
|
|
4569
|
+
[[estimator.intercept_], estimator.coef_]
|
|
4570
|
+
)
|
|
4571
|
+
else:
|
|
4572
|
+
estimator._params = estimator.coef_.copy()
|
|
4573
|
+
|
|
4574
|
+
estimator._scale = np.nan
|
|
4575
|
+
estimator._resid = None
|
|
4576
|
+
estimator._X_design = None
|
|
4577
|
+
estimator._fitted = True
|
|
4578
|
+
else:
|
|
4579
|
+
estimator.fit(X, y, sample_weight=sample_weight)
|
|
4580
|
+
|
|
4581
|
+
self.estimator_ = estimator
|
|
4582
|
+
self.coef_ = np.asarray(estimator.coef_)
|
|
4583
|
+
self.intercept_ = estimator.intercept_
|
|
4584
|
+
self.n_iter_ = int(estimator.n_iter_)
|
|
4585
|
+
|
|
4586
|
+
self._fitted = True
|
|
4587
|
+
return self
|
|
4588
|
+
|
|
4589
|
+
def predict(self, X):
|
|
4590
|
+
self._check_is_fitted()
|
|
4591
|
+
return self.estimator_.predict(X)
|
|
4592
|
+
|
|
4593
|
+
def score(self, X, y):
|
|
4594
|
+
self._check_is_fitted()
|
|
4595
|
+
return self.estimator_.score(X, y)
|
|
4596
|
+
|
|
4597
|
+
|
|
4598
|
+
# =============================================================================
|
|
4599
|
+
# Torch FISTA Solvers
|
|
4600
|
+
# =============================================================================
|
|
4601
|
+
|
|
4602
|
+
def _solve_lasso_path_gpu_fista_batched_from_gram_torch(
|
|
4603
|
+
XtX,
|
|
4604
|
+
Xty,
|
|
4605
|
+
*,
|
|
4606
|
+
n_samples: int,
|
|
4607
|
+
alphas_desc: np.ndarray,
|
|
4608
|
+
max_iter: int,
|
|
4609
|
+
tol: float,
|
|
4610
|
+
stopping: str,
|
|
4611
|
+
lipschitz_L: Optional[float] = None,
|
|
4612
|
+
check_every: int = 8,
|
|
4613
|
+
):
|
|
4614
|
+
"""Solve descending-alpha Lasso path with a batched Torch FISTA update."""
|
|
4615
|
+
import torch
|
|
4616
|
+
|
|
4617
|
+
n_features = int(XtX.shape[0])
|
|
4618
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
4619
|
+
|
|
4620
|
+
coefs = torch.zeros((n_features, n_alphas), dtype=XtX.dtype, device=XtX.device)
|
|
4621
|
+
yk = coefs.clone()
|
|
4622
|
+
tk = torch.ones((n_alphas,), dtype=XtX.dtype, device=XtX.device)
|
|
4623
|
+
n_iters_gpu = torch.zeros((n_alphas,), dtype=torch.int32, device=XtX.device)
|
|
4624
|
+
|
|
4625
|
+
if lipschitz_L is not None:
|
|
4626
|
+
L = torch.tensor(float(lipschitz_L), dtype=XtX.dtype, device=XtX.device)
|
|
4627
|
+
else:
|
|
4628
|
+
try:
|
|
4629
|
+
eigvals = torch.linalg.eigvalsh(XtX)
|
|
4630
|
+
L = eigvals[-1] / float(max(1, n_samples))
|
|
4631
|
+
except Exception:
|
|
4632
|
+
row_sum_bound = torch.max(torch.sum(torch.abs(XtX), dim=1)) / float(max(1, n_samples))
|
|
4633
|
+
L = torch.maximum(row_sum_bound, torch.tensor(1e-12, dtype=XtX.dtype, device=XtX.device))
|
|
4634
|
+
|
|
4635
|
+
L_scalar = float(L.item())
|
|
4636
|
+
if L_scalar <= 0.0:
|
|
4637
|
+
return coefs.T, torch.zeros((n_alphas,), dtype=torch.int32, device=XtX.device).cpu().numpy()
|
|
4638
|
+
|
|
4639
|
+
n_samp = float(max(1, n_samples))
|
|
4640
|
+
step = 1.0 / L
|
|
4641
|
+
alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
|
|
4642
|
+
alpha_gpu = torch.from_numpy(alphas_desc).to(XtX.device).to(XtX.dtype)
|
|
4643
|
+
thresholds = alpha_gpu * step
|
|
4644
|
+
stopping_name = str(stopping).lower()
|
|
4645
|
+
check_every = max(1, int(check_every))
|
|
4646
|
+
|
|
4647
|
+
active_gpu = torch.arange(n_alphas, dtype=torch.int64, device=XtX.device)
|
|
4648
|
+
|
|
4649
|
+
for iteration in range(int(max_iter)):
|
|
4650
|
+
if int(active_gpu.numel()) == 0:
|
|
4651
|
+
break
|
|
4652
|
+
|
|
4653
|
+
y_active = yk[:, active_gpu]
|
|
4654
|
+
coef_old = coefs[:, active_gpu]
|
|
4655
|
+
|
|
4656
|
+
grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
|
|
4657
|
+
thresh = thresholds[active_gpu].reshape(1, -1)
|
|
4658
|
+
coef_new = torch.sign(y_active - step * grad) * torch.maximum(torch.abs(y_active - step * grad) - thresh, torch.tensor(0.0, dtype=XtX.dtype, device=XtX.device))
|
|
4659
|
+
|
|
4660
|
+
t_old = tk[active_gpu]
|
|
4661
|
+
t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
4662
|
+
beta = (t_old - 1.0) / t_new
|
|
4663
|
+
y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
|
|
4664
|
+
|
|
4665
|
+
coefs[:, active_gpu] = coef_new
|
|
4666
|
+
yk[:, active_gpu] = y_new
|
|
4667
|
+
tk[active_gpu] = t_new
|
|
4668
|
+
|
|
4669
|
+
active_ratio = float(int(active_gpu.numel())) / float(max(1, n_alphas))
|
|
4670
|
+
check_every_eff = _adaptive_gpu_check_every(
|
|
4671
|
+
base_check_every=check_every,
|
|
4672
|
+
iteration=iteration,
|
|
4673
|
+
max_iter=int(max_iter),
|
|
4674
|
+
active_ratio=active_ratio,
|
|
4675
|
+
)
|
|
4676
|
+
should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
|
|
4677
|
+
if not should_check:
|
|
4678
|
+
continue
|
|
4679
|
+
|
|
4680
|
+
if stopping_name == "kkt":
|
|
4681
|
+
grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
|
|
4682
|
+
viol = torch.max(
|
|
4683
|
+
torch.maximum(
|
|
4684
|
+
torch.abs(grad_sse) - alpha_gpu[active_gpu].reshape(1, -1),
|
|
4685
|
+
torch.tensor(0.0, dtype=XtX.dtype, device=XtX.device),
|
|
4686
|
+
),
|
|
4687
|
+
dim=0,
|
|
4688
|
+
).values
|
|
4689
|
+
converged_local_gpu = viol < float(tol)
|
|
4690
|
+
else:
|
|
4691
|
+
delta = torch.sum(torch.abs(coef_new - coef_old), dim=0)
|
|
4692
|
+
converged_local_gpu = delta < float(tol)
|
|
4693
|
+
|
|
4694
|
+
done_gpu = active_gpu[converged_local_gpu]
|
|
4695
|
+
if int(done_gpu.numel()) == 0:
|
|
4696
|
+
continue
|
|
4697
|
+
|
|
4698
|
+
n_iters_gpu[done_gpu] = int(iteration) + 1
|
|
4699
|
+
yk[:, done_gpu] = coefs[:, done_gpu]
|
|
4700
|
+
active_gpu = active_gpu[~converged_local_gpu]
|
|
4701
|
+
|
|
4702
|
+
if int(active_gpu.numel()) > 0:
|
|
4703
|
+
n_iters_gpu[active_gpu] = int(max_iter)
|
|
4704
|
+
|
|
4705
|
+
return coefs.T, n_iters_gpu.cpu().numpy()
|
|
4706
|
+
|
|
4707
|
+
|
|
4708
|
+
def _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
|
|
4709
|
+
XtX_batch,
|
|
4710
|
+
Xty_batch,
|
|
4711
|
+
*,
|
|
4712
|
+
n_samples_vec: np.ndarray,
|
|
4713
|
+
alphas_desc: np.ndarray,
|
|
4714
|
+
max_iter: int,
|
|
4715
|
+
tol: float,
|
|
4716
|
+
stopping: str,
|
|
4717
|
+
lipschitz_L: Optional[float] = None,
|
|
4718
|
+
check_every: int = 8,
|
|
4719
|
+
):
|
|
4720
|
+
"""Solve descending-alpha Lasso paths for all folds together on Torch GPU."""
|
|
4721
|
+
import torch
|
|
4722
|
+
|
|
4723
|
+
n_folds = int(XtX_batch.shape[0])
|
|
4724
|
+
n_features = int(XtX_batch.shape[1])
|
|
4725
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
4726
|
+
|
|
4727
|
+
coefs = torch.zeros((n_folds, n_features, n_alphas), dtype=XtX_batch.dtype, device=XtX_batch.device)
|
|
4728
|
+
yk = coefs.clone()
|
|
4729
|
+
tk = torch.ones((n_folds, n_alphas), dtype=XtX_batch.dtype, device=XtX_batch.device)
|
|
4730
|
+
n_iters_gpu = torch.zeros((n_folds, n_alphas), dtype=torch.int32, device=XtX_batch.device)
|
|
4731
|
+
|
|
4732
|
+
n_vec_cpu = n_samples_vec.cpu().numpy().astype(np.float64).reshape(-1)
|
|
4733
|
+
if n_vec_cpu.size != n_folds:
|
|
4734
|
+
raise ValueError("n_samples_vec must have one entry per fold")
|
|
4735
|
+
n_vec = torch.from_numpy(n_vec_cpu).to(XtX_batch.device).to(XtX_batch.dtype)
|
|
4736
|
+
|
|
4737
|
+
if lipschitz_L is not None:
|
|
4738
|
+
L = torch.full((n_folds,), float(lipschitz_L), dtype=XtX_batch.dtype, device=XtX_batch.device)
|
|
4739
|
+
else:
|
|
4740
|
+
try:
|
|
4741
|
+
eigvals = torch.linalg.eigvalsh(XtX_batch)
|
|
4742
|
+
L = eigvals[:, -1] / n_vec
|
|
4743
|
+
except Exception:
|
|
4744
|
+
row_sum_bound = torch.max(torch.sum(torch.abs(XtX_batch), dim=2), dim=1).values / n_vec
|
|
4745
|
+
L = torch.maximum(row_sum_bound, torch.tensor(1e-12, dtype=XtX_batch.dtype, device=XtX_batch.device))
|
|
4746
|
+
|
|
4747
|
+
step = 1.0 / L.reshape(n_folds, 1, 1)
|
|
4748
|
+
alpha_gpu = torch.from_numpy(np.asarray(alphas_desc, dtype=np.float64)).to(XtX_batch.device).to(XtX_batch.dtype).reshape(1, 1, n_alphas)
|
|
4749
|
+
thresholds = alpha_gpu * step
|
|
4750
|
+
|
|
4751
|
+
Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
|
|
4752
|
+
n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
|
|
4753
|
+
stopping_name = str(stopping).lower()
|
|
4754
|
+
check_every = max(1, int(check_every))
|
|
4755
|
+
|
|
4756
|
+
active_gpu = torch.ones((n_folds, n_alphas), dtype=torch.bool, device=XtX_batch.device)
|
|
4757
|
+
active_count = int(n_folds * n_alphas)
|
|
4758
|
+
|
|
4759
|
+
for iteration in range(int(max_iter)):
|
|
4760
|
+
if active_count == 0:
|
|
4761
|
+
break
|
|
4762
|
+
|
|
4763
|
+
active_expanded = active_gpu.unsqueeze(1)
|
|
4764
|
+
|
|
4765
|
+
coef_old = coefs.clone()
|
|
4766
|
+
grad = (torch.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
|
|
4767
|
+
coef_candidate = torch.sign(yk - step * grad) * torch.maximum(torch.abs(yk - step * grad) - thresholds, torch.tensor(0.0, dtype=XtX_batch.dtype, device=XtX_batch.device))
|
|
4768
|
+
coefs = torch.where(active_expanded, coef_candidate, coefs)
|
|
4769
|
+
|
|
4770
|
+
t_old = tk
|
|
4771
|
+
t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
4772
|
+
beta = (t_old - 1.0) / t_new
|
|
4773
|
+
y_candidate = coefs + beta.unsqueeze(1) * (coefs - coef_old)
|
|
4774
|
+
yk = torch.where(active_expanded, y_candidate, yk)
|
|
4775
|
+
tk = torch.where(active_gpu, t_new, tk)
|
|
4776
|
+
|
|
4777
|
+
active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
|
|
4778
|
+
check_every_eff = _adaptive_gpu_check_every(
|
|
4779
|
+
base_check_every=check_every,
|
|
4780
|
+
iteration=iteration,
|
|
4781
|
+
max_iter=int(max_iter),
|
|
4782
|
+
active_ratio=active_ratio,
|
|
4783
|
+
)
|
|
4784
|
+
should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
|
|
4785
|
+
if not should_check:
|
|
4786
|
+
continue
|
|
4787
|
+
|
|
4788
|
+
if stopping_name == "kkt":
|
|
4789
|
+
grad_sse = (torch.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
|
|
4790
|
+
violation = torch.max(torch.maximum(torch.abs(grad_sse) - alpha_gpu, torch.tensor(0.0, dtype=XtX_batch.dtype, device=XtX_batch.device)), dim=1).values
|
|
4791
|
+
converged_local_gpu = violation < float(tol)
|
|
4792
|
+
else:
|
|
4793
|
+
delta = torch.sum(torch.abs(coefs - coef_old), dim=1)
|
|
4794
|
+
converged_local_gpu = delta < float(tol)
|
|
4795
|
+
|
|
4796
|
+
newly_done_gpu = active_gpu & converged_local_gpu
|
|
4797
|
+
done_count = int(torch.count_nonzero(newly_done_gpu).item())
|
|
4798
|
+
if done_count == 0:
|
|
4799
|
+
continue
|
|
4800
|
+
|
|
4801
|
+
n_iters_gpu[newly_done_gpu] = int(iteration) + 1
|
|
4802
|
+
yk = torch.where(newly_done_gpu.unsqueeze(1), coefs, yk)
|
|
4803
|
+
active_gpu = active_gpu & (~converged_local_gpu)
|
|
4804
|
+
active_count -= done_count
|
|
4805
|
+
|
|
4806
|
+
n_iters_gpu[active_gpu] = int(max_iter)
|
|
4807
|
+
|
|
4808
|
+
return coefs.permute(0, 2, 1), n_iters_gpu.cpu().numpy()
|
|
4809
|
+
|
|
4810
|
+
def summary(self):
|
|
4811
|
+
self._check_is_fitted()
|
|
4812
|
+
return self.estimator_.summary()
|
|
4813
|
+
|
|
4814
|
+
|
|
4815
|
+
# =============================================================================
|
|
4816
|
+
# V9 thin wrapper
|
|
4817
|
+
# =============================================================================
|
|
4818
|
+
|
|
4819
|
+
from ._penalized import PenalizedLinearRegression as _PenalizedLinearRegression
|
|
4820
|
+
|
|
4821
|
+
|
|
4822
|
+
class Lasso(_PenalizedLinearRegression):
|
|
4823
|
+
"""Thin sklearn-style wrapper over ``PenalizedLinearRegression`` with L1 penalty."""
|
|
4824
|
+
|
|
4825
|
+
def __init__(
|
|
4826
|
+
self,
|
|
4827
|
+
alpha: float = 1.0,
|
|
4828
|
+
fit_intercept: bool = True,
|
|
4829
|
+
max_iter: int = 1000,
|
|
4830
|
+
tol: float = 1e-4,
|
|
4831
|
+
stopping: str = "coef_delta",
|
|
4832
|
+
inference_method: str = "cpu_ols_inference",
|
|
4833
|
+
n_bootstrap: int = 200,
|
|
4834
|
+
bootstrap_random_state: Optional[int] = None,
|
|
4835
|
+
enable_simultaneous_inference: bool = False,
|
|
4836
|
+
simultaneous_method: str = "maxz_bootstrap",
|
|
4837
|
+
simultaneous_alpha: float = 0.05,
|
|
4838
|
+
simultaneous_n_bootstrap: int = 1000,
|
|
4839
|
+
simultaneous_random_state: Optional[int] = None,
|
|
4840
|
+
simultaneous_include_intercept: bool = False,
|
|
4841
|
+
device: Union[str, Device] = Device.AUTO,
|
|
4842
|
+
n_jobs: Optional[int] = None,
|
|
4843
|
+
compute_inference: bool = True,
|
|
4844
|
+
solver: str = "fista",
|
|
4845
|
+
cpu_solver: str = "coordinate_descent",
|
|
4846
|
+
lipschitz_L: Optional[float] = None,
|
|
4847
|
+
admm_rho: float = 1.0,
|
|
4848
|
+
gpu_memory_cleanup: bool = False,
|
|
4849
|
+
**kwargs,
|
|
4850
|
+
):
|
|
4851
|
+
self.stopping = str(stopping).lower()
|
|
4852
|
+
self.inference_method = str(inference_method).lower()
|
|
4853
|
+
self.n_bootstrap = int(n_bootstrap)
|
|
4854
|
+
self.bootstrap_random_state = bootstrap_random_state
|
|
4855
|
+
self.enable_simultaneous_inference = bool(enable_simultaneous_inference)
|
|
4856
|
+
self.simultaneous_method = str(simultaneous_method).lower()
|
|
4857
|
+
self.simultaneous_alpha = float(simultaneous_alpha)
|
|
4858
|
+
self.simultaneous_n_bootstrap = int(simultaneous_n_bootstrap)
|
|
4859
|
+
self.simultaneous_random_state = simultaneous_random_state
|
|
4860
|
+
self.simultaneous_include_intercept = bool(simultaneous_include_intercept)
|
|
4861
|
+
self.compute_inference = bool(compute_inference)
|
|
4862
|
+
self.admm_rho = float(admm_rho)
|
|
4863
|
+
self._ignored_kwargs = dict(kwargs)
|
|
4864
|
+
super().__init__(
|
|
4865
|
+
penalty="l1",
|
|
4866
|
+
alpha=alpha,
|
|
4867
|
+
fit_intercept=fit_intercept,
|
|
4868
|
+
max_iter=max_iter,
|
|
4869
|
+
tol=tol,
|
|
4870
|
+
device=device,
|
|
4871
|
+
n_jobs=n_jobs,
|
|
4872
|
+
cpu_solver=cpu_solver,
|
|
4873
|
+
solver=solver,
|
|
4874
|
+
lipschitz_L=lipschitz_L,
|
|
4875
|
+
gpu_memory_cleanup=gpu_memory_cleanup,
|
|
4876
|
+
)
|