statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2124 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lasso regression with full statistical inference and GPU support.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__all__ = ["Lasso"]
|
|
6
|
+
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
import hashlib
|
|
9
|
+
import threading
|
|
10
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
11
|
+
import os
|
|
12
|
+
import warnings
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from numba import njit
|
|
17
|
+
|
|
18
|
+
_NUMBA_AVAILABLE = True
|
|
19
|
+
except Exception:
|
|
20
|
+
njit = None
|
|
21
|
+
_NUMBA_AVAILABLE = False
|
|
22
|
+
|
|
23
|
+
from statgpu._base import BaseEstimator
|
|
24
|
+
from statgpu.backends import _to_numpy
|
|
25
|
+
from statgpu._config import Device
|
|
26
|
+
from statgpu.cross_validation._base import CVEstimatorBase, kfold_indices as _kfold_indices, batch_mse as _batch_mse_cv
|
|
27
|
+
from statgpu.backends import get_backend
|
|
28
|
+
from statgpu.inference._distributions_backend import (
|
|
29
|
+
norm,
|
|
30
|
+
t,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_NUMBA_CD_DISABLED = str(os.getenv("STATGPU_DISABLE_NUMBA_CD", "0")).strip().lower() in (
|
|
35
|
+
"1",
|
|
36
|
+
"true",
|
|
37
|
+
"yes",
|
|
38
|
+
"on",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
_LASSO_CV_ALPHA_CACHE_MAXSIZE = int(os.getenv("STATGPU_LASSO_CV_CACHE_SIZE", "64"))
|
|
42
|
+
_LASSO_CV_ALPHA_CACHE: "OrderedDict[Tuple[Any, ...], Dict[str, Any]]" = OrderedDict()
|
|
43
|
+
_LASSO_DEBIASED_M_CACHE_MAXSIZE = int(os.getenv("STATGPU_LASSO_DEBIASED_M_CACHE_SIZE", "16"))
|
|
44
|
+
_LASSO_DEBIASED_M_CACHE: "OrderedDict[Tuple[Any, ...], np.ndarray]" = OrderedDict()
|
|
45
|
+
_LASSO_DEBIASED_M_GPU_HASH_ROW_CHUNK = 1024
|
|
46
|
+
_cache_lock = threading.Lock()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ============================================================================
|
|
50
|
+
# CuPy Fused Kernels for Lasso - Now implemented as Lasso class methods
|
|
51
|
+
# See Lasso._get_cupy_fused_kernels() for details.
|
|
52
|
+
# ============================================================================
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _debiased_m_cache_get(key):
|
|
56
|
+
with _cache_lock:
|
|
57
|
+
val = _LASSO_DEBIASED_M_CACHE.get(key)
|
|
58
|
+
if val is not None:
|
|
59
|
+
_LASSO_DEBIASED_M_CACHE.move_to_end(key)
|
|
60
|
+
return val
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _debiased_m_cache_put(key, value):
|
|
64
|
+
with _cache_lock:
|
|
65
|
+
_LASSO_DEBIASED_M_CACHE[key] = value
|
|
66
|
+
_LASSO_DEBIASED_M_CACHE.move_to_end(key)
|
|
67
|
+
while len(_LASSO_DEBIASED_M_CACHE) > _LASSO_DEBIASED_M_CACHE_MAXSIZE:
|
|
68
|
+
_LASSO_DEBIASED_M_CACHE.popitem(last=False)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _debiased_m_key_from_numpy_design(
|
|
72
|
+
X: np.ndarray,
|
|
73
|
+
*,
|
|
74
|
+
n: int,
|
|
75
|
+
p: int,
|
|
76
|
+
lam_nw: float,
|
|
77
|
+
tol: float,
|
|
78
|
+
):
|
|
79
|
+
X_cache = np.asarray(X)
|
|
80
|
+
if not X_cache.flags["C_CONTIGUOUS"]:
|
|
81
|
+
X_cache = np.ascontiguousarray(X_cache)
|
|
82
|
+
h = hashlib.blake2b(digest_size=32)
|
|
83
|
+
h.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
|
|
84
|
+
h.update(str(X_cache.dtype).encode("utf-8"))
|
|
85
|
+
h.update(np.asarray([float(lam_nw), float(tol)], dtype=np.float64).tobytes())
|
|
86
|
+
h.update(X_cache.view(np.uint8).tobytes())
|
|
87
|
+
return h.hexdigest()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _debiased_m_key_from_sample(
|
|
91
|
+
*,
|
|
92
|
+
n: int,
|
|
93
|
+
p: int,
|
|
94
|
+
dtype_name: str,
|
|
95
|
+
sample_block: np.ndarray,
|
|
96
|
+
lam_nw: float,
|
|
97
|
+
tol: float,
|
|
98
|
+
):
|
|
99
|
+
"""Generate cache key for debiased M matrix from a sample block of X.
|
|
100
|
+
|
|
101
|
+
This is used for Torch backend where we don't want to hash the entire matrix.
|
|
102
|
+
"""
|
|
103
|
+
h = hashlib.blake2b(digest_size=32)
|
|
104
|
+
h.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
|
|
105
|
+
h.update(dtype_name.encode("utf-8"))
|
|
106
|
+
h.update(np.asarray([float(lam_nw), float(tol)], dtype=np.float64).tobytes())
|
|
107
|
+
if not sample_block.flags["C_CONTIGUOUS"]:
|
|
108
|
+
sample_block = np.ascontiguousarray(sample_block)
|
|
109
|
+
h.update(sample_block.view(np.uint8).tobytes())
|
|
110
|
+
return h.hexdigest()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _lasso_alpha_heuristic(y_centered: np.ndarray, n_features: int) -> float:
|
|
115
|
+
n_samples = int(y_centered.shape[0])
|
|
116
|
+
if n_samples > 1:
|
|
117
|
+
sigma_hat = float(np.std(y_centered, ddof=1))
|
|
118
|
+
else:
|
|
119
|
+
sigma_hat = float(np.std(y_centered))
|
|
120
|
+
sigma_hat = max(sigma_hat, 1e-8)
|
|
121
|
+
penalty_scale = np.sqrt(2.0 * np.log(max(2, int(n_features))) / max(1, n_samples))
|
|
122
|
+
return float(sigma_hat * penalty_scale)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _default_lasso_alpha_grid(
|
|
126
|
+
X: np.ndarray,
|
|
127
|
+
y: np.ndarray,
|
|
128
|
+
n_alphas: int = 12,
|
|
129
|
+
alpha_min_ratio: float = 1e-3,
|
|
130
|
+
) -> np.ndarray:
|
|
131
|
+
n_samples = int(X.shape[0])
|
|
132
|
+
corr = np.abs(X.T @ y) / float(max(1, n_samples))
|
|
133
|
+
alpha_max = float(np.max(corr)) if corr.size else 1.0
|
|
134
|
+
alpha_max = max(alpha_max, _lasso_alpha_heuristic(y, n_features=int(X.shape[1])))
|
|
135
|
+
alpha_max = max(alpha_max, 1e-6)
|
|
136
|
+
|
|
137
|
+
if int(n_alphas) <= 1:
|
|
138
|
+
return np.asarray([alpha_max], dtype=np.float64)
|
|
139
|
+
|
|
140
|
+
alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
|
|
141
|
+
return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _default_lasso_alpha_grid_backend(
|
|
145
|
+
X,
|
|
146
|
+
y,
|
|
147
|
+
backend,
|
|
148
|
+
n_alphas: int = 12,
|
|
149
|
+
alpha_min_ratio: float = 1e-3,
|
|
150
|
+
) -> np.ndarray:
|
|
151
|
+
"""Generate default alpha grid for Lasso using backend abstraction."""
|
|
152
|
+
X_arr = backend.asarray(X, dtype=backend.float64)
|
|
153
|
+
y_arr = backend.asarray(y, dtype=backend.float64).reshape(-1)
|
|
154
|
+
|
|
155
|
+
n_samples = int(X_arr.shape[0])
|
|
156
|
+
corr = backend.abs(X_arr.T @ y_arr) / float(max(1, n_samples))
|
|
157
|
+
# Use shape to check size - works for both numpy and torch
|
|
158
|
+
corr_size = int(corr.shape[0]) if hasattr(corr, 'shape') else len(corr)
|
|
159
|
+
alpha_max = float(backend.to_numpy(backend.max(corr))) if corr_size > 0 else 1.0
|
|
160
|
+
|
|
161
|
+
if n_samples > 1:
|
|
162
|
+
# Use ddof=1 (sample std) to match numpy _lasso_alpha_heuristic
|
|
163
|
+
y_var = backend.sum((y_arr - backend.mean(y_arr)) ** 2) / (n_samples - 1)
|
|
164
|
+
sigma_hat = float(backend.to_numpy(backend.sqrt(y_var)))
|
|
165
|
+
else:
|
|
166
|
+
sigma_hat = 0.0
|
|
167
|
+
|
|
168
|
+
sigma_hat = max(sigma_hat, 1e-8)
|
|
169
|
+
penalty_scale = np.sqrt(2.0 * np.log(max(2, int(X_arr.shape[1]))) / max(1, n_samples))
|
|
170
|
+
alpha_max = max(alpha_max, float(sigma_hat * penalty_scale), 1e-6)
|
|
171
|
+
|
|
172
|
+
if int(n_alphas) <= 1:
|
|
173
|
+
return np.asarray([alpha_max], dtype=np.float64)
|
|
174
|
+
|
|
175
|
+
alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
|
|
176
|
+
return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _default_lasso_alpha_grid_cupy(
|
|
180
|
+
X,
|
|
181
|
+
y,
|
|
182
|
+
n_alphas: int = 12,
|
|
183
|
+
alpha_min_ratio: float = 1e-3,
|
|
184
|
+
) -> np.ndarray:
|
|
185
|
+
import cupy as cp
|
|
186
|
+
|
|
187
|
+
X_cp = cp.asarray(X, dtype=cp.float64)
|
|
188
|
+
y_cp = cp.asarray(y, dtype=cp.float64).reshape(-1)
|
|
189
|
+
|
|
190
|
+
n_samples = int(X_cp.shape[0])
|
|
191
|
+
corr = cp.abs(X_cp.T @ y_cp) / float(max(1, n_samples))
|
|
192
|
+
alpha_max = float(cp.max(corr).item()) if int(corr.size) > 0 else 1.0
|
|
193
|
+
|
|
194
|
+
if n_samples > 1:
|
|
195
|
+
sigma_hat = float(cp.std(y_cp, ddof=1).item())
|
|
196
|
+
else:
|
|
197
|
+
sigma_hat = float(cp.std(y_cp).item())
|
|
198
|
+
|
|
199
|
+
sigma_hat = max(sigma_hat, 1e-8)
|
|
200
|
+
penalty_scale = np.sqrt(2.0 * np.log(max(2, int(X_cp.shape[1]))) / max(1, n_samples))
|
|
201
|
+
alpha_max = max(alpha_max, float(sigma_hat * penalty_scale), 1e-6)
|
|
202
|
+
|
|
203
|
+
if int(n_alphas) <= 1:
|
|
204
|
+
return np.asarray([alpha_max], dtype=np.float64)
|
|
205
|
+
|
|
206
|
+
alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
|
|
207
|
+
return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _normalize_cv_splits(cv_splits, n_samples: int):
|
|
211
|
+
if cv_splits is None:
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
n = int(n_samples)
|
|
215
|
+
folds = []
|
|
216
|
+
|
|
217
|
+
for split in cv_splits:
|
|
218
|
+
if not isinstance(split, (tuple, list)) or len(split) != 2:
|
|
219
|
+
raise ValueError("Each cv_splits entry must be a (train_idx, val_idx) pair")
|
|
220
|
+
|
|
221
|
+
train_idx = np.asarray(split[0], dtype=np.int64).reshape(-1)
|
|
222
|
+
val_idx = np.asarray(split[1], dtype=np.int64).reshape(-1)
|
|
223
|
+
|
|
224
|
+
if train_idx.size == 0 or val_idx.size == 0:
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
if (
|
|
228
|
+
bool(np.any(train_idx < 0))
|
|
229
|
+
or bool(np.any(train_idx >= n))
|
|
230
|
+
or bool(np.any(val_idx < 0))
|
|
231
|
+
or bool(np.any(val_idx >= n))
|
|
232
|
+
):
|
|
233
|
+
raise ValueError("cv_splits indices are out of range")
|
|
234
|
+
|
|
235
|
+
folds.append((train_idx, val_idx))
|
|
236
|
+
|
|
237
|
+
if len(folds) == 0:
|
|
238
|
+
raise ValueError("cv_splits must contain at least one non-empty split")
|
|
239
|
+
|
|
240
|
+
return folds
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _folds_are_complements(folds, n_samples: int) -> bool:
|
|
244
|
+
"""Return True when each fold uses train as the exact complement of validation."""
|
|
245
|
+
n = int(n_samples)
|
|
246
|
+
for train_idx, val_idx in folds:
|
|
247
|
+
train_arr = np.asarray(train_idx, dtype=np.int64).reshape(-1)
|
|
248
|
+
val_arr = np.asarray(val_idx, dtype=np.int64).reshape(-1)
|
|
249
|
+
|
|
250
|
+
if int(train_arr.size + val_arr.size) != n:
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
mask = np.zeros((n,), dtype=np.int8)
|
|
254
|
+
mask[train_arr] = 1
|
|
255
|
+
if bool(np.any(mask[val_arr] != 0)):
|
|
256
|
+
return False
|
|
257
|
+
mask[val_arr] = 1
|
|
258
|
+
if bool(np.any(mask == 0)):
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
return True
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _array_identity_token(x: Any) -> Tuple[Any, ...]:
|
|
265
|
+
"""Content-based hash token for array cache keys.
|
|
266
|
+
|
|
267
|
+
Uses sampled rows (via blake2b digest) to keep hashing fast for large
|
|
268
|
+
arrays while avoiding false cache hits from memory pointer reuse.
|
|
269
|
+
"""
|
|
270
|
+
if x is None:
|
|
271
|
+
return ("none",)
|
|
272
|
+
|
|
273
|
+
import hashlib
|
|
274
|
+
|
|
275
|
+
def _hash_bytes(data: bytes) -> str:
|
|
276
|
+
return hashlib.blake2b(data, digest_size=16).hexdigest()
|
|
277
|
+
|
|
278
|
+
def _sample_and_hash(arr_np, n_sample=100):
|
|
279
|
+
"""Hash a representative sample of rows for large arrays."""
|
|
280
|
+
n = arr_np.shape[0]
|
|
281
|
+
if n <= n_sample:
|
|
282
|
+
sample = arr_np
|
|
283
|
+
else:
|
|
284
|
+
idx = np.linspace(0, n - 1, n_sample, dtype=int)
|
|
285
|
+
sample = arr_np[idx]
|
|
286
|
+
return _hash_bytes(np.ascontiguousarray(sample).tobytes())
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
import cupy as cp
|
|
290
|
+
|
|
291
|
+
if isinstance(x, cp.ndarray):
|
|
292
|
+
# Sample on GPU first, then transfer only sampled rows
|
|
293
|
+
n = x.shape[0]
|
|
294
|
+
if n <= 100:
|
|
295
|
+
arr_np = cp.asnumpy(x).astype(np.float64)
|
|
296
|
+
else:
|
|
297
|
+
idx = cp.linspace(0, n - 1, 100, dtype=cp.int64)
|
|
298
|
+
arr_np = cp.asnumpy(x[idx]).astype(np.float64)
|
|
299
|
+
h = _hash_bytes(np.ascontiguousarray(arr_np).tobytes())
|
|
300
|
+
return ("cupy", h, tuple(int(v) for v in x.shape), str(x.dtype))
|
|
301
|
+
except Exception:
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
# Check for Torch tensors
|
|
305
|
+
try:
|
|
306
|
+
import torch
|
|
307
|
+
|
|
308
|
+
if isinstance(x, torch.Tensor):
|
|
309
|
+
# Sample on GPU first, then transfer only sampled rows
|
|
310
|
+
n = x.shape[0]
|
|
311
|
+
if n <= 100:
|
|
312
|
+
arr_np = x.detach().cpu().numpy().astype(np.float64)
|
|
313
|
+
else:
|
|
314
|
+
idx = torch.linspace(0, n - 1, 100, dtype=torch.long, device=x.device)
|
|
315
|
+
arr_np = x[idx].detach().cpu().numpy().astype(np.float64)
|
|
316
|
+
h = _hash_bytes(np.ascontiguousarray(arr_np).tobytes())
|
|
317
|
+
return ("torch", h, tuple(int(v) for v in x.shape), str(x.dtype))
|
|
318
|
+
except Exception:
|
|
319
|
+
pass
|
|
320
|
+
|
|
321
|
+
arr = np.asarray(x, dtype=np.float64)
|
|
322
|
+
h = _sample_and_hash(arr)
|
|
323
|
+
return ("numpy", h, tuple(int(v) for v in arr.shape), str(arr.dtype))
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _alphas_signature(alphas: np.ndarray) -> str:
|
|
327
|
+
arr = np.ascontiguousarray(np.asarray(alphas, dtype=np.float64).reshape(-1))
|
|
328
|
+
return hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest()
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _folds_signature(folds) -> str:
|
|
332
|
+
hasher = hashlib.blake2b(digest_size=16)
|
|
333
|
+
for train_idx, val_idx in folds:
|
|
334
|
+
train_arr = np.ascontiguousarray(np.asarray(train_idx, dtype=np.int64).reshape(-1))
|
|
335
|
+
val_arr = np.ascontiguousarray(np.asarray(val_idx, dtype=np.int64).reshape(-1))
|
|
336
|
+
hasher.update(train_arr.tobytes())
|
|
337
|
+
hasher.update(b"|")
|
|
338
|
+
hasher.update(val_arr.tobytes())
|
|
339
|
+
hasher.update(b";")
|
|
340
|
+
return hasher.hexdigest()
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _make_lasso_cv_auto_cache_key(
|
|
344
|
+
*,
|
|
345
|
+
X,
|
|
346
|
+
y,
|
|
347
|
+
sample_weight,
|
|
348
|
+
alpha_grid: np.ndarray,
|
|
349
|
+
folds,
|
|
350
|
+
fit_intercept: bool,
|
|
351
|
+
use_gpu: bool,
|
|
352
|
+
max_iter: int,
|
|
353
|
+
tol: float,
|
|
354
|
+
cpu_solver: str,
|
|
355
|
+
cv_method: str,
|
|
356
|
+
cd_kkt_check_every: Optional[int],
|
|
357
|
+
gpu_cv_mixed_precision: bool,
|
|
358
|
+
) -> Tuple[Any, ...]:
|
|
359
|
+
return (
|
|
360
|
+
"lasso_cv_auto_v1",
|
|
361
|
+
_array_identity_token(X),
|
|
362
|
+
_array_identity_token(y),
|
|
363
|
+
_array_identity_token(sample_weight),
|
|
364
|
+
_alphas_signature(alpha_grid),
|
|
365
|
+
_folds_signature(folds),
|
|
366
|
+
bool(fit_intercept),
|
|
367
|
+
bool(use_gpu),
|
|
368
|
+
int(max_iter),
|
|
369
|
+
float(tol),
|
|
370
|
+
str(cpu_solver).lower(),
|
|
371
|
+
str(cv_method).lower(),
|
|
372
|
+
None if cd_kkt_check_every is None else int(cd_kkt_check_every),
|
|
373
|
+
bool(gpu_cv_mixed_precision),
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _clone_lasso_cv_cache_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
378
|
+
return {
|
|
379
|
+
"alpha": float(payload["alpha"]),
|
|
380
|
+
"alphas": np.asarray(payload["alphas"], dtype=np.float64).copy(),
|
|
381
|
+
"mse_path": np.asarray(payload["mse_path"], dtype=np.float64).copy(),
|
|
382
|
+
"mean_mse": np.asarray(payload["mean_mse"], dtype=np.float64).copy(),
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def _lasso_cv_cache_get(cache_key: Optional[Tuple[Any, ...]]) -> Optional[Dict[str, Any]]:
|
|
387
|
+
if cache_key is None or _LASSO_CV_ALPHA_CACHE_MAXSIZE <= 0:
|
|
388
|
+
return None
|
|
389
|
+
|
|
390
|
+
with _cache_lock:
|
|
391
|
+
cached = _LASSO_CV_ALPHA_CACHE.get(cache_key)
|
|
392
|
+
if cached is None:
|
|
393
|
+
return None
|
|
394
|
+
_LASSO_CV_ALPHA_CACHE.move_to_end(cache_key)
|
|
395
|
+
return _clone_lasso_cv_cache_payload(cached)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _lasso_cv_cache_put(cache_key: Optional[Tuple[Any, ...]], payload: Dict[str, Any]) -> None:
|
|
399
|
+
if cache_key is None or _LASSO_CV_ALPHA_CACHE_MAXSIZE <= 0:
|
|
400
|
+
return
|
|
401
|
+
|
|
402
|
+
with _cache_lock:
|
|
403
|
+
_LASSO_CV_ALPHA_CACHE[cache_key] = _clone_lasso_cv_cache_payload(payload)
|
|
404
|
+
_LASSO_CV_ALPHA_CACHE.move_to_end(cache_key)
|
|
405
|
+
while len(_LASSO_CV_ALPHA_CACHE) > int(_LASSO_CV_ALPHA_CACHE_MAXSIZE):
|
|
406
|
+
_LASSO_CV_ALPHA_CACHE.popitem(last=False)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _adaptive_gpu_check_every(
|
|
410
|
+
*,
|
|
411
|
+
base_check_every: int,
|
|
412
|
+
iteration: int,
|
|
413
|
+
max_iter: int,
|
|
414
|
+
active_ratio: float,
|
|
415
|
+
) -> int:
|
|
416
|
+
"""Adaptive cadence for expensive global convergence checks on GPU."""
|
|
417
|
+
base = max(1, int(base_check_every))
|
|
418
|
+
ratio = float(max(0.0, min(1.0, active_ratio)))
|
|
419
|
+
|
|
420
|
+
if ratio >= 0.75:
|
|
421
|
+
interval = max(base, 16)
|
|
422
|
+
elif ratio >= 0.40:
|
|
423
|
+
interval = max(base, 12)
|
|
424
|
+
elif ratio >= 0.15:
|
|
425
|
+
interval = max(4, base)
|
|
426
|
+
else:
|
|
427
|
+
interval = max(2, base // 2)
|
|
428
|
+
|
|
429
|
+
progress = float(iteration + 1) / float(max(1, int(max_iter)))
|
|
430
|
+
if progress >= 0.90:
|
|
431
|
+
interval = min(interval, 2)
|
|
432
|
+
elif progress >= 0.75:
|
|
433
|
+
interval = min(interval, 4)
|
|
434
|
+
|
|
435
|
+
return max(1, int(interval))
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def _soft_threshold_numpy(x: np.ndarray, gamma: float) -> np.ndarray:
|
|
439
|
+
gamma_arr = np.asarray(gamma, dtype=np.float64)
|
|
440
|
+
return np.sign(x) * np.maximum(np.abs(x) - gamma_arr, 0.0)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _soft_threshold_scalar(x: float, gamma: float) -> float:
|
|
444
|
+
ax = abs(float(x))
|
|
445
|
+
g = float(gamma)
|
|
446
|
+
if ax <= g:
|
|
447
|
+
return 0.0
|
|
448
|
+
return float(np.sign(x) * (ax - g))
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
if _NUMBA_AVAILABLE:
|
|
452
|
+
|
|
453
|
+
@njit(cache=True)
|
|
454
|
+
def _soft_threshold_scalar_numba(x: float, gamma: float) -> float:
|
|
455
|
+
ax = abs(x)
|
|
456
|
+
if ax <= gamma:
|
|
457
|
+
return 0.0
|
|
458
|
+
if x >= 0.0:
|
|
459
|
+
return ax - gamma
|
|
460
|
+
return -(ax - gamma)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
@njit(cache=True)
|
|
464
|
+
def _solve_lasso_path_cpu_cd_numba_impl(
|
|
465
|
+
XtX: np.ndarray,
|
|
466
|
+
Xty: np.ndarray,
|
|
467
|
+
n_samples: int,
|
|
468
|
+
alphas_desc: np.ndarray,
|
|
469
|
+
max_iter: int,
|
|
470
|
+
tol: float,
|
|
471
|
+
stopping_is_kkt: bool,
|
|
472
|
+
cd_kkt_check_every: int,
|
|
473
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
474
|
+
n_features = XtX.shape[0]
|
|
475
|
+
n_alphas = alphas_desc.shape[0]
|
|
476
|
+
|
|
477
|
+
coefs_path = np.zeros((n_alphas, n_features), dtype=np.float64)
|
|
478
|
+
n_iters = np.zeros((n_alphas,), dtype=np.int32)
|
|
479
|
+
|
|
480
|
+
coef = np.zeros((n_features,), dtype=np.float64)
|
|
481
|
+
grad = -Xty.copy()
|
|
482
|
+
|
|
483
|
+
X_sq_norms = np.empty((n_features,), dtype=np.float64)
|
|
484
|
+
for j in range(n_features):
|
|
485
|
+
X_sq_norms[j] = XtX[j, j]
|
|
486
|
+
|
|
487
|
+
n_samp = float(max(1, n_samples))
|
|
488
|
+
alpha_scaled_desc = np.empty((n_alphas,), dtype=np.float64)
|
|
489
|
+
for idx in range(n_alphas):
|
|
490
|
+
alpha_scaled_desc[idx] = alphas_desc[idx] * n_samp
|
|
491
|
+
|
|
492
|
+
active_mask = np.zeros((n_features,), dtype=np.bool_)
|
|
493
|
+
check_every = max(1, int(cd_kkt_check_every))
|
|
494
|
+
|
|
495
|
+
for alpha_idx in range(n_alphas):
|
|
496
|
+
alpha = float(alphas_desc[alpha_idx])
|
|
497
|
+
alpha_scaled = float(alpha_scaled_desc[alpha_idx])
|
|
498
|
+
if alpha_idx > 0:
|
|
499
|
+
prev_alpha_scaled = float(alpha_scaled_desc[alpha_idx - 1])
|
|
500
|
+
else:
|
|
501
|
+
prev_alpha_scaled = alpha_scaled
|
|
502
|
+
|
|
503
|
+
strong_thresh = 2.0 * alpha_scaled - prev_alpha_scaled
|
|
504
|
+
if strong_thresh < 0.0:
|
|
505
|
+
strong_thresh = 0.0
|
|
506
|
+
|
|
507
|
+
any_active = False
|
|
508
|
+
max_abs_xty = -1.0
|
|
509
|
+
max_abs_xty_idx = 0
|
|
510
|
+
for j in range(n_features):
|
|
511
|
+
abs_xty = abs(Xty[j])
|
|
512
|
+
if abs_xty >= strong_thresh:
|
|
513
|
+
active_mask[j] = True
|
|
514
|
+
any_active = True
|
|
515
|
+
if abs_xty > max_abs_xty:
|
|
516
|
+
max_abs_xty = abs_xty
|
|
517
|
+
max_abs_xty_idx = j
|
|
518
|
+
|
|
519
|
+
if not any_active:
|
|
520
|
+
active_mask[max_abs_xty_idx] = True
|
|
521
|
+
|
|
522
|
+
converged = False
|
|
523
|
+
|
|
524
|
+
for iteration in range(int(max_iter)):
|
|
525
|
+
coef_delta_l1 = 0.0
|
|
526
|
+
|
|
527
|
+
for j in range(n_features):
|
|
528
|
+
if not active_mask[j]:
|
|
529
|
+
continue
|
|
530
|
+
|
|
531
|
+
denom = float(X_sq_norms[j])
|
|
532
|
+
old_val = float(coef[j])
|
|
533
|
+
|
|
534
|
+
if denom > 1e-10:
|
|
535
|
+
rho_j = -float(grad[j]) + denom * old_val
|
|
536
|
+
new_val = _soft_threshold_scalar_numba(rho_j, alpha_scaled) / denom
|
|
537
|
+
else:
|
|
538
|
+
new_val = 0.0
|
|
539
|
+
|
|
540
|
+
delta = new_val - old_val
|
|
541
|
+
if delta != 0.0:
|
|
542
|
+
coef[j] = new_val
|
|
543
|
+
coef_delta_l1 += abs(delta)
|
|
544
|
+
for row_idx in range(n_features):
|
|
545
|
+
grad[row_idx] += XtX[row_idx, j] * delta
|
|
546
|
+
|
|
547
|
+
should_kkt_scan = (
|
|
548
|
+
((iteration + 1) % check_every == 0)
|
|
549
|
+
or (coef_delta_l1 < float(tol))
|
|
550
|
+
or (iteration + 1 == int(max_iter))
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
violation = 0.0
|
|
554
|
+
has_inactive_violation = False
|
|
555
|
+
|
|
556
|
+
if should_kkt_scan:
|
|
557
|
+
for j in range(n_features):
|
|
558
|
+
v = abs(grad[j] / n_samp) - alpha
|
|
559
|
+
if v < 0.0:
|
|
560
|
+
v = 0.0
|
|
561
|
+
if v > violation:
|
|
562
|
+
violation = v
|
|
563
|
+
if v > float(tol) and (not active_mask[j]):
|
|
564
|
+
active_mask[j] = True
|
|
565
|
+
has_inactive_violation = True
|
|
566
|
+
|
|
567
|
+
if stopping_is_kkt:
|
|
568
|
+
if should_kkt_scan and violation < float(tol):
|
|
569
|
+
n_iters[alpha_idx] = int(iteration) + 1
|
|
570
|
+
converged = True
|
|
571
|
+
break
|
|
572
|
+
else:
|
|
573
|
+
if coef_delta_l1 < float(tol) and (not has_inactive_violation):
|
|
574
|
+
n_iters[alpha_idx] = int(iteration) + 1
|
|
575
|
+
converged = True
|
|
576
|
+
break
|
|
577
|
+
|
|
578
|
+
if not converged:
|
|
579
|
+
n_iters[alpha_idx] = int(max_iter)
|
|
580
|
+
|
|
581
|
+
for j in range(n_features):
|
|
582
|
+
coefs_path[alpha_idx, j] = coef[j]
|
|
583
|
+
if abs(coef[j]) > 0.0:
|
|
584
|
+
active_mask[j] = True
|
|
585
|
+
|
|
586
|
+
return coefs_path, n_iters
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _solve_lasso_path_cpu_cd_numba(
|
|
590
|
+
XtX: np.ndarray,
|
|
591
|
+
Xty: np.ndarray,
|
|
592
|
+
*,
|
|
593
|
+
n_samples: int,
|
|
594
|
+
alphas_desc: np.ndarray,
|
|
595
|
+
max_iter: int,
|
|
596
|
+
tol: float,
|
|
597
|
+
stopping: str,
|
|
598
|
+
cd_kkt_check_every: int,
|
|
599
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
600
|
+
XtX_c = np.ascontiguousarray(XtX, dtype=np.float64)
|
|
601
|
+
Xty_c = np.ascontiguousarray(Xty, dtype=np.float64)
|
|
602
|
+
alphas_c = np.ascontiguousarray(np.asarray(alphas_desc, dtype=np.float64))
|
|
603
|
+
stopping_is_kkt = str(stopping).lower() == "kkt"
|
|
604
|
+
return _solve_lasso_path_cpu_cd_numba_impl(
|
|
605
|
+
XtX_c,
|
|
606
|
+
Xty_c,
|
|
607
|
+
int(n_samples),
|
|
608
|
+
alphas_c,
|
|
609
|
+
int(max_iter),
|
|
610
|
+
float(tol),
|
|
611
|
+
bool(stopping_is_kkt),
|
|
612
|
+
int(cd_kkt_check_every),
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def _normalize_lassocv_method(method: str) -> str:
|
|
617
|
+
"""Normalize CV optimization profile name."""
|
|
618
|
+
key = str(method).strip().lower()
|
|
619
|
+
alias_map = {
|
|
620
|
+
"default": "standard",
|
|
621
|
+
"classic": "standard",
|
|
622
|
+
"glmnet_cv": "glmnet",
|
|
623
|
+
"glmnet.cv": "glmnet",
|
|
624
|
+
}
|
|
625
|
+
key = alias_map.get(key, key)
|
|
626
|
+
if key not in ("standard", "glmnet"):
|
|
627
|
+
raise ValueError("method must be one of: 'standard', 'glmnet'")
|
|
628
|
+
return key
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def _normalize_cd_kkt_check_every(cd_kkt_check_every: Optional[int]) -> Optional[int]:
|
|
632
|
+
"""Validate optional coordinate-descent global KKT scan cadence."""
|
|
633
|
+
if cd_kkt_check_every is None:
|
|
634
|
+
return None
|
|
635
|
+
value = int(cd_kkt_check_every)
|
|
636
|
+
if value <= 0:
|
|
637
|
+
raise ValueError("cd_kkt_check_every must be a positive integer or None")
|
|
638
|
+
return value
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def _solve_lasso_path_cpu_fista_batched_from_gram(
|
|
642
|
+
XtX: np.ndarray,
|
|
643
|
+
Xty: np.ndarray,
|
|
644
|
+
*,
|
|
645
|
+
n_samples: int,
|
|
646
|
+
alphas_desc: np.ndarray,
|
|
647
|
+
max_iter: int,
|
|
648
|
+
tol: float,
|
|
649
|
+
stopping: str,
|
|
650
|
+
lipschitz_L: Optional[float] = None,
|
|
651
|
+
check_every: int = 2,
|
|
652
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
653
|
+
"""Solve descending-alpha Lasso path with a batched CPU FISTA update."""
|
|
654
|
+
n_features = int(XtX.shape[0])
|
|
655
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
656
|
+
|
|
657
|
+
coefs = np.zeros((n_features, n_alphas), dtype=np.float64)
|
|
658
|
+
yk = coefs.copy()
|
|
659
|
+
tk = np.ones((n_alphas,), dtype=np.float64)
|
|
660
|
+
n_iters = np.zeros((n_alphas,), dtype=np.int32)
|
|
661
|
+
|
|
662
|
+
if lipschitz_L is not None:
|
|
663
|
+
L = float(lipschitz_L)
|
|
664
|
+
else:
|
|
665
|
+
try:
|
|
666
|
+
eigvals = np.linalg.eigvalsh(XtX)
|
|
667
|
+
L = float(eigvals[-1] / float(max(1, n_samples)))
|
|
668
|
+
except Exception:
|
|
669
|
+
row_sum_bound = float(np.max(np.sum(np.abs(XtX), axis=1)) / float(max(1, n_samples)))
|
|
670
|
+
L = max(row_sum_bound, 1e-12)
|
|
671
|
+
|
|
672
|
+
if L <= 0.0:
|
|
673
|
+
return coefs.T, n_iters
|
|
674
|
+
|
|
675
|
+
n_samp = float(max(1, n_samples))
|
|
676
|
+
step = 1.0 / L
|
|
677
|
+
alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
|
|
678
|
+
thresholds = alphas_desc * step
|
|
679
|
+
stopping_name = str(stopping).lower()
|
|
680
|
+
check_every = max(1, int(check_every))
|
|
681
|
+
|
|
682
|
+
active = np.arange(n_alphas, dtype=np.int64)
|
|
683
|
+
|
|
684
|
+
for iteration in range(int(max_iter)):
|
|
685
|
+
if active.size == 0:
|
|
686
|
+
break
|
|
687
|
+
|
|
688
|
+
y_active = yk[:, active]
|
|
689
|
+
coef_old = coefs[:, active]
|
|
690
|
+
|
|
691
|
+
grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
|
|
692
|
+
thresh = thresholds[active].reshape(1, -1)
|
|
693
|
+
coef_new = _soft_threshold_numpy(y_active - step * grad, thresh)
|
|
694
|
+
|
|
695
|
+
t_old = tk[active]
|
|
696
|
+
t_new = (1.0 + np.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
697
|
+
beta = (t_old - 1.0) / t_new
|
|
698
|
+
y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
|
|
699
|
+
|
|
700
|
+
coefs[:, active] = coef_new
|
|
701
|
+
yk[:, active] = y_new
|
|
702
|
+
tk[active] = t_new
|
|
703
|
+
|
|
704
|
+
should_check = ((iteration + 1) % check_every == 0) or (iteration + 1 == int(max_iter))
|
|
705
|
+
if not should_check:
|
|
706
|
+
continue
|
|
707
|
+
|
|
708
|
+
if stopping_name == "kkt":
|
|
709
|
+
grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
|
|
710
|
+
viol = np.max(
|
|
711
|
+
np.maximum(
|
|
712
|
+
np.abs(grad_sse) - alphas_desc[active].reshape(1, -1),
|
|
713
|
+
0.0,
|
|
714
|
+
),
|
|
715
|
+
axis=0,
|
|
716
|
+
)
|
|
717
|
+
converged_local = viol < float(tol)
|
|
718
|
+
else:
|
|
719
|
+
delta = np.sum(np.abs(coef_new - coef_old), axis=0)
|
|
720
|
+
converged_local = delta < float(tol)
|
|
721
|
+
|
|
722
|
+
if not np.any(converged_local):
|
|
723
|
+
continue
|
|
724
|
+
|
|
725
|
+
done = active[converged_local]
|
|
726
|
+
n_iters[done] = int(iteration) + 1
|
|
727
|
+
yk[:, done] = coefs[:, done]
|
|
728
|
+
active = active[~converged_local]
|
|
729
|
+
|
|
730
|
+
if active.size > 0:
|
|
731
|
+
n_iters[active] = int(max_iter)
|
|
732
|
+
|
|
733
|
+
return coefs.T, n_iters
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def _solve_lasso_path_gpu_fista_batched_from_gram(
|
|
737
|
+
XtX,
|
|
738
|
+
Xty,
|
|
739
|
+
*,
|
|
740
|
+
n_samples: int,
|
|
741
|
+
alphas_desc: np.ndarray,
|
|
742
|
+
max_iter: int,
|
|
743
|
+
tol: float,
|
|
744
|
+
stopping: str,
|
|
745
|
+
lipschitz_L: Optional[float] = None,
|
|
746
|
+
check_every: int = 8,
|
|
747
|
+
):
|
|
748
|
+
"""Solve descending-alpha Lasso path with a batched GPU FISTA update."""
|
|
749
|
+
import cupy as cp
|
|
750
|
+
|
|
751
|
+
n_features = int(XtX.shape[0])
|
|
752
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
753
|
+
|
|
754
|
+
coefs = cp.zeros((n_features, n_alphas), dtype=XtX.dtype)
|
|
755
|
+
yk = coefs.copy()
|
|
756
|
+
tk = cp.ones((n_alphas,), dtype=XtX.dtype)
|
|
757
|
+
n_iters_gpu = cp.zeros((n_alphas,), dtype=cp.int32)
|
|
758
|
+
|
|
759
|
+
if lipschitz_L is not None:
|
|
760
|
+
L = cp.array(float(lipschitz_L), dtype=XtX.dtype)
|
|
761
|
+
else:
|
|
762
|
+
try:
|
|
763
|
+
eigvals = cp.linalg.eigvalsh(XtX)
|
|
764
|
+
L = eigvals[-1] / float(max(1, n_samples))
|
|
765
|
+
except Exception:
|
|
766
|
+
row_sum_bound = cp.max(cp.sum(cp.abs(XtX), axis=1)) / float(max(1, n_samples))
|
|
767
|
+
L = cp.maximum(row_sum_bound, cp.asarray(1e-12, dtype=XtX.dtype))
|
|
768
|
+
|
|
769
|
+
L_scalar = float(cp.asnumpy(L))
|
|
770
|
+
if L_scalar <= 0.0:
|
|
771
|
+
return coefs.T, np.zeros((n_alphas,), dtype=np.int32)
|
|
772
|
+
|
|
773
|
+
n_samp = float(max(1, n_samples))
|
|
774
|
+
step = 1.0 / L
|
|
775
|
+
alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
|
|
776
|
+
alpha_gpu = cp.asarray(alphas_desc, dtype=XtX.dtype)
|
|
777
|
+
thresholds = alpha_gpu * step
|
|
778
|
+
stopping_name = str(stopping).lower()
|
|
779
|
+
check_every = max(1, int(check_every))
|
|
780
|
+
|
|
781
|
+
active_gpu = cp.arange(n_alphas, dtype=cp.int32)
|
|
782
|
+
|
|
783
|
+
for iteration in range(int(max_iter)):
|
|
784
|
+
if int(active_gpu.size) == 0:
|
|
785
|
+
break
|
|
786
|
+
|
|
787
|
+
y_active = yk[:, active_gpu]
|
|
788
|
+
coef_old = coefs[:, active_gpu]
|
|
789
|
+
|
|
790
|
+
grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
|
|
791
|
+
thresh = thresholds[active_gpu].reshape(1, -1)
|
|
792
|
+
coef_new = cp.sign(y_active - step * grad) * cp.maximum(cp.abs(y_active - step * grad) - thresh, 0.0)
|
|
793
|
+
|
|
794
|
+
t_old = tk[active_gpu]
|
|
795
|
+
t_new = (1.0 + cp.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
796
|
+
beta = (t_old - 1.0) / t_new
|
|
797
|
+
y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
|
|
798
|
+
|
|
799
|
+
coefs[:, active_gpu] = coef_new
|
|
800
|
+
yk[:, active_gpu] = y_new
|
|
801
|
+
tk[active_gpu] = t_new
|
|
802
|
+
|
|
803
|
+
active_ratio = float(int(active_gpu.size)) / float(max(1, n_alphas))
|
|
804
|
+
check_every_eff = _adaptive_gpu_check_every(
|
|
805
|
+
base_check_every=check_every,
|
|
806
|
+
iteration=iteration,
|
|
807
|
+
max_iter=int(max_iter),
|
|
808
|
+
active_ratio=active_ratio,
|
|
809
|
+
)
|
|
810
|
+
should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
|
|
811
|
+
if not should_check:
|
|
812
|
+
continue
|
|
813
|
+
|
|
814
|
+
if stopping_name == "kkt":
|
|
815
|
+
grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
|
|
816
|
+
viol = cp.max(
|
|
817
|
+
cp.maximum(
|
|
818
|
+
cp.abs(grad_sse) - alpha_gpu[active_gpu].reshape(1, -1),
|
|
819
|
+
0.0,
|
|
820
|
+
),
|
|
821
|
+
axis=0,
|
|
822
|
+
)
|
|
823
|
+
converged_local_gpu = viol < float(tol)
|
|
824
|
+
else:
|
|
825
|
+
delta = cp.sum(cp.abs(coef_new - coef_old), axis=0)
|
|
826
|
+
converged_local_gpu = delta < float(tol)
|
|
827
|
+
|
|
828
|
+
done_gpu = active_gpu[converged_local_gpu]
|
|
829
|
+
if int(done_gpu.size) == 0:
|
|
830
|
+
continue
|
|
831
|
+
|
|
832
|
+
n_iters_gpu[done_gpu] = int(iteration) + 1
|
|
833
|
+
yk[:, done_gpu] = coefs[:, done_gpu]
|
|
834
|
+
active_gpu = active_gpu[~converged_local_gpu]
|
|
835
|
+
|
|
836
|
+
if int(active_gpu.size) > 0:
|
|
837
|
+
n_iters_gpu[active_gpu] = int(max_iter)
|
|
838
|
+
|
|
839
|
+
return coefs.T, cp.asnumpy(n_iters_gpu)
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def _solve_lasso_path_gpu_fista_multi_fold_from_gram(
|
|
843
|
+
XtX_batch,
|
|
844
|
+
Xty_batch,
|
|
845
|
+
*,
|
|
846
|
+
n_samples_vec,
|
|
847
|
+
alphas_desc,
|
|
848
|
+
max_iter: int,
|
|
849
|
+
tol: float,
|
|
850
|
+
stopping: str,
|
|
851
|
+
lipschitz_L: Optional[float] = None,
|
|
852
|
+
check_every: int = 8,
|
|
853
|
+
):
|
|
854
|
+
"""Solve descending-alpha Lasso paths for all folds together on GPU.
|
|
855
|
+
|
|
856
|
+
Note: Fused kernel optimization is disabled for multi-fold solver due to
|
|
857
|
+
dtype complexity. The single-fold Lasso solver uses fused kernels.
|
|
858
|
+
"""
|
|
859
|
+
import cupy as cp
|
|
860
|
+
|
|
861
|
+
n_folds = int(XtX_batch.shape[0])
|
|
862
|
+
n_features = int(XtX_batch.shape[1])
|
|
863
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
864
|
+
|
|
865
|
+
coefs = cp.zeros((n_folds, n_features, n_alphas), dtype=XtX_batch.dtype)
|
|
866
|
+
yk = coefs.copy()
|
|
867
|
+
tk = cp.ones((n_folds, n_alphas), dtype=XtX_batch.dtype)
|
|
868
|
+
n_iters_gpu = cp.zeros((n_folds, n_alphas), dtype=cp.int32)
|
|
869
|
+
|
|
870
|
+
# Convert n_samples_vec to numpy using .get() if it's a CuPy array
|
|
871
|
+
if hasattr(n_samples_vec, 'get'):
|
|
872
|
+
n_vec_cpu = n_samples_vec.get().astype(np.float64).reshape(-1)
|
|
873
|
+
else:
|
|
874
|
+
n_vec_cpu = np.asarray(n_samples_vec, dtype=np.float64).reshape(-1)
|
|
875
|
+
if n_vec_cpu.size != n_folds:
|
|
876
|
+
raise ValueError("n_samples_vec must have one entry per fold")
|
|
877
|
+
n_vec = cp.asarray(n_vec_cpu, dtype=XtX_batch.dtype)
|
|
878
|
+
|
|
879
|
+
if lipschitz_L is not None:
|
|
880
|
+
L = cp.full((n_folds,), float(lipschitz_L), dtype=XtX_batch.dtype)
|
|
881
|
+
else:
|
|
882
|
+
try:
|
|
883
|
+
eigvals = cp.linalg.eigvalsh(XtX_batch)
|
|
884
|
+
L = eigvals[:, -1] / n_vec
|
|
885
|
+
except Exception:
|
|
886
|
+
row_sum_bound = cp.max(cp.sum(cp.abs(XtX_batch), axis=2), axis=1) / n_vec
|
|
887
|
+
L = cp.maximum(row_sum_bound, cp.asarray(1e-12, dtype=XtX_batch.dtype))
|
|
888
|
+
|
|
889
|
+
step = 1.0 / L.reshape(n_folds, 1, 1)
|
|
890
|
+
# Convert alphas_desc to numpy using .get() if it's a CuPy array
|
|
891
|
+
if hasattr(alphas_desc, 'get'):
|
|
892
|
+
alphas_cpu = alphas_desc.get().astype(np.float64)
|
|
893
|
+
else:
|
|
894
|
+
alphas_cpu = np.asarray(alphas_desc, dtype=np.float64)
|
|
895
|
+
alpha_gpu = cp.asarray(alphas_cpu, dtype=XtX_batch.dtype).reshape(1, 1, n_alphas)
|
|
896
|
+
thresholds = alpha_gpu * step
|
|
897
|
+
|
|
898
|
+
Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
|
|
899
|
+
n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
|
|
900
|
+
stopping_name = str(stopping).lower()
|
|
901
|
+
check_every = max(1, int(check_every))
|
|
902
|
+
|
|
903
|
+
active_gpu = cp.ones((n_folds, n_alphas), dtype=cp.bool_)
|
|
904
|
+
active_count = int(n_folds * n_alphas)
|
|
905
|
+
|
|
906
|
+
# Note: Fused kernels disabled for multi-fold solver due to dtype complexity
|
|
907
|
+
# The single-fold Lasso._fit_gpu uses fused kernels
|
|
908
|
+
use_fused = False
|
|
909
|
+
fused = None
|
|
910
|
+
|
|
911
|
+
for iteration in range(int(max_iter)):
|
|
912
|
+
if active_count == 0:
|
|
913
|
+
break
|
|
914
|
+
|
|
915
|
+
active_expanded = active_gpu[:, cp.newaxis, :]
|
|
916
|
+
|
|
917
|
+
coef_old = coefs.copy()
|
|
918
|
+
grad = (cp.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
|
|
919
|
+
|
|
920
|
+
# Proximal step: soft thresholding
|
|
921
|
+
yk_step = yk - step * grad
|
|
922
|
+
coef_candidate = cp.sign(yk_step) * cp.maximum(cp.abs(yk_step) - thresholds, 0.0)
|
|
923
|
+
coefs = cp.where(active_expanded, coef_candidate, coefs)
|
|
924
|
+
|
|
925
|
+
t_old = tk
|
|
926
|
+
t_new = (1.0 + cp.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
927
|
+
beta = (t_old - 1.0) / t_new
|
|
928
|
+
y_candidate = coefs + beta[:, cp.newaxis, :] * (coefs - coef_old)
|
|
929
|
+
yk = cp.where(active_expanded, y_candidate, yk)
|
|
930
|
+
tk = cp.where(active_gpu, t_new, tk)
|
|
931
|
+
|
|
932
|
+
active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
|
|
933
|
+
check_every_eff = _adaptive_gpu_check_every(
|
|
934
|
+
base_check_every=check_every,
|
|
935
|
+
iteration=iteration,
|
|
936
|
+
max_iter=int(max_iter),
|
|
937
|
+
active_ratio=active_ratio,
|
|
938
|
+
)
|
|
939
|
+
should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
|
|
940
|
+
if not should_check:
|
|
941
|
+
continue
|
|
942
|
+
|
|
943
|
+
if stopping_name == "kkt":
|
|
944
|
+
grad_sse = (cp.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
|
|
945
|
+
violation = cp.max(cp.maximum(cp.abs(grad_sse) - alpha_gpu, 0.0), axis=1)
|
|
946
|
+
converged_local_gpu = violation < float(tol)
|
|
947
|
+
else:
|
|
948
|
+
delta = cp.sum(cp.abs(coefs - coef_old), axis=1)
|
|
949
|
+
converged_local_gpu = delta < float(tol)
|
|
950
|
+
|
|
951
|
+
newly_done_gpu = active_gpu & converged_local_gpu
|
|
952
|
+
done_count = int(cp.count_nonzero(newly_done_gpu).item())
|
|
953
|
+
if done_count == 0:
|
|
954
|
+
continue
|
|
955
|
+
|
|
956
|
+
n_iters_gpu[newly_done_gpu] = int(iteration) + 1
|
|
957
|
+
yk = cp.where(newly_done_gpu[:, cp.newaxis, :], coefs, yk)
|
|
958
|
+
active_gpu = active_gpu & (~converged_local_gpu)
|
|
959
|
+
active_count -= done_count
|
|
960
|
+
|
|
961
|
+
n_iters_gpu[active_gpu] = int(max_iter)
|
|
962
|
+
|
|
963
|
+
return cp.transpose(coefs, (0, 2, 1)), cp.asnumpy(n_iters_gpu)
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
def _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
|
|
967
|
+
XtX_batch,
|
|
968
|
+
Xty_batch,
|
|
969
|
+
*,
|
|
970
|
+
n_samples_vec,
|
|
971
|
+
alphas_desc,
|
|
972
|
+
max_iter: int,
|
|
973
|
+
tol: float,
|
|
974
|
+
stopping: str,
|
|
975
|
+
lipschitz_L: Optional[float] = None,
|
|
976
|
+
check_every: int = 8,
|
|
977
|
+
):
|
|
978
|
+
"""Solve descending-alpha Lasso paths for all folds together on Torch GPU.
|
|
979
|
+
|
|
980
|
+
Mirror of _solve_lasso_path_gpu_fista_multi_fold_from_gram for Torch backend.
|
|
981
|
+
"""
|
|
982
|
+
import torch
|
|
983
|
+
|
|
984
|
+
n_folds = int(XtX_batch.shape[0])
|
|
985
|
+
n_features = int(XtX_batch.shape[1])
|
|
986
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
987
|
+
|
|
988
|
+
coefs = torch.zeros((n_folds, n_features, n_alphas), dtype=XtX_batch.dtype, device=XtX_batch.device)
|
|
989
|
+
yk = coefs.clone()
|
|
990
|
+
tk = torch.ones((n_folds, n_alphas), dtype=XtX_batch.dtype, device=XtX_batch.device)
|
|
991
|
+
n_iters_gpu = torch.zeros((n_folds, n_alphas), dtype=torch.int32, device=XtX_batch.device)
|
|
992
|
+
|
|
993
|
+
n_vec_cpu = np.asarray(_to_numpy(n_samples_vec), dtype=np.float64).reshape(-1)
|
|
994
|
+
if n_vec_cpu.size != n_folds:
|
|
995
|
+
raise ValueError("n_samples_vec must have one entry per fold")
|
|
996
|
+
n_vec = torch.from_numpy(n_vec_cpu).to(dtype=XtX_batch.dtype, device=XtX_batch.device)
|
|
997
|
+
|
|
998
|
+
if lipschitz_L is not None:
|
|
999
|
+
L = torch.full((n_folds,), float(lipschitz_L), dtype=XtX_batch.dtype, device=XtX_batch.device)
|
|
1000
|
+
else:
|
|
1001
|
+
try:
|
|
1002
|
+
eigvals = torch.linalg.eigvalsh(XtX_batch)
|
|
1003
|
+
L = eigvals[:, -1] / n_vec
|
|
1004
|
+
except Exception:
|
|
1005
|
+
row_sum_bound = torch.max(torch.sum(torch.abs(XtX_batch), dim=2), dim=1).values / n_vec
|
|
1006
|
+
L = torch.maximum(row_sum_bound, torch.tensor(1e-12, dtype=XtX_batch.dtype, device=XtX_batch.device))
|
|
1007
|
+
|
|
1008
|
+
step = 1.0 / L.reshape(n_folds, 1, 1)
|
|
1009
|
+
alphas_cpu = np.asarray(_to_numpy(alphas_desc), dtype=np.float64)
|
|
1010
|
+
alpha_gpu = torch.from_numpy(alphas_cpu).to(dtype=XtX_batch.dtype, device=XtX_batch.device).reshape(1, 1, n_alphas)
|
|
1011
|
+
thresholds = alpha_gpu * step
|
|
1012
|
+
|
|
1013
|
+
Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
|
|
1014
|
+
n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
|
|
1015
|
+
stopping_name = str(stopping).lower()
|
|
1016
|
+
check_every = max(1, int(check_every))
|
|
1017
|
+
|
|
1018
|
+
active_gpu = torch.ones((n_folds, n_alphas), dtype=torch.bool, device=XtX_batch.device)
|
|
1019
|
+
active_count = int(n_folds * n_alphas)
|
|
1020
|
+
|
|
1021
|
+
for iteration in range(int(max_iter)):
|
|
1022
|
+
if active_count == 0:
|
|
1023
|
+
break
|
|
1024
|
+
|
|
1025
|
+
active_expanded = active_gpu.unsqueeze(1)
|
|
1026
|
+
|
|
1027
|
+
coef_old = coefs.clone()
|
|
1028
|
+
grad = (torch.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
|
|
1029
|
+
|
|
1030
|
+
# Proximal step: soft thresholding
|
|
1031
|
+
yk_step = yk - step * grad
|
|
1032
|
+
coef_candidate = torch.sign(yk_step) * torch.maximum(torch.abs(yk_step) - thresholds, torch.tensor(0.0, device=XtX_batch.device))
|
|
1033
|
+
coefs = torch.where(active_expanded, coef_candidate, coefs)
|
|
1034
|
+
|
|
1035
|
+
t_old = tk
|
|
1036
|
+
t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
|
|
1037
|
+
beta = (t_old - 1.0) / t_new
|
|
1038
|
+
y_candidate = coefs + beta.unsqueeze(1) * (coefs - coef_old)
|
|
1039
|
+
yk = torch.where(active_expanded, y_candidate, yk)
|
|
1040
|
+
tk = torch.where(active_gpu, t_new, tk)
|
|
1041
|
+
|
|
1042
|
+
active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
|
|
1043
|
+
check_every_eff = _adaptive_gpu_check_every(
|
|
1044
|
+
base_check_every=check_every,
|
|
1045
|
+
iteration=iteration,
|
|
1046
|
+
max_iter=int(max_iter),
|
|
1047
|
+
active_ratio=active_ratio,
|
|
1048
|
+
)
|
|
1049
|
+
should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
|
|
1050
|
+
if not should_check:
|
|
1051
|
+
continue
|
|
1052
|
+
|
|
1053
|
+
if stopping_name == "kkt":
|
|
1054
|
+
grad_sse = (torch.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
|
|
1055
|
+
violation = torch.max(torch.maximum(torch.abs(grad_sse) - alpha_gpu, torch.tensor(0.0, device=XtX_batch.device)), dim=1).values
|
|
1056
|
+
converged_local_gpu = violation < float(tol)
|
|
1057
|
+
else:
|
|
1058
|
+
delta = torch.sum(torch.abs(coefs - coef_old), dim=1)
|
|
1059
|
+
converged_local_gpu = delta < float(tol)
|
|
1060
|
+
|
|
1061
|
+
newly_done_gpu = active_gpu & converged_local_gpu
|
|
1062
|
+
done_count = int(torch.count_nonzero(newly_done_gpu).item())
|
|
1063
|
+
if done_count == 0:
|
|
1064
|
+
continue
|
|
1065
|
+
|
|
1066
|
+
n_iters_gpu[newly_done_gpu] = int(iteration) + 1
|
|
1067
|
+
yk = torch.where(newly_done_gpu.unsqueeze(1), coefs, yk)
|
|
1068
|
+
active_gpu = active_gpu & (~converged_local_gpu)
|
|
1069
|
+
active_count -= done_count
|
|
1070
|
+
|
|
1071
|
+
n_iters_gpu[active_gpu] = int(max_iter)
|
|
1072
|
+
|
|
1073
|
+
return coefs.permute(0, 2, 1), n_iters_gpu.cpu().numpy()
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def _solve_lasso_path_cpu_from_gram(
|
|
1077
|
+
XtX: np.ndarray,
|
|
1078
|
+
Xty: np.ndarray,
|
|
1079
|
+
*,
|
|
1080
|
+
n_samples: int,
|
|
1081
|
+
alphas_desc: np.ndarray,
|
|
1082
|
+
max_iter: int,
|
|
1083
|
+
tol: float,
|
|
1084
|
+
stopping: str,
|
|
1085
|
+
cpu_solver: str,
|
|
1086
|
+
lipschitz_L: Optional[float] = None,
|
|
1087
|
+
cd_kkt_check_every: int = 1,
|
|
1088
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
1089
|
+
"""Solve a descending-alpha Lasso path on CPU using one precomputed Gram matrix."""
|
|
1090
|
+
n_features = int(XtX.shape[0])
|
|
1091
|
+
n_alphas = int(alphas_desc.shape[0])
|
|
1092
|
+
|
|
1093
|
+
coefs_path = np.zeros((n_alphas, n_features), dtype=np.float64)
|
|
1094
|
+
n_iters = np.zeros(n_alphas, dtype=np.int32)
|
|
1095
|
+
|
|
1096
|
+
coef = np.zeros(n_features, dtype=np.float64)
|
|
1097
|
+
stopping_name = str(stopping).lower()
|
|
1098
|
+
solver_name = str(cpu_solver).lower()
|
|
1099
|
+
|
|
1100
|
+
if solver_name == "fista":
|
|
1101
|
+
return _solve_lasso_path_cpu_fista_batched_from_gram(
|
|
1102
|
+
XtX,
|
|
1103
|
+
Xty,
|
|
1104
|
+
n_samples=n_samples,
|
|
1105
|
+
alphas_desc=alphas_desc,
|
|
1106
|
+
max_iter=max_iter,
|
|
1107
|
+
tol=tol,
|
|
1108
|
+
stopping=stopping,
|
|
1109
|
+
lipschitz_L=lipschitz_L,
|
|
1110
|
+
check_every=2,
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
global _NUMBA_CD_DISABLED
|
|
1114
|
+
use_numba_cd = (
|
|
1115
|
+
_NUMBA_AVAILABLE
|
|
1116
|
+
and (not _NUMBA_CD_DISABLED)
|
|
1117
|
+
and solver_name == "coordinate_descent"
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
if use_numba_cd:
|
|
1121
|
+
try:
|
|
1122
|
+
return _solve_lasso_path_cpu_cd_numba(
|
|
1123
|
+
XtX,
|
|
1124
|
+
Xty,
|
|
1125
|
+
n_samples=n_samples,
|
|
1126
|
+
alphas_desc=alphas_desc,
|
|
1127
|
+
max_iter=max_iter,
|
|
1128
|
+
tol=tol,
|
|
1129
|
+
stopping=stopping,
|
|
1130
|
+
cd_kkt_check_every=cd_kkt_check_every,
|
|
1131
|
+
)
|
|
1132
|
+
except Exception:
|
|
1133
|
+
_NUMBA_CD_DISABLED = True
|
|
1134
|
+
|
|
1135
|
+
# Coordinate descent with incremental gradient updates.
|
|
1136
|
+
X_sq_norms = np.diag(XtX).astype(np.float64, copy=False)
|
|
1137
|
+
grad = XtX @ coef - Xty
|
|
1138
|
+
alpha_scaled_desc = np.asarray(alphas_desc, dtype=np.float64) * float(max(1, n_samples))
|
|
1139
|
+
active_mask = np.zeros((n_features,), dtype=bool)
|
|
1140
|
+
cd_kkt_check_every = max(1, int(cd_kkt_check_every))
|
|
1141
|
+
|
|
1142
|
+
for alpha_idx, alpha in enumerate(alphas_desc):
|
|
1143
|
+
alpha_scaled = float(alpha_scaled_desc[alpha_idx])
|
|
1144
|
+
prev_alpha_scaled = float(alpha_scaled_desc[alpha_idx - 1]) if alpha_idx > 0 else alpha_scaled
|
|
1145
|
+
|
|
1146
|
+
# Strong rule screening: expand active set before cyclic updates.
|
|
1147
|
+
strong_thresh = max(0.0, 2.0 * alpha_scaled - prev_alpha_scaled)
|
|
1148
|
+
active_mask |= np.abs(Xty) >= strong_thresh
|
|
1149
|
+
if not bool(np.any(active_mask)):
|
|
1150
|
+
active_mask[int(np.argmax(np.abs(Xty)))] = True
|
|
1151
|
+
|
|
1152
|
+
converged = False
|
|
1153
|
+
|
|
1154
|
+
for iteration in range(int(max_iter)):
|
|
1155
|
+
coef_delta_l1 = 0.0
|
|
1156
|
+
|
|
1157
|
+
active_idx = np.flatnonzero(active_mask)
|
|
1158
|
+
for j in active_idx:
|
|
1159
|
+
denom = float(X_sq_norms[j])
|
|
1160
|
+
old_val = float(coef[j])
|
|
1161
|
+
|
|
1162
|
+
if denom > 1e-10:
|
|
1163
|
+
rho_j = -float(grad[j]) + denom * old_val
|
|
1164
|
+
new_val = _soft_threshold_scalar(rho_j, alpha_scaled) / denom
|
|
1165
|
+
else:
|
|
1166
|
+
new_val = 0.0
|
|
1167
|
+
|
|
1168
|
+
delta = new_val - old_val
|
|
1169
|
+
if abs(delta) > 0.0:
|
|
1170
|
+
coef[j] = new_val
|
|
1171
|
+
grad += XtX[:, j] * delta
|
|
1172
|
+
coef_delta_l1 += abs(delta)
|
|
1173
|
+
|
|
1174
|
+
# glmnet-style optimization can skip full inactive KKT scans on every pass,
|
|
1175
|
+
# then force a check when updates become small.
|
|
1176
|
+
should_kkt_scan = (
|
|
1177
|
+
((iteration + 1) % cd_kkt_check_every == 0)
|
|
1178
|
+
or (coef_delta_l1 < float(tol))
|
|
1179
|
+
or (iteration + 1 == int(max_iter))
|
|
1180
|
+
)
|
|
1181
|
+
violation = float("inf")
|
|
1182
|
+
inactive_violation_idx = np.empty((0,), dtype=np.int64)
|
|
1183
|
+
|
|
1184
|
+
if should_kkt_scan:
|
|
1185
|
+
violation_vec = np.maximum(
|
|
1186
|
+
np.abs(grad / float(max(1, n_samples))) - float(alpha),
|
|
1187
|
+
0.0,
|
|
1188
|
+
)
|
|
1189
|
+
inactive_violation_idx = np.where((violation_vec > float(tol)) & (~active_mask))[0]
|
|
1190
|
+
if inactive_violation_idx.size > 0:
|
|
1191
|
+
active_mask[inactive_violation_idx] = True
|
|
1192
|
+
violation = float(np.max(violation_vec))
|
|
1193
|
+
|
|
1194
|
+
if stopping_name == "kkt":
|
|
1195
|
+
if should_kkt_scan and violation < float(tol):
|
|
1196
|
+
n_iters[alpha_idx] = iteration + 1
|
|
1197
|
+
converged = True
|
|
1198
|
+
break
|
|
1199
|
+
else:
|
|
1200
|
+
if coef_delta_l1 < float(tol) and inactive_violation_idx.size == 0:
|
|
1201
|
+
n_iters[alpha_idx] = iteration + 1
|
|
1202
|
+
converged = True
|
|
1203
|
+
break
|
|
1204
|
+
|
|
1205
|
+
if not converged:
|
|
1206
|
+
n_iters[alpha_idx] = int(max_iter)
|
|
1207
|
+
|
|
1208
|
+
coefs_path[alpha_idx, :] = coef
|
|
1209
|
+
active_mask |= np.abs(coef) > 0.0
|
|
1210
|
+
|
|
1211
|
+
return coefs_path, n_iters
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
def _solve_lasso_path_gpu_from_gram(
|
|
1215
|
+
XtX,
|
|
1216
|
+
Xty,
|
|
1217
|
+
*,
|
|
1218
|
+
n_samples: int,
|
|
1219
|
+
alphas_desc: np.ndarray,
|
|
1220
|
+
max_iter: int,
|
|
1221
|
+
tol: float,
|
|
1222
|
+
stopping: str,
|
|
1223
|
+
lipschitz_L: Optional[float] = None,
|
|
1224
|
+
check_every: int = 8,
|
|
1225
|
+
):
|
|
1226
|
+
"""Solve a descending-alpha Lasso path on GPU using one precomputed Gram matrix."""
|
|
1227
|
+
return _solve_lasso_path_gpu_fista_batched_from_gram(
|
|
1228
|
+
XtX,
|
|
1229
|
+
Xty,
|
|
1230
|
+
n_samples=n_samples,
|
|
1231
|
+
alphas_desc=alphas_desc,
|
|
1232
|
+
max_iter=max_iter,
|
|
1233
|
+
tol=tol,
|
|
1234
|
+
stopping=stopping,
|
|
1235
|
+
lipschitz_L=lipschitz_L,
|
|
1236
|
+
check_every=check_every,
|
|
1237
|
+
)
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
def _soft_threshold_torch(x, gamma):
|
|
1241
|
+
"""Soft thresholding operator for Torch tensors."""
|
|
1242
|
+
import torch
|
|
1243
|
+
return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, torch.tensor(0.0, dtype=x.dtype, device=x.device))
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
def _fit_lasso_single_alpha_fast(
|
|
1247
|
+
X,
|
|
1248
|
+
y,
|
|
1249
|
+
*,
|
|
1250
|
+
alpha: float,
|
|
1251
|
+
fit_intercept: bool,
|
|
1252
|
+
max_iter: int,
|
|
1253
|
+
tol: float,
|
|
1254
|
+
stopping: str,
|
|
1255
|
+
device: str,
|
|
1256
|
+
cpu_solver: str,
|
|
1257
|
+
cd_kkt_check_every: int = 1,
|
|
1258
|
+
sample_weight=None,
|
|
1259
|
+
) -> Dict[str, object]:
|
|
1260
|
+
"""Fast single-alpha Lasso fit using optimized Gram-based path solvers."""
|
|
1261
|
+
device_name = str(device).lower()
|
|
1262
|
+
alpha_vec = np.asarray([float(alpha)], dtype=np.float64)
|
|
1263
|
+
|
|
1264
|
+
# Check if inputs are torch tensors on GPU
|
|
1265
|
+
is_torch_gpu = False
|
|
1266
|
+
try:
|
|
1267
|
+
import torch
|
|
1268
|
+
is_torch_gpu = device_name == Device.CUDA.value and isinstance(X, torch.Tensor)
|
|
1269
|
+
except Exception:
|
|
1270
|
+
pass
|
|
1271
|
+
|
|
1272
|
+
if device_name == Device.CUDA.value and not is_torch_gpu:
|
|
1273
|
+
# CuPy GPU path
|
|
1274
|
+
import cupy as cp
|
|
1275
|
+
|
|
1276
|
+
X_arr = cp.asarray(X)
|
|
1277
|
+
y_arr = cp.asarray(y).reshape(-1)
|
|
1278
|
+
sw = None
|
|
1279
|
+
|
|
1280
|
+
if sample_weight is not None:
|
|
1281
|
+
sw = cp.asarray(sample_weight)
|
|
1282
|
+
sqrt_sw = cp.sqrt(sw)
|
|
1283
|
+
X_arr = X_arr * sqrt_sw[:, cp.newaxis]
|
|
1284
|
+
y_arr = y_arr * sqrt_sw
|
|
1285
|
+
|
|
1286
|
+
if bool(fit_intercept):
|
|
1287
|
+
if sw is not None:
|
|
1288
|
+
# Weighted mean on original (pre-sqrt) data
|
|
1289
|
+
X_orig = X_arr / sqrt_sw[:, cp.newaxis]
|
|
1290
|
+
y_orig = y_arr / sqrt_sw
|
|
1291
|
+
w_sum = float(cp.sum(sw))
|
|
1292
|
+
X_mean = cp.sum(X_orig * sw[:, cp.newaxis], axis=0) / w_sum
|
|
1293
|
+
y_mean = float(cp.sum(y_orig * sw)) / w_sum
|
|
1294
|
+
X_centered = X_arr - sqrt_sw[:, cp.newaxis] * X_mean
|
|
1295
|
+
y_centered = y_arr - sqrt_sw * y_mean
|
|
1296
|
+
else:
|
|
1297
|
+
X_mean = cp.mean(X_arr, axis=0)
|
|
1298
|
+
y_mean = cp.mean(y_arr)
|
|
1299
|
+
X_centered = X_arr - X_mean
|
|
1300
|
+
y_centered = y_arr - y_mean
|
|
1301
|
+
else:
|
|
1302
|
+
X_mean = cp.zeros((X_arr.shape[1],), dtype=X_arr.dtype)
|
|
1303
|
+
y_mean = cp.array(0.0, dtype=X_arr.dtype)
|
|
1304
|
+
X_centered = X_arr
|
|
1305
|
+
y_centered = y_arr
|
|
1306
|
+
|
|
1307
|
+
XtX = X_centered.T @ X_centered
|
|
1308
|
+
Xty = X_centered.T @ y_centered
|
|
1309
|
+
|
|
1310
|
+
coefs_desc, n_iters = _solve_lasso_path_gpu_from_gram(
|
|
1311
|
+
XtX,
|
|
1312
|
+
Xty,
|
|
1313
|
+
n_samples=int(X_arr.shape[0]),
|
|
1314
|
+
alphas_desc=alpha_vec,
|
|
1315
|
+
max_iter=int(max_iter),
|
|
1316
|
+
tol=float(tol),
|
|
1317
|
+
stopping=str(stopping),
|
|
1318
|
+
lipschitz_L=None,
|
|
1319
|
+
check_every=8,
|
|
1320
|
+
)
|
|
1321
|
+
|
|
1322
|
+
coef_gpu = coefs_desc[0]
|
|
1323
|
+
if bool(fit_intercept):
|
|
1324
|
+
intercept_gpu = y_mean - X_mean @ coef_gpu
|
|
1325
|
+
intercept = float(cp.asnumpy(intercept_gpu))
|
|
1326
|
+
else:
|
|
1327
|
+
intercept = 0.0
|
|
1328
|
+
|
|
1329
|
+
coef = np.asarray(cp.asnumpy(coef_gpu), dtype=np.float64)
|
|
1330
|
+
return {
|
|
1331
|
+
"coef": coef,
|
|
1332
|
+
"intercept": float(intercept),
|
|
1333
|
+
"n_iter": int(n_iters[0]),
|
|
1334
|
+
"n_samples": int(X_arr.shape[0]),
|
|
1335
|
+
"n_features": int(X_arr.shape[1]),
|
|
1336
|
+
}
|
|
1337
|
+
|
|
1338
|
+
elif is_torch_gpu:
|
|
1339
|
+
# Torch GPU path - use FISTA solver directly on GPU tensors
|
|
1340
|
+
import torch
|
|
1341
|
+
|
|
1342
|
+
X_arr = X
|
|
1343
|
+
y_arr = y.reshape(-1) if isinstance(y, torch.Tensor) else torch.as_tensor(
|
|
1344
|
+
y, dtype=X_arr.dtype, device=X_arr.device
|
|
1345
|
+
).reshape(-1)
|
|
1346
|
+
sw = None
|
|
1347
|
+
|
|
1348
|
+
if sample_weight is not None:
|
|
1349
|
+
sw = sample_weight if isinstance(sample_weight, torch.Tensor) else torch.as_tensor(
|
|
1350
|
+
sample_weight, dtype=X_arr.dtype, device=X_arr.device
|
|
1351
|
+
)
|
|
1352
|
+
sqrt_sw = torch.sqrt(sw)
|
|
1353
|
+
X_arr = X_arr * sqrt_sw[:, None]
|
|
1354
|
+
y_arr = y_arr * sqrt_sw
|
|
1355
|
+
|
|
1356
|
+
if bool(fit_intercept):
|
|
1357
|
+
if sw is not None:
|
|
1358
|
+
# Weighted mean: sum(w*X)/sum(w) on original (pre-sqrt) data
|
|
1359
|
+
# But X_arr is already sqrt(w)*X, so mean of sqrt(w)*X is not
|
|
1360
|
+
# the weighted mean. Use the original data for centering.
|
|
1361
|
+
X_orig = X_arr / sqrt_sw[:, None]
|
|
1362
|
+
y_orig = y_arr / sqrt_sw
|
|
1363
|
+
w_sum = float(sw.sum())
|
|
1364
|
+
X_mean = torch.sum(X_orig * sw[:, None], dim=0) / w_sum
|
|
1365
|
+
y_mean = float(torch.sum(y_orig * sw)) / w_sum
|
|
1366
|
+
# Re-center the sqrt-weighted data using the weighted mean
|
|
1367
|
+
X_centered = X_arr - sqrt_sw[:, None] * X_mean
|
|
1368
|
+
y_centered = y_arr - sqrt_sw * y_mean
|
|
1369
|
+
else:
|
|
1370
|
+
X_mean = torch.mean(X_arr, dim=0)
|
|
1371
|
+
y_mean = torch.mean(y_arr)
|
|
1372
|
+
X_centered = X_arr - X_mean
|
|
1373
|
+
y_centered = y_arr - y_mean
|
|
1374
|
+
else:
|
|
1375
|
+
X_mean = torch.zeros((X_arr.shape[1],), dtype=X_arr.dtype, device=X_arr.device)
|
|
1376
|
+
y_mean = torch.tensor(0.0, dtype=X_arr.dtype, device=X_arr.device)
|
|
1377
|
+
X_centered = X_arr
|
|
1378
|
+
y_centered = y_arr
|
|
1379
|
+
|
|
1380
|
+
n_samples = int(X_arr.shape[0])
|
|
1381
|
+
n_features = int(X_arr.shape[1])
|
|
1382
|
+
|
|
1383
|
+
# Precompute Gram matrix and X'y for FISTA gradient
|
|
1384
|
+
XtX = X_centered.T @ X_centered
|
|
1385
|
+
Xty = X_centered.T @ y_centered
|
|
1386
|
+
|
|
1387
|
+
# Compute Lipschitz constant L = max eigenvalue of XtX / n
|
|
1388
|
+
try:
|
|
1389
|
+
eigvals = torch.linalg.eigvalsh(XtX)
|
|
1390
|
+
L = eigvals[-1] / n_samples
|
|
1391
|
+
except Exception:
|
|
1392
|
+
L = torch.sum(X_centered ** 2) / n_samples
|
|
1393
|
+
L = torch.clamp(L, min=1e-10)
|
|
1394
|
+
|
|
1395
|
+
step = 1.0 / L
|
|
1396
|
+
thresh = float(alpha) * step
|
|
1397
|
+
|
|
1398
|
+
# FISTA initialization
|
|
1399
|
+
coef = torch.zeros(n_features, dtype=X_arr.dtype, device=X_arr.device)
|
|
1400
|
+
z = coef.clone()
|
|
1401
|
+
t = torch.tensor(1.0, dtype=X_arr.dtype, device=X_arr.device)
|
|
1402
|
+
|
|
1403
|
+
# FISTA iterations
|
|
1404
|
+
for iteration in range(int(max_iter)):
|
|
1405
|
+
coef_old = coef.clone()
|
|
1406
|
+
|
|
1407
|
+
# Gradient step at z
|
|
1408
|
+
grad = (XtX @ z - Xty) / n_samples
|
|
1409
|
+
coef = _soft_threshold_torch(z - step * grad, thresh)
|
|
1410
|
+
|
|
1411
|
+
# Momentum update
|
|
1412
|
+
t_new = (1.0 + torch.sqrt(1.0 + 4.0 * t ** 2)) / 2.0
|
|
1413
|
+
z = coef + ((t - 1.0) / t_new) * (coef - coef_old)
|
|
1414
|
+
t = t_new
|
|
1415
|
+
|
|
1416
|
+
# Convergence check
|
|
1417
|
+
if str(stopping).lower() == "kkt":
|
|
1418
|
+
grad_sse = (XtX @ coef - Xty) / n_samples
|
|
1419
|
+
violation = torch.max(torch.maximum(torch.abs(grad_sse) - float(alpha), torch.tensor(0.0, dtype=X_arr.dtype, device=X_arr.device)))
|
|
1420
|
+
if violation < float(tol):
|
|
1421
|
+
break
|
|
1422
|
+
else:
|
|
1423
|
+
if torch.sum(torch.abs(coef - coef_old)) < float(tol):
|
|
1424
|
+
break
|
|
1425
|
+
|
|
1426
|
+
# Build coefficients
|
|
1427
|
+
if bool(fit_intercept):
|
|
1428
|
+
intercept_torch = y_mean - X_mean @ coef
|
|
1429
|
+
intercept = float(intercept_torch.item())
|
|
1430
|
+
else:
|
|
1431
|
+
intercept = 0.0
|
|
1432
|
+
|
|
1433
|
+
coef_np = np.asarray(coef.detach().cpu().numpy(), dtype=np.float64)
|
|
1434
|
+
return {
|
|
1435
|
+
"coef": coef_np,
|
|
1436
|
+
"intercept": float(intercept),
|
|
1437
|
+
"n_iter": int(iteration + 1),
|
|
1438
|
+
"n_samples": n_samples,
|
|
1439
|
+
"n_features": n_features,
|
|
1440
|
+
}
|
|
1441
|
+
|
|
1442
|
+
X_arr = np.asarray(X)
|
|
1443
|
+
y_arr = np.asarray(y).reshape(-1)
|
|
1444
|
+
|
|
1445
|
+
if sample_weight is not None:
|
|
1446
|
+
sw = np.asarray(sample_weight)
|
|
1447
|
+
sqrt_sw = np.sqrt(sw)
|
|
1448
|
+
X_arr = X_arr * sqrt_sw[:, np.newaxis]
|
|
1449
|
+
y_arr = y_arr * sqrt_sw
|
|
1450
|
+
|
|
1451
|
+
if bool(fit_intercept):
|
|
1452
|
+
if sample_weight is not None:
|
|
1453
|
+
# Weighted mean on original (pre-sqrt) data
|
|
1454
|
+
sw = np.asarray(sample_weight)
|
|
1455
|
+
w_sum = float(np.sum(sw))
|
|
1456
|
+
X_orig = X_arr / sqrt_sw[:, np.newaxis]
|
|
1457
|
+
y_orig = y_arr / sqrt_sw
|
|
1458
|
+
X_mean = np.sum(X_orig * sw[:, np.newaxis], axis=0) / w_sum
|
|
1459
|
+
y_mean = float(np.sum(y_orig * sw)) / w_sum
|
|
1460
|
+
# Center the sqrt-weighted data using the weighted mean
|
|
1461
|
+
X_centered = X_arr - sqrt_sw[:, np.newaxis] * X_mean
|
|
1462
|
+
y_centered = y_arr - sqrt_sw * y_mean
|
|
1463
|
+
else:
|
|
1464
|
+
X_mean = np.mean(X_arr, axis=0)
|
|
1465
|
+
y_mean = float(np.mean(y_arr))
|
|
1466
|
+
X_centered = X_arr - X_mean
|
|
1467
|
+
y_centered = y_arr - y_mean
|
|
1468
|
+
else:
|
|
1469
|
+
X_mean = np.zeros((X_arr.shape[1],), dtype=np.float64)
|
|
1470
|
+
y_mean = 0.0
|
|
1471
|
+
X_centered = X_arr
|
|
1472
|
+
y_centered = y_arr
|
|
1473
|
+
|
|
1474
|
+
XtX = X_centered.T @ X_centered
|
|
1475
|
+
Xty = X_centered.T @ y_centered
|
|
1476
|
+
|
|
1477
|
+
coefs_desc, n_iters = _solve_lasso_path_cpu_from_gram(
|
|
1478
|
+
XtX,
|
|
1479
|
+
Xty,
|
|
1480
|
+
n_samples=int(X_arr.shape[0]),
|
|
1481
|
+
alphas_desc=alpha_vec,
|
|
1482
|
+
max_iter=int(max_iter),
|
|
1483
|
+
tol=float(tol),
|
|
1484
|
+
stopping=str(stopping),
|
|
1485
|
+
cpu_solver=str(cpu_solver),
|
|
1486
|
+
lipschitz_L=None,
|
|
1487
|
+
cd_kkt_check_every=int(cd_kkt_check_every),
|
|
1488
|
+
)
|
|
1489
|
+
|
|
1490
|
+
coef = np.asarray(coefs_desc[0], dtype=np.float64)
|
|
1491
|
+
if bool(fit_intercept):
|
|
1492
|
+
intercept = float(y_mean - X_mean @ coef)
|
|
1493
|
+
else:
|
|
1494
|
+
intercept = 0.0
|
|
1495
|
+
|
|
1496
|
+
return {
|
|
1497
|
+
"coef": coef,
|
|
1498
|
+
"intercept": float(intercept),
|
|
1499
|
+
"n_iter": int(n_iters[0]),
|
|
1500
|
+
"n_samples": int(X_arr.shape[0]),
|
|
1501
|
+
"n_features": int(X_arr.shape[1]),
|
|
1502
|
+
}
|
|
1503
|
+
|
|
1504
|
+
|
|
1505
|
+
def _select_lasso_alpha_cv(
|
|
1506
|
+
X,
|
|
1507
|
+
y,
|
|
1508
|
+
*,
|
|
1509
|
+
alphas=None,
|
|
1510
|
+
n_alphas: int = 12,
|
|
1511
|
+
alpha_min_ratio: float = 1e-3,
|
|
1512
|
+
cv_folds: int = 5,
|
|
1513
|
+
cv_splits=None,
|
|
1514
|
+
random_state: Optional[int] = None,
|
|
1515
|
+
sample_weight=None,
|
|
1516
|
+
fit_intercept: bool = False,
|
|
1517
|
+
device: Union[str, Device] = Device.CPU,
|
|
1518
|
+
max_iter: int = 3000,
|
|
1519
|
+
tol: float = 1e-4,
|
|
1520
|
+
cpu_solver: str = "coordinate_descent",
|
|
1521
|
+
method: str = "standard",
|
|
1522
|
+
cd_kkt_check_every: Optional[int] = None,
|
|
1523
|
+
gpu_cv_mixed_precision: bool = True,
|
|
1524
|
+
return_details: bool = False,
|
|
1525
|
+
cache_key: Optional[Tuple[Any, ...]] = None,
|
|
1526
|
+
):
|
|
1527
|
+
"""
|
|
1528
|
+
Select alpha via K-fold CV using statgpu's own Lasso implementation.
|
|
1529
|
+
|
|
1530
|
+
Notes
|
|
1531
|
+
-----
|
|
1532
|
+
- Does not depend on sklearn.
|
|
1533
|
+
- Supports GPU path by setting ``device='cuda'``.
|
|
1534
|
+
"""
|
|
1535
|
+
device_name = str(device).lower()
|
|
1536
|
+
use_gpu = device_name in (Device.CUDA.value, Device.TORCH.value)
|
|
1537
|
+
gpu_requested = use_gpu
|
|
1538
|
+
|
|
1539
|
+
gpu_input_cupy = False
|
|
1540
|
+
gpu_input_torch = False
|
|
1541
|
+
if use_gpu:
|
|
1542
|
+
# Check if inputs are already on GPU (CuPy or Torch)
|
|
1543
|
+
try:
|
|
1544
|
+
import cupy as cp
|
|
1545
|
+
gpu_input_cupy = isinstance(X, cp.ndarray) and isinstance(y, cp.ndarray)
|
|
1546
|
+
if sample_weight is not None and not isinstance(sample_weight, cp.ndarray):
|
|
1547
|
+
gpu_input_cupy = False
|
|
1548
|
+
except Exception:
|
|
1549
|
+
pass
|
|
1550
|
+
|
|
1551
|
+
# Also check for torch tensors
|
|
1552
|
+
if not gpu_input_cupy:
|
|
1553
|
+
try:
|
|
1554
|
+
import torch
|
|
1555
|
+
gpu_input_torch = isinstance(X, torch.Tensor) and isinstance(y, torch.Tensor)
|
|
1556
|
+
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
|
|
1557
|
+
gpu_input_torch = False
|
|
1558
|
+
except Exception:
|
|
1559
|
+
pass
|
|
1560
|
+
|
|
1561
|
+
X_np = None
|
|
1562
|
+
y_np = None
|
|
1563
|
+
sample_weight_np = None
|
|
1564
|
+
|
|
1565
|
+
if gpu_input_cupy or gpu_input_torch:
|
|
1566
|
+
# GPU inputs - get backend for validation
|
|
1567
|
+
backend = get_backend(backend='auto', device='cuda')
|
|
1568
|
+
if len(tuple(X.shape)) != 2:
|
|
1569
|
+
raise ValueError("X must be a 2D array")
|
|
1570
|
+
n_samples = int(X.shape[0])
|
|
1571
|
+
y_check = backend.asarray(y).reshape(-1)
|
|
1572
|
+
if int(y_check.shape[0]) != n_samples:
|
|
1573
|
+
raise ValueError("y must have the same number of rows as X")
|
|
1574
|
+
if sample_weight is not None:
|
|
1575
|
+
sw_check = backend.asarray(sample_weight).reshape(-1)
|
|
1576
|
+
if int(sw_check.shape[0]) != n_samples:
|
|
1577
|
+
raise ValueError("sample_weight must have the same number of rows as X")
|
|
1578
|
+
else:
|
|
1579
|
+
X_np = np.asarray(X, dtype=np.float64)
|
|
1580
|
+
y_np = np.asarray(y, dtype=np.float64).reshape(-1)
|
|
1581
|
+
if sample_weight is not None:
|
|
1582
|
+
sample_weight_np = np.asarray(sample_weight, dtype=np.float64).reshape(-1)
|
|
1583
|
+
if X_np.ndim != 2:
|
|
1584
|
+
raise ValueError("X must be a 2D array")
|
|
1585
|
+
if y_np.shape[0] != X_np.shape[0]:
|
|
1586
|
+
raise ValueError("y must have the same number of rows as X")
|
|
1587
|
+
if sample_weight_np is not None and sample_weight_np.shape[0] != X_np.shape[0]:
|
|
1588
|
+
raise ValueError("sample_weight must have the same number of rows as X")
|
|
1589
|
+
n_samples = int(X_np.shape[0])
|
|
1590
|
+
|
|
1591
|
+
cv_method = _normalize_lassocv_method(method)
|
|
1592
|
+
requested_cd_kkt_check_every = _normalize_cd_kkt_check_every(cd_kkt_check_every)
|
|
1593
|
+
|
|
1594
|
+
if alphas is None:
|
|
1595
|
+
if gpu_input_cupy or gpu_input_torch:
|
|
1596
|
+
# Get backend based on input type
|
|
1597
|
+
if gpu_input_torch:
|
|
1598
|
+
backend = get_backend(backend='torch', device='cuda')
|
|
1599
|
+
else:
|
|
1600
|
+
backend = get_backend(backend='cupy', device='cuda')
|
|
1601
|
+
alpha_grid = _default_lasso_alpha_grid_backend(
|
|
1602
|
+
X,
|
|
1603
|
+
y,
|
|
1604
|
+
backend,
|
|
1605
|
+
n_alphas=n_alphas,
|
|
1606
|
+
alpha_min_ratio=alpha_min_ratio,
|
|
1607
|
+
)
|
|
1608
|
+
else:
|
|
1609
|
+
alpha_grid = _default_lasso_alpha_grid(
|
|
1610
|
+
X_np,
|
|
1611
|
+
y_np,
|
|
1612
|
+
n_alphas=n_alphas,
|
|
1613
|
+
alpha_min_ratio=alpha_min_ratio,
|
|
1614
|
+
)
|
|
1615
|
+
else:
|
|
1616
|
+
alpha_grid = np.asarray(alphas, dtype=np.float64).reshape(-1)
|
|
1617
|
+
alpha_grid = alpha_grid[np.isfinite(alpha_grid)]
|
|
1618
|
+
alpha_grid = alpha_grid[alpha_grid > 0.0]
|
|
1619
|
+
if alpha_grid.size == 0:
|
|
1620
|
+
if gpu_input_cupy or gpu_input_torch:
|
|
1621
|
+
# Get backend based on input type
|
|
1622
|
+
if gpu_input_torch:
|
|
1623
|
+
backend = get_backend(backend='torch', device='cuda')
|
|
1624
|
+
else:
|
|
1625
|
+
backend = get_backend(backend='cupy', device='cuda')
|
|
1626
|
+
alpha_grid = _default_lasso_alpha_grid_backend(
|
|
1627
|
+
X,
|
|
1628
|
+
y,
|
|
1629
|
+
backend,
|
|
1630
|
+
n_alphas=n_alphas,
|
|
1631
|
+
alpha_min_ratio=alpha_min_ratio,
|
|
1632
|
+
)
|
|
1633
|
+
else:
|
|
1634
|
+
alpha_grid = _default_lasso_alpha_grid(
|
|
1635
|
+
X_np,
|
|
1636
|
+
y_np,
|
|
1637
|
+
n_alphas=n_alphas,
|
|
1638
|
+
alpha_min_ratio=alpha_min_ratio,
|
|
1639
|
+
)
|
|
1640
|
+
|
|
1641
|
+
user_folds = _normalize_cv_splits(cv_splits, n_samples=n_samples)
|
|
1642
|
+
effective_n_folds = int(len(user_folds)) if user_folds is not None else int(cv_folds)
|
|
1643
|
+
|
|
1644
|
+
if int(n_samples) < 4 or int(alpha_grid.size) == 1 or int(effective_n_folds) < 2:
|
|
1645
|
+
alpha0 = float(alpha_grid[0])
|
|
1646
|
+
if not return_details:
|
|
1647
|
+
return alpha0
|
|
1648
|
+
return {
|
|
1649
|
+
"alpha": alpha0,
|
|
1650
|
+
"alphas": alpha_grid.astype(np.float64, copy=False),
|
|
1651
|
+
"mse_path": np.full((int(alpha_grid.size), 1), np.nan, dtype=np.float64),
|
|
1652
|
+
"mean_mse": np.full(int(alpha_grid.size), np.nan, dtype=np.float64),
|
|
1653
|
+
}
|
|
1654
|
+
|
|
1655
|
+
if user_folds is not None:
|
|
1656
|
+
folds = user_folds
|
|
1657
|
+
else:
|
|
1658
|
+
folds = _kfold_indices(
|
|
1659
|
+
n_samples=int(n_samples),
|
|
1660
|
+
n_splits=int(cv_folds),
|
|
1661
|
+
random_state=random_state,
|
|
1662
|
+
)
|
|
1663
|
+
|
|
1664
|
+
folds_are_complements = _folds_are_complements(folds, n_samples=int(n_samples))
|
|
1665
|
+
|
|
1666
|
+
alpha_grid = alpha_grid.astype(np.float64, copy=False)
|
|
1667
|
+
n_alpha = int(alpha_grid.size)
|
|
1668
|
+
n_folds = int(len(folds))
|
|
1669
|
+
|
|
1670
|
+
cache_key_eff = cache_key
|
|
1671
|
+
if cache_key_eff is None and _LASSO_CV_ALPHA_CACHE_MAXSIZE > 0:
|
|
1672
|
+
cache_key_eff = _make_lasso_cv_auto_cache_key(
|
|
1673
|
+
X=X,
|
|
1674
|
+
y=y,
|
|
1675
|
+
sample_weight=sample_weight,
|
|
1676
|
+
alpha_grid=alpha_grid,
|
|
1677
|
+
folds=folds,
|
|
1678
|
+
fit_intercept=bool(fit_intercept),
|
|
1679
|
+
use_gpu=bool(use_gpu),
|
|
1680
|
+
max_iter=int(max_iter),
|
|
1681
|
+
tol=float(tol),
|
|
1682
|
+
cpu_solver=str(cpu_solver),
|
|
1683
|
+
cv_method=str(cv_method),
|
|
1684
|
+
cd_kkt_check_every=requested_cd_kkt_check_every,
|
|
1685
|
+
gpu_cv_mixed_precision=bool(gpu_cv_mixed_precision),
|
|
1686
|
+
)
|
|
1687
|
+
|
|
1688
|
+
cached_details = _lasso_cv_cache_get(cache_key_eff)
|
|
1689
|
+
if cached_details is not None:
|
|
1690
|
+
if return_details:
|
|
1691
|
+
return cached_details
|
|
1692
|
+
return float(cached_details["alpha"])
|
|
1693
|
+
|
|
1694
|
+
# Evaluate alpha path in descending order for warm-start efficiency.
|
|
1695
|
+
alpha_order_desc = np.argsort(-alpha_grid)
|
|
1696
|
+
alpha_desc = alpha_grid[alpha_order_desc]
|
|
1697
|
+
|
|
1698
|
+
mse_path = np.full((n_alpha, n_folds), np.nan, dtype=np.float64)
|
|
1699
|
+
|
|
1700
|
+
best_alpha = float(alpha_grid[0])
|
|
1701
|
+
best_mse = float("inf")
|
|
1702
|
+
|
|
1703
|
+
if use_gpu:
|
|
1704
|
+
try:
|
|
1705
|
+
# Get backend based on input type - prefer Torch backend for Torch tensors
|
|
1706
|
+
if gpu_input_torch:
|
|
1707
|
+
backend = get_backend(backend='torch', device='cuda')
|
|
1708
|
+
elif gpu_input_cupy:
|
|
1709
|
+
backend = get_backend(backend='cupy', device='cuda')
|
|
1710
|
+
else:
|
|
1711
|
+
backend = get_backend(backend='auto', device='cuda')
|
|
1712
|
+
xp = backend.xp
|
|
1713
|
+
|
|
1714
|
+
cv_dtype = backend.float32 if bool(gpu_cv_mixed_precision) else backend.float64
|
|
1715
|
+
|
|
1716
|
+
# Convert inputs to backend arrays
|
|
1717
|
+
if gpu_input_cupy or gpu_input_torch:
|
|
1718
|
+
# Already on GPU (CuPy or Torch)
|
|
1719
|
+
X_full = backend.asarray(X, dtype=cv_dtype)
|
|
1720
|
+
y_full = backend.asarray(y, dtype=cv_dtype).reshape(-1)
|
|
1721
|
+
if sample_weight is not None:
|
|
1722
|
+
sw_full = backend.asarray(sample_weight, dtype=cv_dtype).reshape(-1)
|
|
1723
|
+
else:
|
|
1724
|
+
sw_full = None
|
|
1725
|
+
else:
|
|
1726
|
+
# Convert from numpy
|
|
1727
|
+
X_full = backend.asarray(X_np, dtype=cv_dtype)
|
|
1728
|
+
y_full = backend.asarray(y_np, dtype=cv_dtype)
|
|
1729
|
+
if sample_weight_np is not None:
|
|
1730
|
+
sw_full = backend.asarray(sample_weight_np, dtype=cv_dtype)
|
|
1731
|
+
else:
|
|
1732
|
+
sw_full = None
|
|
1733
|
+
|
|
1734
|
+
XtX_folds = []
|
|
1735
|
+
Xty_folds = []
|
|
1736
|
+
n_train_folds = []
|
|
1737
|
+
X_mean_folds = []
|
|
1738
|
+
y_mean_folds = []
|
|
1739
|
+
fold_eval_payload = []
|
|
1740
|
+
|
|
1741
|
+
fast_fold_stats = (sw_full is None) and bool(folds_are_complements)
|
|
1742
|
+
if fast_fold_stats:
|
|
1743
|
+
n_total = int(X_full.shape[0])
|
|
1744
|
+
XtX_full = X_full.T @ X_full
|
|
1745
|
+
Xty_full = X_full.T @ y_full
|
|
1746
|
+
if bool(fit_intercept):
|
|
1747
|
+
X_sum_full = backend.sum(X_full, axis=0)
|
|
1748
|
+
y_sum_full = backend.sum(y_full)
|
|
1749
|
+
else:
|
|
1750
|
+
X_sum_full = None
|
|
1751
|
+
y_sum_full = None
|
|
1752
|
+
|
|
1753
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
1754
|
+
train_idx_gpu = backend.asarray(train_idx)
|
|
1755
|
+
val_idx_gpu = backend.asarray(val_idx)
|
|
1756
|
+
|
|
1757
|
+
X_val = X_full[val_idx_gpu]
|
|
1758
|
+
y_val = y_full[val_idx_gpu]
|
|
1759
|
+
sw_val = None if sw_full is None else sw_full[val_idx_gpu]
|
|
1760
|
+
sw_train = None # initialized per-fold in slow path; None for fast path
|
|
1761
|
+
|
|
1762
|
+
if fast_fold_stats:
|
|
1763
|
+
n_val = int(val_idx_gpu.shape[0])
|
|
1764
|
+
n_train = int(n_total - n_val)
|
|
1765
|
+
|
|
1766
|
+
XtX_val = X_val.T @ X_val
|
|
1767
|
+
Xty_val = X_val.T @ y_val
|
|
1768
|
+
XtX_raw = XtX_full - XtX_val
|
|
1769
|
+
Xty_raw = Xty_full - Xty_val
|
|
1770
|
+
|
|
1771
|
+
if bool(fit_intercept):
|
|
1772
|
+
X_sum_val = backend.sum(X_val, axis=0)
|
|
1773
|
+
y_sum_val = backend.sum(y_val)
|
|
1774
|
+
X_sum_train = X_sum_full - X_sum_val
|
|
1775
|
+
y_sum_train = y_sum_full - y_sum_val
|
|
1776
|
+
|
|
1777
|
+
inv_n = backend.asarray(1.0 / float(max(1, n_train)), dtype=X_full.dtype)
|
|
1778
|
+
X_mean = X_sum_train * inv_n
|
|
1779
|
+
y_mean = y_sum_train * inv_n
|
|
1780
|
+
XtX = XtX_raw - backend.outer(X_sum_train, X_sum_train) * inv_n
|
|
1781
|
+
Xty = Xty_raw - X_sum_train * y_mean
|
|
1782
|
+
else:
|
|
1783
|
+
X_mean = backend.zeros((X_full.shape[1],), dtype=X_full.dtype)
|
|
1784
|
+
y_mean = backend.array(0.0, dtype=X_full.dtype)
|
|
1785
|
+
XtX = XtX_raw
|
|
1786
|
+
Xty = Xty_raw
|
|
1787
|
+
else:
|
|
1788
|
+
X_train = X_full[train_idx_gpu]
|
|
1789
|
+
y_train = y_full[train_idx_gpu]
|
|
1790
|
+
sw_train = None if sw_full is None else sw_full[train_idx_gpu]
|
|
1791
|
+
|
|
1792
|
+
if sw_train is not None:
|
|
1793
|
+
sqrt_sw = backend.sqrt(sw_train)
|
|
1794
|
+
X_train = X_train * sqrt_sw[:, None]
|
|
1795
|
+
y_train = y_train * sqrt_sw
|
|
1796
|
+
|
|
1797
|
+
if bool(fit_intercept):
|
|
1798
|
+
X_mean = backend.mean(X_train, axis=0)
|
|
1799
|
+
y_mean = backend.mean(y_train)
|
|
1800
|
+
X_centered = X_train - X_mean
|
|
1801
|
+
y_centered = y_train - y_mean
|
|
1802
|
+
else:
|
|
1803
|
+
X_mean = backend.zeros((X_train.shape[1],), dtype=X_train.dtype)
|
|
1804
|
+
y_mean = backend.array(0.0, dtype=X_train.dtype)
|
|
1805
|
+
X_centered = X_train
|
|
1806
|
+
y_centered = y_train
|
|
1807
|
+
|
|
1808
|
+
XtX = X_centered.T @ X_centered
|
|
1809
|
+
Xty = X_centered.T @ y_centered
|
|
1810
|
+
# For weighted case, effective sample size is sum(weights)
|
|
1811
|
+
if sw_train is not None:
|
|
1812
|
+
n_train = float(backend.sum(sw_train))
|
|
1813
|
+
else:
|
|
1814
|
+
n_train = int(X_train.shape[0])
|
|
1815
|
+
|
|
1816
|
+
XtX_folds.append(XtX)
|
|
1817
|
+
Xty_folds.append(Xty)
|
|
1818
|
+
n_train_folds.append(float(n_train) if sw_train is not None else int(n_train))
|
|
1819
|
+
X_mean_folds.append(X_mean)
|
|
1820
|
+
y_mean_folds.append(y_mean)
|
|
1821
|
+
fold_eval_payload.append((X_val, y_val, sw_val))
|
|
1822
|
+
|
|
1823
|
+
XtX_batch = backend.stack(XtX_folds, axis=0)
|
|
1824
|
+
Xty_batch = backend.stack(Xty_folds, axis=0)
|
|
1825
|
+
|
|
1826
|
+
# Use native Torch FISTA solver for Torch backend
|
|
1827
|
+
if hasattr(xp, '__name__') and 'torch' in xp.__name__.lower():
|
|
1828
|
+
import torch
|
|
1829
|
+
n_samples_vec_torch = torch.tensor(np.asarray(n_train_folds, dtype=np.int32), device=XtX_batch.device, dtype=XtX_batch.dtype)
|
|
1830
|
+
|
|
1831
|
+
coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
|
|
1832
|
+
XtX_batch,
|
|
1833
|
+
Xty_batch,
|
|
1834
|
+
n_samples_vec=n_samples_vec_torch,
|
|
1835
|
+
alphas_desc=alpha_desc,
|
|
1836
|
+
max_iter=int(max_iter),
|
|
1837
|
+
tol=float(tol),
|
|
1838
|
+
stopping="coef_delta",
|
|
1839
|
+
lipschitz_L=None,
|
|
1840
|
+
check_every=8,
|
|
1841
|
+
)
|
|
1842
|
+
|
|
1843
|
+
# Convert results back to numpy for evaluation
|
|
1844
|
+
for fold_idx in range(int(len(folds))):
|
|
1845
|
+
coefs_desc_np = coefs_batch_desc[fold_idx] # already numpy from the solver
|
|
1846
|
+
|
|
1847
|
+
if bool(fit_intercept):
|
|
1848
|
+
y_mean_val = float(y_mean_folds[fold_idx])
|
|
1849
|
+
X_mean_val = X_mean_folds[fold_idx]
|
|
1850
|
+
intercepts_desc = y_mean_val - X_mean_val @ coefs_desc_np.T
|
|
1851
|
+
intercepts_desc_gpu = backend.asarray(intercepts_desc)
|
|
1852
|
+
coefs_desc_gpu = backend.asarray(coefs_desc_np)
|
|
1853
|
+
else:
|
|
1854
|
+
intercepts_desc_gpu = backend.zeros((coefs_desc_np.shape[0],), dtype=coefs_desc_np.dtype)
|
|
1855
|
+
coefs_desc_gpu = backend.asarray(coefs_desc_np)
|
|
1856
|
+
|
|
1857
|
+
X_val, y_val, sw_val = fold_eval_payload[fold_idx]
|
|
1858
|
+
mse_desc = _batch_mse_cv(X_val, y_val, coefs_desc_gpu, intercepts_desc_gpu, sample_weight=sw_val)
|
|
1859
|
+
|
|
1860
|
+
mse_path[alpha_order_desc, fold_idx] = mse_desc
|
|
1861
|
+
else:
|
|
1862
|
+
# CuPy backend - use existing solver directly
|
|
1863
|
+
import cupy as cp
|
|
1864
|
+
n_samples_vec_cp = cp.asarray(np.asarray(n_train_folds, dtype=np.int32))
|
|
1865
|
+
|
|
1866
|
+
coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram(
|
|
1867
|
+
XtX_batch,
|
|
1868
|
+
Xty_batch,
|
|
1869
|
+
n_samples_vec=n_samples_vec_cp,
|
|
1870
|
+
alphas_desc=alpha_desc,
|
|
1871
|
+
max_iter=int(max_iter),
|
|
1872
|
+
tol=float(tol),
|
|
1873
|
+
stopping="coef_delta",
|
|
1874
|
+
lipschitz_L=None,
|
|
1875
|
+
check_every=8,
|
|
1876
|
+
)
|
|
1877
|
+
|
|
1878
|
+
for fold_idx in range(int(len(folds))):
|
|
1879
|
+
coefs_desc = coefs_batch_desc[fold_idx]
|
|
1880
|
+
|
|
1881
|
+
if bool(fit_intercept):
|
|
1882
|
+
intercepts_desc = y_mean_folds[fold_idx] - X_mean_folds[fold_idx] @ coefs_desc.T
|
|
1883
|
+
else:
|
|
1884
|
+
intercepts_desc = backend.zeros((coefs_desc.shape[0],), dtype=coefs_desc.dtype)
|
|
1885
|
+
|
|
1886
|
+
X_val, y_val, sw_val = fold_eval_payload[fold_idx]
|
|
1887
|
+
mse_desc = _batch_mse_cv(X_val, y_val, coefs_desc, intercepts_desc, sample_weight=sw_val)
|
|
1888
|
+
|
|
1889
|
+
mse_path[alpha_order_desc, fold_idx] = mse_desc
|
|
1890
|
+
|
|
1891
|
+
except Exception as exc:
|
|
1892
|
+
raise RuntimeError(
|
|
1893
|
+
"GPU path failed in _select_lasso_alpha_cv with device='cuda'; "
|
|
1894
|
+
"CPU fallback is disabled for strict CUDA execution."
|
|
1895
|
+
) from exc
|
|
1896
|
+
|
|
1897
|
+
if not use_gpu:
|
|
1898
|
+
if gpu_requested:
|
|
1899
|
+
raise RuntimeError(
|
|
1900
|
+
"device='cuda' requested but GPU path was not executed; "
|
|
1901
|
+
"CPU fallback is disabled for strict CUDA execution."
|
|
1902
|
+
)
|
|
1903
|
+
cpu_solver_name = str(cpu_solver).lower()
|
|
1904
|
+
|
|
1905
|
+
if cv_method == "glmnet":
|
|
1906
|
+
# glmnet-like CV profile: coordinate-descent path with periodic full KKT scans.
|
|
1907
|
+
cpu_solver_name = "coordinate_descent"
|
|
1908
|
+
|
|
1909
|
+
if requested_cd_kkt_check_every is None:
|
|
1910
|
+
cd_kkt_check_every_effective = 4 if cv_method == "glmnet" else 1
|
|
1911
|
+
else:
|
|
1912
|
+
cd_kkt_check_every_effective = int(requested_cd_kkt_check_every)
|
|
1913
|
+
|
|
1914
|
+
fast_fold_stats = (sample_weight_np is None) and bool(folds_are_complements)
|
|
1915
|
+
if fast_fold_stats:
|
|
1916
|
+
n_total = int(X_np.shape[0])
|
|
1917
|
+
XtX_full = X_np.T @ X_np
|
|
1918
|
+
Xty_full = X_np.T @ y_np
|
|
1919
|
+
if bool(fit_intercept):
|
|
1920
|
+
X_sum_full = np.sum(X_np, axis=0)
|
|
1921
|
+
y_sum_full = float(np.sum(y_np))
|
|
1922
|
+
else:
|
|
1923
|
+
X_sum_full = None
|
|
1924
|
+
y_sum_full = None
|
|
1925
|
+
|
|
1926
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
1927
|
+
X_val = X_np[val_idx]
|
|
1928
|
+
y_val = y_np[val_idx]
|
|
1929
|
+
sw_val = None if sample_weight_np is None else sample_weight_np[val_idx]
|
|
1930
|
+
|
|
1931
|
+
if fast_fold_stats:
|
|
1932
|
+
n_val = int(np.asarray(val_idx, dtype=np.int64).reshape(-1).size)
|
|
1933
|
+
n_train = int(n_total - n_val)
|
|
1934
|
+
|
|
1935
|
+
XtX_val = X_val.T @ X_val
|
|
1936
|
+
Xty_val = X_val.T @ y_val
|
|
1937
|
+
XtX_raw = XtX_full - XtX_val
|
|
1938
|
+
Xty_raw = Xty_full - Xty_val
|
|
1939
|
+
|
|
1940
|
+
if bool(fit_intercept):
|
|
1941
|
+
X_sum_val = np.sum(X_val, axis=0)
|
|
1942
|
+
y_sum_val = float(np.sum(y_val))
|
|
1943
|
+
X_sum_train = X_sum_full - X_sum_val
|
|
1944
|
+
y_sum_train = y_sum_full - y_sum_val
|
|
1945
|
+
|
|
1946
|
+
inv_n = 1.0 / float(max(1, n_train))
|
|
1947
|
+
X_mean = X_sum_train * inv_n
|
|
1948
|
+
y_mean = y_sum_train * inv_n
|
|
1949
|
+
XtX = XtX_raw - np.outer(X_sum_train, X_sum_train) * inv_n
|
|
1950
|
+
Xty = Xty_raw - X_sum_train * y_mean
|
|
1951
|
+
else:
|
|
1952
|
+
X_mean = np.zeros((X_np.shape[1],), dtype=np.float64)
|
|
1953
|
+
y_mean = 0.0
|
|
1954
|
+
XtX = XtX_raw
|
|
1955
|
+
Xty = Xty_raw
|
|
1956
|
+
else:
|
|
1957
|
+
X_train = X_np[train_idx]
|
|
1958
|
+
y_train = y_np[train_idx]
|
|
1959
|
+
sw_train = None if sample_weight_np is None else sample_weight_np[train_idx]
|
|
1960
|
+
|
|
1961
|
+
if bool(fit_intercept):
|
|
1962
|
+
# Compute weighted means on ORIGINAL data (before sqrt-weighting)
|
|
1963
|
+
if sw_train is not None:
|
|
1964
|
+
sw_sum = float(np.sum(sw_train))
|
|
1965
|
+
X_mean = np.sum(X_train * sw_train[:, np.newaxis], axis=0) / sw_sum
|
|
1966
|
+
y_mean = float(np.sum(y_train * sw_train)) / sw_sum
|
|
1967
|
+
else:
|
|
1968
|
+
X_mean = np.mean(X_train, axis=0)
|
|
1969
|
+
y_mean = float(np.mean(y_train))
|
|
1970
|
+
X_centered = X_train - X_mean
|
|
1971
|
+
y_centered = y_train - y_mean
|
|
1972
|
+
else:
|
|
1973
|
+
X_mean = np.zeros((X_train.shape[1],), dtype=np.float64)
|
|
1974
|
+
y_mean = 0.0
|
|
1975
|
+
X_centered = X_train
|
|
1976
|
+
y_centered = y_train
|
|
1977
|
+
|
|
1978
|
+
# Apply sqrt-weighting after centering
|
|
1979
|
+
if sw_train is not None:
|
|
1980
|
+
sqrt_sw = np.sqrt(sw_train)
|
|
1981
|
+
X_centered = X_centered * sqrt_sw[:, np.newaxis]
|
|
1982
|
+
y_centered = y_centered * sqrt_sw
|
|
1983
|
+
|
|
1984
|
+
XtX = X_centered.T @ X_centered
|
|
1985
|
+
Xty = X_centered.T @ y_centered
|
|
1986
|
+
# Use weight sum as effective sample size for proper alpha scaling
|
|
1987
|
+
n_train = float(np.sum(sw_train)) if sw_train is not None else int(X_train.shape[0])
|
|
1988
|
+
|
|
1989
|
+
coefs_desc, _ = _solve_lasso_path_cpu_from_gram(
|
|
1990
|
+
XtX,
|
|
1991
|
+
Xty,
|
|
1992
|
+
n_samples=int(n_train),
|
|
1993
|
+
alphas_desc=alpha_desc,
|
|
1994
|
+
max_iter=int(max_iter),
|
|
1995
|
+
tol=float(tol),
|
|
1996
|
+
stopping="coef_delta",
|
|
1997
|
+
cpu_solver=cpu_solver_name,
|
|
1998
|
+
lipschitz_L=None,
|
|
1999
|
+
cd_kkt_check_every=cd_kkt_check_every_effective,
|
|
2000
|
+
)
|
|
2001
|
+
|
|
2002
|
+
if bool(fit_intercept):
|
|
2003
|
+
intercepts_desc = y_mean - X_mean @ coefs_desc.T
|
|
2004
|
+
else:
|
|
2005
|
+
intercepts_desc = np.zeros((coefs_desc.shape[0],), dtype=np.float64)
|
|
2006
|
+
|
|
2007
|
+
mse_desc = _batch_mse_cv(
|
|
2008
|
+
X_val,
|
|
2009
|
+
y_val,
|
|
2010
|
+
coefs_desc,
|
|
2011
|
+
intercepts_desc,
|
|
2012
|
+
sample_weight=sw_val,
|
|
2013
|
+
)
|
|
2014
|
+
|
|
2015
|
+
mse_path[alpha_order_desc, fold_idx] = np.asarray(mse_desc, dtype=np.float64)
|
|
2016
|
+
|
|
2017
|
+
for alpha_idx, alpha in enumerate(alpha_grid):
|
|
2018
|
+
alpha_f = float(alpha)
|
|
2019
|
+
valid = np.isfinite(mse_path[alpha_idx])
|
|
2020
|
+
if not bool(np.any(valid)):
|
|
2021
|
+
continue
|
|
2022
|
+
|
|
2023
|
+
mean_mse = float(np.mean(mse_path[alpha_idx, valid]))
|
|
2024
|
+
if mean_mse < best_mse:
|
|
2025
|
+
best_mse = mean_mse
|
|
2026
|
+
best_alpha = alpha_f
|
|
2027
|
+
|
|
2028
|
+
mean_mse_vec = np.full(int(alpha_grid.size), np.nan, dtype=np.float64)
|
|
2029
|
+
for alpha_idx in range(int(alpha_grid.size)):
|
|
2030
|
+
valid = np.isfinite(mse_path[alpha_idx])
|
|
2031
|
+
if bool(np.any(valid)):
|
|
2032
|
+
mean_mse_vec[alpha_idx] = float(np.mean(mse_path[alpha_idx, valid]))
|
|
2033
|
+
|
|
2034
|
+
details = {
|
|
2035
|
+
"alpha": float(best_alpha),
|
|
2036
|
+
"alphas": alpha_grid.astype(np.float64, copy=False),
|
|
2037
|
+
"mse_path": mse_path,
|
|
2038
|
+
"mean_mse": mean_mse_vec,
|
|
2039
|
+
}
|
|
2040
|
+
|
|
2041
|
+
_lasso_cv_cache_put(cache_key_eff, details)
|
|
2042
|
+
|
|
2043
|
+
if return_details:
|
|
2044
|
+
return details
|
|
2045
|
+
|
|
2046
|
+
return float(details["alpha"])
|
|
2047
|
+
|
|
2048
|
+
|
|
2049
|
+
from statgpu.linear_model.penalized._penalized_linear import PenalizedLinearRegression as _PenalizedLinearRegression
|
|
2050
|
+
|
|
2051
|
+
|
|
2052
|
+
class Lasso(_PenalizedLinearRegression):
|
|
2053
|
+
"""Thin sklearn-style wrapper over ``PenalizedLinearRegression`` with L1 penalty."""
|
|
2054
|
+
|
|
2055
|
+
def __init__(
|
|
2056
|
+
self,
|
|
2057
|
+
alpha: float = 1.0,
|
|
2058
|
+
fit_intercept: bool = True,
|
|
2059
|
+
max_iter: int = 1000,
|
|
2060
|
+
tol: float = 1e-4,
|
|
2061
|
+
stopping: str = "coef_delta",
|
|
2062
|
+
inference_method: str = "debiased",
|
|
2063
|
+
n_bootstrap: int = 200,
|
|
2064
|
+
bootstrap_random_state: Optional[int] = None,
|
|
2065
|
+
enable_simultaneous_inference: bool = False,
|
|
2066
|
+
simultaneous_method: str = "maxz_bootstrap",
|
|
2067
|
+
simultaneous_alpha: float = 0.05,
|
|
2068
|
+
simultaneous_n_bootstrap: int = 1000,
|
|
2069
|
+
simultaneous_random_state: Optional[int] = None,
|
|
2070
|
+
simultaneous_include_intercept: bool = False,
|
|
2071
|
+
device: Union[str, Device] = Device.AUTO,
|
|
2072
|
+
n_jobs: Optional[int] = None,
|
|
2073
|
+
compute_inference: bool = True,
|
|
2074
|
+
solver: str = "fista",
|
|
2075
|
+
cpu_solver: str = "coordinate_descent",
|
|
2076
|
+
lipschitz_L: Optional[float] = None,
|
|
2077
|
+
admm_rho: float = 1.0,
|
|
2078
|
+
gpu_memory_cleanup: bool = False,
|
|
2079
|
+
):
|
|
2080
|
+
# Lasso-specific attributes (set before super().__init__ which doesn't know them)
|
|
2081
|
+
self.n_bootstrap = int(n_bootstrap)
|
|
2082
|
+
self.bootstrap_random_state = bootstrap_random_state
|
|
2083
|
+
self.enable_simultaneous_inference = bool(enable_simultaneous_inference)
|
|
2084
|
+
_sm = str(simultaneous_method).lower()
|
|
2085
|
+
self.simultaneous_method = simultaneous_method if simultaneous_method == _sm else _sm
|
|
2086
|
+
self.simultaneous_alpha = float(simultaneous_alpha)
|
|
2087
|
+
self.simultaneous_n_bootstrap = int(simultaneous_n_bootstrap)
|
|
2088
|
+
self.simultaneous_random_state = simultaneous_random_state
|
|
2089
|
+
self.simultaneous_include_intercept = bool(simultaneous_include_intercept)
|
|
2090
|
+
self.admm_rho = float(admm_rho)
|
|
2091
|
+
super().__init__(
|
|
2092
|
+
penalty="l1",
|
|
2093
|
+
alpha=alpha,
|
|
2094
|
+
fit_intercept=fit_intercept,
|
|
2095
|
+
max_iter=max_iter,
|
|
2096
|
+
tol=tol,
|
|
2097
|
+
device=device,
|
|
2098
|
+
n_jobs=n_jobs,
|
|
2099
|
+
cpu_solver=cpu_solver,
|
|
2100
|
+
solver=solver,
|
|
2101
|
+
lipschitz_L=lipschitz_L,
|
|
2102
|
+
gpu_memory_cleanup=gpu_memory_cleanup,
|
|
2103
|
+
compute_inference=compute_inference,
|
|
2104
|
+
stopping=stopping,
|
|
2105
|
+
)
|
|
2106
|
+
# Re-set after super().__init__() which overwrites with parent default
|
|
2107
|
+
_im = str(inference_method).lower()
|
|
2108
|
+
self.inference_method = inference_method if inference_method == _im else _im
|
|
2109
|
+
|
|
2110
|
+
# Validate simultaneous inference settings
|
|
2111
|
+
if self.enable_simultaneous_inference:
|
|
2112
|
+
if self.simultaneous_method != "maxz_bootstrap":
|
|
2113
|
+
raise ValueError(
|
|
2114
|
+
f"simultaneous_method must be 'maxz_bootstrap', "
|
|
2115
|
+
f"got '{self.simultaneous_method}'"
|
|
2116
|
+
)
|
|
2117
|
+
if "debiased" not in self.inference_method:
|
|
2118
|
+
raise ValueError(
|
|
2119
|
+
"Simultaneous inference requires inference_method='debiased'."
|
|
2120
|
+
)
|
|
2121
|
+
if not self.compute_inference:
|
|
2122
|
+
raise ValueError(
|
|
2123
|
+
"Simultaneous inference requires compute_inference=True."
|
|
2124
|
+
)
|