statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,936 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Elastic Net regression with GPU acceleration and full statistical inference.
|
|
3
|
+
|
|
4
|
+
Elastic Net combines L1 and L2 regularization:
|
|
5
|
+
minimize (1/(2n)) * ||y - Xw||²₂ + α * l1_ratio * ||w||₁ + 0.5 * α * (1 - l1_ratio) * ||w||²₂
|
|
6
|
+
|
|
7
|
+
where:
|
|
8
|
+
- α (alpha) controls the overall regularization strength
|
|
9
|
+
- l1_ratio controls the mix: 1.0 = Lasso, 0.0 = Ridge, 0.5 = balanced Elastic Net
|
|
10
|
+
|
|
11
|
+
Optimized implementations:
|
|
12
|
+
- CPU: FISTA with pre-computed Gram matrix
|
|
13
|
+
- GPU (CuPy): Fused kernel operations with @cp.fuse()
|
|
14
|
+
- GPU (Torch): torch.compile() with warm-up strategy
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Optional, Union
|
|
18
|
+
import warnings
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from statgpu._base import BaseEstimator
|
|
22
|
+
from statgpu._config import Device
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# ============================================================================
|
|
26
|
+
# CuPy Fused Kernels for Elastic Net
|
|
27
|
+
# ============================================================================
|
|
28
|
+
|
|
29
|
+
def _get_cupy_fused_kernels():
|
|
30
|
+
"""Lazy load CuPy fused kernels."""
|
|
31
|
+
try:
|
|
32
|
+
import cupy as cp
|
|
33
|
+
except ImportError:
|
|
34
|
+
return None, None, None, None
|
|
35
|
+
|
|
36
|
+
@cp.fuse()
|
|
37
|
+
def _elastic_net_proximal(x, thresh, l2_scale):
|
|
38
|
+
"""Fused soft thresholding with L2 scaling."""
|
|
39
|
+
return cp.sign(x) * cp.maximum(cp.abs(x) - thresh, 0) / l2_scale
|
|
40
|
+
|
|
41
|
+
@cp.fuse()
|
|
42
|
+
def _fista_momentum_update(coef, coef_old, t_old, t_new):
|
|
43
|
+
"""Fused FISTA momentum update."""
|
|
44
|
+
beta = (t_old - 1) / t_new
|
|
45
|
+
return coef + beta * (coef - coef_old)
|
|
46
|
+
|
|
47
|
+
@cp.fuse()
|
|
48
|
+
def _compute_coef_delta(coef, coef_old):
|
|
49
|
+
"""Compute absolute coefficient change."""
|
|
50
|
+
return cp.abs(coef - coef_old)
|
|
51
|
+
|
|
52
|
+
ELASTIC_NET_PROXIMAL_KERNEL = cp.ElementwiseKernel(
|
|
53
|
+
'float64 w_tilde, float64 thresh, float64 l2_scale',
|
|
54
|
+
'float64 coef',
|
|
55
|
+
'''
|
|
56
|
+
double abs_w = abs(w_tilde);
|
|
57
|
+
if (abs_w > thresh) {
|
|
58
|
+
coef = (w_tilde > 0 ? 1.0 : -1.0) * (abs_w - thresh) / l2_scale;
|
|
59
|
+
} else {
|
|
60
|
+
coef = 0.0;
|
|
61
|
+
}
|
|
62
|
+
''',
|
|
63
|
+
'elastic_net_proximal'
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return _elastic_net_proximal, _fista_momentum_update, _compute_coef_delta, ELASTIC_NET_PROXIMAL_KERNEL
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _fit_elasticnet_cupy_optimized(X, y, alpha, l1_ratio, n_samples, n_features,
|
|
70
|
+
max_iter=1000, tol=1e-4, lipschitz_L=None,
|
|
71
|
+
stopping='coef_delta', warmup=True):
|
|
72
|
+
"""
|
|
73
|
+
Fit Elastic Net using optimized CuPy operations with fused kernels.
|
|
74
|
+
"""
|
|
75
|
+
import cupy as cp
|
|
76
|
+
|
|
77
|
+
# Get fused kernels
|
|
78
|
+
_elastic_net_proximal, _fista_momentum_update, _compute_coef_delta, _ = _get_cupy_fused_kernels()
|
|
79
|
+
if _elastic_net_proximal is None:
|
|
80
|
+
raise ImportError("CuPy not available")
|
|
81
|
+
|
|
82
|
+
# Precompute Gram matrix and cross product
|
|
83
|
+
XtX = X.T @ X
|
|
84
|
+
Xty = X.T @ y
|
|
85
|
+
|
|
86
|
+
# Parameters
|
|
87
|
+
l2_ratio = 1.0 - l1_ratio
|
|
88
|
+
|
|
89
|
+
# Lipschitz constant: L = lambda_max(XtX) / n
|
|
90
|
+
if lipschitz_L is not None:
|
|
91
|
+
L = float(lipschitz_L)
|
|
92
|
+
else:
|
|
93
|
+
eigvals = cp.linalg.eigvalsh(XtX)
|
|
94
|
+
L = float(eigvals[-1]) / n_samples
|
|
95
|
+
|
|
96
|
+
if L <= 0:
|
|
97
|
+
return cp.zeros(n_features), 0
|
|
98
|
+
|
|
99
|
+
step = 1.0 / L
|
|
100
|
+
thresh = alpha * l1_ratio * step
|
|
101
|
+
l2_scale = 1.0 + alpha * l2_ratio * step
|
|
102
|
+
|
|
103
|
+
# Pre-compute inverse for multiplication (faster than division)
|
|
104
|
+
inv_n_samples = 1.0 / n_samples
|
|
105
|
+
inv_l2_scale = 1.0 / l2_scale
|
|
106
|
+
|
|
107
|
+
# Allocate buffers (reuse to minimize allocation overhead)
|
|
108
|
+
coef = cp.zeros(n_features, dtype=X.dtype)
|
|
109
|
+
y_k = cp.zeros(n_features, dtype=X.dtype)
|
|
110
|
+
coef_old = cp.zeros(n_features, dtype=X.dtype)
|
|
111
|
+
grad = cp.empty(n_features, dtype=X.dtype)
|
|
112
|
+
w_tilde = cp.empty(n_features, dtype=X.dtype)
|
|
113
|
+
|
|
114
|
+
# FISTA state
|
|
115
|
+
t_k = 1.0
|
|
116
|
+
n_iter = 0
|
|
117
|
+
|
|
118
|
+
# Warm-up: Call fused kernel once to trigger JIT compilation
|
|
119
|
+
if warmup:
|
|
120
|
+
_ = _elastic_net_proximal(w_tilde, thresh, l2_scale)
|
|
121
|
+
_ = (1.0 + cp.sqrt(1.0 + 4.0 * t_k * t_k)) * 0.5
|
|
122
|
+
|
|
123
|
+
for iteration in range(max_iter):
|
|
124
|
+
# Store old coefficients for convergence check
|
|
125
|
+
coef_old[:] = coef
|
|
126
|
+
|
|
127
|
+
# Gradient step: grad = (XtX @ y_k - Xty) / n
|
|
128
|
+
grad = XtX @ y_k
|
|
129
|
+
grad -= Xty
|
|
130
|
+
grad *= inv_n_samples
|
|
131
|
+
|
|
132
|
+
# Proximal step: w_tilde = y_k - step * grad
|
|
133
|
+
w_tilde = y_k - step * grad
|
|
134
|
+
|
|
135
|
+
# Soft thresholding with L2 scaling (using fused kernel)
|
|
136
|
+
coef = _elastic_net_proximal(w_tilde, thresh, l2_scale)
|
|
137
|
+
|
|
138
|
+
# FISTA momentum update
|
|
139
|
+
t_new = (1.0 + cp.sqrt(1.0 + 4.0 * t_k * t_k)) * 0.5
|
|
140
|
+
beta = (t_k - 1.0) / t_new
|
|
141
|
+
y_k = coef + beta * (coef - coef_old)
|
|
142
|
+
t_k = t_new
|
|
143
|
+
|
|
144
|
+
n_iter = iteration + 1
|
|
145
|
+
|
|
146
|
+
# Convergence check
|
|
147
|
+
if stopping == 'kkt':
|
|
148
|
+
kkt_grad = XtX @ coef
|
|
149
|
+
kkt_grad -= Xty
|
|
150
|
+
kkt_grad *= inv_n_samples
|
|
151
|
+
|
|
152
|
+
grad_l2 = alpha * l2_ratio * coef
|
|
153
|
+
sign_coef = cp.sign(coef)
|
|
154
|
+
sign_coef[coef == 0] = 0
|
|
155
|
+
|
|
156
|
+
kkt_violation = cp.maximum(
|
|
157
|
+
cp.abs(kkt_grad + grad_l2 + alpha * l1_ratio * sign_coef),
|
|
158
|
+
cp.maximum(cp.abs(kkt_grad + grad_l2) - alpha * l1_ratio, 0)
|
|
159
|
+
)
|
|
160
|
+
violation = float(cp.max(kkt_violation))
|
|
161
|
+
else:
|
|
162
|
+
delta = cp.abs(coef - coef_old)
|
|
163
|
+
violation = float(cp.max(delta))
|
|
164
|
+
|
|
165
|
+
if violation < tol:
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
return coef, n_iter
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ============================================================================
|
|
172
|
+
# Torch Compiled Kernels for Elastic Net
|
|
173
|
+
# ============================================================================
|
|
174
|
+
|
|
175
|
+
def _get_torch_compiled_proximal():
|
|
176
|
+
"""Lazy load torch.compile proximal operator."""
|
|
177
|
+
try:
|
|
178
|
+
import torch
|
|
179
|
+
except ImportError:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
def _elastic_net_proximal_torch(w_tilde, thresh, l2_scale):
|
|
183
|
+
"""Soft thresholding with L2 scaling for Elastic Net."""
|
|
184
|
+
return torch.sign(w_tilde) * torch.maximum(
|
|
185
|
+
torch.abs(w_tilde) - thresh,
|
|
186
|
+
torch.tensor(0.0, device=w_tilde.device, dtype=w_tilde.dtype)
|
|
187
|
+
) / l2_scale
|
|
188
|
+
|
|
189
|
+
# Compile the proximal operator
|
|
190
|
+
try:
|
|
191
|
+
torch._dynamo.config.suppress_errors = True
|
|
192
|
+
torch._dynamo.config.guard_immutable_object = False
|
|
193
|
+
_elastic_net_proximal_compiled = torch.compile(
|
|
194
|
+
_elastic_net_proximal_torch, mode='reduce-overhead'
|
|
195
|
+
)
|
|
196
|
+
except (AttributeError, RuntimeError):
|
|
197
|
+
_elastic_net_proximal_compiled = _elastic_net_proximal_torch
|
|
198
|
+
|
|
199
|
+
return _elastic_net_proximal_compiled
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _fit_elasticnet_torch_optimized(X, y, alpha, l1_ratio, n_samples, n_features,
|
|
203
|
+
max_iter=1000, tol=1e-4, lipschitz_L=None,
|
|
204
|
+
stopping='coef_delta', warmup=True):
|
|
205
|
+
"""
|
|
206
|
+
Fit Elastic Net using optimized PyTorch operations with torch.compile().
|
|
207
|
+
"""
|
|
208
|
+
import torch
|
|
209
|
+
|
|
210
|
+
# Get compiled proximal operator
|
|
211
|
+
_elastic_net_proximal_compiled = _get_torch_compiled_proximal()
|
|
212
|
+
if _elastic_net_proximal_compiled is None:
|
|
213
|
+
raise ImportError("Torch not available")
|
|
214
|
+
|
|
215
|
+
# Precompute Gram matrix and cross product
|
|
216
|
+
XtX = X.T @ X
|
|
217
|
+
Xty = X.T @ y
|
|
218
|
+
|
|
219
|
+
# Parameters
|
|
220
|
+
l2_ratio = 1.0 - l1_ratio
|
|
221
|
+
|
|
222
|
+
# Lipschitz constant: L = lambda_max(XtX) / n
|
|
223
|
+
if lipschitz_L is not None:
|
|
224
|
+
L = float(lipschitz_L)
|
|
225
|
+
else:
|
|
226
|
+
eigvals = torch.linalg.eigvalsh(XtX)
|
|
227
|
+
L = float(eigvals[-1]) / n_samples
|
|
228
|
+
|
|
229
|
+
if L <= 0:
|
|
230
|
+
return torch.zeros(n_features, device=X.device, dtype=X.dtype), 0
|
|
231
|
+
|
|
232
|
+
step = 1.0 / L
|
|
233
|
+
thresh = alpha * l1_ratio * step
|
|
234
|
+
l2_scale = 1.0 + alpha * l2_ratio * step
|
|
235
|
+
|
|
236
|
+
# Pre-compute inverse for multiplication (faster than division)
|
|
237
|
+
inv_n_samples = 1.0 / n_samples
|
|
238
|
+
inv_l2_scale = 1.0 / l2_scale
|
|
239
|
+
|
|
240
|
+
# Allocate buffers (reuse to minimize allocation overhead)
|
|
241
|
+
coef = torch.zeros(n_features, dtype=X.dtype, device=X.device)
|
|
242
|
+
y_k = torch.zeros(n_features, dtype=X.dtype, device=X.device)
|
|
243
|
+
coef_old = torch.zeros(n_features, dtype=X.dtype, device=X.device)
|
|
244
|
+
grad = torch.empty(n_features, dtype=X.dtype, device=X.device)
|
|
245
|
+
w_tilde = torch.empty(n_features, dtype=X.dtype, device=X.device)
|
|
246
|
+
|
|
247
|
+
# FISTA state
|
|
248
|
+
t_k = 1.0
|
|
249
|
+
n_iter = 0
|
|
250
|
+
|
|
251
|
+
# Warm-up: Call compiled kernel once to trigger JIT compilation
|
|
252
|
+
if warmup:
|
|
253
|
+
_ = _elastic_net_proximal_compiled(w_tilde, thresh, l2_scale)
|
|
254
|
+
_ = (1.0 + torch.sqrt(1.0 + 4.0 * t_k * t_k)) * 0.5
|
|
255
|
+
|
|
256
|
+
for iteration in range(max_iter):
|
|
257
|
+
# Store old coefficients for convergence check
|
|
258
|
+
coef_old.copy_(coef)
|
|
259
|
+
|
|
260
|
+
# Gradient step: grad = (XtX @ y_k - Xty) / n
|
|
261
|
+
torch.matmul(XtX, y_k, out=grad)
|
|
262
|
+
grad -= Xty
|
|
263
|
+
grad *= inv_n_samples
|
|
264
|
+
|
|
265
|
+
# Proximal step: w_tilde = y_k - step * grad
|
|
266
|
+
torch.subtract(y_k, grad, alpha=step, out=w_tilde)
|
|
267
|
+
|
|
268
|
+
# Soft thresholding with L2 scaling (using compiled fused kernel)
|
|
269
|
+
coef = _elastic_net_proximal_compiled(w_tilde, thresh, l2_scale)
|
|
270
|
+
|
|
271
|
+
# FISTA momentum update
|
|
272
|
+
t_new = (1.0 + torch.sqrt(1.0 + 4.0 * t_k * t_k)) * 0.5
|
|
273
|
+
beta = (t_k - 1.0) / t_new
|
|
274
|
+
y_k = coef + beta * (coef - coef_old)
|
|
275
|
+
t_k = t_new
|
|
276
|
+
|
|
277
|
+
n_iter = iteration + 1
|
|
278
|
+
|
|
279
|
+
# Convergence check
|
|
280
|
+
if stopping == 'kkt':
|
|
281
|
+
kkt_grad = torch.matmul(XtX, coef, out=grad)
|
|
282
|
+
kkt_grad -= Xty
|
|
283
|
+
kkt_grad *= inv_n_samples
|
|
284
|
+
|
|
285
|
+
grad_l2 = alpha * l2_ratio * coef
|
|
286
|
+
sign_coef = torch.sign(coef)
|
|
287
|
+
sign_coef[coef == 0] = 0
|
|
288
|
+
|
|
289
|
+
kkt_violation = torch.maximum(
|
|
290
|
+
torch.abs(kkt_grad + grad_l2 + alpha * l1_ratio * sign_coef),
|
|
291
|
+
torch.maximum(
|
|
292
|
+
torch.abs(kkt_grad + grad_l2) - alpha * l1_ratio,
|
|
293
|
+
torch.tensor(0.0, device=X.device)
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
violation = float(torch.max(kkt_violation).item())
|
|
297
|
+
else:
|
|
298
|
+
delta = torch.abs(coef - coef_old)
|
|
299
|
+
violation = float(torch.max(delta).item())
|
|
300
|
+
|
|
301
|
+
if violation < tol:
|
|
302
|
+
break
|
|
303
|
+
|
|
304
|
+
return coef, n_iter
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# ============================================================================
|
|
308
|
+
# Elastic Net Estimator Class
|
|
309
|
+
# ============================================================================
|
|
310
|
+
|
|
311
|
+
class ElasticNet(BaseEstimator):
|
|
312
|
+
"""
|
|
313
|
+
Elastic Net regression with GPU acceleration.
|
|
314
|
+
|
|
315
|
+
Elastic Net combines L1 (Lasso) and L2 (Ridge) regularization, controlled by
|
|
316
|
+
the `l1_ratio` parameter. This provides:
|
|
317
|
+
- Feature selection from L1 (sparse solutions)
|
|
318
|
+
- Grouping effect from L2 (handles correlated features)
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
alpha : float, default=1.0
|
|
323
|
+
Regularization strength. Larger values specify stronger regularization.
|
|
324
|
+
Must be non-negative.
|
|
325
|
+
l1_ratio : float, default=0.5
|
|
326
|
+
Elastic Net mixing parameter, between 0 and 1 inclusive.
|
|
327
|
+
- l1_ratio = 1: L1 penalty only (Lasso)
|
|
328
|
+
- l1_ratio = 0: L2 penalty only (Ridge)
|
|
329
|
+
- 0 < l1_ratio < 1: Combination of L1 and L2 penalties
|
|
330
|
+
fit_intercept : bool, default=True
|
|
331
|
+
Whether to calculate the intercept.
|
|
332
|
+
max_iter : int, default=1000
|
|
333
|
+
Maximum number of iterations for the solver.
|
|
334
|
+
tol : float, default=1e-4
|
|
335
|
+
Tolerance for convergence.
|
|
336
|
+
stopping : str, default='coef_delta'
|
|
337
|
+
Stopping criterion: 'coef_delta' or 'kkt'.
|
|
338
|
+
device : str or Device, default='auto'
|
|
339
|
+
Computation device: 'cpu', 'cuda', or 'auto'.
|
|
340
|
+
solver : str, default='fista'
|
|
341
|
+
GPU optimization algorithm: 'fista' or 'admm'.
|
|
342
|
+
Note: ADMM not yet implemented for Elastic Net.
|
|
343
|
+
cpu_solver : str, default='fista'
|
|
344
|
+
CPU optimization algorithm: 'fista' or 'coordinate_descent'.
|
|
345
|
+
Note: coordinate_descent not yet implemented for Elastic Net.
|
|
346
|
+
lipschitz_L : float, optional
|
|
347
|
+
Pre-computed Lipschitz constant. If not provided, will be estimated.
|
|
348
|
+
gpu_memory_cleanup : bool, default=False
|
|
349
|
+
If True, free GPU memory pool after fitting.
|
|
350
|
+
|
|
351
|
+
Attributes
|
|
352
|
+
----------
|
|
353
|
+
coef_ : ndarray of shape (n_features,)
|
|
354
|
+
Estimated coefficients.
|
|
355
|
+
intercept_ : float
|
|
356
|
+
Independent term.
|
|
357
|
+
n_iter_ : int
|
|
358
|
+
Number of iterations run.
|
|
359
|
+
|
|
360
|
+
See Also
|
|
361
|
+
--------
|
|
362
|
+
Lasso : Lasso regression with L1 regularization.
|
|
363
|
+
Ridge : Ridge regression with L2 regularization.
|
|
364
|
+
|
|
365
|
+
Notes
|
|
366
|
+
-----
|
|
367
|
+
The objective function is:
|
|
368
|
+
|
|
369
|
+
(1 / (2 * n_samples)) * ||y - Xw||²₂ + α * l1_ratio * ||w||₁ + 0.5 * α * (1 - l1_ratio) * ||w||²₂
|
|
370
|
+
|
|
371
|
+
References
|
|
372
|
+
----------
|
|
373
|
+
.. [1] Zou, H., & Hastie, T. (2005). Regularization and variable selection
|
|
374
|
+
via the elastic net. Journal of the Royal Statistical Society:
|
|
375
|
+
Series B, 67(2), 301-320.
|
|
376
|
+
.. [2] Beck, A., & Teboulle, M. (2009). A fast iterative shrinkage-thresholding
|
|
377
|
+
algorithm for linear inverse problems. SIAM Journal on Imaging Sciences,
|
|
378
|
+
2(1), 183-202.
|
|
379
|
+
|
|
380
|
+
Examples
|
|
381
|
+
--------
|
|
382
|
+
>>> import numpy as np
|
|
383
|
+
>>> from statgpu.linear_model import ElasticNet
|
|
384
|
+
>>> X = np.random.randn(100, 10)
|
|
385
|
+
>>> y = X @ np.random.randn(10) + np.random.randn(100)
|
|
386
|
+
>>> model = ElasticNet(alpha=1.0, l1_ratio=0.5)
|
|
387
|
+
>>> model.fit(X, y)
|
|
388
|
+
>>> print(model.coef_)
|
|
389
|
+
"""
|
|
390
|
+
|
|
391
|
+
def __init__(
|
|
392
|
+
self,
|
|
393
|
+
alpha: float = 1.0,
|
|
394
|
+
l1_ratio: float = 0.5,
|
|
395
|
+
fit_intercept: bool = True,
|
|
396
|
+
max_iter: int = 1000,
|
|
397
|
+
tol: float = 1e-4,
|
|
398
|
+
stopping: str = "coef_delta",
|
|
399
|
+
device: Union[str, Device] = Device.AUTO,
|
|
400
|
+
n_jobs: Optional[int] = None,
|
|
401
|
+
solver: str = "fista",
|
|
402
|
+
cpu_solver: str = "fista",
|
|
403
|
+
lipschitz_L: Optional[float] = None,
|
|
404
|
+
gpu_memory_cleanup: bool = False,
|
|
405
|
+
):
|
|
406
|
+
super().__init__(device=device, n_jobs=n_jobs)
|
|
407
|
+
self.alpha = alpha
|
|
408
|
+
self.l1_ratio = l1_ratio
|
|
409
|
+
self.fit_intercept = fit_intercept
|
|
410
|
+
self.max_iter = max_iter
|
|
411
|
+
self.tol = tol
|
|
412
|
+
self.stopping = stopping.lower()
|
|
413
|
+
self.solver = solver.lower()
|
|
414
|
+
self.cpu_solver = cpu_solver.lower()
|
|
415
|
+
self.lipschitz_L = lipschitz_L
|
|
416
|
+
self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
|
|
417
|
+
|
|
418
|
+
self.coef_ = None
|
|
419
|
+
self.intercept_ = None
|
|
420
|
+
self.n_iter_ = 0
|
|
421
|
+
|
|
422
|
+
# Internal storage
|
|
423
|
+
self._params = None
|
|
424
|
+
self._scale = None
|
|
425
|
+
self._df_resid = None
|
|
426
|
+
self._nobs = None
|
|
427
|
+
self._X_design = None
|
|
428
|
+
self._resid = None
|
|
429
|
+
|
|
430
|
+
def fit(self, X, y, sample_weight=None, initial_coef=None):
|
|
431
|
+
"""
|
|
432
|
+
Fit Elastic Net model.
|
|
433
|
+
|
|
434
|
+
Parameters
|
|
435
|
+
----------
|
|
436
|
+
X : array-like of shape (n_samples, n_features)
|
|
437
|
+
Training data.
|
|
438
|
+
y : array-like of shape (n_samples,)
|
|
439
|
+
Target values.
|
|
440
|
+
sample_weight : array-like of shape (n_samples,), optional
|
|
441
|
+
Sample weights.
|
|
442
|
+
initial_coef : array-like of shape (n_features,), optional
|
|
443
|
+
Initial coefficient vector for warm-start. When fitting along a
|
|
444
|
+
regularization path (alphas from large to small), passing the
|
|
445
|
+
previous solution can significantly reduce iterations.
|
|
446
|
+
|
|
447
|
+
Returns
|
|
448
|
+
-------
|
|
449
|
+
self : ElasticNet
|
|
450
|
+
Fitted estimator.
|
|
451
|
+
"""
|
|
452
|
+
device = self._get_compute_device()
|
|
453
|
+
backend = self._get_backend(backend="auto")
|
|
454
|
+
backend_name = backend.name
|
|
455
|
+
|
|
456
|
+
X_arr = self._to_array(X, backend=backend_name)
|
|
457
|
+
y_arr = self._to_array(y, backend=backend_name)
|
|
458
|
+
|
|
459
|
+
# Route to appropriate backend
|
|
460
|
+
if backend_name == "torch":
|
|
461
|
+
self._fit_torch(X_arr, y_arr, sample_weight)
|
|
462
|
+
elif device == Device.CUDA:
|
|
463
|
+
self._fit_gpu(X_arr, y_arr, sample_weight)
|
|
464
|
+
else:
|
|
465
|
+
self._fit_cpu(X_arr, y_arr, sample_weight, initial_coef=initial_coef)
|
|
466
|
+
|
|
467
|
+
self._fitted = True
|
|
468
|
+
return self
|
|
469
|
+
|
|
470
|
+
def predict(self, X):
|
|
471
|
+
"""
|
|
472
|
+
Predict using Elastic Net model.
|
|
473
|
+
|
|
474
|
+
Parameters
|
|
475
|
+
----------
|
|
476
|
+
X : array-like of shape (n_samples, n_features)
|
|
477
|
+
Test data.
|
|
478
|
+
|
|
479
|
+
Returns
|
|
480
|
+
-------
|
|
481
|
+
y_pred : ndarray of shape (n_samples,)
|
|
482
|
+
Predicted values.
|
|
483
|
+
"""
|
|
484
|
+
if self.coef_ is None:
|
|
485
|
+
raise RuntimeError("Model has not been fitted yet.")
|
|
486
|
+
|
|
487
|
+
device = self._get_compute_device()
|
|
488
|
+
if device == Device.CUDA:
|
|
489
|
+
import cupy as cp
|
|
490
|
+
X_gpu = cp.asarray(self._to_array(X, Device.CUDA))
|
|
491
|
+
coef_gpu = cp.asarray(self.coef_)
|
|
492
|
+
y_pred = X_gpu @ coef_gpu
|
|
493
|
+
if self.fit_intercept:
|
|
494
|
+
y_pred += cp.asarray(self.intercept_, dtype=coef_gpu.dtype)
|
|
495
|
+
return y_pred
|
|
496
|
+
if device == Device.TORCH:
|
|
497
|
+
import torch
|
|
498
|
+
X_torch = self._to_array(X, Device.TORCH, backend="torch").to(torch.float64)
|
|
499
|
+
coef_torch = torch.as_tensor(self.coef_, dtype=X_torch.dtype, device=X_torch.device)
|
|
500
|
+
y_pred = X_torch @ coef_torch
|
|
501
|
+
if self.fit_intercept:
|
|
502
|
+
y_pred = y_pred + torch.as_tensor(
|
|
503
|
+
self.intercept_, dtype=y_pred.dtype, device=y_pred.device
|
|
504
|
+
)
|
|
505
|
+
return y_pred
|
|
506
|
+
|
|
507
|
+
X = np.asarray(X)
|
|
508
|
+
y_pred = X @ self.coef_
|
|
509
|
+
if self.fit_intercept:
|
|
510
|
+
y_pred += self.intercept_
|
|
511
|
+
return y_pred
|
|
512
|
+
|
|
513
|
+
def score(self, X, y):
|
|
514
|
+
"""
|
|
515
|
+
Return the coefficient of determination R².
|
|
516
|
+
|
|
517
|
+
Parameters
|
|
518
|
+
----------
|
|
519
|
+
X : array-like of shape (n_samples, n_features)
|
|
520
|
+
Test data.
|
|
521
|
+
y : array-like of shape (n_samples,)
|
|
522
|
+
True values.
|
|
523
|
+
|
|
524
|
+
Returns
|
|
525
|
+
-------
|
|
526
|
+
r2 : float
|
|
527
|
+
R² score.
|
|
528
|
+
"""
|
|
529
|
+
y_pred = self.predict(X)
|
|
530
|
+
device = self._get_compute_device()
|
|
531
|
+
if device == Device.CUDA:
|
|
532
|
+
import cupy as cp
|
|
533
|
+
|
|
534
|
+
yb = cp.asarray(self._to_array(y, Device.CUDA))
|
|
535
|
+
ss_res = cp.sum((yb - y_pred) ** 2)
|
|
536
|
+
ss_tot = cp.sum((yb - cp.mean(yb)) ** 2)
|
|
537
|
+
return float((1 - ss_res / ss_tot).item()) if float(ss_tot.item()) > 0 else 0.0
|
|
538
|
+
if device == Device.TORCH:
|
|
539
|
+
import torch
|
|
540
|
+
|
|
541
|
+
yb = self._to_array(y, Device.TORCH, backend="torch").to(y_pred.dtype)
|
|
542
|
+
ss_res = torch.sum((yb - y_pred) ** 2)
|
|
543
|
+
ss_tot = torch.sum((yb - torch.mean(yb)) ** 2)
|
|
544
|
+
return float((1 - ss_res / ss_tot).item()) if float(ss_tot.item()) > 0 else 0.0
|
|
545
|
+
y_pred = np.asarray(y_pred)
|
|
546
|
+
y = self._to_numpy(y)
|
|
547
|
+
ss_res = np.sum((y - y_pred) ** 2)
|
|
548
|
+
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
|
549
|
+
return 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
|
|
550
|
+
|
|
551
|
+
def _soft_threshold(self, x, gamma):
|
|
552
|
+
"""Standard soft thresholding operator for Lasso."""
|
|
553
|
+
return np.sign(x) * np.maximum(np.abs(x) - gamma, 0)
|
|
554
|
+
|
|
555
|
+
def _soft_threshold_elastic(self, x, gamma, l2_scale):
|
|
556
|
+
"""
|
|
557
|
+
Elastic Net soft thresholding operator.
|
|
558
|
+
|
|
559
|
+
Applies soft thresholding then divides by L2 scaling factor.
|
|
560
|
+
This is the proximal operator for L1 + L2 regularization.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
x : ndarray
|
|
565
|
+
Input array
|
|
566
|
+
gamma : float
|
|
567
|
+
Threshold parameter (alpha * l1_ratio * step)
|
|
568
|
+
l2_scale : float
|
|
569
|
+
L2 scaling factor (1 + alpha * (1 - l1_ratio) * step)
|
|
570
|
+
|
|
571
|
+
Returns
|
|
572
|
+
-------
|
|
573
|
+
ndarray
|
|
574
|
+
Soft thresholded and scaled result
|
|
575
|
+
"""
|
|
576
|
+
return self._soft_threshold(x, gamma) / l2_scale
|
|
577
|
+
|
|
578
|
+
def _fit_cpu(self, X, y, sample_weight=None, initial_coef=None):
|
|
579
|
+
"""
|
|
580
|
+
Fit using CPU FISTA solver with optimized implementation.
|
|
581
|
+
|
|
582
|
+
Elastic Net proximal gradient update:
|
|
583
|
+
grad = (XtX @ w - Xty) / n # gradient of RSS only
|
|
584
|
+
w = soft_threshold(w - step*grad, alpha*l1_ratio*step) / (1 + alpha*(1-l1_ratio)*step)
|
|
585
|
+
|
|
586
|
+
Note: L2 regularization is handled in the proximal step, NOT in the gradient.
|
|
587
|
+
|
|
588
|
+
Parameters
|
|
589
|
+
----------
|
|
590
|
+
X : ndarray
|
|
591
|
+
Training data (n_samples, n_features).
|
|
592
|
+
y : ndarray
|
|
593
|
+
Target values (n_samples,).
|
|
594
|
+
sample_weight : ndarray, optional
|
|
595
|
+
Sample weights.
|
|
596
|
+
initial_coef : ndarray, optional
|
|
597
|
+
Initial coefficient vector for warm-start. If provided, avoids starting from zero
|
|
598
|
+
and can significantly speed up convergence along a regularization path.
|
|
599
|
+
"""
|
|
600
|
+
X = np.asarray(X)
|
|
601
|
+
y = np.asarray(y)
|
|
602
|
+
|
|
603
|
+
n_samples, n_features = X.shape
|
|
604
|
+
self._nobs = n_samples
|
|
605
|
+
|
|
606
|
+
if sample_weight is not None:
|
|
607
|
+
sample_weight = np.asarray(sample_weight)
|
|
608
|
+
sqrt_sw = np.sqrt(sample_weight)
|
|
609
|
+
X = X * sqrt_sw[:, np.newaxis]
|
|
610
|
+
y = y * sqrt_sw
|
|
611
|
+
|
|
612
|
+
if self.fit_intercept:
|
|
613
|
+
X_mean = np.mean(X, axis=0)
|
|
614
|
+
y_mean = np.mean(y)
|
|
615
|
+
# Memory-efficient centering: avoid creating full X_centered (n×p) matrix
|
|
616
|
+
XtX = X.T @ X - n_samples * np.outer(X_mean, X_mean)
|
|
617
|
+
Xty = X.T @ y - n_samples * X_mean * y_mean
|
|
618
|
+
else:
|
|
619
|
+
y_mean = 0.0
|
|
620
|
+
XtX = X.T @ X
|
|
621
|
+
Xty = X.T @ y
|
|
622
|
+
|
|
623
|
+
if Xty.ndim == 0:
|
|
624
|
+
Xty = Xty.reshape(1)
|
|
625
|
+
if Xty.ndim == 1:
|
|
626
|
+
Xty = Xty.reshape(-1, 1)
|
|
627
|
+
Xty_flat = Xty.flatten()
|
|
628
|
+
|
|
629
|
+
# Elastic Net parameters
|
|
630
|
+
alpha = float(self.alpha)
|
|
631
|
+
l1_ratio = float(self.l1_ratio)
|
|
632
|
+
l2_ratio = 1.0 - l1_ratio
|
|
633
|
+
|
|
634
|
+
# Lipschitz constant: L = lambda_max(XtX)/n (RSS only, L2 is handled in proximal step)
|
|
635
|
+
if self.lipschitz_L is not None:
|
|
636
|
+
L = float(self.lipschitz_L)
|
|
637
|
+
else:
|
|
638
|
+
try:
|
|
639
|
+
eig_max = np.linalg.eigvalsh(XtX)[-1]
|
|
640
|
+
L = float(eig_max / n_samples)
|
|
641
|
+
except Exception:
|
|
642
|
+
# Frobenius norm squared / n = trace(XtX) / n = sum(X_centered^2) / n
|
|
643
|
+
L = float(np.trace(XtX) / n_samples)
|
|
644
|
+
|
|
645
|
+
if L <= 0:
|
|
646
|
+
# Degenerate case: apply pure proximal operator
|
|
647
|
+
thresh = alpha * l1_ratio
|
|
648
|
+
l2_scale = 1.0 + alpha * l2_ratio
|
|
649
|
+
coef = self._soft_threshold_elastic(np.zeros(n_features), thresh, l2_scale)
|
|
650
|
+
self.n_iter_ = 0
|
|
651
|
+
else:
|
|
652
|
+
step = 1.0 / L
|
|
653
|
+
|
|
654
|
+
# Elastic Net proximal parameters
|
|
655
|
+
thresh = alpha * l1_ratio * step
|
|
656
|
+
l2_scale = 1.0 + alpha * l2_ratio * step
|
|
657
|
+
inv_l2_scale = 1.0 / l2_scale
|
|
658
|
+
inv_n_samples = 1.0 / n_samples
|
|
659
|
+
|
|
660
|
+
# FISTA variables - use warm-start if available
|
|
661
|
+
if initial_coef is not None and len(initial_coef) == n_features:
|
|
662
|
+
coef = np.asarray(initial_coef, dtype=np.float64).copy()
|
|
663
|
+
else:
|
|
664
|
+
coef = np.zeros(n_features)
|
|
665
|
+
y_k = coef.copy()
|
|
666
|
+
t_k = 1.0
|
|
667
|
+
|
|
668
|
+
# Pre-allocate buffers to reduce allocation overhead
|
|
669
|
+
coef_old = np.empty_like(coef)
|
|
670
|
+
grad = np.empty_like(coef)
|
|
671
|
+
w_tilde = np.empty_like(coef)
|
|
672
|
+
delta = np.empty_like(coef)
|
|
673
|
+
|
|
674
|
+
for iteration in range(self.max_iter):
|
|
675
|
+
# Store old coefficients (in-place copy)
|
|
676
|
+
coef_old[:] = coef
|
|
677
|
+
|
|
678
|
+
# Gradient of RSS ONLY (L2 is handled in proximal step)
|
|
679
|
+
# grad = (XtX @ y_k - Xty) / n_samples
|
|
680
|
+
np.matmul(XtX, y_k, out=grad)
|
|
681
|
+
grad -= Xty_flat
|
|
682
|
+
grad *= inv_n_samples
|
|
683
|
+
|
|
684
|
+
# Proximal gradient step with Elastic Net soft thresholding
|
|
685
|
+
# w_tilde = y_k - step * grad
|
|
686
|
+
np.subtract(y_k, step * grad, out=w_tilde)
|
|
687
|
+
|
|
688
|
+
# coef = soft_threshold(w_tilde, thresh) / l2_scale
|
|
689
|
+
# Using vectorized operations with pre-computed inv_l2_scale
|
|
690
|
+
np.abs(w_tilde, out=delta)
|
|
691
|
+
np.maximum(delta - thresh, 0, out=delta)
|
|
692
|
+
coef[:] = np.sign(w_tilde) * delta * inv_l2_scale
|
|
693
|
+
|
|
694
|
+
# Momentum update (FISTA)
|
|
695
|
+
sqrt_arg = 1.0 + 4.0 * t_k * t_k
|
|
696
|
+
t_new = (1.0 + np.sqrt(sqrt_arg)) * 0.5
|
|
697
|
+
beta = (t_k - 1.0) / t_new
|
|
698
|
+
# y_k = coef + beta * (coef - coef_old)
|
|
699
|
+
np.subtract(coef, coef_old, out=y_k)
|
|
700
|
+
y_k *= beta
|
|
701
|
+
y_k += coef
|
|
702
|
+
t_k = t_new
|
|
703
|
+
|
|
704
|
+
# Convergence test - use L-infinity norm of coefficient change
|
|
705
|
+
np.abs(coef - coef_old, out=delta)
|
|
706
|
+
violation = float(np.max(delta))
|
|
707
|
+
|
|
708
|
+
if violation < self.tol:
|
|
709
|
+
self.n_iter_ = iteration + 1
|
|
710
|
+
break
|
|
711
|
+
else:
|
|
712
|
+
self.n_iter_ = self.max_iter
|
|
713
|
+
|
|
714
|
+
# Compute intercept
|
|
715
|
+
if self.fit_intercept:
|
|
716
|
+
self.intercept_ = float(y_mean - X_mean @ coef)
|
|
717
|
+
self.coef_ = coef
|
|
718
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
719
|
+
else:
|
|
720
|
+
self.intercept_ = 0.0
|
|
721
|
+
self.coef_ = coef
|
|
722
|
+
self._params = coef.copy()
|
|
723
|
+
|
|
724
|
+
self._df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
|
|
725
|
+
|
|
726
|
+
def _soft_threshold_cupy(self, x, gamma, l2_scale=None):
|
|
727
|
+
"""Soft thresholding operator for CuPy arrays."""
|
|
728
|
+
import cupy as cp
|
|
729
|
+
if l2_scale is not None:
|
|
730
|
+
return cp.sign(x) * cp.maximum(cp.abs(x) - gamma, 0) / l2_scale
|
|
731
|
+
return cp.sign(x) * cp.maximum(cp.abs(x) - gamma, 0)
|
|
732
|
+
|
|
733
|
+
def _cleanup_cuda_memory(self):
|
|
734
|
+
"""Free CuPy memory pool."""
|
|
735
|
+
if not self.gpu_memory_cleanup:
|
|
736
|
+
return
|
|
737
|
+
try:
|
|
738
|
+
import cupy as cp
|
|
739
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
740
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
741
|
+
except Exception:
|
|
742
|
+
pass
|
|
743
|
+
|
|
744
|
+
def _fit_gpu(self, X, y, sample_weight=None):
|
|
745
|
+
"""
|
|
746
|
+
Fit using GPU (CuPy) with optimized FISTA solver and fused kernels.
|
|
747
|
+
"""
|
|
748
|
+
import cupy as cp
|
|
749
|
+
|
|
750
|
+
if self.solver not in ("fista",):
|
|
751
|
+
raise ValueError("Elastic Net currently only supports 'fista' solver")
|
|
752
|
+
|
|
753
|
+
n_samples, n_features = X.shape
|
|
754
|
+
self._nobs = n_samples
|
|
755
|
+
|
|
756
|
+
# Ensure CuPy arrays
|
|
757
|
+
X = cp.asarray(X)
|
|
758
|
+
y = cp.asarray(y)
|
|
759
|
+
|
|
760
|
+
if sample_weight is not None:
|
|
761
|
+
sample_weight = cp.asarray(sample_weight)
|
|
762
|
+
sqrt_sw = cp.sqrt(sample_weight)
|
|
763
|
+
X = X * sqrt_sw[:, cp.newaxis]
|
|
764
|
+
y = y * sqrt_sw
|
|
765
|
+
|
|
766
|
+
# Ensure vector y on GPU
|
|
767
|
+
y = y.reshape(-1)
|
|
768
|
+
|
|
769
|
+
# Center for intercept
|
|
770
|
+
if self.fit_intercept:
|
|
771
|
+
X_mean = cp.mean(X, axis=0)
|
|
772
|
+
y_mean = cp.mean(y)
|
|
773
|
+
X_centered = X - X_mean
|
|
774
|
+
y_centered = y - y_mean
|
|
775
|
+
else:
|
|
776
|
+
X_centered = X
|
|
777
|
+
y_mean = cp.array(0.0, dtype=X.dtype)
|
|
778
|
+
y_centered = y
|
|
779
|
+
|
|
780
|
+
# Use optimized implementation with fused kernels
|
|
781
|
+
coef, self.n_iter_ = _fit_elasticnet_cupy_optimized(
|
|
782
|
+
X=X_centered,
|
|
783
|
+
y=y_centered,
|
|
784
|
+
alpha=float(self.alpha),
|
|
785
|
+
l1_ratio=float(self.l1_ratio),
|
|
786
|
+
n_samples=n_samples,
|
|
787
|
+
n_features=n_features,
|
|
788
|
+
max_iter=self.max_iter,
|
|
789
|
+
tol=self.tol,
|
|
790
|
+
lipschitz_L=self.lipschitz_L,
|
|
791
|
+
stopping=self.stopping,
|
|
792
|
+
warmup=True # Enable warm-up to avoid JIT overhead
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
# Build full coefficients
|
|
796
|
+
if self.fit_intercept:
|
|
797
|
+
intercept_gpu = y_mean - X_mean @ coef
|
|
798
|
+
coef_full = cp.concatenate([intercept_gpu.reshape(1), coef])
|
|
799
|
+
else:
|
|
800
|
+
coef_full = coef
|
|
801
|
+
|
|
802
|
+
# Transfer to CPU
|
|
803
|
+
coef_full_np = coef_full.get()
|
|
804
|
+
|
|
805
|
+
if self.fit_intercept:
|
|
806
|
+
self.intercept_ = float(coef_full_np[0])
|
|
807
|
+
self.coef_ = coef_full_np[1:]
|
|
808
|
+
self._params = coef_full_np
|
|
809
|
+
else:
|
|
810
|
+
self.intercept_ = 0.0
|
|
811
|
+
self.coef_ = coef_full_np
|
|
812
|
+
self._params = coef_full_np
|
|
813
|
+
|
|
814
|
+
self._df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
|
|
815
|
+
|
|
816
|
+
# Cleanup
|
|
817
|
+
self._cleanup_cuda_memory()
|
|
818
|
+
|
|
819
|
+
def _soft_threshold_elastic_cupy(self, x, gamma, l2_scale):
|
|
820
|
+
"""Elastic Net soft thresholding for CuPy."""
|
|
821
|
+
import cupy as cp
|
|
822
|
+
return cp.sign(x) * cp.maximum(cp.abs(x) - gamma, 0) / l2_scale
|
|
823
|
+
|
|
824
|
+
def _soft_threshold_torch(self, x, gamma, l2_scale=None):
|
|
825
|
+
"""Soft thresholding operator for Torch tensors."""
|
|
826
|
+
import torch
|
|
827
|
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
|
828
|
+
if l2_scale is not None:
|
|
829
|
+
return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, zero) / l2_scale
|
|
830
|
+
return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, zero)
|
|
831
|
+
|
|
832
|
+
def _soft_threshold_elastic_torch(self, x, gamma, l2_scale):
|
|
833
|
+
"""Elastic Net soft thresholding for Torch."""
|
|
834
|
+
import torch
|
|
835
|
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
|
836
|
+
return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, zero) / l2_scale
|
|
837
|
+
|
|
838
|
+
def _cleanup_torch_memory(self):
|
|
839
|
+
"""Free Torch memory pool."""
|
|
840
|
+
if not self.gpu_memory_cleanup:
|
|
841
|
+
return
|
|
842
|
+
try:
|
|
843
|
+
import torch
|
|
844
|
+
if torch.cuda.is_available():
|
|
845
|
+
torch.cuda.empty_cache()
|
|
846
|
+
except Exception:
|
|
847
|
+
pass
|
|
848
|
+
|
|
849
|
+
def _fit_torch(self, X, y, sample_weight=None):
|
|
850
|
+
"""
|
|
851
|
+
Fit using Torch GPU with optimized FISTA solver and torch.compile().
|
|
852
|
+
"""
|
|
853
|
+
import torch
|
|
854
|
+
|
|
855
|
+
if self.solver not in ("fista",):
|
|
856
|
+
raise ValueError("Torch backend currently only supports 'fista' solver")
|
|
857
|
+
|
|
858
|
+
n_samples, n_features = X.shape
|
|
859
|
+
self._nobs = n_samples
|
|
860
|
+
|
|
861
|
+
# Ensure Torch tensors on GPU
|
|
862
|
+
if not isinstance(X, torch.Tensor):
|
|
863
|
+
X = torch.from_numpy(X).to('cuda')
|
|
864
|
+
if not isinstance(y, torch.Tensor):
|
|
865
|
+
y = torch.from_numpy(y).to('cuda')
|
|
866
|
+
if y.dtype != torch.float64:
|
|
867
|
+
y = y.to(torch.float64)
|
|
868
|
+
if X.dtype != torch.float64:
|
|
869
|
+
X = X.to(torch.float64)
|
|
870
|
+
|
|
871
|
+
if sample_weight is not None:
|
|
872
|
+
if not isinstance(sample_weight, torch.Tensor):
|
|
873
|
+
sample_weight = torch.from_numpy(sample_weight).to('cuda')
|
|
874
|
+
sqrt_sw = torch.sqrt(sample_weight)
|
|
875
|
+
X = X * sqrt_sw[:, None]
|
|
876
|
+
y = y * sqrt_sw
|
|
877
|
+
|
|
878
|
+
# Ensure vector y on GPU
|
|
879
|
+
y = y.reshape(-1)
|
|
880
|
+
|
|
881
|
+
# Center for intercept
|
|
882
|
+
if self.fit_intercept:
|
|
883
|
+
X_mean = torch.mean(X, dim=0)
|
|
884
|
+
y_mean = torch.mean(y)
|
|
885
|
+
X_centered = X - X_mean
|
|
886
|
+
y_centered = y - y_mean
|
|
887
|
+
else:
|
|
888
|
+
X_centered = X
|
|
889
|
+
y_mean = torch.tensor(0.0, dtype=X.dtype, device=X.device)
|
|
890
|
+
y_centered = y
|
|
891
|
+
|
|
892
|
+
# Use optimized implementation with torch.compile()
|
|
893
|
+
coef, self.n_iter_ = _fit_elasticnet_torch_optimized(
|
|
894
|
+
X=X_centered,
|
|
895
|
+
y=y_centered,
|
|
896
|
+
alpha=float(self.alpha),
|
|
897
|
+
l1_ratio=float(self.l1_ratio),
|
|
898
|
+
n_samples=n_samples,
|
|
899
|
+
n_features=n_features,
|
|
900
|
+
max_iter=self.max_iter,
|
|
901
|
+
tol=self.tol,
|
|
902
|
+
lipschitz_L=self.lipschitz_L,
|
|
903
|
+
stopping=self.stopping,
|
|
904
|
+
warmup=True # Enable warm-up to avoid JIT overhead
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# Build full coefficients
|
|
908
|
+
if self.fit_intercept:
|
|
909
|
+
intercept_torch = y_mean - X_mean @ coef
|
|
910
|
+
coef_full = torch.cat([intercept_torch.reshape(1), coef])
|
|
911
|
+
else:
|
|
912
|
+
coef_full = coef
|
|
913
|
+
|
|
914
|
+
# Transfer to CPU
|
|
915
|
+
coef_full_np = coef_full.cpu().numpy()
|
|
916
|
+
|
|
917
|
+
if self.fit_intercept:
|
|
918
|
+
self.intercept_ = float(coef_full_np[0])
|
|
919
|
+
self.coef_ = coef_full_np[1:]
|
|
920
|
+
self._params = coef_full_np
|
|
921
|
+
else:
|
|
922
|
+
self.intercept_ = 0.0
|
|
923
|
+
self.coef_ = coef_full_np
|
|
924
|
+
self._params = coef_full_np
|
|
925
|
+
|
|
926
|
+
self._df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
|
|
927
|
+
|
|
928
|
+
# Cleanup
|
|
929
|
+
self._cleanup_torch_memory()
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
# =============================================================================
|
|
933
|
+
# V9 thin wrapper
|
|
934
|
+
# =============================================================================
|
|
935
|
+
|
|
936
|
+
from statgpu.linear_model.penalized._penalized_linear import PenalizedLinearRegression as _PenalizedLinearRegression
|