statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1877 @@
|
|
|
1
|
+
"""Fit mixin for PenalizedGeneralizedLinearModel."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from statgpu._config import Device
|
|
9
|
+
from statgpu.backends import get_backend, _get_torch_device_str, _to_numpy, _LINALG_ERRORS
|
|
10
|
+
from statgpu.solvers._utils import _nesterov_momentum, _nesterov_update
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ._base import PenalizedGeneralizedLinearModel as _Self
|
|
14
|
+
|
|
15
|
+
# ---------------------------------------------------------------------------
|
|
16
|
+
# Solver dispatch table for solver='auto'
|
|
17
|
+
# ---------------------------------------------------------------------------
|
|
18
|
+
# Each entry is (solver, condition_fn). First match wins.
|
|
19
|
+
# condition_fn takes (loss, penalty, backend, l1_ratio, cv_mode, problem_size).
|
|
20
|
+
|
|
21
|
+
# Import shared penalty categories (single source of truth)
|
|
22
|
+
from statgpu.penalties._categories import (
|
|
23
|
+
NONCONVEX as _NONCONVEX_PENALTIES,
|
|
24
|
+
SPARSE as _SPARSE_PENALTIES,
|
|
25
|
+
)
|
|
26
|
+
_SMOOTH_PENALTIES = frozenset({"l2", "none", "null", ""})
|
|
27
|
+
|
|
28
|
+
# (solver, condition)
|
|
29
|
+
# condition = (loss, penalty, backend, l1_ratio, cv_mode, problem_size) -> bool
|
|
30
|
+
_SOLVER_DISPATCH_TABLE = [
|
|
31
|
+
# -- Priority 1: Exact closed-form solutions (highest priority) --
|
|
32
|
+
# Ridge + squared_error has an exact eigendecomposition solver.
|
|
33
|
+
("exact", lambda l, p, b, lr, cv, ps: l == "squared_error" and p == "l2"),
|
|
34
|
+
|
|
35
|
+
# -- Priority 2: Nonconvex penalties always use FISTA+LLA wrapper --
|
|
36
|
+
# SCAD/MCP/adaptive_l1 require iteratively reweighted L1 (LLA approximation).
|
|
37
|
+
("fista", lambda l, p, b, lr, cv, ps: p in _NONCONVEX_PENALTIES),
|
|
38
|
+
|
|
39
|
+
# -- Priority 3: Squared error + sparse penalties -> FISTA --
|
|
40
|
+
# Quadratic loss + L1/ElasticNet: FISTA with exact line search.
|
|
41
|
+
("fista", lambda l, p, b, lr, cv, ps: l == "squared_error" and p in _SPARSE_PENALTIES),
|
|
42
|
+
|
|
43
|
+
# -- Priority 4: GLM + GPU + sparse penalties (size-gated) --
|
|
44
|
+
# Poisson + GPU + L1: fista_bb for small/medium problems (< 2M elements).
|
|
45
|
+
("fista_bb", lambda l, p, b, lr, cv, ps: cv and l == "poisson" and b in ("cupy", "torch") and p == "l1" and (ps is None or ps < 2_000_000)),
|
|
46
|
+
# Poisson + GPU + ElasticNet: fista_bb (BB step adapts well to EN geometry).
|
|
47
|
+
("fista_bb", lambda l, p, b, lr, cv, ps: cv and l == "poisson" and b in ("cupy", "torch") and p in ("elasticnet", "en")),
|
|
48
|
+
# Poisson + CPU + sparse: FISTA (CPU backtracking is cheap).
|
|
49
|
+
("fista", lambda l, p, b, lr, cv, ps: cv and l == "poisson" and p in _SPARSE_PENALTIES),
|
|
50
|
+
|
|
51
|
+
# -- Priority 5: NB + GPU + sparse penalties --
|
|
52
|
+
# NB + GPU + L1: fista_bb (NB gradient is well-behaved for BB steps).
|
|
53
|
+
("fista_bb", lambda l, p, b, lr, cv, ps: cv and l == "negative_binomial" and b in ("cupy", "torch") and p == "l1"),
|
|
54
|
+
# NB + GPU + ElasticNet: FISTA for medium problems (200K-1M), fista_bb otherwise.
|
|
55
|
+
("fista", lambda l, p, b, lr, cv, ps: cv and l == "negative_binomial" and b in ("cupy", "torch") and p in ("elasticnet", "en") and ps is not None and 200_000 <= ps < 1_000_000),
|
|
56
|
+
("fista_bb", lambda l, p, b, lr, cv, ps: cv and l == "negative_binomial" and b in ("cupy", "torch") and p in ("elasticnet", "en")),
|
|
57
|
+
|
|
58
|
+
# -- Priority 6: Gamma/IG/Tweedie + sparse -> FISTA --
|
|
59
|
+
# These families have steep loss landscapes; FISTA with backtracking is safer.
|
|
60
|
+
("fista", lambda l, p, b, lr, cv, ps: l in ("gamma", "inverse_gaussian") and p in _SPARSE_PENALTIES),
|
|
61
|
+
("fista", lambda l, p, b, lr, cv, ps: l == "tweedie" and b in ("cupy", "torch") and p in _SPARSE_PENALTIES),
|
|
62
|
+
|
|
63
|
+
# -- Priority 7: Logistic + sparse -> FISTA --
|
|
64
|
+
# Logistic has iterate-dependent Lipschitz; FISTA with fixed global bound.
|
|
65
|
+
("fista", lambda l, p, b, lr, cv, ps: cv and l == "logistic" and p in _SPARSE_PENALTIES),
|
|
66
|
+
|
|
67
|
+
# -- Priority 8: Default sparse -> fista_bb --
|
|
68
|
+
# Catch-all for remaining sparse penalty cases.
|
|
69
|
+
("fista_bb", lambda l, p, b, lr, cv, ps: p in _SPARSE_PENALTIES),
|
|
70
|
+
|
|
71
|
+
# -- Priority 9: CV + L2: loss-specific smooth solvers --
|
|
72
|
+
# NB needs L-BFGS (non-canonical link issues with IRLS).
|
|
73
|
+
("lbfgs", lambda l, p, b, lr, cv, ps: cv and p == "l2" and l == "negative_binomial"),
|
|
74
|
+
# Poisson/Tweedie: Newton (canonical link, well-conditioned).
|
|
75
|
+
("newton", lambda l, p, b, lr, cv, ps: cv and p == "l2" and l in ("poisson", "tweedie")),
|
|
76
|
+
# Gamma/IG: L-BFGS (non-canonical link, better convergence).
|
|
77
|
+
("lbfgs", lambda l, p, b, lr, cv, ps: cv and p == "l2" and l in ("gamma", "inverse_gaussian")),
|
|
78
|
+
|
|
79
|
+
# -- Priority 10: Smooth penalties (L2/none) with loss-specific solvers --
|
|
80
|
+
("newton", lambda l, p, b, lr, cv, ps: p in _SMOOTH_PENALTIES and l in ("gamma", "tweedie", "inverse_gaussian")),
|
|
81
|
+
("irls", lambda l, p, b, lr, cv, ps: p in _SMOOTH_PENALTIES and l in ("logistic", "poisson", "negative_binomial")),
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _preferred_penalized_glm_solver(
|
|
86
|
+
loss_name,
|
|
87
|
+
penalty_name,
|
|
88
|
+
backend_name=None,
|
|
89
|
+
l1_ratio=0.5,
|
|
90
|
+
cv_mode=False,
|
|
91
|
+
problem_size=None,
|
|
92
|
+
):
|
|
93
|
+
"""Private benchmark-backed solver policy for solver='auto'.
|
|
94
|
+
|
|
95
|
+
This helper only chooses an internal solver. It must never be used to
|
|
96
|
+
override an explicitly requested solver or to change the selected device.
|
|
97
|
+
|
|
98
|
+
Dispatch is table-driven: first matching rule wins.
|
|
99
|
+
"""
|
|
100
|
+
loss_name = str(loss_name or "").lower()
|
|
101
|
+
penalty_name = str(penalty_name or "").lower()
|
|
102
|
+
backend_name = str(backend_name or "").lower()
|
|
103
|
+
if problem_size is not None:
|
|
104
|
+
problem_size = int(problem_size)
|
|
105
|
+
|
|
106
|
+
for solver, cond in _SOLVER_DISPATCH_TABLE:
|
|
107
|
+
if cond(loss_name, penalty_name, backend_name, l1_ratio, cv_mode, problem_size):
|
|
108
|
+
return solver
|
|
109
|
+
|
|
110
|
+
return "fista"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _resolve_loss_name(loss_name, loss_kwargs=None):
|
|
114
|
+
"""Resolve loss name string to loss object via the GLM loss registry."""
|
|
115
|
+
from statgpu.glm_core._base import get_glm_loss
|
|
116
|
+
loss_kwargs = loss_kwargs or {}
|
|
117
|
+
return get_glm_loss(loss_name, **loss_kwargs)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _irls_ridge_init(X, y, loss_name, alpha=0.01, max_iter=100, tol=1e-4, loss_kwargs=None):
|
|
121
|
+
"""Compute ridge-penalized GLM coefficients for adaptive_l1 init.
|
|
122
|
+
|
|
123
|
+
For squared_error uses IRLS-CD (matching R glmnet's ridge solver).
|
|
124
|
+
For GLM losses (logistic, poisson, etc.) uses FISTA with L2 penalty,
|
|
125
|
+
which has proper line search and handles extreme y values robustly.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
X : ndarray of shape (n, p)
|
|
130
|
+
Feature matrix (no intercept column).
|
|
131
|
+
y : ndarray of shape (n,)
|
|
132
|
+
Response vector.
|
|
133
|
+
loss_name : str
|
|
134
|
+
GLM loss name: 'logistic', 'poisson', 'squared_error', etc.
|
|
135
|
+
alpha : float
|
|
136
|
+
Ridge penalty strength (lambda in R glmnet).
|
|
137
|
+
max_iter : int
|
|
138
|
+
Maximum IRLS iterations.
|
|
139
|
+
tol : float
|
|
140
|
+
Convergence tolerance on coefficient change.
|
|
141
|
+
|
|
142
|
+
Returns
|
|
143
|
+
-------
|
|
144
|
+
coef : ndarray of shape (p,)
|
|
145
|
+
Ridge-penalized coefficient estimates (no intercept).
|
|
146
|
+
"""
|
|
147
|
+
if loss_name in ("squared_error", ""):
|
|
148
|
+
coef = _irls_ridge_init_cd(X, y, alpha, max_iter, tol)
|
|
149
|
+
else:
|
|
150
|
+
# For GLM losses, use FISTA with L2 penalty (robust line search)
|
|
151
|
+
# Pass arrays directly — solver handles backend detection internally
|
|
152
|
+
from statgpu.solvers import fista_solver
|
|
153
|
+
from statgpu.penalties import get_penalty
|
|
154
|
+
l2_pen = get_penalty("l2", alpha=alpha)
|
|
155
|
+
loss_obj = _resolve_loss_name(loss_name, loss_kwargs=loss_kwargs)
|
|
156
|
+
coef, _ = fista_solver(loss_obj, l2_pen, X, y, max_iter=max_iter, tol=tol)
|
|
157
|
+
# Return as numpy array (caller expects numpy for penalty.set_weights)
|
|
158
|
+
from statgpu.backends import _to_numpy
|
|
159
|
+
return np.asarray(_to_numpy(coef), dtype=np.float64)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _irls_ridge_init_cd(X, y, alpha, max_iter, tol):
|
|
163
|
+
"""Ridge regression initialization for adaptive L1 weights.
|
|
164
|
+
|
|
165
|
+
Uses closed-form solution: beta = (X'X + alpha*I)^-1 X'y
|
|
166
|
+
which is O(p^3) but fully parallelizable on GPU (single matmul + solve).
|
|
167
|
+
Much faster than sequential coordinate descent on GPU.
|
|
168
|
+
"""
|
|
169
|
+
from statgpu.backends import _resolve_backend
|
|
170
|
+
from statgpu.backends._utils import _get_xp
|
|
171
|
+
|
|
172
|
+
backend = _resolve_backend("auto", X)
|
|
173
|
+
xp = _get_xp(backend)
|
|
174
|
+
|
|
175
|
+
n, p = X.shape
|
|
176
|
+
# Normalize features
|
|
177
|
+
feat_norms = xp.sqrt(xp.sum(X ** 2, axis=0))
|
|
178
|
+
if backend == "torch":
|
|
179
|
+
import torch
|
|
180
|
+
feat_norms = xp.maximum(feat_norms, torch.tensor(1e-20, dtype=feat_norms.dtype, device=feat_norms.device))
|
|
181
|
+
scale = torch.tensor(float(n) ** 0.5, dtype=X.dtype, device=X.device) / feat_norms
|
|
182
|
+
else:
|
|
183
|
+
feat_norms = xp.maximum(feat_norms, 1e-20)
|
|
184
|
+
scale = xp.asarray(float(n) ** 0.5, dtype=X.dtype) / feat_norms
|
|
185
|
+
X_work = X * scale
|
|
186
|
+
|
|
187
|
+
# Closed-form Ridge: (X'X + alpha*I)^-1 X'y
|
|
188
|
+
XtX = X_work.T @ X_work / n
|
|
189
|
+
Xty = X_work.T @ y / n
|
|
190
|
+
|
|
191
|
+
if backend == "torch":
|
|
192
|
+
import torch
|
|
193
|
+
I_mat = torch.eye(p, dtype=X.dtype, device=X.device)
|
|
194
|
+
beta = torch.linalg.solve(XtX + alpha * I_mat, Xty)
|
|
195
|
+
elif backend == "cupy":
|
|
196
|
+
import cupy as cp
|
|
197
|
+
I_mat = cp.eye(p, dtype=X.dtype)
|
|
198
|
+
beta = cp.linalg.solve(XtX + alpha * I_mat, Xty)
|
|
199
|
+
else:
|
|
200
|
+
I_mat = np.eye(p, dtype=X.dtype)
|
|
201
|
+
beta = np.linalg.solve(XtX + alpha * I_mat, Xty)
|
|
202
|
+
|
|
203
|
+
return beta * scale
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class _PenalizedFitMixin:
|
|
207
|
+
|
|
208
|
+
def fit(self, X=None, y=None, sample_weight=None, formula=None, data=None):
|
|
209
|
+
"""
|
|
210
|
+
Fit penalized GLM model.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
X : array-like of shape (n_samples, n_features), optional
|
|
215
|
+
Training data. Required when ``formula`` is None.
|
|
216
|
+
y : array-like of shape (n_samples,), optional
|
|
217
|
+
Target values. Required when ``formula`` is None.
|
|
218
|
+
sample_weight : array-like of shape (n_samples,), optional
|
|
219
|
+
Sample weights.
|
|
220
|
+
formula : str, optional
|
|
221
|
+
R-style formula string, e.g. ``"y ~ x1 + C(group)"``.
|
|
222
|
+
data : pandas.DataFrame, optional
|
|
223
|
+
Data used to evaluate ``formula``.
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
self : PenalizedLinearRegression
|
|
228
|
+
Fitted estimator.
|
|
229
|
+
"""
|
|
230
|
+
if formula is not None:
|
|
231
|
+
if data is None:
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"formula was provided but data is None. "
|
|
234
|
+
"Pass data=your_dataframe when using formula."
|
|
235
|
+
)
|
|
236
|
+
from statgpu.core.formula import FormulaParser
|
|
237
|
+
|
|
238
|
+
parser = FormulaParser(formula)
|
|
239
|
+
y, X, design_info = parser.eval(data)
|
|
240
|
+
formula_column_names = list(design_info.column_names)
|
|
241
|
+
self._design_info = design_info
|
|
242
|
+
self._formula_has_intercept = "Intercept" in formula_column_names
|
|
243
|
+
self._feature_names = [name for name in formula_column_names if name != "Intercept"]
|
|
244
|
+
if self._formula_has_intercept:
|
|
245
|
+
X = np.delete(X, formula_column_names.index("Intercept"), axis=1)
|
|
246
|
+
self._use_intercept = True
|
|
247
|
+
else:
|
|
248
|
+
# Formula syntax owns intercept semantics, matching statsmodels/R.
|
|
249
|
+
self._use_intercept = False
|
|
250
|
+
else:
|
|
251
|
+
if X is None or y is None:
|
|
252
|
+
raise ValueError("Either formula+data or X+y must be provided.")
|
|
253
|
+
self._feature_names = None
|
|
254
|
+
self._design_info = None
|
|
255
|
+
self._formula_has_intercept = None
|
|
256
|
+
self._use_intercept = None
|
|
257
|
+
|
|
258
|
+
# Record number of features for sklearn compatibility
|
|
259
|
+
if X is not None:
|
|
260
|
+
X_arr = np.asarray(X) if not hasattr(X, 'shape') else X
|
|
261
|
+
self.n_features_in_ = X_arr.shape[1] if X_arr.ndim >= 2 else 1
|
|
262
|
+
|
|
263
|
+
self._penalty = self._resolve_penalty()
|
|
264
|
+
self._validate_solver_penalty()
|
|
265
|
+
self._loss = self._resolve_loss()
|
|
266
|
+
self._validate_inference_request()
|
|
267
|
+
self._inference_precomputed = False
|
|
268
|
+
self._precomputed_gaussian_state = None
|
|
269
|
+
self._clear_inference_state()
|
|
270
|
+
|
|
271
|
+
# Resolve the actual backend before auto-selecting the solver. This
|
|
272
|
+
# keeps solver="auto" device-aware: CPU can use IRLS for smooth GLMs,
|
|
273
|
+
# while GPU/Torch stays on accelerator-capable FISTA.
|
|
274
|
+
backend = self._get_backend(backend="auto")
|
|
275
|
+
backend_name = backend.name
|
|
276
|
+
|
|
277
|
+
# Auto-dispatch small problems to CPU only when device="auto".
|
|
278
|
+
# Explicit CUDA/TORCH device selection must never silently fall back.
|
|
279
|
+
if self.device == Device.AUTO and backend_name in ("cupy", "torch") and X is not None:
|
|
280
|
+
_n, _p = X.shape
|
|
281
|
+
if _n * _p < 200_000:
|
|
282
|
+
backend_name = "numpy"
|
|
283
|
+
|
|
284
|
+
backend_name = self._auto_backend_override(backend_name, X)
|
|
285
|
+
selected_solver = self._select_solver(
|
|
286
|
+
self._loss, backend_name=backend_name, X=X
|
|
287
|
+
)
|
|
288
|
+
self._selected_solver = selected_solver
|
|
289
|
+
self._selected_backend_name = backend_name
|
|
290
|
+
|
|
291
|
+
# Handle penalties requiring initialization (e.g., Adaptive Lasso)
|
|
292
|
+
if self._penalty.requires_init:
|
|
293
|
+
init_coef = self._fit_initial(X, y, backend_name=backend_name)
|
|
294
|
+
self._penalty.set_weights(init_coef)
|
|
295
|
+
|
|
296
|
+
# Non-convex penalties (SCAD, MCP) for squared_error: use IRLS-CD
|
|
297
|
+
# directly with a 100-step continuation path from lambda_max.
|
|
298
|
+
# This matches R ncvreg's algorithm for Gaussian regression.
|
|
299
|
+
# GLM+SCAD/MCP must NOT use IRLS-CD -- it cycles due to non-convex
|
|
300
|
+
# penalty causing features to flip on/off between IRLS iterations.
|
|
301
|
+
# GLM+SCAD/MCP goes through _fit_lla -> FISTA with proximal operator.
|
|
302
|
+
_pen_name = str(getattr(self._penalty, 'name', '')).lower()
|
|
303
|
+
_loss_name = str(getattr(self._loss, 'name', '') if hasattr(self, '_loss') else self.loss).lower()
|
|
304
|
+
_is_glm_loss = _loss_name not in ("squared_error", "")
|
|
305
|
+
if _pen_name in ("scad", "mcp") and self._lla_enabled and not _is_glm_loss:
|
|
306
|
+
# Use fused FISTA+LLA path for all backends (CPU/GPU).
|
|
307
|
+
from statgpu.solvers import fista_lla_path
|
|
308
|
+
self._nobs = X.shape[0]
|
|
309
|
+
X_arr = self._to_array(X, backend=backend_name)
|
|
310
|
+
y_arr = self._to_array(y, backend=backend_name)
|
|
311
|
+
# Lambda_max computation uses numpy (one-time cost, negligible).
|
|
312
|
+
_X_np = _to_numpy(X_arr)
|
|
313
|
+
_y_np = _to_numpy(y_arr)
|
|
314
|
+
_n = _X_np.shape[0]
|
|
315
|
+
_col_norms = np.sqrt(np.sum(_X_np ** 2, axis=0))
|
|
316
|
+
_col_norms = np.maximum(_col_norms, 1e-20)
|
|
317
|
+
_X_s = _X_np * (np.sqrt(_n) / _col_norms)
|
|
318
|
+
_y_c = _y_np - np.mean(_y_np)
|
|
319
|
+
_lam_max = float(np.max(np.abs(_X_s.T @ _y_c / _n)))
|
|
320
|
+
_target_alpha = float(self._penalty.alpha)
|
|
321
|
+
_n_cont = 20
|
|
322
|
+
_alpha_start = max(_lam_max, _target_alpha * 1.1)
|
|
323
|
+
if (not np.isfinite(_alpha_start)) or _alpha_start <= 0.0 or _target_alpha <= 0.0:
|
|
324
|
+
_alpha_path = np.linspace(max(_lam_max, 0.0), _target_alpha, _n_cont)
|
|
325
|
+
else:
|
|
326
|
+
_alpha_path = np.geomspace(_alpha_start, _target_alpha, _n_cont)
|
|
327
|
+
_max_lla_per_step = max(6, getattr(self, '_max_lla_iters', 50) // _n_cont)
|
|
328
|
+
_saved_mi = self.max_iter
|
|
329
|
+
_mi_path = []
|
|
330
|
+
for _i in range(_n_cont):
|
|
331
|
+
_is_last = (_i == _n_cont - 1)
|
|
332
|
+
_mi_path.append(_saved_mi if _is_last else max(100, _saved_mi // 10))
|
|
333
|
+
coef_np, intercept, n_iter = fista_lla_path(
|
|
334
|
+
self._loss, self._penalty,
|
|
335
|
+
X_arr, y_arr,
|
|
336
|
+
alpha_path=_alpha_path,
|
|
337
|
+
max_lla_per_step=_max_lla_per_step,
|
|
338
|
+
lla_tol=getattr(self, '_lla_tol', 1e-6),
|
|
339
|
+
max_iter=_mi_path,
|
|
340
|
+
tol=self.tol,
|
|
341
|
+
fit_intercept=self._effective_intercept,
|
|
342
|
+
sample_weight=sample_weight,
|
|
343
|
+
)
|
|
344
|
+
self.coef_ = coef_np
|
|
345
|
+
self.intercept_ = intercept
|
|
346
|
+
self.n_iter_ = n_iter
|
|
347
|
+
if self._effective_intercept:
|
|
348
|
+
self._params = np.concatenate([[self.intercept_], np.asarray(self.coef_)])
|
|
349
|
+
else:
|
|
350
|
+
self._params = np.asarray(self.coef_).copy()
|
|
351
|
+
self._df_resid = X.shape[0] - (X.shape[1] + (1 if self._effective_intercept else 0))
|
|
352
|
+
self._compute_post_fit_gaussian_inference(X, y, sample_weight=sample_weight)
|
|
353
|
+
if backend_name == "cupy":
|
|
354
|
+
self._cleanup_cuda_memory()
|
|
355
|
+
elif backend_name == "torch":
|
|
356
|
+
self._cleanup_torch_memory()
|
|
357
|
+
self._fitted = True
|
|
358
|
+
return self
|
|
359
|
+
|
|
360
|
+
X_arr = self._to_array(X, backend=backend_name)
|
|
361
|
+
y_arr = self._to_array(y, backend=backend_name)
|
|
362
|
+
|
|
363
|
+
if backend_name == "torch":
|
|
364
|
+
self._fit_torch(X_arr, y_arr, sample_weight)
|
|
365
|
+
elif backend_name == "cupy":
|
|
366
|
+
self._fit_gpu(X_arr, y_arr, sample_weight)
|
|
367
|
+
else:
|
|
368
|
+
self._fit_cpu(X_arr, y_arr, sample_weight)
|
|
369
|
+
|
|
370
|
+
self._compute_post_fit_gaussian_inference(X, y, sample_weight=sample_weight)
|
|
371
|
+
self._fitted = True
|
|
372
|
+
# Clean up CV cache unless a caller is intentionally reusing one
|
|
373
|
+
# across repeated fits, as PenalizedGLM_CV does within a fold.
|
|
374
|
+
if hasattr(self, '_cv_cache') and not getattr(self, '_preserve_cv_cache', False):
|
|
375
|
+
del self._cv_cache
|
|
376
|
+
return self
|
|
377
|
+
|
|
378
|
+
def _select_solver(self, loss, backend_name=None, X=None):
|
|
379
|
+
"""Auto-select solver based on loss, penalty, and backend."""
|
|
380
|
+
if self.solver != "auto":
|
|
381
|
+
return self.solver
|
|
382
|
+
return _preferred_penalized_glm_solver(
|
|
383
|
+
getattr(loss, "name", self.loss),
|
|
384
|
+
getattr(self._penalty, "name", self.penalty),
|
|
385
|
+
backend_name=backend_name,
|
|
386
|
+
l1_ratio=getattr(self._penalty, "l1_ratio", self.l1_ratio),
|
|
387
|
+
cv_mode=False,
|
|
388
|
+
problem_size=None if X is None else int(X.shape[0]) * int(X.shape[1]),
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
@staticmethod
|
|
392
|
+
def _torch_cuda_available():
|
|
393
|
+
try:
|
|
394
|
+
import torch
|
|
395
|
+
return torch.cuda.is_available()
|
|
396
|
+
except Exception:
|
|
397
|
+
return False
|
|
398
|
+
|
|
399
|
+
@staticmethod
|
|
400
|
+
def _cupy_available():
|
|
401
|
+
try:
|
|
402
|
+
import cupy as cp
|
|
403
|
+
return cp.cuda.runtime.getDeviceCount() > 0
|
|
404
|
+
except Exception:
|
|
405
|
+
return False
|
|
406
|
+
|
|
407
|
+
# Backend override rules for device='auto' at large scale (problem_size >= 1M).
|
|
408
|
+
# Each entry: (loss, penalties, target_backend, reason_template)
|
|
409
|
+
# First match wins. target_backend="numpy" means always CPU;
|
|
410
|
+
# target_backend="torch" means prefer torch over cupy.
|
|
411
|
+
_AUTO_BACKEND_CPU_OVERRIDES = [
|
|
412
|
+
("squared_error", ("l2",), "numpy", "large squared-error exact solve is faster on CPU"),
|
|
413
|
+
("squared_error", ("l1", "elasticnet", "en"), "numpy", "large squared-error l1/elasticnet is faster on CPU"),
|
|
414
|
+
("negative_binomial", ("l1", "elasticnet", "en"), "numpy", "large negative-binomial l1/elasticnet is faster on CPU"),
|
|
415
|
+
("logistic", ("l1", "elasticnet", "en"), "numpy", "large logistic {penalty} is faster on CPU"),
|
|
416
|
+
("gamma", ("l2",), "numpy", "large gamma l2/newton is faster on CPU"),
|
|
417
|
+
("tweedie", ("l1", "elasticnet", "en"), "numpy", "large tweedie {penalty} is faster on CPU"),
|
|
418
|
+
]
|
|
419
|
+
_AUTO_BACKEND_CUPY_OVERRIDES = [
|
|
420
|
+
("negative_binomial", ("l2",), "torch", "large negative-binomial l2 is faster on {target} than cupy"),
|
|
421
|
+
("logistic", ("l1", "elasticnet", "en"), "torch", "large logistic {penalty} is faster on {target} than cupy"),
|
|
422
|
+
("poisson", ("l1", "elasticnet", "en"), "torch", "large poisson {penalty} is faster on {target} than cupy"),
|
|
423
|
+
]
|
|
424
|
+
|
|
425
|
+
def _auto_backend_override(self, backend_name, X):
|
|
426
|
+
"""Benchmark-backed backend routing for device='auto' only."""
|
|
427
|
+
self._auto_backend_reason = None
|
|
428
|
+
if self.device != Device.AUTO or self.solver != "auto" or X is None:
|
|
429
|
+
return backend_name
|
|
430
|
+
|
|
431
|
+
n_samples, n_features = X.shape
|
|
432
|
+
problem_size = int(n_samples) * int(n_features)
|
|
433
|
+
if problem_size < 1_000_000:
|
|
434
|
+
return backend_name
|
|
435
|
+
|
|
436
|
+
loss_name = str(getattr(self._loss, "name", self.loss)).lower()
|
|
437
|
+
penalty_name = str(getattr(self._penalty, "name", self.penalty)).lower()
|
|
438
|
+
torch_ok = self._torch_cuda_available()
|
|
439
|
+
|
|
440
|
+
# CPU overrides: always route to numpy
|
|
441
|
+
for loss, penalties, target, reason_tpl in self._AUTO_BACKEND_CPU_OVERRIDES:
|
|
442
|
+
if loss_name == loss and penalty_name in penalties:
|
|
443
|
+
self._auto_backend_reason = reason_tpl.format(penalty=penalty_name)
|
|
444
|
+
return target
|
|
445
|
+
|
|
446
|
+
# CuPy->Torch overrides: prefer torch when available, else CPU
|
|
447
|
+
if backend_name == "cupy":
|
|
448
|
+
for loss, penalties, target, reason_tpl in self._AUTO_BACKEND_CUPY_OVERRIDES:
|
|
449
|
+
if loss_name == loss and penalty_name in penalties:
|
|
450
|
+
if torch_ok:
|
|
451
|
+
self._auto_backend_reason = reason_tpl.format(
|
|
452
|
+
penalty=penalty_name, target="torch")
|
|
453
|
+
return "torch"
|
|
454
|
+
self._auto_backend_reason = reason_tpl.format(
|
|
455
|
+
penalty=penalty_name, target="CPU")
|
|
456
|
+
return "numpy"
|
|
457
|
+
|
|
458
|
+
return backend_name
|
|
459
|
+
|
|
460
|
+
def _fit_initial(self, X, y, backend_name="numpy"):
|
|
461
|
+
"""Fit initial model for penalties requiring initialization.
|
|
462
|
+
|
|
463
|
+
Parameters
|
|
464
|
+
----------
|
|
465
|
+
X : array
|
|
466
|
+
Design matrix.
|
|
467
|
+
y : array
|
|
468
|
+
Target vector.
|
|
469
|
+
backend_name : str
|
|
470
|
+
Backend to use ('numpy', 'torch', 'cupy'). Default 'numpy'.
|
|
471
|
+
|
|
472
|
+
Uses OLS when n_samples > n_features (well-determined, unbiased),
|
|
473
|
+
and Ridge otherwise (works for any p, required when p > n).
|
|
474
|
+
|
|
475
|
+
The ``init_method`` on the penalty controls which path is taken:
|
|
476
|
+
- 'auto': OLS if n > p, Ridge otherwise
|
|
477
|
+
- 'ols': forced OLS (raises if p > n)
|
|
478
|
+
- 'ridge': forced Ridge (always works)
|
|
479
|
+
|
|
480
|
+
OLS is only safe for squared_error (Gaussian) data. For GLM losses
|
|
481
|
+
(Poisson, logistic, etc.) OLS can produce extreme coefficients whose
|
|
482
|
+
Lipschitz constant is enormous, causing the inner FISTA solver to
|
|
483
|
+
take zero-length steps and exit immediately without moving.
|
|
484
|
+
|
|
485
|
+
For GLM losses we use sparse L1 initialization only for non-convex
|
|
486
|
+
penalties (SCAD, MCP) that will enter the LLA outer loop -- a sparse
|
|
487
|
+
seed gives LLA differentiated weights and drives genuine sparsity.
|
|
488
|
+
Convex penalties with ``requires_init=True`` (adaptive_l1) need a
|
|
489
|
+
dense seed because their weights are 1/|coef| -- zero entries from
|
|
490
|
+
L1 init become permanently frozen."""
|
|
491
|
+
n_samples, n_features = X.shape
|
|
492
|
+
init_method = getattr(self._penalty, "init_method", "auto")
|
|
493
|
+
_is_glm = getattr(self, 'loss', 'squared_error') != "squared_error"
|
|
494
|
+
_is_nonconvex = not getattr(self._penalty, "is_convex", True)
|
|
495
|
+
|
|
496
|
+
if not _is_glm and not self._penalty.requires_init and (
|
|
497
|
+
init_method == "ols" or (init_method == "auto" and n_samples > n_features)
|
|
498
|
+
):
|
|
499
|
+
ols_coef, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
|
|
500
|
+
return ols_coef
|
|
501
|
+
|
|
502
|
+
if _is_glm and _is_nonconvex:
|
|
503
|
+
# Dense l2-penalized GLM init for non-convex penalties (SCAD, MCP).
|
|
504
|
+
# With the corrected lla_weights (= P'(|coef|), not P'(|coef|)/|coef|),
|
|
505
|
+
# a dense starting point lets the LLA continuation path push small
|
|
506
|
+
# coefficients through the transition region where SCAD and MCP
|
|
507
|
+
# differ, matching the path-based strategy used by R's ncvreg.
|
|
508
|
+
from statgpu.penalties import get_penalty
|
|
509
|
+
from statgpu.solvers import fista_solver
|
|
510
|
+
|
|
511
|
+
l2_pen = get_penalty("l2", alpha=0.001)
|
|
512
|
+
loss_obj = self._resolve_loss()
|
|
513
|
+
# Use matching backend for GPU data
|
|
514
|
+
if backend_name in ("torch", "cupy"):
|
|
515
|
+
backend = get_backend(backend=backend_name, device='cuda')
|
|
516
|
+
X_b = backend.asarray(X, dtype=backend.float64)
|
|
517
|
+
y_b = backend.asarray(y, dtype=backend.float64)
|
|
518
|
+
else:
|
|
519
|
+
X_b = np.asarray(_to_numpy(X), dtype=np.float64)
|
|
520
|
+
y_b = np.asarray(_to_numpy(y), dtype=np.float64)
|
|
521
|
+
init_coef, _ = fista_solver(
|
|
522
|
+
loss_obj, l2_pen, X_b, y_b,
|
|
523
|
+
max_iter=500, tol=1e-4,
|
|
524
|
+
)
|
|
525
|
+
return init_coef
|
|
526
|
+
|
|
527
|
+
if self._penalty.requires_init:
|
|
528
|
+
# adaptive_l1: weights = 1/(|init_coef|+eps)^nu, so init must
|
|
529
|
+
# produce well-scaled coefficients. Use IRLS with coordinate
|
|
530
|
+
# descent (matching R glmnet's ridge solver) instead of FISTA,
|
|
531
|
+
# which converges more tightly and gives larger coefficients
|
|
532
|
+
# -> smaller weights -> too many features surviving.
|
|
533
|
+
loss_name = getattr(self, 'loss', 'squared_error')
|
|
534
|
+
# Use matching backend for GPU data
|
|
535
|
+
if backend_name in ("torch", "cupy"):
|
|
536
|
+
backend = get_backend(backend=backend_name, device='cuda')
|
|
537
|
+
X_b = backend.asarray(X, dtype=backend.float64)
|
|
538
|
+
y_b = backend.asarray(y, dtype=backend.float64)
|
|
539
|
+
else:
|
|
540
|
+
X_b = np.asarray(_to_numpy(X), dtype=np.float64)
|
|
541
|
+
y_b = np.asarray(_to_numpy(y), dtype=np.float64)
|
|
542
|
+
init_coef = _irls_ridge_init(
|
|
543
|
+
X_b, y_b,
|
|
544
|
+
loss_name=loss_name,
|
|
545
|
+
alpha=0.01,
|
|
546
|
+
max_iter=100,
|
|
547
|
+
tol=1e-4,
|
|
548
|
+
loss_kwargs=getattr(self, "loss_kwargs", None),
|
|
549
|
+
)
|
|
550
|
+
return init_coef
|
|
551
|
+
|
|
552
|
+
from statgpu.linear_model.wrappers._ridge import Ridge
|
|
553
|
+
|
|
554
|
+
init_model = Ridge(
|
|
555
|
+
alpha=0.1,
|
|
556
|
+
fit_intercept=self._effective_intercept,
|
|
557
|
+
device=self.device,
|
|
558
|
+
)
|
|
559
|
+
init_model.fit(X, y)
|
|
560
|
+
return init_model.coef_
|
|
561
|
+
|
|
562
|
+
def _fit_cpu(self, X, y, sample_weight=None):
|
|
563
|
+
"""Fit using CPU (FISTA or coordinate descent)."""
|
|
564
|
+
X = np.asarray(X)
|
|
565
|
+
y = np.asarray(y)
|
|
566
|
+
|
|
567
|
+
n_samples, n_features = X.shape
|
|
568
|
+
self._nobs = n_samples
|
|
569
|
+
|
|
570
|
+
# Route to loss-aware solver for non-squared_error loss
|
|
571
|
+
solver_name = self._selected_solver or self._select_solver(
|
|
572
|
+
self._loss, backend_name="numpy"
|
|
573
|
+
)
|
|
574
|
+
if self.loss != "squared_error" or solver_name in ("irls", "newton", "lbfgs", "admm"):
|
|
575
|
+
if solver_name == "irls":
|
|
576
|
+
self._fit_irls_backend(X, y, sample_weight, "numpy")
|
|
577
|
+
else:
|
|
578
|
+
self._fit_loss_backend(X, y, sample_weight, solver_name, "numpy")
|
|
579
|
+
return
|
|
580
|
+
|
|
581
|
+
# Route squared_error + SCAD/MCP/adaptive_l1/group_lasso/elasticnet
|
|
582
|
+
# through _fit_loss_backend so CPU and GPU paths produce identical results.
|
|
583
|
+
_cd_penalties_for_sqerr = ("scad", "mcp", "adaptive_l1", "adaptive_lasso", "group_lasso")
|
|
584
|
+
if getattr(self._penalty, 'name', '') in _cd_penalties_for_sqerr:
|
|
585
|
+
self._fit_loss_backend(X, y, sample_weight, solver_name, "numpy")
|
|
586
|
+
return
|
|
587
|
+
|
|
588
|
+
# Original squared-error path (backward compatible)
|
|
589
|
+
|
|
590
|
+
if sample_weight is not None:
|
|
591
|
+
sample_weight = np.asarray(sample_weight)
|
|
592
|
+
sqrt_sw = np.sqrt(sample_weight)
|
|
593
|
+
X = X * sqrt_sw[:, np.newaxis]
|
|
594
|
+
y = y * sqrt_sw
|
|
595
|
+
|
|
596
|
+
pen = self._penalty
|
|
597
|
+
|
|
598
|
+
if self._effective_intercept:
|
|
599
|
+
X_mean = np.mean(X, axis=0)
|
|
600
|
+
y_mean = np.mean(y)
|
|
601
|
+
X_centered = X - X_mean
|
|
602
|
+
y_centered = y - y_mean
|
|
603
|
+
else:
|
|
604
|
+
X_centered = X
|
|
605
|
+
y_mean = 0.0
|
|
606
|
+
y_centered = y
|
|
607
|
+
|
|
608
|
+
if y_centered.ndim == 1:
|
|
609
|
+
y_centered = y_centered.reshape(-1, 1)
|
|
610
|
+
|
|
611
|
+
# Precompute for gradient (use CV cache if available)
|
|
612
|
+
_cv = getattr(self, '_cv_cache', None)
|
|
613
|
+
if _cv is not None and 'XtX' in _cv:
|
|
614
|
+
XtX = _cv['XtX']
|
|
615
|
+
Xty = _cv['Xty']
|
|
616
|
+
else:
|
|
617
|
+
XtX = X_centered.T @ X_centered
|
|
618
|
+
Xty = X_centered.T @ y_centered.flatten()
|
|
619
|
+
|
|
620
|
+
pen = self._penalty
|
|
621
|
+
if solver_name == "exact":
|
|
622
|
+
if pen.name != "l2":
|
|
623
|
+
raise ValueError("solver='exact' is only supported for L2/Ridge penalty.")
|
|
624
|
+
self.coef_ = self._solve_exact_numpy(XtX, Xty, n_samples)
|
|
625
|
+
self.n_iter_ = 1
|
|
626
|
+
if self._effective_intercept:
|
|
627
|
+
self.intercept_ = float(y_mean - X_mean @ self.coef_)
|
|
628
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
629
|
+
else:
|
|
630
|
+
self.intercept_ = 0.0
|
|
631
|
+
self._params = self.coef_.copy()
|
|
632
|
+
self._df_resid = n_samples - (n_features + (1 if self._effective_intercept else 0))
|
|
633
|
+
return
|
|
634
|
+
|
|
635
|
+
# Lipschitz constant: L = lambda_max(XtX) / n
|
|
636
|
+
if self.lipschitz_L is not None:
|
|
637
|
+
L = float(self.lipschitz_L)
|
|
638
|
+
else:
|
|
639
|
+
from statgpu.backends._array_ops import _max_eigval_power
|
|
640
|
+
L = _max_eigval_power(XtX) / n_samples
|
|
641
|
+
|
|
642
|
+
if L <= 0:
|
|
643
|
+
self.coef_ = np.zeros(n_features)
|
|
644
|
+
self.n_iter_ = 0
|
|
645
|
+
else:
|
|
646
|
+
step = 1.0 / L
|
|
647
|
+
|
|
648
|
+
_cd_penalties = ("adaptive_l1", "adaptive_lasso", "scad", "mcp", "group_lasso")
|
|
649
|
+
if solver_name in ("fista_bb", "fista") and pen.name not in _cd_penalties:
|
|
650
|
+
# FISTA with XtX precomputation.
|
|
651
|
+
# BB step (fista_bb) provides no benefit for quadratic losses
|
|
652
|
+
# (BB1=BB2=1/R_H(dw)), so both use the fixed Lipschitz step.
|
|
653
|
+
if hasattr(self, '_init_coef') and self._init_coef is not None:
|
|
654
|
+
coef = np.asarray(self._init_coef, dtype=np.float64).copy()
|
|
655
|
+
else:
|
|
656
|
+
coef = np.zeros(n_features)
|
|
657
|
+
y_k = coef.copy()
|
|
658
|
+
t_k = 1.0
|
|
659
|
+
|
|
660
|
+
for iteration in range(self.max_iter):
|
|
661
|
+
coef_old = coef.copy()
|
|
662
|
+
|
|
663
|
+
grad_at_y = (XtX @ y_k - Xty) / n_samples
|
|
664
|
+
w_tilde = y_k - step * grad_at_y
|
|
665
|
+
coef = pen.proximal(w_tilde, step, backend="numpy")
|
|
666
|
+
|
|
667
|
+
# Scheduled momentum restart
|
|
668
|
+
if iteration > 0 and iteration % 50 == 0:
|
|
669
|
+
t_k = 1.0
|
|
670
|
+
|
|
671
|
+
# Nesterov momentum
|
|
672
|
+
y_k, t_k = _nesterov_update(coef, coef_old, t_k)
|
|
673
|
+
|
|
674
|
+
self.n_iter_ = iteration + 1
|
|
675
|
+
|
|
676
|
+
if np.sum(np.abs(coef - coef_old)) < self.tol:
|
|
677
|
+
break
|
|
678
|
+
|
|
679
|
+
else:
|
|
680
|
+
# Coordinate descent (for L1-type penalties)
|
|
681
|
+
X_sq_norms = np.diag(XtX)
|
|
682
|
+
if hasattr(self, '_init_coef') and self._init_coef is not None:
|
|
683
|
+
coef = np.asarray(self._init_coef, dtype=np.float64).copy()
|
|
684
|
+
else:
|
|
685
|
+
coef = np.zeros(n_features)
|
|
686
|
+
|
|
687
|
+
# Precompute per-coordinate thresholds for adaptive penalties.
|
|
688
|
+
# The penalty object stores mean-normalized weights (w = pf / mean(pf))
|
|
689
|
+
# and _norm_factor = mean(pf). The CD threshold per coordinate is
|
|
690
|
+
# alpha * w_j * n, matching R glmnet's lambda * pf_j * n / X_j'X_j
|
|
691
|
+
# after dividing by X_sq_norms[j].
|
|
692
|
+
_adaptive_thresh = None
|
|
693
|
+
if pen.name in ("adaptive_l1", "adaptive_lasso"):
|
|
694
|
+
_w = np.asarray(getattr(pen, '_weights', np.ones(n_features)), dtype=float)
|
|
695
|
+
_adaptive_thresh = self.alpha * _w * n_samples
|
|
696
|
+
|
|
697
|
+
# Precompute SCAD/MCP constants (hoisted out of inner loop)
|
|
698
|
+
_a_scad = float(getattr(pen, 'a', 3.7)) if pen.name == "scad" else 0.0
|
|
699
|
+
_gamma_mcp = float(getattr(pen, 'gamma', 3.0)) if pen.name == "mcp" else 0.0
|
|
700
|
+
|
|
701
|
+
# Precompute group info for group_lasso block CD
|
|
702
|
+
_is_group = pen.name == "group_lasso"
|
|
703
|
+
if _is_group:
|
|
704
|
+
_g_indices = getattr(pen, '_group_indices', None)
|
|
705
|
+
_sqrt_pg = getattr(pen, '_sqrt_pg', None)
|
|
706
|
+
if _g_indices is None or _sqrt_pg is None:
|
|
707
|
+
raise ValueError(
|
|
708
|
+
"group_lasso penalty must have groups set. "
|
|
709
|
+
"Pass groups=... in penalty_kwargs."
|
|
710
|
+
)
|
|
711
|
+
_n_groups = len(_g_indices)
|
|
712
|
+
# Precompute XtX blocks per group: XtX[g_idx][:, g_idx]
|
|
713
|
+
_XtX_blocks = []
|
|
714
|
+
for g_idx in _g_indices:
|
|
715
|
+
_XtX_blocks.append(XtX[np.ix_(g_idx, g_idx)])
|
|
716
|
+
|
|
717
|
+
for iteration in range(self.max_iter):
|
|
718
|
+
coef_old = coef.copy()
|
|
719
|
+
|
|
720
|
+
if _is_group:
|
|
721
|
+
# Block coordinate descent: iterate over groups
|
|
722
|
+
for g in range(_n_groups):
|
|
723
|
+
g_idx = _g_indices[g]
|
|
724
|
+
# Group partial residual:
|
|
725
|
+
# rho_g = Xty[g] - XtX[g,:] @ coef + XtX[g,g] @ coef[g]
|
|
726
|
+
rho_g = Xty[g_idx] - XtX[g_idx, :] @ coef + _XtX_blocks[g] @ coef[g_idx]
|
|
727
|
+
# Unpenalized group update: w_g = (X'X)_gg^{-1} @ rho_g
|
|
728
|
+
try:
|
|
729
|
+
w_g = np.linalg.solve(_XtX_blocks[g], rho_g)
|
|
730
|
+
except np.linalg.LinAlgError:
|
|
731
|
+
w_g = np.zeros(len(g_idx))
|
|
732
|
+
# Block soft-thresholding
|
|
733
|
+
norm_w = np.linalg.norm(w_g)
|
|
734
|
+
thresh_g = self.alpha * _sqrt_pg[g]
|
|
735
|
+
if norm_w > thresh_g:
|
|
736
|
+
coef[g_idx] = w_g * (1.0 - thresh_g / norm_w)
|
|
737
|
+
else:
|
|
738
|
+
coef[g_idx] = 0.0
|
|
739
|
+
else:
|
|
740
|
+
# Per-coordinate CD for L1-type penalties
|
|
741
|
+
for j in range(n_features):
|
|
742
|
+
rho_j = Xty[j] - np.dot(XtX[j, :], coef) + XtX[j, j] * coef[j]
|
|
743
|
+
|
|
744
|
+
if pen.name in ("adaptive_l1", "adaptive_lasso"):
|
|
745
|
+
thresh = _adaptive_thresh[j]
|
|
746
|
+
if X_sq_norms[j] > 1e-10:
|
|
747
|
+
coef[j] = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0) / X_sq_norms[j]
|
|
748
|
+
else:
|
|
749
|
+
coef[j] = 0.0
|
|
750
|
+
elif pen.name == "l1":
|
|
751
|
+
# Soft thresholding
|
|
752
|
+
thresh = self.alpha * n_samples
|
|
753
|
+
if X_sq_norms[j] > 1e-10:
|
|
754
|
+
coef[j] = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0) / X_sq_norms[j]
|
|
755
|
+
else:
|
|
756
|
+
coef[j] = 0.0
|
|
757
|
+
elif pen.name == "elasticnet":
|
|
758
|
+
# Elastic net CD matching both sklearn and R glmnet:
|
|
759
|
+
# beta_j = S(rho_j, alpha*l1_ratio*n) / (X_j'X_j + alpha*(1-l1_ratio)*n)
|
|
760
|
+
thresh = self.alpha * self.l1_ratio * n_samples
|
|
761
|
+
if X_sq_norms[j] > 1e-10:
|
|
762
|
+
st = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0)
|
|
763
|
+
coef[j] = st / (X_sq_norms[j] + self.alpha * (1 - self.l1_ratio) * n_samples)
|
|
764
|
+
else:
|
|
765
|
+
coef[j] = 0.0
|
|
766
|
+
elif pen.name == "scad":
|
|
767
|
+
# SCAD CD matching R ncvreg: threshold = alpha * n
|
|
768
|
+
# Guard: a_scad must be > 1 and != 2 to avoid div/0.
|
|
769
|
+
a_scad = max(float(_a_scad), 1.0 + 1e-6)
|
|
770
|
+
if abs(a_scad - 2.0) < 1e-6:
|
|
771
|
+
a_scad = 2.0 + 1e-6
|
|
772
|
+
if X_sq_norms[j] > 1e-10:
|
|
773
|
+
w_j = rho_j / X_sq_norms[j]
|
|
774
|
+
aw = np.abs(w_j)
|
|
775
|
+
lam = self.alpha * n_samples
|
|
776
|
+
if aw > a_scad * lam:
|
|
777
|
+
coef[j] = w_j
|
|
778
|
+
elif aw > lam:
|
|
779
|
+
coef[j] = np.sign(w_j) * ((a_scad - 1.0) * aw - a_scad * lam) / (a_scad - 2.0)
|
|
780
|
+
else:
|
|
781
|
+
coef[j] = 0.0
|
|
782
|
+
else:
|
|
783
|
+
coef[j] = 0.0
|
|
784
|
+
elif pen.name == "mcp":
|
|
785
|
+
# MCP CD matching R ncvreg: threshold = alpha * n
|
|
786
|
+
# Guard: gamma_mcp must be > 1 to avoid div/0.
|
|
787
|
+
gamma_mcp = max(float(_gamma_mcp), 1.0 + 1e-6)
|
|
788
|
+
if X_sq_norms[j] > 1e-10:
|
|
789
|
+
w_j = rho_j / X_sq_norms[j]
|
|
790
|
+
aw = np.abs(w_j)
|
|
791
|
+
lam = self.alpha * n_samples
|
|
792
|
+
if aw > gamma_mcp * lam:
|
|
793
|
+
coef[j] = w_j
|
|
794
|
+
elif aw > lam:
|
|
795
|
+
coef[j] = np.sign(w_j) * (aw - lam) / (1.0 - 1.0 / gamma_mcp)
|
|
796
|
+
else:
|
|
797
|
+
coef[j] = 0.0
|
|
798
|
+
else:
|
|
799
|
+
coef[j] = 0.0
|
|
800
|
+
else:
|
|
801
|
+
raise NotImplementedError(
|
|
802
|
+
f"Coordinate descent not implemented for "
|
|
803
|
+
f"penalty '{pen.name}'. Use solver='fista'."
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
self.n_iter_ = iteration + 1
|
|
807
|
+
|
|
808
|
+
if np.sum(np.abs(coef - coef_old)) < self.tol:
|
|
809
|
+
break
|
|
810
|
+
|
|
811
|
+
# Compute intercept and store results
|
|
812
|
+
if L > 0:
|
|
813
|
+
self.coef_ = coef
|
|
814
|
+
|
|
815
|
+
if self._effective_intercept:
|
|
816
|
+
self.intercept_ = float(y_mean - X_mean @ self.coef_)
|
|
817
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
818
|
+
else:
|
|
819
|
+
self.intercept_ = 0.0
|
|
820
|
+
self._params = self.coef_.copy()
|
|
821
|
+
|
|
822
|
+
self._df_resid = n_samples - (n_features + (1 if self._effective_intercept else 0))
|
|
823
|
+
|
|
824
|
+
def _fit_gpu(self, X, y, sample_weight=None):
|
|
825
|
+
"""Fit using GPU (CuPy) with FISTA."""
|
|
826
|
+
self._fit_gpu_backend(X, y, sample_weight, backend_name="cupy")
|
|
827
|
+
|
|
828
|
+
def _fit_torch(self, X, y, sample_weight=None):
|
|
829
|
+
"""Fit using Torch GPU with FISTA. Delegates to unified backend."""
|
|
830
|
+
self._fit_gpu_backend(X, y, sample_weight, backend_name="torch")
|
|
831
|
+
|
|
832
|
+
# ------------------------------------------------------------------
|
|
833
|
+
# Unified GPU backend (replaces _fit_gpu + _fit_torch)
|
|
834
|
+
# ------------------------------------------------------------------
|
|
835
|
+
|
|
836
|
+
@staticmethod
|
|
837
|
+
def _soft_threshold_gpu(w, thresh, xp):
|
|
838
|
+
"""Backend-agnostic soft-thresholding on GPU."""
|
|
839
|
+
if xp.__name__ == "torch":
|
|
840
|
+
import torch
|
|
841
|
+
return torch.sign(w) * torch.relu(torch.abs(w) - thresh)
|
|
842
|
+
return xp.sign(w) * xp.maximum(xp.abs(w) - thresh, 0.0)
|
|
843
|
+
|
|
844
|
+
def _fit_gpu_backend(self, X, y, sample_weight=None, backend_name="cupy"):
|
|
845
|
+
"""Unified GPU fit method for both CuPy and Torch backends.
|
|
846
|
+
|
|
847
|
+
Handles exact (L2), FISTA, and FISTA-BE solvers with inline
|
|
848
|
+
XtX precomputation and fused element-wise kernels.
|
|
849
|
+
"""
|
|
850
|
+
from statgpu.backends._utils import _get_xp, xp_asarray, xp_zeros, xp_copy, xp_ones
|
|
851
|
+
from statgpu.backends import _to_numpy
|
|
852
|
+
from statgpu.backends._array_ops import _abs_sum_dev
|
|
853
|
+
|
|
854
|
+
xp = _get_xp(backend_name)
|
|
855
|
+
is_torch = (backend_name == "torch")
|
|
856
|
+
|
|
857
|
+
solver_name = self._selected_solver or self._select_solver(
|
|
858
|
+
self._loss, backend_name=backend_name
|
|
859
|
+
)
|
|
860
|
+
_backend_label = "Torch" if is_torch else "CuPy"
|
|
861
|
+
if solver_name not in ("fista", "fista_bb", "admm", "auto", "exact", "irls", "newton", "lbfgs"):
|
|
862
|
+
raise ValueError(
|
|
863
|
+
f"{_backend_label} backend supports solver='fista', 'fista_bb', 'admm', "
|
|
864
|
+
f"'exact', 'irls', 'newton', and 'lbfgs', got '{solver_name}'."
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
n_samples, n_features = X.shape
|
|
868
|
+
self._nobs = n_samples
|
|
869
|
+
|
|
870
|
+
# --- Exact solver (closed-form Ridge) ---
|
|
871
|
+
if solver_name == "exact":
|
|
872
|
+
if self._penalty.name != "l2":
|
|
873
|
+
raise ValueError("solver='exact' is only supported for L2/Ridge penalty.")
|
|
874
|
+
X = xp_asarray(X, dtype=np.float64, xp=xp, ref_arr=X)
|
|
875
|
+
y = xp_asarray(y, dtype=np.float64, xp=xp, ref_arr=y)
|
|
876
|
+
if is_torch:
|
|
877
|
+
import torch
|
|
878
|
+
if X.dtype != torch.float64:
|
|
879
|
+
X = X.to(torch.float64)
|
|
880
|
+
if sample_weight is not None:
|
|
881
|
+
sw = xp_asarray(sample_weight, dtype=X.dtype, xp=xp, ref_arr=X)
|
|
882
|
+
sqrt_sw = xp.sqrt(sw)
|
|
883
|
+
X = X * sqrt_sw[:, None]
|
|
884
|
+
y = y * sqrt_sw
|
|
885
|
+
if self._effective_intercept:
|
|
886
|
+
X_mean = xp.mean(X, axis=0)
|
|
887
|
+
y_mean = xp.mean(y)
|
|
888
|
+
X_centered = X - X_mean
|
|
889
|
+
y_centered = y - y_mean
|
|
890
|
+
else:
|
|
891
|
+
X_centered = X
|
|
892
|
+
y_mean = xp_zeros((), X.dtype, xp, ref_arr=X) if is_torch else xp.array(0.0, dtype=X.dtype)
|
|
893
|
+
y_centered = y
|
|
894
|
+
if y_centered.ndim == 1:
|
|
895
|
+
y_centered = y_centered.reshape(-1)
|
|
896
|
+
_cv = getattr(self, '_cv_cache', None)
|
|
897
|
+
if _cv is not None and 'XtX' in _cv:
|
|
898
|
+
XtX = _cv['XtX']
|
|
899
|
+
Xty = _cv['Xty']
|
|
900
|
+
else:
|
|
901
|
+
XtX = X_centered.T @ X_centered
|
|
902
|
+
Xty = X_centered.T @ y_centered
|
|
903
|
+
|
|
904
|
+
# Dispatch to backend-specific exact solver
|
|
905
|
+
solve_fn = getattr(self, f'_solve_exact_{"torch" if is_torch else "cupy"}')
|
|
906
|
+
coef = solve_fn(XtX, Xty, n_samples)
|
|
907
|
+
self.n_iter_ = 1
|
|
908
|
+
if self.compute_inference:
|
|
909
|
+
infer_fn = getattr(self, f'_precompute_exact_l2_inference_{"torch" if is_torch else "cupy"}')
|
|
910
|
+
if self._effective_intercept:
|
|
911
|
+
intercept_gpu = (y_mean.reshape(1) - X_mean.reshape(1, -1) @ coef.reshape(-1, 1)).reshape(-1)
|
|
912
|
+
coef_full_gpu = xp.concatenate([intercept_gpu, coef.reshape(-1)])
|
|
913
|
+
infer_fn(X, y, XtX, X_mean, coef_full_gpu.reshape(-1), n_samples)
|
|
914
|
+
else:
|
|
915
|
+
infer_fn(X, y, XtX, None, coef.reshape(-1), n_samples)
|
|
916
|
+
coef_np = _to_numpy(coef)
|
|
917
|
+
if self._effective_intercept:
|
|
918
|
+
self.intercept_ = float(_to_numpy(y_mean) - _to_numpy(X_mean) @ coef_np)
|
|
919
|
+
self.coef_ = coef_np
|
|
920
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
921
|
+
else:
|
|
922
|
+
self.intercept_ = 0.0
|
|
923
|
+
self.coef_ = coef_np
|
|
924
|
+
self._params = coef_np.copy()
|
|
925
|
+
self._df_resid = n_samples - (n_features + (1 if self._effective_intercept else 0))
|
|
926
|
+
if is_torch:
|
|
927
|
+
self._cleanup_torch_memory()
|
|
928
|
+
else:
|
|
929
|
+
self._cleanup_cuda_memory()
|
|
930
|
+
return
|
|
931
|
+
|
|
932
|
+
# Route IRLS/newton/lbfgs through their dedicated backends.
|
|
933
|
+
if solver_name in ("irls", "newton", "lbfgs"):
|
|
934
|
+
if solver_name == "irls":
|
|
935
|
+
self._fit_irls_backend(X, y, sample_weight, backend_name)
|
|
936
|
+
else:
|
|
937
|
+
self._fit_loss_backend(X, y, sample_weight, solver_name, backend_name)
|
|
938
|
+
return
|
|
939
|
+
|
|
940
|
+
# Route non-L1 and non-squared-error through the generic loss backend.
|
|
941
|
+
if self.loss != "squared_error" or solver_name == "admm" or self._penalty.name not in ("l1", "elasticnet", "en"):
|
|
942
|
+
self._fit_loss_backend(X, y, sample_weight, solver_name, backend_name)
|
|
943
|
+
return
|
|
944
|
+
|
|
945
|
+
# --- Inline FISTA fast-path for L1 + squared_error ---
|
|
946
|
+
X = xp_asarray(X, dtype=np.float64, xp=xp, ref_arr=X)
|
|
947
|
+
y = xp_asarray(y, dtype=np.float64, xp=xp, ref_arr=y)
|
|
948
|
+
if is_torch:
|
|
949
|
+
import torch
|
|
950
|
+
if X.dtype != torch.float64:
|
|
951
|
+
X = X.to(torch.float64)
|
|
952
|
+
|
|
953
|
+
if sample_weight is not None:
|
|
954
|
+
sample_weight = xp_asarray(sample_weight, dtype=X.dtype, xp=xp, ref_arr=X)
|
|
955
|
+
sqrt_sw = xp.sqrt(sample_weight)
|
|
956
|
+
X = X * sqrt_sw[:, None]
|
|
957
|
+
y = y * sqrt_sw
|
|
958
|
+
|
|
959
|
+
if self._effective_intercept:
|
|
960
|
+
X_mean = xp.mean(X, axis=0)
|
|
961
|
+
y_mean = xp.mean(y)
|
|
962
|
+
X_centered = X - X_mean
|
|
963
|
+
y_centered = y - y_mean
|
|
964
|
+
else:
|
|
965
|
+
X_centered = X
|
|
966
|
+
y_mean = xp_zeros((), X.dtype, xp, ref_arr=X) if is_torch else xp.array(0.0, dtype=X.dtype)
|
|
967
|
+
y_centered = y
|
|
968
|
+
|
|
969
|
+
if y_centered.ndim == 1:
|
|
970
|
+
y_centered = y_centered.reshape(-1)
|
|
971
|
+
|
|
972
|
+
_cv = getattr(self, '_cv_cache', None)
|
|
973
|
+
if _cv is not None and 'XtX' in _cv:
|
|
974
|
+
XtX = _cv['XtX']
|
|
975
|
+
Xty = _cv['Xty']
|
|
976
|
+
else:
|
|
977
|
+
XtX = X_centered.T @ X_centered
|
|
978
|
+
Xty = X_centered.T @ y_centered
|
|
979
|
+
|
|
980
|
+
# Lipschitz constant: L = lambda_max(XtX) / n
|
|
981
|
+
if self.lipschitz_L is not None:
|
|
982
|
+
L = float(self.lipschitz_L)
|
|
983
|
+
else:
|
|
984
|
+
if n_features < 1000:
|
|
985
|
+
L = float(xp.linalg.eigvalsh(XtX)[-1]) / n_samples
|
|
986
|
+
else:
|
|
987
|
+
v = xp_ones(n_features, X.dtype, xp, ref_arr=X)
|
|
988
|
+
v = v / xp.linalg.norm(v)
|
|
989
|
+
for _ in range(50):
|
|
990
|
+
v_new = XtX @ v
|
|
991
|
+
v_norm = xp.linalg.norm(v_new)
|
|
992
|
+
if v_norm < 1e-15:
|
|
993
|
+
break
|
|
994
|
+
v = v_new / v_norm
|
|
995
|
+
L = float(_to_numpy(v @ (XtX @ v))) / n_samples
|
|
996
|
+
|
|
997
|
+
if L <= 0:
|
|
998
|
+
coef = xp_zeros(n_features, X.dtype, xp, ref_arr=X)
|
|
999
|
+
self.n_iter_ = 0
|
|
1000
|
+
elif solver_name in ("fista_bb", "fista"):
|
|
1001
|
+
step = 1.0 / L
|
|
1002
|
+
step_over_n = step / n_samples
|
|
1003
|
+
step_over_n_Xty = step_over_n * Xty
|
|
1004
|
+
if self._penalty.name in ("elasticnet", "en"):
|
|
1005
|
+
thresh = self.alpha * self._penalty.l1_ratio * step
|
|
1006
|
+
l2_scale = 1.0 + self.alpha * (1.0 - self._penalty.l1_ratio) * step
|
|
1007
|
+
else:
|
|
1008
|
+
thresh = self.alpha * step
|
|
1009
|
+
l2_scale = 1.0
|
|
1010
|
+
_use_l2 = abs(l2_scale - 1.0) > 1e-12
|
|
1011
|
+
|
|
1012
|
+
if hasattr(self, '_init_coef') and self._init_coef is not None:
|
|
1013
|
+
coef = xp_asarray(self._init_coef, dtype=X.dtype, xp=xp, ref_arr=X)
|
|
1014
|
+
else:
|
|
1015
|
+
coef = xp_zeros(n_features, X.dtype, xp, ref_arr=X)
|
|
1016
|
+
y_k = xp_copy(coef)
|
|
1017
|
+
t_k = 1.0
|
|
1018
|
+
beta = 0.0
|
|
1019
|
+
|
|
1020
|
+
# Build fused element-wise kernel (backend-specific JIT)
|
|
1021
|
+
_fused_step = None
|
|
1022
|
+
_fused_step_l2 = None
|
|
1023
|
+
_st_fn = self._soft_threshold_gpu
|
|
1024
|
+
|
|
1025
|
+
if is_torch:
|
|
1026
|
+
import torch
|
|
1027
|
+
if _use_l2:
|
|
1028
|
+
try:
|
|
1029
|
+
def _fista_elementwise_l2(_y_k, _xtx_y, _step_over_n_Xty, _step_over_n,
|
|
1030
|
+
_thresh, _l2_scale, _coef_old, _beta):
|
|
1031
|
+
w = _y_k - _step_over_n * _xtx_y + _step_over_n_Xty
|
|
1032
|
+
c = _st_fn(w, _thresh, xp) / _l2_scale
|
|
1033
|
+
y = c + _beta * (c - _coef_old)
|
|
1034
|
+
return c, y
|
|
1035
|
+
_fused_step_l2 = torch.compile(_fista_elementwise_l2, mode='reduce-overhead')
|
|
1036
|
+
except Exception:
|
|
1037
|
+
_fused_step_l2 = None
|
|
1038
|
+
else:
|
|
1039
|
+
try:
|
|
1040
|
+
def _fista_elementwise(_y_k, _xtx_y, _step_over_n_Xty, _step_over_n,
|
|
1041
|
+
_thresh, _coef_old, _beta):
|
|
1042
|
+
w = _y_k - _step_over_n * _xtx_y + _step_over_n_Xty
|
|
1043
|
+
c = _st_fn(w, _thresh, xp)
|
|
1044
|
+
y = c + _beta * (c - _coef_old)
|
|
1045
|
+
return c, y
|
|
1046
|
+
_fused_step = torch.compile(_fista_elementwise, mode='reduce-overhead')
|
|
1047
|
+
except Exception:
|
|
1048
|
+
_fused_step = None
|
|
1049
|
+
else:
|
|
1050
|
+
import cupy as cp
|
|
1051
|
+
if _use_l2:
|
|
1052
|
+
try:
|
|
1053
|
+
@cp.fuse()
|
|
1054
|
+
def _fista_elementwise_l2(_y_k, _xtx_y, _step_over_n_Xty, _step_over_n,
|
|
1055
|
+
_thresh, _l2_scale, _coef_old, _beta):
|
|
1056
|
+
w = _y_k - _step_over_n * _xtx_y + _step_over_n_Xty
|
|
1057
|
+
c = (cp.sign(w) * cp.maximum(cp.abs(w) - _thresh, 0.0) / _l2_scale)
|
|
1058
|
+
y = c + _beta * (c - _coef_old)
|
|
1059
|
+
return c, y
|
|
1060
|
+
_fused_step_l2 = _fista_elementwise_l2
|
|
1061
|
+
_dummy = cp.zeros(1, dtype=X.dtype)
|
|
1062
|
+
_fused_step_l2(_dummy, _dummy, _dummy, 0.0, 0.0, 1.0, _dummy, 0.0)
|
|
1063
|
+
except Exception:
|
|
1064
|
+
_fused_step_l2 = None
|
|
1065
|
+
else:
|
|
1066
|
+
try:
|
|
1067
|
+
@cp.fuse()
|
|
1068
|
+
def _fista_elementwise(_y_k, _xtx_y, _step_over_n_Xty, _step_over_n,
|
|
1069
|
+
_thresh, _coef_old, _beta):
|
|
1070
|
+
w = _y_k - _step_over_n * _xtx_y + _step_over_n_Xty
|
|
1071
|
+
c = (cp.sign(w) * cp.maximum(cp.abs(w) - _thresh, 0.0))
|
|
1072
|
+
y = c + _beta * (c - _coef_old)
|
|
1073
|
+
return c, y
|
|
1074
|
+
_fused_step = _fista_elementwise
|
|
1075
|
+
_dummy = cp.zeros(1, dtype=X.dtype)
|
|
1076
|
+
_fused_step(_dummy, _dummy, _dummy, 0.0, 0.0, _dummy, 0.0)
|
|
1077
|
+
except Exception:
|
|
1078
|
+
_fused_step = None
|
|
1079
|
+
|
|
1080
|
+
for iteration in range(self.max_iter):
|
|
1081
|
+
coef_old = xp_copy(coef)
|
|
1082
|
+
xtx_y = XtX @ y_k
|
|
1083
|
+
|
|
1084
|
+
if _use_l2:
|
|
1085
|
+
if _fused_step_l2 is not None:
|
|
1086
|
+
coef, y_k = _fused_step_l2(
|
|
1087
|
+
y_k, xtx_y, step_over_n_Xty, step_over_n,
|
|
1088
|
+
thresh, l2_scale, coef_old, beta,
|
|
1089
|
+
)
|
|
1090
|
+
else:
|
|
1091
|
+
w_tilde = y_k - step_over_n * xtx_y + step_over_n_Xty
|
|
1092
|
+
coef = _st_fn(w_tilde, thresh, xp) / l2_scale
|
|
1093
|
+
y_k = coef + beta * (coef - coef_old)
|
|
1094
|
+
else:
|
|
1095
|
+
if _fused_step is not None:
|
|
1096
|
+
coef, y_k = _fused_step(
|
|
1097
|
+
y_k, xtx_y, step_over_n_Xty, step_over_n,
|
|
1098
|
+
thresh, coef_old, beta,
|
|
1099
|
+
)
|
|
1100
|
+
else:
|
|
1101
|
+
w_tilde = y_k - step_over_n * xtx_y + step_over_n_Xty
|
|
1102
|
+
coef = _st_fn(w_tilde, thresh, xp)
|
|
1103
|
+
y_k = coef + beta * (coef - coef_old)
|
|
1104
|
+
|
|
1105
|
+
if iteration > 0 and iteration % 50 == 0:
|
|
1106
|
+
t_k = 1.0
|
|
1107
|
+
|
|
1108
|
+
beta, t_k = _nesterov_momentum(t_k)
|
|
1109
|
+
|
|
1110
|
+
self.n_iter_ = iteration + 1
|
|
1111
|
+
if iteration % 5 == 4 and float(_to_numpy(_abs_sum_dev(coef - coef_old))) < self.tol:
|
|
1112
|
+
break
|
|
1113
|
+
else:
|
|
1114
|
+
step = 1.0 / L
|
|
1115
|
+
if hasattr(self, '_init_coef') and self._init_coef is not None:
|
|
1116
|
+
coef = xp_asarray(self._init_coef, dtype=X.dtype, xp=xp, ref_arr=X)
|
|
1117
|
+
else:
|
|
1118
|
+
coef = xp_zeros(n_features, X.dtype, xp, ref_arr=X)
|
|
1119
|
+
y_k = xp_copy(coef)
|
|
1120
|
+
t_k = 1.0
|
|
1121
|
+
|
|
1122
|
+
for iteration in range(self.max_iter):
|
|
1123
|
+
coef_old = xp_copy(coef)
|
|
1124
|
+
grad = (XtX @ y_k - Xty) / n_samples
|
|
1125
|
+
w_tilde = y_k - step * grad
|
|
1126
|
+
coef = self._penalty.proximal(w_tilde, step, backend=backend_name)
|
|
1127
|
+
|
|
1128
|
+
if iteration > 0 and iteration % 50 == 0:
|
|
1129
|
+
t_k = 1.0
|
|
1130
|
+
|
|
1131
|
+
y_k, t_k = _nesterov_update(coef, coef_old, t_k)
|
|
1132
|
+
|
|
1133
|
+
self.n_iter_ = iteration + 1
|
|
1134
|
+
if iteration % 5 == 4 and float(_to_numpy(_abs_sum_dev(coef - coef_old))) < self.tol:
|
|
1135
|
+
break
|
|
1136
|
+
|
|
1137
|
+
# Transfer to CPU
|
|
1138
|
+
coef_np = _to_numpy(coef)
|
|
1139
|
+
if self._effective_intercept:
|
|
1140
|
+
self.intercept_ = float(_to_numpy(y_mean) - _to_numpy(X_mean) @ coef_np)
|
|
1141
|
+
self.coef_ = coef_np
|
|
1142
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
1143
|
+
else:
|
|
1144
|
+
self.intercept_ = 0.0
|
|
1145
|
+
self.coef_ = coef_np
|
|
1146
|
+
self._params = coef_np.copy()
|
|
1147
|
+
|
|
1148
|
+
self._df_resid = n_samples - (n_features + (1 if self._effective_intercept else 0))
|
|
1149
|
+
|
|
1150
|
+
# Debiased inference on GPU (before cleanup)
|
|
1151
|
+
if self.compute_inference and "debiased" in str(getattr(self, "inference_method", "")).lower():
|
|
1152
|
+
penalty_name = str(getattr(self._penalty, "name", self.penalty)).lower()
|
|
1153
|
+
if penalty_name in ("l1", "elasticnet", "en"):
|
|
1154
|
+
infer_fn = getattr(self, f'_compute_inference_debiased_{"torch" if is_torch else "gpu"}')
|
|
1155
|
+
infer_fn(X, y, coef)
|
|
1156
|
+
|
|
1157
|
+
if is_torch:
|
|
1158
|
+
self._cleanup_torch_memory()
|
|
1159
|
+
else:
|
|
1160
|
+
self._cleanup_cuda_memory()
|
|
1161
|
+
|
|
1162
|
+
def _ridge_alpha_for_exact(self) -> float:
|
|
1163
|
+
"""Return L2 alpha for the exact Ridge normal equations."""
|
|
1164
|
+
return float(getattr(self._penalty, "alpha", self.alpha))
|
|
1165
|
+
|
|
1166
|
+
def _solve_exact_numpy(self, XtX, Xty, n_samples):
|
|
1167
|
+
alpha = self._ridge_alpha_for_exact()
|
|
1168
|
+
p = XtX.shape[0]
|
|
1169
|
+
# Per-sample convention: XtX is unnormalized (X'X), so we need
|
|
1170
|
+
# n*alpha to match loss/n + alpha*||w||^2 used by all other paths.
|
|
1171
|
+
A = XtX + (float(n_samples) * alpha) * np.eye(p, dtype=XtX.dtype)
|
|
1172
|
+
try:
|
|
1173
|
+
return np.linalg.solve(A, Xty)
|
|
1174
|
+
except np.linalg.LinAlgError:
|
|
1175
|
+
return np.linalg.pinv(A) @ Xty
|
|
1176
|
+
|
|
1177
|
+
def _solve_exact_cupy(self, XtX, Xty, n_samples):
|
|
1178
|
+
import cupy as cp
|
|
1179
|
+
from cupyx.scipy.linalg import solve_triangular as cp_solve_triangular
|
|
1180
|
+
|
|
1181
|
+
alpha = self._ridge_alpha_for_exact()
|
|
1182
|
+
p = XtX.shape[0]
|
|
1183
|
+
A = XtX + (float(n_samples) * alpha) * cp.eye(p, dtype=XtX.dtype)
|
|
1184
|
+
try:
|
|
1185
|
+
# Cholesky + triangular solve is faster than general solve
|
|
1186
|
+
# for positive-definite matrices (Ridge penalty guarantees PD)
|
|
1187
|
+
L = cp.linalg.cholesky(A)
|
|
1188
|
+
tmp = cp_solve_triangular(L, Xty, lower=True)
|
|
1189
|
+
return cp_solve_triangular(L.T, tmp, lower=False)
|
|
1190
|
+
except _LINALG_ERRORS:
|
|
1191
|
+
try:
|
|
1192
|
+
return cp.linalg.solve(A, Xty)
|
|
1193
|
+
except _LINALG_ERRORS:
|
|
1194
|
+
return cp.linalg.pinv(A) @ Xty
|
|
1195
|
+
|
|
1196
|
+
def _solve_exact_torch(self, XtX, Xty, n_samples):
|
|
1197
|
+
import torch
|
|
1198
|
+
|
|
1199
|
+
alpha = self._ridge_alpha_for_exact()
|
|
1200
|
+
p = XtX.shape[0]
|
|
1201
|
+
A = XtX + (float(n_samples) * alpha) * torch.eye(
|
|
1202
|
+
p, dtype=XtX.dtype, device=XtX.device
|
|
1203
|
+
)
|
|
1204
|
+
try:
|
|
1205
|
+
# torch.linalg.solve is faster than Cholesky + solve_triangular
|
|
1206
|
+
# on PyTorch due to kernel launch overhead for small matrices
|
|
1207
|
+
return torch.linalg.solve(A, Xty)
|
|
1208
|
+
except RuntimeError:
|
|
1209
|
+
return torch.linalg.pinv(A) @ Xty
|
|
1210
|
+
|
|
1211
|
+
def _block_cd_group_lasso(self, pen, X_work, y_arr, init):
|
|
1212
|
+
"""Block coordinate descent for group_lasso penalty.
|
|
1213
|
+
|
|
1214
|
+
Matches R grpreg's block CD algorithm: iterate over groups, compute
|
|
1215
|
+
partial residual per group, solve the group subproblem, apply block
|
|
1216
|
+
soft-thresholding.
|
|
1217
|
+
"""
|
|
1218
|
+
import numpy as np
|
|
1219
|
+
|
|
1220
|
+
n, pp = X_work.shape
|
|
1221
|
+
p = pp - 1 if self._effective_intercept else pp
|
|
1222
|
+
alpha = self.alpha
|
|
1223
|
+
|
|
1224
|
+
_inner = getattr(self, '_penalty', pen)
|
|
1225
|
+
_g_indices = getattr(_inner, '_group_indices', None)
|
|
1226
|
+
_sqrt_pg = getattr(_inner, '_sqrt_pg', None)
|
|
1227
|
+
if _g_indices is None or _sqrt_pg is None:
|
|
1228
|
+
raise ValueError(
|
|
1229
|
+
"group_lasso penalty must have groups set. "
|
|
1230
|
+
"Pass groups=... in penalty_kwargs."
|
|
1231
|
+
)
|
|
1232
|
+
_n_groups = len(_g_indices)
|
|
1233
|
+
|
|
1234
|
+
XtX = X_work.T @ X_work / n
|
|
1235
|
+
Xty = (X_work.T @ y_arr.flatten()) / n
|
|
1236
|
+
|
|
1237
|
+
_XtX_blocks = []
|
|
1238
|
+
for g_idx in _g_indices:
|
|
1239
|
+
_XtX_blocks.append(XtX[np.ix_(g_idx, g_idx)])
|
|
1240
|
+
|
|
1241
|
+
if init is not None:
|
|
1242
|
+
coef = np.array(init, dtype=np.float64)
|
|
1243
|
+
else:
|
|
1244
|
+
coef = np.zeros(pp, dtype=np.float64)
|
|
1245
|
+
|
|
1246
|
+
iteration = -1 # ensure defined when max_iter=0
|
|
1247
|
+
for iteration in range(self.max_iter):
|
|
1248
|
+
coef_old = coef.copy()
|
|
1249
|
+
|
|
1250
|
+
for g in range(_n_groups):
|
|
1251
|
+
g_idx = _g_indices[g]
|
|
1252
|
+
rho_g = Xty[g_idx] - XtX[g_idx, :] @ coef + _XtX_blocks[g] @ coef[g_idx]
|
|
1253
|
+
try:
|
|
1254
|
+
w_g = np.linalg.solve(_XtX_blocks[g], rho_g)
|
|
1255
|
+
except np.linalg.LinAlgError:
|
|
1256
|
+
w_g = np.zeros(len(g_idx))
|
|
1257
|
+
norm_w = np.linalg.norm(w_g)
|
|
1258
|
+
thresh_g = alpha * _sqrt_pg[g]
|
|
1259
|
+
if norm_w > thresh_g:
|
|
1260
|
+
coef[g_idx] = w_g * (1.0 - thresh_g / norm_w)
|
|
1261
|
+
else:
|
|
1262
|
+
coef[g_idx] = 0.0
|
|
1263
|
+
|
|
1264
|
+
if self._effective_intercept:
|
|
1265
|
+
coef[pp - 1] = np.mean(y_arr - X_work[:, :p] @ coef[:p])
|
|
1266
|
+
|
|
1267
|
+
if np.max(np.abs(coef - coef_old)) < self.tol:
|
|
1268
|
+
break
|
|
1269
|
+
|
|
1270
|
+
n_iter = iteration + 1
|
|
1271
|
+
|
|
1272
|
+
if self._effective_intercept:
|
|
1273
|
+
beta = coef[:p]
|
|
1274
|
+
intercept = float(coef[p])
|
|
1275
|
+
else:
|
|
1276
|
+
beta = coef
|
|
1277
|
+
intercept = 0.0
|
|
1278
|
+
|
|
1279
|
+
return beta, intercept, n_iter
|
|
1280
|
+
|
|
1281
|
+
def _block_cd_group_lasso_gpu(self, pen, X_work, y_arr, init, backend_name):
|
|
1282
|
+
"""GPU-native block coordinate descent for group_lasso penalty.
|
|
1283
|
+
|
|
1284
|
+
Same algorithm as _block_cd_group_lasso but keeps all arrays on GPU.
|
|
1285
|
+
Enforces float64 precision to avoid NaN from float32 conditioning issues.
|
|
1286
|
+
"""
|
|
1287
|
+
from statgpu.backends._array_ops import _xp_copy, _xp_zeros, _xp_asarray, _xp_eye
|
|
1288
|
+
from statgpu.backends._utils import _get_xp, xp_astype
|
|
1289
|
+
xp = _get_xp(backend_name)
|
|
1290
|
+
|
|
1291
|
+
# Enforce float64 precision for numerical stability
|
|
1292
|
+
X_work = xp_astype(X_work, xp.float64, xp)
|
|
1293
|
+
y_arr = xp_astype(y_arr, xp.float64, xp)
|
|
1294
|
+
|
|
1295
|
+
n, pp = X_work.shape
|
|
1296
|
+
p = pp - 1 if self._effective_intercept else pp
|
|
1297
|
+
alpha = self.alpha
|
|
1298
|
+
|
|
1299
|
+
_inner = getattr(self, '_penalty', pen)
|
|
1300
|
+
_g_indices = getattr(_inner, '_group_indices', None)
|
|
1301
|
+
_sqrt_pg_np = getattr(_inner, '_sqrt_pg', None)
|
|
1302
|
+
if _g_indices is None or _sqrt_pg_np is None:
|
|
1303
|
+
raise ValueError(
|
|
1304
|
+
"group_lasso penalty must have groups set. "
|
|
1305
|
+
"Pass groups=... in penalty_kwargs."
|
|
1306
|
+
)
|
|
1307
|
+
_n_groups = len(_g_indices)
|
|
1308
|
+
_sqrt_pg = [float(s) for s in _sqrt_pg_np]
|
|
1309
|
+
|
|
1310
|
+
XtX = X_work.T @ X_work / n
|
|
1311
|
+
Xty = (X_work.T @ y_arr.flatten()) / n
|
|
1312
|
+
|
|
1313
|
+
# Pre-compute XtX blocks with diagonal ridge for conditioning
|
|
1314
|
+
from statgpu.backends._array_ops import _scalar_tensor
|
|
1315
|
+
_XtX_blocks = []
|
|
1316
|
+
_ridge = _scalar_tensor(1e-10, X_work)
|
|
1317
|
+
for g_idx in _g_indices:
|
|
1318
|
+
block = XtX[g_idx][:, g_idx]
|
|
1319
|
+
block = block + _ridge * _xp_eye(block.shape[0], block.dtype, block)
|
|
1320
|
+
_XtX_blocks.append(block)
|
|
1321
|
+
|
|
1322
|
+
if init is not None:
|
|
1323
|
+
if isinstance(init, np.ndarray):
|
|
1324
|
+
coef = _xp_asarray(init, X_work.dtype, X_work)
|
|
1325
|
+
else:
|
|
1326
|
+
coef = _xp_copy(init)
|
|
1327
|
+
else:
|
|
1328
|
+
coef = _xp_zeros(pp, X_work.dtype, X_work)
|
|
1329
|
+
|
|
1330
|
+
iteration = -1 # ensure defined when max_iter=0
|
|
1331
|
+
for iteration in range(self.max_iter):
|
|
1332
|
+
coef_old = _xp_copy(coef)
|
|
1333
|
+
|
|
1334
|
+
for g in range(_n_groups):
|
|
1335
|
+
g_idx = _g_indices[g]
|
|
1336
|
+
rho_g = Xty[g_idx] - XtX[g_idx, :] @ coef + _XtX_blocks[g] @ coef[g_idx]
|
|
1337
|
+
try:
|
|
1338
|
+
w_g = xp.linalg.solve(_XtX_blocks[g], rho_g)
|
|
1339
|
+
if xp.any(xp.isnan(w_g)) or xp.any(xp.isinf(w_g)):
|
|
1340
|
+
w_g = _xp_zeros(len(g_idx), X_work.dtype, X_work)
|
|
1341
|
+
except Exception:
|
|
1342
|
+
w_g = _xp_zeros(len(g_idx), X_work.dtype, X_work)
|
|
1343
|
+
norm_w = float(xp.linalg.norm(w_g))
|
|
1344
|
+
thresh_g = alpha * _sqrt_pg[g]
|
|
1345
|
+
if norm_w > thresh_g:
|
|
1346
|
+
coef[g_idx] = w_g * (1.0 - thresh_g / norm_w)
|
|
1347
|
+
else:
|
|
1348
|
+
coef[g_idx] = 0.0
|
|
1349
|
+
|
|
1350
|
+
if self._effective_intercept:
|
|
1351
|
+
coef[pp - 1] = float(xp.mean(y_arr - X_work[:, :p] @ coef[:p]))
|
|
1352
|
+
|
|
1353
|
+
_max_change = float(xp.max(xp.abs(coef - coef_old)))
|
|
1354
|
+
if _max_change < self.tol:
|
|
1355
|
+
break
|
|
1356
|
+
|
|
1357
|
+
n_iter = iteration + 1
|
|
1358
|
+
|
|
1359
|
+
if self._effective_intercept:
|
|
1360
|
+
beta = coef[:p]
|
|
1361
|
+
intercept = float(coef[p])
|
|
1362
|
+
else:
|
|
1363
|
+
beta = coef
|
|
1364
|
+
intercept = 0.0
|
|
1365
|
+
|
|
1366
|
+
return beta, intercept, n_iter
|
|
1367
|
+
|
|
1368
|
+
def _fit_loss_backend(self, X, y, sample_weight, solver_name, backend_name):
|
|
1369
|
+
"""Fit GLMLoss + Penalty without changing the selected backend."""
|
|
1370
|
+
from statgpu.solvers import (
|
|
1371
|
+
fista_solver,
|
|
1372
|
+
fista_bb_solver,
|
|
1373
|
+
admm_solver,
|
|
1374
|
+
lbfgs_solver,
|
|
1375
|
+
newton_solver,
|
|
1376
|
+
)
|
|
1377
|
+
|
|
1378
|
+
# Convert to target backend with float64 precision for numerical stability
|
|
1379
|
+
from statgpu.backends._array_ops import _xp_asarray
|
|
1380
|
+
from statgpu.backends._utils import _get_xp
|
|
1381
|
+
_xp = _get_xp(backend_name)
|
|
1382
|
+
_ref = X if not isinstance(X, np.ndarray) else _xp.zeros(1, dtype=_xp.float64)
|
|
1383
|
+
X_arr = _xp_asarray(X, _xp.float64, _ref)
|
|
1384
|
+
y_arr = _xp_asarray(y, _xp.float64, X_arr)
|
|
1385
|
+
if self._effective_intercept:
|
|
1386
|
+
p = X_arr.shape[1]
|
|
1387
|
+
X_work = self._column_stack(
|
|
1388
|
+
[X_arr, self._ones(X_arr.shape[0], backend_name, X_arr)],
|
|
1389
|
+
backend_name,
|
|
1390
|
+
)
|
|
1391
|
+
pen = self._selective_penalty(p, backend_name)
|
|
1392
|
+
init = None
|
|
1393
|
+
if self._init_coef is not None:
|
|
1394
|
+
init_intercept = float(getattr(self, '_init_intercept', 0.0) or 0.0)
|
|
1395
|
+
init = np.append(self._init_coef, init_intercept)
|
|
1396
|
+
init = _xp_asarray(init, X_arr.dtype, X_arr)
|
|
1397
|
+
else:
|
|
1398
|
+
# Warm-start intercept for GLM losses (prevents divergence
|
|
1399
|
+
# of the unpenalized intercept toward -inf for zero-heavy data).
|
|
1400
|
+
_loss_name = getattr(self._loss, 'name', '')
|
|
1401
|
+
_y_mean = float(np.mean(_to_numpy(y_arr)))
|
|
1402
|
+
if _loss_name == "poisson":
|
|
1403
|
+
_int_init = np.log(max(_y_mean, 1e-3))
|
|
1404
|
+
elif _loss_name == "logistic":
|
|
1405
|
+
_y_mean_clipped = np.clip(_y_mean, 1e-3, 1.0 - 1e-3)
|
|
1406
|
+
_int_init = np.log(_y_mean_clipped / (1.0 - _y_mean_clipped))
|
|
1407
|
+
elif _loss_name in ("gamma", "inverse_gaussian", "negative_binomial", "tweedie"):
|
|
1408
|
+
# All use log link: intercept init = log(y_mean)
|
|
1409
|
+
_int_init = np.log(max(_y_mean, 1e-3))
|
|
1410
|
+
else:
|
|
1411
|
+
_int_init = _y_mean # identity link (squared_error)
|
|
1412
|
+
init = np.zeros(p + 1)
|
|
1413
|
+
init[-1] = _int_init
|
|
1414
|
+
init = _xp_asarray(init, X_arr.dtype, X_arr)
|
|
1415
|
+
else:
|
|
1416
|
+
p = X_arr.shape[1]
|
|
1417
|
+
X_work = X_arr
|
|
1418
|
+
pen = self._penalty
|
|
1419
|
+
init = None
|
|
1420
|
+
if self._init_coef is not None:
|
|
1421
|
+
init = np.asarray(self._init_coef, dtype=np.float64)
|
|
1422
|
+
init = _xp_asarray(init, X_arr.dtype, X_arr)
|
|
1423
|
+
|
|
1424
|
+
# SCAD/MCP and adaptive_l1 use IRLS-CD (matching R ncvreg's
|
|
1425
|
+
# per-coordinate algorithm). GLM+SCAD/MCP uses 1 CD sweep per
|
|
1426
|
+
# IRLS iteration to avoid cycling.
|
|
1427
|
+
_loss_name = getattr(self._loss, 'name', '')
|
|
1428
|
+
_pen_name = getattr(pen, 'name', '')
|
|
1429
|
+
# SelectivePenalty (intercept wrapper) has no name; fall back to
|
|
1430
|
+
# the original penalty's name so SCAD/MCP routing works.
|
|
1431
|
+
if not _pen_name:
|
|
1432
|
+
_pen_name = getattr(self._penalty, 'name', '')
|
|
1433
|
+
_is_glm_loss = _loss_name not in ("squared_error", "")
|
|
1434
|
+
# Routing:
|
|
1435
|
+
# adaptive_l1/adaptive_lasso -> FISTA (weighted L1 proximal, works
|
|
1436
|
+
# for both GLM and squared_error; avoids slow sequential CD)
|
|
1437
|
+
# squared_error + SCAD/MCP -> IRLS-CD (matching R ncvreg)
|
|
1438
|
+
# GLM + SCAD/MCP -> IRLS-CD (matching R ncvreg's IRLS+CD algorithm)
|
|
1439
|
+
_use_fista = _pen_name in ("adaptive_l1", "adaptive_lasso")
|
|
1440
|
+
_use_irls_cd = (
|
|
1441
|
+
(_pen_name in ("scad", "mcp") and not _is_glm_loss)
|
|
1442
|
+
)
|
|
1443
|
+
_use_lla_fista = (
|
|
1444
|
+
_pen_name in ("scad", "mcp") and _is_glm_loss
|
|
1445
|
+
)
|
|
1446
|
+
_use_lla_group = (
|
|
1447
|
+
_pen_name in ("group_mcp", "group_scad", "gmcp", "gscad") and _is_glm_loss
|
|
1448
|
+
)
|
|
1449
|
+
|
|
1450
|
+
if _use_fista:
|
|
1451
|
+
# FISTA for GLM+adaptive_l1 -- works on any backend.
|
|
1452
|
+
from statgpu.solvers import fista_solver
|
|
1453
|
+
params, n_iter = fista_solver(
|
|
1454
|
+
self._loss, pen, X_work, y_arr,
|
|
1455
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
1456
|
+
init_coef=init, sample_weight=sample_weight,
|
|
1457
|
+
)
|
|
1458
|
+
elif _use_irls_cd:
|
|
1459
|
+
# squared_error + SCAD/MCP: use fused FISTA+LLA on all backends.
|
|
1460
|
+
# Produces identical results across CPU/GPU and avoids slow
|
|
1461
|
+
# sequential coordinate descent on GPU.
|
|
1462
|
+
from statgpu.solvers import fista_lla_path
|
|
1463
|
+
import numpy as _np
|
|
1464
|
+
|
|
1465
|
+
# Compute continuation path (lambda_max -> target alpha)
|
|
1466
|
+
_X_feat = _to_numpy(X_work[:, :p] if self._effective_intercept else X_work)
|
|
1467
|
+
_y_feat = _to_numpy(y_arr)
|
|
1468
|
+
_n = _X_feat.shape[0]
|
|
1469
|
+
_col_norms = _np.sqrt(_np.sum(_X_feat ** 2, axis=0))
|
|
1470
|
+
_col_norms = _np.maximum(_col_norms, 1e-20)
|
|
1471
|
+
_X_s = _X_feat * (_np.sqrt(_n) / _col_norms)
|
|
1472
|
+
_y_c = _y_feat - _np.mean(_y_feat)
|
|
1473
|
+
_lam_max = float(_np.max(_np.abs(_X_s.T @ _y_c / _n)))
|
|
1474
|
+
_target_alpha = float(getattr(self._penalty, 'alpha', self.alpha))
|
|
1475
|
+
_n_cont = 20
|
|
1476
|
+
_alpha_path = _np.geomspace(
|
|
1477
|
+
max(_lam_max, _target_alpha * 1.1), _target_alpha, _n_cont,
|
|
1478
|
+
)
|
|
1479
|
+
_max_lla_per_step = max(6, getattr(self, '_max_lla_iters', 50) // _n_cont)
|
|
1480
|
+
_saved_mi = self.max_iter
|
|
1481
|
+
_mi_path = []
|
|
1482
|
+
for _i in range(_n_cont):
|
|
1483
|
+
_is_last = (_i == _n_cont - 1)
|
|
1484
|
+
_mi_path.append(_saved_mi if _is_last else max(100, _saved_mi // 10))
|
|
1485
|
+
|
|
1486
|
+
X_orig = X_work[:, :p] if self._effective_intercept else X_work
|
|
1487
|
+
coef_np, intercept, n_iter = fista_lla_path(
|
|
1488
|
+
self._loss, self._penalty,
|
|
1489
|
+
X_orig, y_arr,
|
|
1490
|
+
alpha_path=_alpha_path,
|
|
1491
|
+
max_lla_per_step=_max_lla_per_step,
|
|
1492
|
+
lla_tol=getattr(self, '_lla_tol', 1e-6),
|
|
1493
|
+
max_iter=_mi_path,
|
|
1494
|
+
tol=self.tol,
|
|
1495
|
+
fit_intercept=self._effective_intercept,
|
|
1496
|
+
sample_weight=sample_weight,
|
|
1497
|
+
)
|
|
1498
|
+
if self._effective_intercept:
|
|
1499
|
+
params_np = np.concatenate([coef_np, [intercept]])
|
|
1500
|
+
else:
|
|
1501
|
+
params_np = coef_np
|
|
1502
|
+
params = params_np
|
|
1503
|
+
elif _use_lla_fista:
|
|
1504
|
+
# GLM + SCAD/MCP: use LLA outer loop + FISTA inner solve.
|
|
1505
|
+
from statgpu.solvers import fista_lla_path
|
|
1506
|
+
import numpy as _np
|
|
1507
|
+
|
|
1508
|
+
xp = get_backend(backend_name).xp
|
|
1509
|
+
|
|
1510
|
+
# lambda_max with backend-native arrays (no CPU-GPU transfer)
|
|
1511
|
+
X_feat = X_work[:, :p] if self._effective_intercept else X_work
|
|
1512
|
+
_n = X_feat.shape[0]
|
|
1513
|
+
_col_norms = xp.sqrt(xp.sum(X_feat ** 2, axis=0))
|
|
1514
|
+
if backend_name == "torch":
|
|
1515
|
+
import torch
|
|
1516
|
+
_col_norms = torch.clamp(_col_norms, min=1e-20)
|
|
1517
|
+
else:
|
|
1518
|
+
_col_norms = xp.maximum(_col_norms, 1e-20)
|
|
1519
|
+
X_s = X_feat * (float(_n) ** 0.5 / _col_norms)
|
|
1520
|
+
y_c = y_arr - xp.mean(y_arr)
|
|
1521
|
+
_lam_max = float(xp.max(xp.abs(X_s.T @ y_c / _n)))
|
|
1522
|
+
_cv_alpha_path = getattr(self, '_cv_alpha_path', None)
|
|
1523
|
+
_cv_return_path = _cv_alpha_path is not None
|
|
1524
|
+
if _cv_return_path:
|
|
1525
|
+
_targets = _np.asarray(_cv_alpha_path, dtype=float).ravel()
|
|
1526
|
+
_targets = _targets[_np.isfinite(_targets) & (_targets > 0.0)]
|
|
1527
|
+
if _targets.size == 0:
|
|
1528
|
+
_targets = _np.asarray([float(getattr(self._penalty, 'alpha', self.alpha))])
|
|
1529
|
+
_targets = _np.sort(_targets)[::-1]
|
|
1530
|
+
_target_alpha = float(_targets[-1])
|
|
1531
|
+
_alpha_start = max(_lam_max, float(_targets[0]) * 1.1)
|
|
1532
|
+
if _alpha_start > float(_targets[0]) * (1.0 + 1e-10):
|
|
1533
|
+
_alpha_path = _np.concatenate([[_alpha_start], _targets])
|
|
1534
|
+
else:
|
|
1535
|
+
_alpha_path = _targets
|
|
1536
|
+
_n_cont = int(_alpha_path.size)
|
|
1537
|
+
else:
|
|
1538
|
+
_target_alpha = float(getattr(self._penalty, 'alpha', self.alpha))
|
|
1539
|
+
_n_cont = 20
|
|
1540
|
+
_alpha_path = _np.geomspace(
|
|
1541
|
+
max(_lam_max, _target_alpha * 1.1), _target_alpha, _n_cont,
|
|
1542
|
+
)
|
|
1543
|
+
|
|
1544
|
+
_max_lla_per_step = max(6, getattr(self, '_max_lla_iters', 50) // max(_n_cont, 1))
|
|
1545
|
+
_saved_mi = self.max_iter
|
|
1546
|
+
if _cv_return_path:
|
|
1547
|
+
_mi_path = [max(200, _saved_mi // 2)] * max(_n_cont - 1, 0) + [_saved_mi]
|
|
1548
|
+
else:
|
|
1549
|
+
_mi_path = [_saved_mi if i == _n_cont - 1 else max(100, _saved_mi // 10)
|
|
1550
|
+
for i in range(_n_cont)]
|
|
1551
|
+
|
|
1552
|
+
X_orig = X_work[:, :p] if self._effective_intercept else X_work
|
|
1553
|
+
|
|
1554
|
+
_warm_coef = None
|
|
1555
|
+
_warm_intercept = None
|
|
1556
|
+
_init = getattr(self, '_init_coef', None)
|
|
1557
|
+
if _init is not None:
|
|
1558
|
+
_init_np = np.asarray(_to_numpy(_init), dtype=np.float64).ravel()
|
|
1559
|
+
if self._effective_intercept and _init_np.size == p + 1:
|
|
1560
|
+
_warm_coef = _init_np[:p]
|
|
1561
|
+
_warm_intercept = float(_init_np[p])
|
|
1562
|
+
elif _init_np.size == p:
|
|
1563
|
+
_warm_coef = _init_np
|
|
1564
|
+
if self._effective_intercept:
|
|
1565
|
+
_warm_intercept = float(
|
|
1566
|
+
getattr(self, '_init_intercept', 0.0) or 0.0
|
|
1567
|
+
)
|
|
1568
|
+
|
|
1569
|
+
_lla_result = fista_lla_path(
|
|
1570
|
+
self._loss, self._penalty,
|
|
1571
|
+
X_orig, y_arr,
|
|
1572
|
+
alpha_path=_alpha_path,
|
|
1573
|
+
max_lla_per_step=_max_lla_per_step,
|
|
1574
|
+
lla_tol=getattr(self, '_lla_tol', 1e-6),
|
|
1575
|
+
max_iter=_mi_path,
|
|
1576
|
+
tol=self.tol,
|
|
1577
|
+
fit_intercept=self._effective_intercept,
|
|
1578
|
+
sample_weight=sample_weight,
|
|
1579
|
+
init_coef=_warm_coef,
|
|
1580
|
+
init_intercept=_warm_intercept,
|
|
1581
|
+
return_path=_cv_return_path,
|
|
1582
|
+
)
|
|
1583
|
+
if _cv_return_path:
|
|
1584
|
+
coef_np, intercept, n_iter, _path_results = _lla_result
|
|
1585
|
+
self._cv_path_results = _path_results
|
|
1586
|
+
else:
|
|
1587
|
+
coef_np, intercept, n_iter = _lla_result
|
|
1588
|
+
# fista_lla_path returns numpy, convert back to backend-native
|
|
1589
|
+
if self._effective_intercept:
|
|
1590
|
+
params = xp.concatenate([xp.asarray(coef_np), xp.asarray([intercept])])
|
|
1591
|
+
else:
|
|
1592
|
+
params = xp.asarray(coef_np)
|
|
1593
|
+
elif _use_lla_group:
|
|
1594
|
+
# GLM + group_mcp/group_scad: LLA outer loop + FISTA inner solve
|
|
1595
|
+
# with AdaptiveGroupLassoPenalty as inner penalty.
|
|
1596
|
+
from statgpu.solvers import fista_lla_path
|
|
1597
|
+
from statgpu.penalties._group_lasso import AdaptiveGroupLassoPenalty
|
|
1598
|
+
import numpy as _np
|
|
1599
|
+
|
|
1600
|
+
xp = get_backend(backend_name).xp
|
|
1601
|
+
|
|
1602
|
+
# lambda_max with backend-native arrays
|
|
1603
|
+
X_feat = X_work[:, :p] if self._effective_intercept else X_work
|
|
1604
|
+
_n = X_feat.shape[0]
|
|
1605
|
+
_col_norms = xp.sqrt(xp.sum(X_feat ** 2, axis=0))
|
|
1606
|
+
if backend_name == "torch":
|
|
1607
|
+
import torch
|
|
1608
|
+
_col_norms = torch.clamp(_col_norms, min=1e-20)
|
|
1609
|
+
else:
|
|
1610
|
+
_col_norms = xp.maximum(_col_norms, 1e-20)
|
|
1611
|
+
X_s = X_feat * (float(_n) ** 0.5 / _col_norms)
|
|
1612
|
+
y_c = y_arr - xp.mean(y_arr)
|
|
1613
|
+
_lam_max = float(xp.max(xp.abs(X_s.T @ y_c / _n)))
|
|
1614
|
+
_target_alpha = float(getattr(self._penalty, 'alpha', self.alpha))
|
|
1615
|
+
|
|
1616
|
+
_n_cont = 20
|
|
1617
|
+
_alpha_path = _np.geomspace(
|
|
1618
|
+
max(_lam_max, _target_alpha * 1.1), _target_alpha, _n_cont,
|
|
1619
|
+
)
|
|
1620
|
+
_max_lla_per_step = max(6, getattr(self, '_max_lla_iters', 50) // _n_cont)
|
|
1621
|
+
_saved_mi = self.max_iter
|
|
1622
|
+
_mi_path = [_saved_mi if i == _n_cont - 1 else max(100, _saved_mi // 10)
|
|
1623
|
+
for i in range(_n_cont)]
|
|
1624
|
+
|
|
1625
|
+
# Create penalty factory for group LLA
|
|
1626
|
+
_orig_pen = self._penalty # unwrap SelectivePenalty
|
|
1627
|
+
_groups = getattr(_orig_pen, '_group_indices', None)
|
|
1628
|
+
_pen_alpha = float(_orig_pen.alpha)
|
|
1629
|
+
|
|
1630
|
+
# Create penalty object once; reuse via set_weights() to avoid
|
|
1631
|
+
# repeated _init_groups() + object creation overhead.
|
|
1632
|
+
_adaptive_pen = AdaptiveGroupLassoPenalty(
|
|
1633
|
+
groups=_groups, alpha=_pen_alpha,
|
|
1634
|
+
)
|
|
1635
|
+
def _group_lla_factory(weights_np):
|
|
1636
|
+
# lla_weights returns per-coordinate; compute per-group weights
|
|
1637
|
+
# as the norm of the per-coordinate weights within each group
|
|
1638
|
+
_gw = np.array([
|
|
1639
|
+
float(np.sqrt(np.sum(weights_np[idx] ** 2))) if len(idx) > 0 else 0.0
|
|
1640
|
+
for idx in _groups
|
|
1641
|
+
])
|
|
1642
|
+
_adaptive_pen.set_weights(_gw)
|
|
1643
|
+
return _adaptive_pen
|
|
1644
|
+
|
|
1645
|
+
X_orig = X_work[:, :p] if self._effective_intercept else X_work
|
|
1646
|
+
coef_np, intercept, n_iter = fista_lla_path(
|
|
1647
|
+
self._loss, self._penalty,
|
|
1648
|
+
X_orig, y_arr,
|
|
1649
|
+
alpha_path=_alpha_path,
|
|
1650
|
+
max_lla_per_step=_max_lla_per_step,
|
|
1651
|
+
lla_tol=getattr(self, '_lla_tol', 1e-6),
|
|
1652
|
+
max_iter=_mi_path,
|
|
1653
|
+
tol=self.tol,
|
|
1654
|
+
fit_intercept=self._effective_intercept,
|
|
1655
|
+
sample_weight=sample_weight,
|
|
1656
|
+
lla_penalty_factory=_group_lla_factory,
|
|
1657
|
+
)
|
|
1658
|
+
# fista_lla_path returns numpy, convert back to backend-native
|
|
1659
|
+
if self._effective_intercept:
|
|
1660
|
+
params = xp.concatenate([xp.asarray(coef_np), xp.asarray([intercept])])
|
|
1661
|
+
else:
|
|
1662
|
+
params = xp.asarray(coef_np)
|
|
1663
|
+
elif _pen_name == "group_lasso":
|
|
1664
|
+
# Block CD for group_lasso: use GPU-native solver on GPU backends.
|
|
1665
|
+
if backend_name != "numpy":
|
|
1666
|
+
coef_gpu, intercept, n_iter = self._block_cd_group_lasso_gpu(
|
|
1667
|
+
pen, X_work, y_arr, init, backend_name,
|
|
1668
|
+
)
|
|
1669
|
+
if self._effective_intercept:
|
|
1670
|
+
from statgpu.backends._utils import _get_xp as _get_xp_fn
|
|
1671
|
+
from statgpu.backends._array_ops import _xp_asarray as _xp_asarray_fn
|
|
1672
|
+
_xp = _get_xp_fn(backend_name)
|
|
1673
|
+
_int_arr = _xp_asarray_fn([intercept], coef_gpu.dtype, coef_gpu)
|
|
1674
|
+
params = _xp.concatenate([coef_gpu, _int_arr])
|
|
1675
|
+
else:
|
|
1676
|
+
params = coef_gpu
|
|
1677
|
+
else:
|
|
1678
|
+
coef_np, intercept, n_iter = self._block_cd_group_lasso(
|
|
1679
|
+
pen, X_work, y_arr, init,
|
|
1680
|
+
)
|
|
1681
|
+
if self._effective_intercept:
|
|
1682
|
+
params = np.concatenate([coef_np, [intercept]])
|
|
1683
|
+
else:
|
|
1684
|
+
params = coef_np
|
|
1685
|
+
elif solver_name == "auto":
|
|
1686
|
+
# For smooth penalties (l2, elasticnet with low l1_ratio),
|
|
1687
|
+
# fista_bb with BB step sizes converges much more reliably
|
|
1688
|
+
# than standard FISTA with Nesterov momentum + proximal l2.
|
|
1689
|
+
_is_smooth = (_pen_name == "l2") or (
|
|
1690
|
+
_pen_name == "elasticnet" and
|
|
1691
|
+
float(getattr(pen, 'l1_ratio', 1.0)) < 0.5
|
|
1692
|
+
)
|
|
1693
|
+
if _is_smooth:
|
|
1694
|
+
params, n_iter = fista_bb_solver(
|
|
1695
|
+
self._loss, pen, X_work, y_arr,
|
|
1696
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
1697
|
+
init_coef=init, sample_weight=sample_weight,
|
|
1698
|
+
)
|
|
1699
|
+
else:
|
|
1700
|
+
params, n_iter = fista_solver(
|
|
1701
|
+
self._loss, pen, X_work, y_arr,
|
|
1702
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
1703
|
+
init_coef=init, sample_weight=sample_weight,
|
|
1704
|
+
)
|
|
1705
|
+
elif solver_name == "fista":
|
|
1706
|
+
params, n_iter = fista_solver(
|
|
1707
|
+
self._loss, pen, X_work, y_arr,
|
|
1708
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
1709
|
+
init_coef=init, sample_weight=sample_weight,
|
|
1710
|
+
)
|
|
1711
|
+
elif solver_name == "fista_bb":
|
|
1712
|
+
params, n_iter = fista_bb_solver(
|
|
1713
|
+
self._loss, pen, X_work, y_arr,
|
|
1714
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
1715
|
+
init_coef=init, sample_weight=sample_weight,
|
|
1716
|
+
)
|
|
1717
|
+
elif solver_name == "admm":
|
|
1718
|
+
params, n_iter = admm_solver(
|
|
1719
|
+
self._loss, pen, X_work, y_arr,
|
|
1720
|
+
max_iter=self.max_iter,
|
|
1721
|
+
tol=self.tol, rho=1.0, adaptive_rho=True,
|
|
1722
|
+
init_coef=init, sample_weight=sample_weight,
|
|
1723
|
+
)
|
|
1724
|
+
elif solver_name == "newton":
|
|
1725
|
+
params, n_iter = newton_solver(
|
|
1726
|
+
self._loss, pen, X_work, y_arr,
|
|
1727
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
1728
|
+
init_coef=init, sample_weight=sample_weight,
|
|
1729
|
+
)
|
|
1730
|
+
elif solver_name == "lbfgs":
|
|
1731
|
+
params, n_iter = lbfgs_solver(
|
|
1732
|
+
self._loss, pen, X_work, y_arr,
|
|
1733
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
1734
|
+
init_coef=init, sample_weight=sample_weight,
|
|
1735
|
+
)
|
|
1736
|
+
else:
|
|
1737
|
+
raise ValueError(f"Unsupported solver: {solver_name}")
|
|
1738
|
+
|
|
1739
|
+
params_np = _to_numpy(params)
|
|
1740
|
+
self.n_iter_ = n_iter
|
|
1741
|
+
if self._effective_intercept:
|
|
1742
|
+
self.coef_ = params_np[:p]
|
|
1743
|
+
self.intercept_ = float(params_np[p])
|
|
1744
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
1745
|
+
else:
|
|
1746
|
+
self.coef_ = params_np.copy()
|
|
1747
|
+
self.intercept_ = 0.0
|
|
1748
|
+
self._params = self.coef_.copy()
|
|
1749
|
+
self._df_resid = self._nobs - (
|
|
1750
|
+
X_arr.shape[1] + (1 if self._effective_intercept else 0)
|
|
1751
|
+
)
|
|
1752
|
+
if backend_name == "cupy":
|
|
1753
|
+
self._cleanup_cuda_memory()
|
|
1754
|
+
elif backend_name == "torch":
|
|
1755
|
+
self._cleanup_torch_memory()
|
|
1756
|
+
|
|
1757
|
+
def _fit_irls_backend(self, X, y, sample_weight=None, backend_name="numpy"):
|
|
1758
|
+
"""Fit smooth L2 GLM via IRLS on the selected backend."""
|
|
1759
|
+
from statgpu.glm_core._irls import IRLSSolver
|
|
1760
|
+
|
|
1761
|
+
if str(getattr(self._penalty, "name", self.penalty)).lower() != "l2":
|
|
1762
|
+
raise ValueError("solver='irls' only supports L2 penalties.")
|
|
1763
|
+
|
|
1764
|
+
from statgpu.backends._utils import _get_xp, xp_asarray
|
|
1765
|
+
_xp = _get_xp(backend_name)
|
|
1766
|
+
X_arr = xp_asarray(X, dtype=_xp.float64, xp=_xp, ref_arr=X if not isinstance(X, np.ndarray) else np.zeros(1))
|
|
1767
|
+
y_arr = xp_asarray(y, dtype=_xp.float64, xp=_xp, ref_arr=X_arr)
|
|
1768
|
+
n_samples = X_arr.shape[0]
|
|
1769
|
+
if self._effective_intercept:
|
|
1770
|
+
X_work = self._column_stack(
|
|
1771
|
+
[self._ones(X_arr.shape[0], backend_name, X_arr), X_arr],
|
|
1772
|
+
backend_name,
|
|
1773
|
+
)
|
|
1774
|
+
else:
|
|
1775
|
+
X_work = X_arr
|
|
1776
|
+
|
|
1777
|
+
# Respect CV warm starts first. IRLS uses [intercept, coef...] while
|
|
1778
|
+
# the FISTA design stores the intercept as the final column.
|
|
1779
|
+
_loss_name = getattr(self._loss, 'name', '')
|
|
1780
|
+
init_coef = None
|
|
1781
|
+
init_features = getattr(self, '_init_coef', None)
|
|
1782
|
+
if init_features is not None:
|
|
1783
|
+
init_features_np = np.asarray(init_features, dtype=np.float64).ravel()
|
|
1784
|
+
if self._effective_intercept:
|
|
1785
|
+
init_intercept = float(getattr(self, '_init_intercept', 0.0) or 0.0)
|
|
1786
|
+
init_coef_np = np.concatenate([[init_intercept], init_features_np])
|
|
1787
|
+
else:
|
|
1788
|
+
init_coef_np = init_features_np
|
|
1789
|
+
if backend_name == "cupy":
|
|
1790
|
+
import cupy as cp
|
|
1791
|
+
init_coef = cp.asarray(init_coef_np, dtype=cp.float64)
|
|
1792
|
+
elif backend_name == "torch":
|
|
1793
|
+
import torch
|
|
1794
|
+
init_coef = torch.as_tensor(
|
|
1795
|
+
init_coef_np,
|
|
1796
|
+
dtype=torch.float64,
|
|
1797
|
+
device=X_work.device,
|
|
1798
|
+
)
|
|
1799
|
+
else:
|
|
1800
|
+
init_coef = init_coef_np
|
|
1801
|
+
|
|
1802
|
+
# Otherwise warm-start intercept for GLM losses whose default eta=0
|
|
1803
|
+
# can be far from the intercept-only optimum.
|
|
1804
|
+
_log_link_losses = ("gamma", "poisson", "inverse_gaussian",
|
|
1805
|
+
"negative_binomial", "tweedie")
|
|
1806
|
+
if init_coef is None and self._effective_intercept and (
|
|
1807
|
+
_loss_name in _log_link_losses or _loss_name == "logistic"
|
|
1808
|
+
):
|
|
1809
|
+
_y_mean = float(np.mean(_to_numpy(y_arr)))
|
|
1810
|
+
if _loss_name == "logistic":
|
|
1811
|
+
_y_mean = float(np.clip(_y_mean, 1e-3, 1.0 - 1e-3))
|
|
1812
|
+
_int_init = np.log(_y_mean / (1.0 - _y_mean))
|
|
1813
|
+
else:
|
|
1814
|
+
_int_init = np.log(max(_y_mean, 1e-3))
|
|
1815
|
+
n_feat = X_work.shape[1]
|
|
1816
|
+
init_coef_np = np.zeros(n_feat)
|
|
1817
|
+
init_coef_np[0] = _int_init
|
|
1818
|
+
if backend_name == "cupy":
|
|
1819
|
+
import cupy as cp
|
|
1820
|
+
init_coef = cp.asarray(init_coef_np)
|
|
1821
|
+
elif backend_name == "torch":
|
|
1822
|
+
import torch
|
|
1823
|
+
init_coef = torch.from_numpy(init_coef_np).to(X_work.device)
|
|
1824
|
+
else:
|
|
1825
|
+
init_coef = init_coef_np
|
|
1826
|
+
|
|
1827
|
+
solver = IRLSSolver(
|
|
1828
|
+
self._family_for_loss(), max_iter=self.max_iter, tol=self.tol
|
|
1829
|
+
)
|
|
1830
|
+
params, n_iter = solver.fit(
|
|
1831
|
+
X_work, y_arr,
|
|
1832
|
+
sample_weight=sample_weight,
|
|
1833
|
+
ridge_alpha=float(n_samples * self.alpha),
|
|
1834
|
+
ridge_penalize_intercept=False if self._effective_intercept else True,
|
|
1835
|
+
backend=backend_name,
|
|
1836
|
+
init_coef=init_coef,
|
|
1837
|
+
)
|
|
1838
|
+
|
|
1839
|
+
params_np = _to_numpy(params)
|
|
1840
|
+
self.n_iter_ = n_iter
|
|
1841
|
+
if self._effective_intercept:
|
|
1842
|
+
self.intercept_ = float(params_np[0])
|
|
1843
|
+
self.coef_ = params_np[1:]
|
|
1844
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
1845
|
+
else:
|
|
1846
|
+
self.intercept_ = 0.0
|
|
1847
|
+
self.coef_ = params_np.copy()
|
|
1848
|
+
self._params = self.coef_.copy()
|
|
1849
|
+
self._df_resid = self._nobs - (
|
|
1850
|
+
X_arr.shape[1] + (1 if self._effective_intercept else 0)
|
|
1851
|
+
)
|
|
1852
|
+
if backend_name == "cupy":
|
|
1853
|
+
self._cleanup_cuda_memory()
|
|
1854
|
+
elif backend_name == "torch":
|
|
1855
|
+
self._cleanup_torch_memory()
|
|
1856
|
+
|
|
1857
|
+
def _cleanup_cuda_memory(self):
|
|
1858
|
+
"""Free CuPy memory pool."""
|
|
1859
|
+
if not self.gpu_memory_cleanup:
|
|
1860
|
+
return
|
|
1861
|
+
try:
|
|
1862
|
+
import cupy as cp
|
|
1863
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
1864
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
1865
|
+
except Exception:
|
|
1866
|
+
pass
|
|
1867
|
+
|
|
1868
|
+
def _cleanup_torch_memory(self):
|
|
1869
|
+
"""Free Torch memory pool."""
|
|
1870
|
+
if not self.gpu_memory_cleanup:
|
|
1871
|
+
return
|
|
1872
|
+
try:
|
|
1873
|
+
import torch
|
|
1874
|
+
if torch.cuda.is_available():
|
|
1875
|
+
torch.cuda.empty_cache()
|
|
1876
|
+
except Exception:
|
|
1877
|
+
pass
|