statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1174 @@
|
|
|
1
|
+
"""Legacy solver/CD methods from _penalized.py.
|
|
2
|
+
|
|
3
|
+
These methods were replaced by newer implementations (FISTA, _fit_loss_backend)
|
|
4
|
+
but are retained for reference and backward compatibility.
|
|
5
|
+
|
|
6
|
+
DO NOT import or use in production code.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
# Methods extracted from PenalizedGeneralizedLinearModel:
|
|
14
|
+
def _irls_cd_gpu(self, pen, X_work, y_arr, init, backend_name, _lla_continuation=False):
|
|
15
|
+
"""GPU-native IRLS with coordinate descent for GLM + non-smooth penalties.
|
|
16
|
+
|
|
17
|
+
Same algorithm as _irls_cd but keeps all arrays on GPU to avoid
|
|
18
|
+
CPU-GPU transfer overhead. Supports cupy and torch backends.
|
|
19
|
+
"""
|
|
20
|
+
from statgpu.backends._array_ops import _xp_copy, _xp_zeros, _xp_asarray
|
|
21
|
+
from statgpu.backends._utils import _get_xp
|
|
22
|
+
xp = _get_xp(backend_name)
|
|
23
|
+
|
|
24
|
+
n, pp = X_work.shape
|
|
25
|
+
p = pp - 1 if self._effective_intercept else pp
|
|
26
|
+
|
|
27
|
+
# Access weights from the original penalty
|
|
28
|
+
_inner = getattr(self, '_penalty', pen)
|
|
29
|
+
_w_np = np.asarray(getattr(_inner, '_weights', np.ones(p)), dtype=float)
|
|
30
|
+
_w = _xp_asarray(_w_np, X_work.dtype, X_work)
|
|
31
|
+
alpha = float(getattr(_inner, 'alpha', self.alpha))
|
|
32
|
+
pen_name = getattr(pen, 'name', '') or getattr(_inner, 'name', '')
|
|
33
|
+
|
|
34
|
+
# SCAD/MCP parameters (guard against division-by-zero)
|
|
35
|
+
a_scad = float(getattr(_inner, 'a', 3.7)) if pen_name == "scad" else 0.0
|
|
36
|
+
if pen_name == "scad":
|
|
37
|
+
a_scad = max(a_scad, 1.0 + 1e-6)
|
|
38
|
+
if abs(a_scad - 2.0) < 1e-6:
|
|
39
|
+
a_scad = 2.0 + 1e-6
|
|
40
|
+
gamma_mcp = float(getattr(_inner, 'gamma', 3.0)) if pen_name == "mcp" else 0.0
|
|
41
|
+
if pen_name == "mcp":
|
|
42
|
+
gamma_mcp = max(gamma_mcp, 1.0 + 1e-6)
|
|
43
|
+
|
|
44
|
+
# Penalty value helper (uses numpy for portability; coef_slice is numpy)
|
|
45
|
+
def _nonconvex_penalty_value(coef_slice, _pen_name, _alpha, _a_scad, _gamma_mcp):
|
|
46
|
+
_abs_b = np.abs(coef_slice)
|
|
47
|
+
if _pen_name == "scad":
|
|
48
|
+
return float(np.sum(np.where(
|
|
49
|
+
_abs_b <= _alpha, _alpha * _abs_b,
|
|
50
|
+
np.where(_abs_b <= _a_scad * _alpha,
|
|
51
|
+
(_a_scad * _alpha * _abs_b - 0.5 * (coef_slice**2 + _alpha**2)) / (_a_scad - 1.0),
|
|
52
|
+
0.5 * (_a_scad + 1.0) * _alpha**2))))
|
|
53
|
+
if _pen_name == "mcp":
|
|
54
|
+
return float(np.sum(np.where(
|
|
55
|
+
_abs_b <= _gamma_mcp * _alpha,
|
|
56
|
+
_alpha * _abs_b - 0.5 * coef_slice**2 / _gamma_mcp,
|
|
57
|
+
0.5 * _gamma_mcp * _alpha**2)))
|
|
58
|
+
return 0.0
|
|
59
|
+
|
|
60
|
+
if init is not None:
|
|
61
|
+
if isinstance(init, np.ndarray):
|
|
62
|
+
beta = _xp_asarray(init, X_work.dtype, X_work)
|
|
63
|
+
else:
|
|
64
|
+
beta = _xp_copy(init)
|
|
65
|
+
else:
|
|
66
|
+
beta = _xp_zeros(pp, X_work.dtype, X_work)
|
|
67
|
+
|
|
68
|
+
loss_name = self._loss.name
|
|
69
|
+
_is_glm = (loss_name != "squared_error")
|
|
70
|
+
|
|
71
|
+
# Continuation path for SCAD/MCP
|
|
72
|
+
_cont_path = [alpha]
|
|
73
|
+
if pen_name in ("scad", "mcp") and not _lla_continuation:
|
|
74
|
+
_y_np = _to_numpy(y_arr)
|
|
75
|
+
if loss_name == "logistic":
|
|
76
|
+
_p0 = np.clip(np.mean(_y_np), 1e-3, 1 - 1e-3)
|
|
77
|
+
_resid = _y_np - _p0
|
|
78
|
+
elif loss_name == "poisson":
|
|
79
|
+
_mu0 = max(float(np.mean(_y_np)), 1e-3)
|
|
80
|
+
_resid = _y_np - _mu0
|
|
81
|
+
elif loss_name == "gamma":
|
|
82
|
+
_mu0 = max(float(np.mean(_y_np)), 1e-3)
|
|
83
|
+
_resid = (_y_np - _mu0) / _mu0
|
|
84
|
+
else:
|
|
85
|
+
_resid = _y_np - np.mean(_y_np)
|
|
86
|
+
_X_np = _to_numpy(X_work)
|
|
87
|
+
_xty = np.abs(_X_np[:, :p].T @ _resid)
|
|
88
|
+
_xnorm_sq = np.sum(_X_np[:, :p] ** 2, axis=0)
|
|
89
|
+
_xnorm_sq = np.maximum(_xnorm_sq, 1e-20)
|
|
90
|
+
_lam_max = float(np.max(_xty / _xnorm_sq))
|
|
91
|
+
if _lam_max > alpha * 1.1:
|
|
92
|
+
_n_cont = 100
|
|
93
|
+
_cont_path = np.geomspace(_lam_max, alpha, _n_cont)
|
|
94
|
+
|
|
95
|
+
_n_cd_sweeps_base = 1 if _is_glm else min(self.max_iter, 200)
|
|
96
|
+
_n_outer_base = self.max_iter if _is_glm else 1
|
|
97
|
+
|
|
98
|
+
# Precompute X^T X diagonal for squared_error
|
|
99
|
+
if not _is_glm:
|
|
100
|
+
d = _xp_zeros((n,), X_work.dtype, X_work) + 1.0 # ones on correct device
|
|
101
|
+
z = y_arr
|
|
102
|
+
XDX_diag = xp.sum(d[:, None] * X_work ** 2, axis=0)
|
|
103
|
+
|
|
104
|
+
for _cont_idx, _cont_alpha in enumerate(_cont_path):
|
|
105
|
+
if len(_cont_path) > 1:
|
|
106
|
+
alpha = float(_cont_alpha)
|
|
107
|
+
_is_last = (_cont_idx == len(_cont_path) - 1)
|
|
108
|
+
_n_cd_sweeps = _n_cd_sweeps_base if _is_last else 20
|
|
109
|
+
if _is_glm:
|
|
110
|
+
_n_outer = _n_outer_base if _is_last else min(20, _n_outer_base)
|
|
111
|
+
else:
|
|
112
|
+
_n_outer = _n_outer_base
|
|
113
|
+
else:
|
|
114
|
+
_n_cd_sweeps = _n_cd_sweeps_base
|
|
115
|
+
_n_outer = _n_outer_base
|
|
116
|
+
|
|
117
|
+
it = -1
|
|
118
|
+
for it in range(_n_outer):
|
|
119
|
+
beta_old = beta.clone() if backend_name == "torch" else beta.copy()
|
|
120
|
+
|
|
121
|
+
# Compute objective before CD for step-halving (GLM only)
|
|
122
|
+
_obj_before = None
|
|
123
|
+
if _is_glm:
|
|
124
|
+
try:
|
|
125
|
+
_obj_before = float(xp.sum(self._loss.per_sample_value(X_work, y_arr, beta_old)))
|
|
126
|
+
_obj_before += _nonconvex_penalty_value(
|
|
127
|
+
_to_numpy(beta_old[:p]) if backend_name != "numpy" else beta_old[:p],
|
|
128
|
+
pen_name, alpha, a_scad, gamma_mcp)
|
|
129
|
+
except Exception:
|
|
130
|
+
_obj_before = None
|
|
131
|
+
|
|
132
|
+
if _is_glm:
|
|
133
|
+
eta = X_work @ beta
|
|
134
|
+
if loss_name == "logistic":
|
|
135
|
+
mu = 1.0 / (1.0 + _exp(-_clip(eta, -500, 500)))
|
|
136
|
+
mu = _clip(mu, 1e-15, 1.0 - 1e-15)
|
|
137
|
+
d = mu * (1.0 - mu)
|
|
138
|
+
z = eta + (y_arr - mu) / d
|
|
139
|
+
elif loss_name == "poisson":
|
|
140
|
+
mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
|
|
141
|
+
d = mu
|
|
142
|
+
z = eta + (y_arr - mu) / d
|
|
143
|
+
elif loss_name == "gamma":
|
|
144
|
+
mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
|
|
145
|
+
d = _xp_zeros((n,), X_work.dtype, X_work) + 1.0
|
|
146
|
+
z = eta + (y_arr - mu) / mu
|
|
147
|
+
elif loss_name == "inverse_gaussian":
|
|
148
|
+
mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
|
|
149
|
+
d = 1.0 / mu
|
|
150
|
+
z = eta + (y_arr - mu) / mu
|
|
151
|
+
elif loss_name == "negative_binomial":
|
|
152
|
+
mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
|
|
153
|
+
theta_nb = float(getattr(self._loss, 'alpha', 1.0))
|
|
154
|
+
d = mu / (1.0 + mu / theta_nb)
|
|
155
|
+
z = eta + (y_arr - mu) / d
|
|
156
|
+
elif loss_name == "tweedie":
|
|
157
|
+
mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
|
|
158
|
+
tweedie_p = float(getattr(self._loss, 'power', 1.5))
|
|
159
|
+
d = mu ** tweedie_p
|
|
160
|
+
d = _clip(d, 1e-15, None)
|
|
161
|
+
z = eta + (y_arr - mu) / (d * mu)
|
|
162
|
+
else:
|
|
163
|
+
grad = self._loss.gradient(X_work, y_arr, beta)
|
|
164
|
+
d = _xp_zeros((n,), X_work.dtype, X_work) + 1.0
|
|
165
|
+
z = eta - grad * n
|
|
166
|
+
XDX_diag = xp.sum(d[:, None] * X_work ** 2, axis=0)
|
|
167
|
+
|
|
168
|
+
# Effective sample size for correct normalization with sample weights
|
|
169
|
+
n_eff = float(xp.sum(d))
|
|
170
|
+
|
|
171
|
+
r = z - X_work @ beta
|
|
172
|
+
|
|
173
|
+
# Precompute active mask and vectorized penalty weights
|
|
174
|
+
_active = XDX_diag >= 1e-20
|
|
175
|
+
_v_all = XDX_diag / n_eff
|
|
176
|
+
_v_safe = xp.where(_active, _v_all, 1.0) # avoid division by zero
|
|
177
|
+
if pen_name in ("adaptive_l1", "adaptive_lasso"):
|
|
178
|
+
_l1_all = alpha * _w # shape (p,)
|
|
179
|
+
|
|
180
|
+
for _cd in range(_n_cd_sweeps):
|
|
181
|
+
# --- Vectorized block coordinate descent ---
|
|
182
|
+
# 1. Batch gradient: rho_all = X' (d * r) + XDX_diag * beta
|
|
183
|
+
rho_all = X_work.T @ (d * r) + XDX_diag * beta
|
|
184
|
+
w_all = rho_all / (n_eff * _v_safe) # un-penalized solution
|
|
185
|
+
|
|
186
|
+
# 2. Save old beta for residual update
|
|
187
|
+
old_beta = beta
|
|
188
|
+
|
|
189
|
+
# 3. Vectorized thresholding (penalty-specific)
|
|
190
|
+
if self._effective_intercept:
|
|
191
|
+
new_beta = xp.zeros_like(beta)
|
|
192
|
+
w_feat = w_all[:p]
|
|
193
|
+
else:
|
|
194
|
+
w_feat = w_all
|
|
195
|
+
new_beta = xp.zeros_like(beta)
|
|
196
|
+
|
|
197
|
+
if pen_name in ("adaptive_l1", "adaptive_lasso"):
|
|
198
|
+
aw = xp.abs(w_feat)
|
|
199
|
+
new_beta_feat = xp.sign(w_feat) * xp.maximum(aw - _l1_all, 0.0)
|
|
200
|
+
elif pen_name == "scad":
|
|
201
|
+
aw = xp.abs(w_feat)
|
|
202
|
+
l1 = alpha
|
|
203
|
+
new_beta_feat = xp.where(
|
|
204
|
+
aw > a_scad * l1, w_feat,
|
|
205
|
+
xp.where(
|
|
206
|
+
aw > l1,
|
|
207
|
+
xp.sign(w_feat) * ((a_scad - 1.0) * aw - a_scad * l1) / (a_scad - 2.0),
|
|
208
|
+
0.0,
|
|
209
|
+
),
|
|
210
|
+
)
|
|
211
|
+
elif pen_name == "mcp":
|
|
212
|
+
aw = xp.abs(w_feat)
|
|
213
|
+
l1 = alpha
|
|
214
|
+
new_beta_feat = xp.where(
|
|
215
|
+
aw > gamma_mcp * l1, w_feat,
|
|
216
|
+
xp.where(
|
|
217
|
+
aw > l1,
|
|
218
|
+
xp.sign(w_feat) * (aw - l1) / (1.0 - 1.0 / gamma_mcp),
|
|
219
|
+
0.0,
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
# lasso / elasticnet (pure L1)
|
|
224
|
+
aw = xp.abs(w_feat)
|
|
225
|
+
new_beta_feat = xp.sign(w_feat) * xp.maximum(aw - alpha, 0.0)
|
|
226
|
+
|
|
227
|
+
# Zero out degenerate columns
|
|
228
|
+
if self._effective_intercept:
|
|
229
|
+
new_beta[:p] = new_beta_feat * _active[:p]
|
|
230
|
+
new_beta[p:] = w_all[p:] # intercept: no penalty
|
|
231
|
+
else:
|
|
232
|
+
new_beta = new_beta_feat * _active
|
|
233
|
+
|
|
234
|
+
# 4. Residual update (single matvec instead of p dot products)
|
|
235
|
+
delta = new_beta - old_beta
|
|
236
|
+
r = r - X_work @ delta
|
|
237
|
+
beta = new_beta
|
|
238
|
+
|
|
239
|
+
# 5. Convergence check (single GPU reduction + one sync)
|
|
240
|
+
_max_cd_change = float(xp.max(xp.abs(delta)))
|
|
241
|
+
|
|
242
|
+
if not _is_glm and _max_cd_change < self.tol:
|
|
243
|
+
break
|
|
244
|
+
|
|
245
|
+
# Step-halving for GLM: ensure penalized objective decreases.
|
|
246
|
+
# Mirrors the CPU path (_irls_cd) to prevent IRLS overshooting.
|
|
247
|
+
if _is_glm:
|
|
248
|
+
_obj_after = float(xp.sum(self._loss.per_sample_value(X_work, y_arr, beta)))
|
|
249
|
+
_obj_after += _nonconvex_penalty_value(
|
|
250
|
+
_to_numpy(beta[:p]) if backend_name != "numpy" else beta[:p],
|
|
251
|
+
pen_name, alpha, a_scad, gamma_mcp)
|
|
252
|
+
if _obj_before is not None and _obj_after > _obj_before + 1e-10:
|
|
253
|
+
beta_new_gpu = beta.clone() if backend_name == "torch" else beta.copy()
|
|
254
|
+
for _sh in range(1, 11):
|
|
255
|
+
_frac = 0.5 ** _sh
|
|
256
|
+
beta_sh = beta_old + _frac * (beta_new_gpu - beta_old)
|
|
257
|
+
_obj_sh = float(xp.sum(self._loss.per_sample_value(X_work, y_arr, beta_sh)))
|
|
258
|
+
_obj_sh += _nonconvex_penalty_value(
|
|
259
|
+
_to_numpy(beta_sh[:p]) if backend_name != "numpy" else beta_sh[:p],
|
|
260
|
+
pen_name, alpha, a_scad, gamma_mcp)
|
|
261
|
+
if _obj_sh <= _obj_before + 1e-10:
|
|
262
|
+
beta = beta_sh
|
|
263
|
+
break
|
|
264
|
+
else:
|
|
265
|
+
# All step-halving attempts failed — revert to previous iterate
|
|
266
|
+
beta = beta_old
|
|
267
|
+
|
|
268
|
+
# IRLS-level convergence check
|
|
269
|
+
_delta = float(xp.max(xp.abs(beta[:p] - beta_old[:p])))
|
|
270
|
+
if not _is_glm and _delta < self.tol:
|
|
271
|
+
break
|
|
272
|
+
if _is_glm and len(_cont_path) > 1 and not _is_last:
|
|
273
|
+
if _delta < self.tol * 10:
|
|
274
|
+
break
|
|
275
|
+
|
|
276
|
+
n_iter = it + 1 if _n_outer > 0 else 0
|
|
277
|
+
return beta, n_iter
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _block_cd_group_lasso_gpu_batched(self, pen, X_work, y_arr, init, backend_name):
|
|
281
|
+
"""Batched GPU block coordinate descent for group_lasso penalty.
|
|
282
|
+
|
|
283
|
+
Processes all groups in parallel within each iteration to minimize
|
|
284
|
+
kernel launch overhead. Groups of the same size are batched together
|
|
285
|
+
for efficient linear solves.
|
|
286
|
+
"""
|
|
287
|
+
from statgpu.backends._array_ops import _xp_copy, _xp_zeros, _xp_asarray, _scalar_tensor
|
|
288
|
+
from statgpu.backends._utils import _get_xp
|
|
289
|
+
xp = _get_xp(backend_name)
|
|
290
|
+
|
|
291
|
+
n, pp = X_work.shape
|
|
292
|
+
p = pp - 1 if self._effective_intercept else pp
|
|
293
|
+
alpha = self.alpha
|
|
294
|
+
|
|
295
|
+
_inner = getattr(self, '_penalty', pen)
|
|
296
|
+
_g_indices = getattr(_inner, '_group_indices', None)
|
|
297
|
+
_sqrt_pg_np = getattr(_inner, '_sqrt_pg', None)
|
|
298
|
+
if _g_indices is None or _sqrt_pg_np is None:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
"group_lasso penalty must have groups set. "
|
|
301
|
+
"Pass groups=... in penalty_kwargs."
|
|
302
|
+
)
|
|
303
|
+
_n_groups = len(_g_indices)
|
|
304
|
+
_sqrt_pg = [float(s) for s in _sqrt_pg_np]
|
|
305
|
+
|
|
306
|
+
# Pre-compute XtX and Xty once
|
|
307
|
+
XtX = X_work.T @ X_work / n
|
|
308
|
+
Xty = (X_work.T @ y_arr.flatten()) / n
|
|
309
|
+
|
|
310
|
+
# Pre-compute XtX blocks for each group
|
|
311
|
+
_XtX_blocks = []
|
|
312
|
+
for g_idx in _g_indices:
|
|
313
|
+
_XtX_blocks.append(XtX[g_idx][:, g_idx])
|
|
314
|
+
|
|
315
|
+
# Group indices by size for batched solving
|
|
316
|
+
_size_groups = {} # size -> list of (group_idx, indices)
|
|
317
|
+
for g, g_idx in enumerate(_g_indices):
|
|
318
|
+
sz = len(g_idx)
|
|
319
|
+
if sz not in _size_groups:
|
|
320
|
+
_size_groups[sz] = []
|
|
321
|
+
_size_groups[sz].append((g, g_idx))
|
|
322
|
+
|
|
323
|
+
if init is not None:
|
|
324
|
+
if isinstance(init, np.ndarray):
|
|
325
|
+
coef = _xp_asarray(init, X_work.dtype, X_work)
|
|
326
|
+
else:
|
|
327
|
+
coef = _xp_copy(init)
|
|
328
|
+
else:
|
|
329
|
+
coef = _xp_zeros(pp, X_work.dtype, X_work)
|
|
330
|
+
|
|
331
|
+
iteration = -1 # ensure defined when max_iter=0
|
|
332
|
+
for iteration in range(self.max_iter):
|
|
333
|
+
coef_old = _xp_copy(coef)
|
|
334
|
+
|
|
335
|
+
# Process groups by size for batched solving
|
|
336
|
+
for sz, size_groups in _size_groups.items():
|
|
337
|
+
n_batch = len(size_groups)
|
|
338
|
+
if n_batch == 0:
|
|
339
|
+
continue
|
|
340
|
+
|
|
341
|
+
# Collect indices for all groups of this size
|
|
342
|
+
all_indices = []
|
|
343
|
+
batch_g_indices = []
|
|
344
|
+
for g, g_idx in size_groups:
|
|
345
|
+
all_indices.extend(g_idx)
|
|
346
|
+
batch_g_indices.append(g)
|
|
347
|
+
|
|
348
|
+
# Compute rho_g for all groups of this size in one shot
|
|
349
|
+
# rho_g = Xty[g_idx] - XtX[g_idx, :] @ coef + XtX_block[g] @ coef[g_idx]
|
|
350
|
+
# Stack all indices for batched indexing
|
|
351
|
+
idx_arr = _xp_asarray(all_indices, xp.int32 if backend_name == "cupy" else None, X_work)
|
|
352
|
+
# Compute XtX[g_idx, :] @ coef for all groups at once
|
|
353
|
+
XtX_coef = XtX[idx_arr, :] @ coef # shape: (n_batch * sz,)
|
|
354
|
+
# Compute Xty for all groups
|
|
355
|
+
Xty_all = Xty[idx_arr]
|
|
356
|
+
# Compute block diagonal contributions
|
|
357
|
+
block_contrib = _xp_zeros(Xty_all.shape, Xty_all.dtype, Xty_all)
|
|
358
|
+
for i, (g, g_idx) in enumerate(size_groups):
|
|
359
|
+
block_contrib[i*sz:(i+1)*sz] = _XtX_blocks[g] @ coef[g_idx]
|
|
360
|
+
# rho_g = Xty - XtX_coef + block_contrib
|
|
361
|
+
rho_all = Xty_all - XtX_coef + block_contrib
|
|
362
|
+
|
|
363
|
+
# Solve all group systems in one batched call
|
|
364
|
+
rho_mat = rho_all.reshape(n_batch, sz, 1)
|
|
365
|
+
XtX_batch = xp.stack([_XtX_blocks[g] for g in batch_g_indices])
|
|
366
|
+
try:
|
|
367
|
+
w_all = xp.linalg.solve(XtX_batch, rho_mat) # (n_batch, sz, 1)
|
|
368
|
+
w_all = w_all.reshape(n_batch, sz)
|
|
369
|
+
except Exception:
|
|
370
|
+
w_all = _xp_zeros((n_batch, sz), X_work.dtype, X_work)
|
|
371
|
+
|
|
372
|
+
# Apply soft-thresholding to all groups at once
|
|
373
|
+
_norm_dim = 1 # axis for numpy/cupy, dim for torch (both use 1)
|
|
374
|
+
norms = xp.linalg.norm(w_all, axis=_norm_dim) # (n_batch,)
|
|
375
|
+
thresh = _xp_asarray(
|
|
376
|
+
[alpha * _sqrt_pg[g] for g in batch_g_indices],
|
|
377
|
+
X_work.dtype, X_work,
|
|
378
|
+
)
|
|
379
|
+
scale = xp.where(norms > thresh, 1.0 - thresh / (norms + 1e-12), 0.0)
|
|
380
|
+
|
|
381
|
+
# Write back coefficients
|
|
382
|
+
for i, (g, g_idx) in enumerate(size_groups):
|
|
383
|
+
coef[g_idx] = w_all[i] * scale[i]
|
|
384
|
+
|
|
385
|
+
if self._effective_intercept:
|
|
386
|
+
coef[pp - 1] = float(xp.mean(y_arr - X_work[:, :p] @ coef[:p]))
|
|
387
|
+
|
|
388
|
+
_max_change = float(xp.max(xp.abs(coef - coef_old)))
|
|
389
|
+
if _max_change < self.tol:
|
|
390
|
+
break
|
|
391
|
+
|
|
392
|
+
n_iter = iteration + 1
|
|
393
|
+
|
|
394
|
+
if self._effective_intercept:
|
|
395
|
+
beta = coef[:p]
|
|
396
|
+
intercept = float(coef[p])
|
|
397
|
+
else:
|
|
398
|
+
beta = coef
|
|
399
|
+
intercept = 0.0
|
|
400
|
+
|
|
401
|
+
return beta, intercept, n_iter
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _cd_elasticnet(self, pen, X_work, y_arr, init):
|
|
405
|
+
"""Coordinate descent for elasticnet penalty (squared_error loss).
|
|
406
|
+
|
|
407
|
+
Matches R glmnet's CD algorithm for elasticnet:
|
|
408
|
+
beta_j = S(rho_j, alpha*l1_ratio*n) / (X_j'X_j + alpha*(1-l1_ratio)*n)
|
|
409
|
+
"""
|
|
410
|
+
import numpy as np
|
|
411
|
+
|
|
412
|
+
n, pp = X_work.shape
|
|
413
|
+
p = pp - 1 if self._effective_intercept else pp
|
|
414
|
+
alpha = self.alpha
|
|
415
|
+
l1_ratio = getattr(pen, 'l1_ratio', getattr(self, 'l1_ratio', 0.5))
|
|
416
|
+
|
|
417
|
+
XtX = X_work.T @ X_work
|
|
418
|
+
Xty = X_work.T @ y_arr.flatten()
|
|
419
|
+
X_sq_norms = np.diag(XtX)
|
|
420
|
+
|
|
421
|
+
if init is not None:
|
|
422
|
+
coef = np.array(init, dtype=np.float64)
|
|
423
|
+
else:
|
|
424
|
+
coef = np.zeros(pp, dtype=np.float64)
|
|
425
|
+
|
|
426
|
+
thresh = alpha * l1_ratio * n
|
|
427
|
+
|
|
428
|
+
iteration = -1 # ensure defined when max_iter=0
|
|
429
|
+
for iteration in range(self.max_iter):
|
|
430
|
+
coef_old = coef.copy()
|
|
431
|
+
|
|
432
|
+
for j in range(p):
|
|
433
|
+
rho_j = Xty[j] - np.dot(XtX[j, :], coef) + XtX[j, j] * coef[j]
|
|
434
|
+
if X_sq_norms[j] > 1e-10:
|
|
435
|
+
st = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0)
|
|
436
|
+
coef[j] = st / (X_sq_norms[j] + alpha * (1 - l1_ratio) * n)
|
|
437
|
+
else:
|
|
438
|
+
coef[j] = 0.0
|
|
439
|
+
|
|
440
|
+
if self._effective_intercept:
|
|
441
|
+
coef[pp - 1] = np.mean(y_arr - X_work[:, :p] @ coef[:p])
|
|
442
|
+
|
|
443
|
+
if np.max(np.abs(coef - coef_old)) < self.tol:
|
|
444
|
+
break
|
|
445
|
+
|
|
446
|
+
n_iter = iteration + 1
|
|
447
|
+
|
|
448
|
+
if self._effective_intercept:
|
|
449
|
+
beta = coef[:p]
|
|
450
|
+
intercept = float(coef[p])
|
|
451
|
+
else:
|
|
452
|
+
beta = coef
|
|
453
|
+
intercept = 0.0
|
|
454
|
+
|
|
455
|
+
return beta, intercept, n_iter
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _cd_l1(self, pen, X_work, y_arr, init):
|
|
459
|
+
"""Coordinate descent for L1 (lasso) penalty (squared_error loss).
|
|
460
|
+
|
|
461
|
+
Matches R glmnet's CD algorithm:
|
|
462
|
+
beta_j = S(rho_j, alpha*n) / X_j'X_j
|
|
463
|
+
"""
|
|
464
|
+
import numpy as np
|
|
465
|
+
|
|
466
|
+
n, pp = X_work.shape
|
|
467
|
+
p = pp - 1 if self._effective_intercept else pp
|
|
468
|
+
alpha = self.alpha
|
|
469
|
+
|
|
470
|
+
XtX = X_work.T @ X_work
|
|
471
|
+
Xty = X_work.T @ y_arr.flatten()
|
|
472
|
+
X_sq_norms = np.diag(XtX)
|
|
473
|
+
|
|
474
|
+
if init is not None:
|
|
475
|
+
coef = np.array(init, dtype=np.float64)
|
|
476
|
+
else:
|
|
477
|
+
coef = np.zeros(pp, dtype=np.float64)
|
|
478
|
+
|
|
479
|
+
thresh = alpha * n
|
|
480
|
+
|
|
481
|
+
iteration = -1 # ensure defined when max_iter=0
|
|
482
|
+
for iteration in range(self.max_iter):
|
|
483
|
+
coef_old = coef.copy()
|
|
484
|
+
|
|
485
|
+
for j in range(p):
|
|
486
|
+
rho_j = Xty[j] - np.dot(XtX[j, :], coef) + XtX[j, j] * coef[j]
|
|
487
|
+
if X_sq_norms[j] > 1e-10:
|
|
488
|
+
coef[j] = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0) / X_sq_norms[j]
|
|
489
|
+
else:
|
|
490
|
+
coef[j] = 0.0
|
|
491
|
+
|
|
492
|
+
if self._effective_intercept:
|
|
493
|
+
coef[pp - 1] = np.mean(y_arr - X_work[:, :p] @ coef[:p])
|
|
494
|
+
|
|
495
|
+
if np.max(np.abs(coef - coef_old)) < self.tol:
|
|
496
|
+
break
|
|
497
|
+
|
|
498
|
+
n_iter = iteration + 1
|
|
499
|
+
|
|
500
|
+
if self._effective_intercept:
|
|
501
|
+
beta = coef[:p]
|
|
502
|
+
intercept = float(coef[p])
|
|
503
|
+
else:
|
|
504
|
+
beta = coef
|
|
505
|
+
intercept = 0.0
|
|
506
|
+
|
|
507
|
+
return beta, intercept, n_iter
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _fit_cpu_loss(self, X, y, sample_weight=None, solver="fista"):
|
|
511
|
+
"""Fit using loss-aware solver (FISTA with arbitrary loss).
|
|
512
|
+
|
|
513
|
+
For GLM losses (logistic, poisson) with intercept, augments X with
|
|
514
|
+
a column of ones and uses a selective penalty (no penalty on intercept)
|
|
515
|
+
to converge to the correct joint optimum.
|
|
516
|
+
"""
|
|
517
|
+
from statgpu.solvers import fista_solver
|
|
518
|
+
|
|
519
|
+
X_arr = np.asarray(X)
|
|
520
|
+
y_arr = np.asarray(y)
|
|
521
|
+
|
|
522
|
+
if self.loss in ("logistic", "poisson") and self._effective_intercept:
|
|
523
|
+
# Augment X with intercept column
|
|
524
|
+
X_aug = np.column_stack([X_arr, np.ones(X_arr.shape[0])])
|
|
525
|
+
p = X_arr.shape[1]
|
|
526
|
+
pen = self._penalty
|
|
527
|
+
|
|
528
|
+
from statgpu.linear_model.penalized._base import SelectivePenalty
|
|
529
|
+
singleton = SelectivePenalty()
|
|
530
|
+
singleton.configure(self._penalty, p, "numpy")
|
|
531
|
+
|
|
532
|
+
full_coef, n_iter = fista_solver(
|
|
533
|
+
self._loss, singleton, X_aug, y_arr,
|
|
534
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
535
|
+
init_coef=None, sample_weight=sample_weight,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
self.coef_ = full_coef[:p]
|
|
539
|
+
self.intercept_ = float(full_coef[p])
|
|
540
|
+
self.n_iter_ = n_iter
|
|
541
|
+
elif self._effective_intercept:
|
|
542
|
+
# Squared error: center X and y, fit once
|
|
543
|
+
X_arr = X_arr - X_arr.mean(axis=0)
|
|
544
|
+
y_arr = y_arr - y_arr.mean()
|
|
545
|
+
|
|
546
|
+
coef, n_iter = fista_solver(
|
|
547
|
+
self._loss, self._penalty, X_arr, y_arr,
|
|
548
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
549
|
+
init_coef=None, sample_weight=sample_weight,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
self.coef_ = coef
|
|
553
|
+
self.n_iter_ = n_iter
|
|
554
|
+
self.intercept_ = float(np.mean(y) - np.mean(X, axis=0) @ coef)
|
|
555
|
+
else:
|
|
556
|
+
coef, n_iter = fista_solver(
|
|
557
|
+
self._loss, self._penalty, X_arr, y_arr,
|
|
558
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
559
|
+
init_coef=None, sample_weight=sample_weight,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
self.coef_ = coef
|
|
563
|
+
self.n_iter_ = n_iter
|
|
564
|
+
self.intercept_ = 0.0
|
|
565
|
+
|
|
566
|
+
self._df_resid = self._nobs - (X.shape[1] + (1 if self._effective_intercept else 0))
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def _fit_cpu_irls(self, X, y, sample_weight=None):
|
|
570
|
+
"""Fit using IRLS for smooth penalty + smooth loss (e.g., Logistic/Poisson + L2).
|
|
571
|
+
|
|
572
|
+
Each IRLS iteration:
|
|
573
|
+
1. Compute working response z and weights W
|
|
574
|
+
2. Solve: (X'WX + n*alpha*I) params = X'Wz
|
|
575
|
+
"""
|
|
576
|
+
from statgpu.glm_core._irls import IRLSSolver
|
|
577
|
+
from statgpu.glm_core._family import (
|
|
578
|
+
Binomial, Poisson, Gaussian, Gamma,
|
|
579
|
+
InverseGaussian, NegativeBinomial, Tweedie,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
X_arr = np.asarray(X)
|
|
583
|
+
y_arr = np.asarray(y)
|
|
584
|
+
n_samples = X_arr.shape[0]
|
|
585
|
+
|
|
586
|
+
# Add intercept column if needed
|
|
587
|
+
if self._effective_intercept:
|
|
588
|
+
X_arr = np.column_stack([np.ones(X_arr.shape[0]), X_arr])
|
|
589
|
+
|
|
590
|
+
# L2 penalty: for objective min loss/n + alpha*0.5*||w||^2,
|
|
591
|
+
# IRLS uses unnormalized X'WX, so ridge = n * alpha.
|
|
592
|
+
# Don't penalize the intercept column (matches sklearn/FISTA behavior).
|
|
593
|
+
ridge_alpha = float(n_samples * self.alpha)
|
|
594
|
+
ridge_penalize_intercept = False if self._effective_intercept else True
|
|
595
|
+
|
|
596
|
+
# Select family
|
|
597
|
+
if self.loss == "logistic":
|
|
598
|
+
family = Binomial()
|
|
599
|
+
elif self.loss == "poisson":
|
|
600
|
+
family = Poisson()
|
|
601
|
+
elif self.loss == "gamma":
|
|
602
|
+
family = Gamma()
|
|
603
|
+
elif self.loss == "inverse_gaussian":
|
|
604
|
+
family = InverseGaussian()
|
|
605
|
+
elif self.loss == "negative_binomial":
|
|
606
|
+
family = NegativeBinomial()
|
|
607
|
+
elif self.loss == "tweedie":
|
|
608
|
+
family = Tweedie()
|
|
609
|
+
else:
|
|
610
|
+
family = Gaussian()
|
|
611
|
+
|
|
612
|
+
solver = IRLSSolver(family, max_iter=self.max_iter, tol=self.tol)
|
|
613
|
+
params, n_iter = solver.fit(
|
|
614
|
+
X_arr, y_arr, sample_weight=sample_weight,
|
|
615
|
+
ridge_alpha=ridge_alpha,
|
|
616
|
+
ridge_penalize_intercept=ridge_penalize_intercept,
|
|
617
|
+
backend="numpy",
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
self.n_iter_ = n_iter
|
|
621
|
+
|
|
622
|
+
if self._effective_intercept:
|
|
623
|
+
self.intercept_ = float(params[0])
|
|
624
|
+
self.coef_ = params[1:]
|
|
625
|
+
self._params = np.concatenate([[self.intercept_], np.asarray(self.coef_)])
|
|
626
|
+
else:
|
|
627
|
+
self.intercept_ = 0.0
|
|
628
|
+
self.coef_ = params.copy()
|
|
629
|
+
self._params = np.asarray(self.coef_).copy()
|
|
630
|
+
|
|
631
|
+
self._df_resid = self._nobs - (X.shape[1] + (1 if self._effective_intercept else 0))
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
# --- _irls_cd (dead code, moved from _penalized.py) ---
|
|
636
|
+
def _irls_cd(self, pen, X_work, y_arr, init, _lla_continuation=False):
|
|
637
|
+
"""IRLS with coordinate descent for GLM + non-smooth penalties.
|
|
638
|
+
|
|
639
|
+
Matches R glmnet/ncvreg algorithm: outer IRLS loop computes working
|
|
640
|
+
response and weights, inner CD loop solves the weighted penalized
|
|
641
|
+
least squares subproblem with per-coordinate thresholds.
|
|
642
|
+
Supports: adaptive_l1, scad, mcp.
|
|
643
|
+
"""
|
|
644
|
+
import numpy as np
|
|
645
|
+
|
|
646
|
+
n, pp = X_work.shape
|
|
647
|
+
p = pp - 1 if self._effective_intercept else pp
|
|
648
|
+
|
|
649
|
+
# Access weights from the original penalty (not the SelectivePenalty wrapper)
|
|
650
|
+
_inner = getattr(self, '_penalty', pen)
|
|
651
|
+
_w = np.asarray(getattr(_inner, '_weights', np.ones(p)), dtype=float)
|
|
652
|
+
# Read alpha from the penalty object. The threshold per coordinate
|
|
653
|
+
# is alpha * _w[j] where _w has mean=1 (matching R glmnet convention).
|
|
654
|
+
alpha = float(getattr(_inner, 'alpha', self.alpha))
|
|
655
|
+
_nf = float(getattr(_inner, '_norm_factor', 1.0))
|
|
656
|
+
pen_name = getattr(pen, 'name', '') or getattr(_inner, 'name', '')
|
|
657
|
+
|
|
658
|
+
# SCAD/MCP parameters (guard against division-by-zero)
|
|
659
|
+
a_scad = float(getattr(_inner, 'a', 3.7)) if pen_name == "scad" else 0.0
|
|
660
|
+
if pen_name == "scad":
|
|
661
|
+
a_scad = max(a_scad, 1.0 + 1e-6)
|
|
662
|
+
if abs(a_scad - 2.0) < 1e-6:
|
|
663
|
+
a_scad = 2.0 + 1e-6
|
|
664
|
+
gamma_mcp = float(getattr(_inner, 'gamma', 3.0)) if pen_name == "mcp" else 0.0
|
|
665
|
+
if pen_name == "mcp":
|
|
666
|
+
gamma_mcp = max(gamma_mcp, 1.0 + 1e-6)
|
|
667
|
+
|
|
668
|
+
if init is not None:
|
|
669
|
+
beta = np.asarray(init, dtype=float).copy()
|
|
670
|
+
else:
|
|
671
|
+
beta = np.zeros(pp)
|
|
672
|
+
|
|
673
|
+
loss_name = self._loss.name
|
|
674
|
+
_is_glm = (loss_name != "squared_error")
|
|
675
|
+
|
|
676
|
+
def _nonconvex_penalty_value(coef_slice, _pen_name, _alpha, _a_scad, _gamma_mcp):
|
|
677
|
+
"""Compute SCAD/MCP penalty value for a coefficient vector."""
|
|
678
|
+
_abs_b = np.abs(coef_slice)
|
|
679
|
+
if _pen_name == "scad":
|
|
680
|
+
return float(np.sum(np.where(
|
|
681
|
+
_abs_b <= _alpha, _alpha * _abs_b,
|
|
682
|
+
np.where(_abs_b <= _a_scad * _alpha,
|
|
683
|
+
(_a_scad * _alpha * _abs_b - 0.5 * (coef_slice**2 + _alpha**2)) / (_a_scad - 1.0),
|
|
684
|
+
0.5 * (_a_scad + 1.0) * _alpha**2))))
|
|
685
|
+
if _pen_name == "mcp":
|
|
686
|
+
return float(np.sum(np.where(
|
|
687
|
+
_abs_b <= _gamma_mcp * _alpha,
|
|
688
|
+
_alpha * _abs_b - 0.5 * coef_slice**2 / _gamma_mcp,
|
|
689
|
+
0.5 * _gamma_mcp * _alpha**2)))
|
|
690
|
+
return 0.0
|
|
691
|
+
|
|
692
|
+
# Continuation path for SCAD/MCP: trace the solution from lambda_max
|
|
693
|
+
# down to the target alpha, matching R ncvreg's pathwise approach.
|
|
694
|
+
# Without this, solving directly at the target alpha can converge to
|
|
695
|
+
# a different local minimum than ncvreg (non-convex penalties have
|
|
696
|
+
# multiple local minima that depend on the starting point).
|
|
697
|
+
# Skip when _lla_continuation=True (outer _fit_lla handles the path).
|
|
698
|
+
_cont_path = [alpha]
|
|
699
|
+
if pen_name in ("scad", "mcp") and not _lla_continuation:
|
|
700
|
+
# lambda_max = max(|X_j^T resid| / ||X_j||^2) at the null model.
|
|
701
|
+
# For squared_error: resid = y - mean(y)
|
|
702
|
+
# For GLM: resid = (y - mu0) / mu0 (working residual at null)
|
|
703
|
+
if loss_name == "logistic":
|
|
704
|
+
_p0 = np.clip(np.mean(y_arr), 1e-3, 1 - 1e-3)
|
|
705
|
+
_resid = y_arr - _p0
|
|
706
|
+
elif loss_name == "poisson":
|
|
707
|
+
_mu0 = max(float(np.mean(y_arr)), 1e-3)
|
|
708
|
+
_resid = y_arr - _mu0
|
|
709
|
+
elif loss_name == "gamma":
|
|
710
|
+
_mu0 = max(float(np.mean(y_arr)), 1e-3)
|
|
711
|
+
_resid = (y_arr - _mu0) / _mu0
|
|
712
|
+
else:
|
|
713
|
+
_resid = y_arr - np.mean(y_arr)
|
|
714
|
+
_xty = np.abs(X_work[:, :p].T @ _resid)
|
|
715
|
+
_xnorm_sq = np.sum(X_work[:, :p] ** 2, axis=0)
|
|
716
|
+
_xnorm_sq = np.maximum(_xnorm_sq, 1e-20)
|
|
717
|
+
_lam_max = float(np.max(_xty / _xnorm_sq))
|
|
718
|
+
if _lam_max > alpha * 1.1:
|
|
719
|
+
_n_cont = 100 # match ncvreg's default nlambda
|
|
720
|
+
_cont_path = np.geomspace(_lam_max, alpha, _n_cont)
|
|
721
|
+
|
|
722
|
+
# For GLM losses, do ONE CD sweep per IRLS iteration (matching
|
|
723
|
+
# R ncvreg/glmnet). The IRLS outer loop handles convergence.
|
|
724
|
+
# For squared_error, use the convergence-based CD loop since
|
|
725
|
+
# there is no outer IRLS loop.
|
|
726
|
+
_n_cd_sweeps_base = 1 if _is_glm else min(self.max_iter, 200)
|
|
727
|
+
# For squared_error, the outer IRLS loop is redundant (d=1, z=y
|
|
728
|
+
# are constant). Run the outer loop only once.
|
|
729
|
+
_n_outer_base = self.max_iter if _is_glm else 1
|
|
730
|
+
|
|
731
|
+
# For squared_error, d/z/XDX_diag are constant across continuation
|
|
732
|
+
# steps — compute once before the loop.
|
|
733
|
+
if not _is_glm:
|
|
734
|
+
d = np.ones(n)
|
|
735
|
+
z = y_arr
|
|
736
|
+
XDX_diag = np.sum(d[:, None] * X_work ** 2, axis=0)
|
|
737
|
+
|
|
738
|
+
for _cont_idx, _cont_alpha in enumerate(_cont_path):
|
|
739
|
+
# Update alpha for this continuation step
|
|
740
|
+
if len(_cont_path) > 1:
|
|
741
|
+
alpha = float(_cont_alpha)
|
|
742
|
+
_is_last = (_cont_idx == len(_cont_path) - 1)
|
|
743
|
+
_n_cd_sweeps = _n_cd_sweeps_base if _is_last else 20
|
|
744
|
+
# For GLM with continuation: limit IRLS iterations on
|
|
745
|
+
# non-final steps. ncvreg does ~10 IRLS per lambda value.
|
|
746
|
+
if _is_glm:
|
|
747
|
+
_n_outer = _n_outer_base if _is_last else min(20, _n_outer_base)
|
|
748
|
+
else:
|
|
749
|
+
_n_outer = _n_outer_base
|
|
750
|
+
else:
|
|
751
|
+
_n_cd_sweeps = _n_cd_sweeps_base
|
|
752
|
+
_n_outer = _n_outer_base
|
|
753
|
+
|
|
754
|
+
it = -1
|
|
755
|
+
for it in range(_n_outer):
|
|
756
|
+
beta_old = beta.copy()
|
|
757
|
+
|
|
758
|
+
if _is_glm:
|
|
759
|
+
eta = X_work @ beta
|
|
760
|
+
if loss_name == "logistic":
|
|
761
|
+
mu = 1.0 / (1.0 + np.exp(-np.clip(eta, -500, 500)))
|
|
762
|
+
mu = np.clip(mu, 1e-15, 1.0 - 1e-15)
|
|
763
|
+
d = mu * (1.0 - mu)
|
|
764
|
+
z = eta + (y_arr - mu) / d
|
|
765
|
+
elif loss_name == "poisson":
|
|
766
|
+
mu = np.exp(np.clip(eta, -500, 500))
|
|
767
|
+
mu = np.maximum(mu, 1e-15)
|
|
768
|
+
d = mu
|
|
769
|
+
z = eta + (y_arr - mu) / d
|
|
770
|
+
elif loss_name == "gamma":
|
|
771
|
+
mu = np.exp(np.clip(eta, -500, 500))
|
|
772
|
+
mu = np.maximum(mu, 1e-15)
|
|
773
|
+
d = np.ones(n)
|
|
774
|
+
z = eta + (y_arr - mu) / mu
|
|
775
|
+
elif loss_name == "inverse_gaussian":
|
|
776
|
+
# V(mu) = mu^3, log link g'(mu) = 1/mu
|
|
777
|
+
# IRLS weight: w = 1/(V(mu) * [g'(mu)]^2) = 1/(mu^3 * 1/mu^2) = 1/mu
|
|
778
|
+
# Working response: z = eta + (y - mu) * g'(mu) = eta + (y - mu)/mu
|
|
779
|
+
mu = np.exp(np.clip(eta, -500, 500))
|
|
780
|
+
mu = np.maximum(mu, 1e-15)
|
|
781
|
+
d = 1.0 / mu
|
|
782
|
+
z = eta + (y_arr - mu) / mu
|
|
783
|
+
elif loss_name == "negative_binomial":
|
|
784
|
+
mu = np.exp(np.clip(eta, -500, 500))
|
|
785
|
+
mu = np.maximum(mu, 1e-15)
|
|
786
|
+
theta_nb = float(getattr(self._loss, 'alpha', 1.0))
|
|
787
|
+
d = mu / (1.0 + mu / theta_nb)
|
|
788
|
+
z = eta + (y_arr - mu) / d
|
|
789
|
+
elif loss_name == "tweedie":
|
|
790
|
+
mu = np.exp(np.clip(eta, -500, 500))
|
|
791
|
+
mu = np.maximum(mu, 1e-15)
|
|
792
|
+
tweedie_p = float(getattr(self._loss, 'power', 1.5))
|
|
793
|
+
d = mu ** tweedie_p
|
|
794
|
+
d = np.maximum(d, 1e-15)
|
|
795
|
+
z = eta + (y_arr - mu) / (d * mu)
|
|
796
|
+
else:
|
|
797
|
+
grad = self._loss.gradient(X_work, y_arr, beta)
|
|
798
|
+
d = np.ones(n)
|
|
799
|
+
z = eta - grad * n
|
|
800
|
+
XDX_diag = np.sum(d[:, None] * X_work ** 2, axis=0)
|
|
801
|
+
|
|
802
|
+
# Effective sample size: use sum(d) for correct normalization
|
|
803
|
+
# when sample weights are present (d already includes sw scaling).
|
|
804
|
+
n_eff = float(np.sum(d))
|
|
805
|
+
|
|
806
|
+
r = z - X_work @ beta
|
|
807
|
+
|
|
808
|
+
# Compute penalized objective before CD (for step-halving)
|
|
809
|
+
if _is_glm:
|
|
810
|
+
# Use full design matrix (including intercept) for correct objective
|
|
811
|
+
_obj_before = float(self._loss.value(X_work, y_arr, beta))
|
|
812
|
+
_obj_before += _nonconvex_penalty_value(beta[:p], pen_name, alpha, a_scad, gamma_mcp)
|
|
813
|
+
|
|
814
|
+
for _cd in range(_n_cd_sweeps):
|
|
815
|
+
_max_cd_change = 0.0
|
|
816
|
+
for j in range(pp):
|
|
817
|
+
if XDX_diag[j] < 1e-20:
|
|
818
|
+
beta[j] = 0.0
|
|
819
|
+
continue
|
|
820
|
+
|
|
821
|
+
rho_j = np.dot(d * X_work[:, j], r) + XDX_diag[j] * beta[j]
|
|
822
|
+
old_bj = beta[j]
|
|
823
|
+
|
|
824
|
+
u_j = rho_j / n_eff
|
|
825
|
+
v_j = XDX_diag[j] / n_eff
|
|
826
|
+
|
|
827
|
+
if j >= p:
|
|
828
|
+
beta[j] = u_j / v_j
|
|
829
|
+
elif pen_name in ("adaptive_l1", "adaptive_lasso"):
|
|
830
|
+
l1 = alpha * _w[j]
|
|
831
|
+
w_j = u_j / v_j
|
|
832
|
+
if w_j > l1:
|
|
833
|
+
beta[j] = (w_j - l1)
|
|
834
|
+
elif w_j < -l1:
|
|
835
|
+
beta[j] = (w_j + l1)
|
|
836
|
+
else:
|
|
837
|
+
beta[j] = 0.0
|
|
838
|
+
elif pen_name == "scad":
|
|
839
|
+
l1 = alpha
|
|
840
|
+
w_j = u_j / v_j
|
|
841
|
+
aw = np.abs(w_j)
|
|
842
|
+
if aw > a_scad * l1:
|
|
843
|
+
beta[j] = w_j
|
|
844
|
+
elif aw > l1:
|
|
845
|
+
beta[j] = np.sign(w_j) * ((a_scad - 1.0) * aw - a_scad * l1) / (a_scad - 2.0)
|
|
846
|
+
else:
|
|
847
|
+
beta[j] = 0.0
|
|
848
|
+
elif pen_name == "mcp":
|
|
849
|
+
l1 = alpha
|
|
850
|
+
w_j = u_j / v_j
|
|
851
|
+
aw = np.abs(w_j)
|
|
852
|
+
if aw > gamma_mcp * l1:
|
|
853
|
+
beta[j] = w_j
|
|
854
|
+
elif aw > l1:
|
|
855
|
+
beta[j] = np.sign(w_j) * (aw - l1) / (1.0 - 1.0 / gamma_mcp)
|
|
856
|
+
else:
|
|
857
|
+
beta[j] = 0.0
|
|
858
|
+
else:
|
|
859
|
+
l1 = alpha
|
|
860
|
+
w_j = u_j / v_j
|
|
861
|
+
if w_j > l1:
|
|
862
|
+
beta[j] = (w_j - l1)
|
|
863
|
+
elif w_j < -l1:
|
|
864
|
+
beta[j] = (w_j + l1)
|
|
865
|
+
else:
|
|
866
|
+
beta[j] = 0.0
|
|
867
|
+
|
|
868
|
+
if beta[j] != old_bj:
|
|
869
|
+
r += X_work[:, j] * (old_bj - beta[j])
|
|
870
|
+
_cd_change = abs(beta[j] - old_bj)
|
|
871
|
+
if _cd_change > _max_cd_change:
|
|
872
|
+
_max_cd_change = _cd_change
|
|
873
|
+
|
|
874
|
+
# Inner CD convergence check (only for squared_error)
|
|
875
|
+
if not _is_glm and _max_cd_change < self.tol:
|
|
876
|
+
break
|
|
877
|
+
|
|
878
|
+
# Step-halving for GLM: ensure penalized objective decreases.
|
|
879
|
+
# ncvreg uses step-halving to prevent IRLS overshooting.
|
|
880
|
+
if _is_glm:
|
|
881
|
+
_obj_after = float(self._loss.value(X_work, y_arr, beta))
|
|
882
|
+
_obj_after += _nonconvex_penalty_value(beta[:p], pen_name, alpha, a_scad, gamma_mcp)
|
|
883
|
+
if _obj_after > _obj_before + 1e-10:
|
|
884
|
+
# Step-halving: interpolate between old and new beta
|
|
885
|
+
# beta_sh = beta_old + 0.5^k * (beta_new - beta_old)
|
|
886
|
+
beta_new = beta.copy()
|
|
887
|
+
for _sh in range(1, 11):
|
|
888
|
+
_frac = 0.5 ** _sh
|
|
889
|
+
beta[:] = beta_old + _frac * (beta_new - beta_old)
|
|
890
|
+
_obj_after = float(self._loss.value(X_work, y_arr, beta))
|
|
891
|
+
_obj_after += _nonconvex_penalty_value(beta[:p], pen_name, alpha, a_scad, gamma_mcp)
|
|
892
|
+
if _obj_after <= _obj_before + 1e-10:
|
|
893
|
+
break
|
|
894
|
+
|
|
895
|
+
# IRLS-level convergence check.
|
|
896
|
+
_delta = np.max(np.abs(beta[:p] - beta_old[:p]))
|
|
897
|
+
if not _is_glm and _delta < self.tol:
|
|
898
|
+
break
|
|
899
|
+
# For GLM with continuation: early exit on convergence
|
|
900
|
+
# for non-final steps (avoids wasting iterations).
|
|
901
|
+
if _is_glm and len(_cont_path) > 1 and not _is_last:
|
|
902
|
+
if _delta < self.tol * 10:
|
|
903
|
+
break
|
|
904
|
+
|
|
905
|
+
n_iter = (it + 1) if _n_outer > 0 else 0
|
|
906
|
+
return beta, n_iter
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
# --- _fit_lla (dead code, moved from _penalized.py) ---
|
|
910
|
+
def _fit_lla(self, X, y, sample_weight, backend_name, init_coef=None):
|
|
911
|
+
"""Fit non-convex penalty via Local Linear Approximation.
|
|
912
|
+
|
|
913
|
+
Outer loop reweights the non-convex penalty as per-coordinate
|
|
914
|
+
weighted L1. Each inner iteration solves a convex problem
|
|
915
|
+
(ADMM for squared-error, FISTA for GLM) with the current weights.
|
|
916
|
+
|
|
917
|
+
A **continuation path** is used for all losses: alpha is stepped
|
|
918
|
+
down geometrically from 15× the target to the target (8 steps).
|
|
919
|
+
Without this, small coefficients from the init receive weak L1
|
|
920
|
+
weights (= P'(|coef|) ≈ alpha) and survive the inner solve,
|
|
921
|
+
producing too many non-zeros. Starting from a larger alpha and
|
|
922
|
+
stepping down forces coefficients to cross the SCAD/MCP transition
|
|
923
|
+
region (alpha .. a·alpha) where the two penalties differ — the
|
|
924
|
+
same strategy used internally by R's ncvreg.
|
|
925
|
+
|
|
926
|
+
For the inner loop the penalty is temporarily swapped for an
|
|
927
|
+
``AdaptiveL1Penalty`` whose per-coordinate weights are set from
|
|
928
|
+
``penalty.lla_weights(coef)``.
|
|
929
|
+
"""
|
|
930
|
+
n_features = X.shape[1]
|
|
931
|
+
|
|
932
|
+
if init_coef is not None:
|
|
933
|
+
coef_lla = np.asarray(init_coef, dtype=float).copy()
|
|
934
|
+
elif self._penalty.requires_init:
|
|
935
|
+
coef_lla = np.zeros(n_features)
|
|
936
|
+
else:
|
|
937
|
+
coef_lla = self._fit_initial(X, y, backend_name=backend_name)
|
|
938
|
+
|
|
939
|
+
# For GLM + SCAD/MCP direct IRLS-CD path, override init to zeros.
|
|
940
|
+
# R's ncvreg starts from lambda_max with all-zero coefficients and
|
|
941
|
+
# warm-starts down the continuation path. The L2-penalized GLM
|
|
942
|
+
# init gives large coefficients that cause numerical overflow in
|
|
943
|
+
# the IRLS working response when eta is extreme.
|
|
944
|
+
_pen_name_init = str(getattr(self._penalty, 'name', '')).lower()
|
|
945
|
+
_is_glm_scad_mcp = (self.loss != "squared_error") and _pen_name_init in ("scad", "mcp")
|
|
946
|
+
_is_scad_mcp = _pen_name_init in ("scad", "mcp")
|
|
947
|
+
if _is_scad_mcp:
|
|
948
|
+
coef_lla = np.zeros(n_features)
|
|
949
|
+
|
|
950
|
+
from statgpu.penalties._adaptive_l1 import AdaptiveL1Penalty
|
|
951
|
+
|
|
952
|
+
# ADMM inner solver was used for squared_error CPU path for cross-backend
|
|
953
|
+
# consistency, but on CPU it is 4000× slower than FISTA (admm_solver
|
|
954
|
+
# recomputes X@w and X.T@g per CG iteration instead of precomputing XtX
|
|
955
|
+
# once). On GPU the cuBLAS matmuls are fast enough that ADMM is
|
|
956
|
+
# competitive. Use fista_bb for CPU (O(p²) gradient with XtX precompute)
|
|
957
|
+
# GLM losses: use fista_bb for early continuation steps (large alpha,
|
|
958
|
+
# small coef — exp(X@coef) ≈ 1, BB steps are safe and 3-10× faster),
|
|
959
|
+
# then switch to fista (backtracking) only for the final step where
|
|
960
|
+
# coefficients may grow large enough to cause exp-link explosion.
|
|
961
|
+
# Gamma is excluded — its gradient scale (1/mu) makes BB step estimates
|
|
962
|
+
# unreliable even at small coefficients.
|
|
963
|
+
saved_cpu_solver = self.cpu_solver
|
|
964
|
+
saved_selected_solver = self._selected_solver
|
|
965
|
+
_is_glm = (self.loss != "squared_error")
|
|
966
|
+
_glm_bb_safe = _is_glm and self.loss in ("poisson", "logistic")
|
|
967
|
+
if _is_glm and not _glm_bb_safe:
|
|
968
|
+
self.cpu_solver = "fista"
|
|
969
|
+
self._selected_solver = "fista"
|
|
970
|
+
elif not _is_glm:
|
|
971
|
+
if _is_scad_mcp:
|
|
972
|
+
# SCAD/MCP uses direct FISTA+proximal (not ADMM)
|
|
973
|
+
self.cpu_solver = "fista_bb"
|
|
974
|
+
self._selected_solver = "fista_bb"
|
|
975
|
+
else:
|
|
976
|
+
# CPU: use fista_bb (precomputes XtX, O(p²) per iter, ~9ms total)
|
|
977
|
+
# GPU: use admm (cuBLAS matmuls, ~40ms total with perfect x-backend consistency)
|
|
978
|
+
if backend_name == "numpy":
|
|
979
|
+
self.cpu_solver = "fista_bb"
|
|
980
|
+
self._selected_solver = "fista_bb"
|
|
981
|
+
else:
|
|
982
|
+
self.cpu_solver = "admm"
|
|
983
|
+
self._selected_solver = "admm"
|
|
984
|
+
|
|
985
|
+
# Continuation path for all losses: start from a larger alpha and
|
|
986
|
+
# step down geometrically to the target. This forces coefficients
|
|
987
|
+
# to cross the SCAD/MCP transition region (alpha .. a·alpha).
|
|
988
|
+
# Squared-error + ADMM uses a wider path (20× / 8 steps) because
|
|
989
|
+
# the OLS init produces many small but non-zero coefficients that
|
|
990
|
+
# need stronger initial shrinkage to match R's ncvreg. GLM losses
|
|
991
|
+
# use a moderate path (10× / 5 steps) to balance sparsity and
|
|
992
|
+
# convergence — larger paths cause FISTA to overshoot.
|
|
993
|
+
import numpy as _np
|
|
994
|
+
|
|
995
|
+
# Compute lambda_max — the smallest penalty where all coefficients are zero.
|
|
996
|
+
# Matches R ncvreg: lambda_max = max_j |sum(x_s_j * resid)| / n
|
|
997
|
+
# on standardized X (||X_j|| = sqrt(n)). The IRLS-CD gradient
|
|
998
|
+
# u_j = rho_j/n equals this at the null model, and the SCAD/MCP
|
|
999
|
+
# threshold is l1 = alpha on u_j.
|
|
1000
|
+
_X_np = _np.asarray(X, dtype=float)
|
|
1001
|
+
_y_np = _np.asarray(y, dtype=float)
|
|
1002
|
+
_n = _X_np.shape[0]
|
|
1003
|
+
# Standardize X to match ncvreg: ||X_j|| = sqrt(n), i.e. mean(x^2) = 1
|
|
1004
|
+
_col_norms = _np.sqrt(_np.sum(_X_np ** 2, axis=0))
|
|
1005
|
+
_col_norms = _np.maximum(_col_norms, 1e-20)
|
|
1006
|
+
_X_s = _X_np * (_np.sqrt(_n) / _col_norms)
|
|
1007
|
+
if self.loss == "logistic":
|
|
1008
|
+
_p0 = _np.clip(_np.mean(_y_np), 1e-3, 1-1e-3)
|
|
1009
|
+
_lam_max = float(_np.max(_np.abs(_X_s.T @ (_y_np - _p0) / _n)))
|
|
1010
|
+
elif self.loss == "poisson":
|
|
1011
|
+
_mu0 = max(float(_np.mean(_y_np)), 1e-3)
|
|
1012
|
+
_lam_max = float(_np.max(_np.abs(_X_s.T @ (_y_np - _mu0) / _n)))
|
|
1013
|
+
elif self.loss == "gamma":
|
|
1014
|
+
_mu0 = max(float(_np.mean(_y_np)), 1e-3)
|
|
1015
|
+
_lam_max = float(_np.max(_np.abs(_X_s.T @ ((_y_np - _mu0) / _mu0) / _n)))
|
|
1016
|
+
elif self.loss == "squared_error":
|
|
1017
|
+
_y_centered = _y_np - _np.mean(_y_np)
|
|
1018
|
+
_lam_max = float(_np.max(_np.abs(_X_s.T @ _y_centered / _n)))
|
|
1019
|
+
else:
|
|
1020
|
+
_lam_max = self.alpha * 15.0 # fallback
|
|
1021
|
+
|
|
1022
|
+
_n_cont = 20 if _is_scad_mcp else 10
|
|
1023
|
+
# Start from lambda_max to match R ncvreg's pathwise approach.
|
|
1024
|
+
# lambda_max is the smallest penalty where all coefficients are zero.
|
|
1025
|
+
_alpha_start = float(_lam_max)
|
|
1026
|
+
_alpha_end = float(self.alpha)
|
|
1027
|
+
if _alpha_start <= 0.0 or _alpha_end <= 0.0:
|
|
1028
|
+
_lo = max(min(_alpha_start, _alpha_end), 1e-12)
|
|
1029
|
+
_hi = max(_alpha_start, _alpha_end, 1e-12)
|
|
1030
|
+
if _hi <= _lo:
|
|
1031
|
+
_alpha_path = _np.full(_n_cont, _hi, dtype=float)
|
|
1032
|
+
else:
|
|
1033
|
+
_alpha_path = _np.linspace(_hi, _lo, _n_cont, dtype=float)
|
|
1034
|
+
_alpha_path[-1] = max(_alpha_end, 1e-12)
|
|
1035
|
+
else:
|
|
1036
|
+
_alpha_path = _np.geomspace(_alpha_start, _alpha_end, _n_cont)
|
|
1037
|
+
_max_lla_per_step = max(6, self._max_lla_iters // _n_cont)
|
|
1038
|
+
|
|
1039
|
+
saved_max_iter = self.max_iter
|
|
1040
|
+
|
|
1041
|
+
try:
|
|
1042
|
+
# squared_error+SCAD/MCP: fused LLA+FISTA path.
|
|
1043
|
+
# Runs entire continuation+LLA+FISTA loop in one tight function
|
|
1044
|
+
# to eliminate per-call overhead (300+ fista_solver calls).
|
|
1045
|
+
if _is_scad_mcp and not _is_glm:
|
|
1046
|
+
from statgpu.solvers import fista_lla_path
|
|
1047
|
+
X_cached = self._to_array(X, backend=backend_name)
|
|
1048
|
+
y_cached = self._to_array(y, backend=backend_name)
|
|
1049
|
+
|
|
1050
|
+
# Build max_iter schedule: early steps need fewer iterations
|
|
1051
|
+
_mi_path = []
|
|
1052
|
+
for _i in range(_n_cont):
|
|
1053
|
+
_is_last = (_i == _n_cont - 1)
|
|
1054
|
+
_mi_path.append(saved_max_iter if _is_last else max(100, saved_max_iter // 10))
|
|
1055
|
+
|
|
1056
|
+
coef_np, intercept, n_iter = fista_lla_path(
|
|
1057
|
+
self._loss, self._penalty,
|
|
1058
|
+
X_cached, y_cached,
|
|
1059
|
+
alpha_path=_alpha_path,
|
|
1060
|
+
max_lla_per_step=_max_lla_per_step,
|
|
1061
|
+
lla_tol=self._lla_tol,
|
|
1062
|
+
max_iter=_mi_path,
|
|
1063
|
+
tol=self.tol,
|
|
1064
|
+
fit_intercept=self._effective_intercept,
|
|
1065
|
+
sample_weight=sample_weight,
|
|
1066
|
+
)
|
|
1067
|
+
coef_lla = coef_np
|
|
1068
|
+
self.coef_ = coef_np
|
|
1069
|
+
self.intercept_ = intercept
|
|
1070
|
+
self.n_iter_ = n_iter
|
|
1071
|
+
self._lla_n_iters_ = _n_cont * _max_lla_per_step
|
|
1072
|
+
else:
|
|
1073
|
+
# Cache GPU arrays once outside the continuation loop
|
|
1074
|
+
X_cached = self._to_array(X, backend=backend_name)
|
|
1075
|
+
y_cached = self._to_array(y, backend=backend_name)
|
|
1076
|
+
|
|
1077
|
+
for _cont_step, _cont_alpha in enumerate(_alpha_path):
|
|
1078
|
+
# Create a copy with the continuation alpha to avoid
|
|
1079
|
+
# mutating the shared penalty object (thread-safety).
|
|
1080
|
+
_pen_step = copy.copy(self._penalty)
|
|
1081
|
+
_pen_step.alpha = float(_cont_alpha)
|
|
1082
|
+
|
|
1083
|
+
_is_last_cont = (_cont_step == _n_cont - 1)
|
|
1084
|
+
if _is_glm_scad_mcp:
|
|
1085
|
+
self.max_iter = 500 if _is_last_cont else 100
|
|
1086
|
+
elif _is_last_cont:
|
|
1087
|
+
self.max_iter = saved_max_iter
|
|
1088
|
+
else:
|
|
1089
|
+
self.max_iter = max(200, saved_max_iter // 3)
|
|
1090
|
+
_is_gamma = (self.loss == "gamma")
|
|
1091
|
+
if _is_gamma:
|
|
1092
|
+
self.max_iter = max(300, self.max_iter // 2)
|
|
1093
|
+
if _glm_bb_safe:
|
|
1094
|
+
self.cpu_solver = "fista_bb"
|
|
1095
|
+
self._selected_solver = "fista_bb"
|
|
1096
|
+
|
|
1097
|
+
if _is_scad_mcp and not _is_glm:
|
|
1098
|
+
# This branch is now handled above by fista_lla_path
|
|
1099
|
+
pass
|
|
1100
|
+
else:
|
|
1101
|
+
for _lla_local in range(_max_lla_per_step):
|
|
1102
|
+
# Compute LLA weights from current estimate
|
|
1103
|
+
lla_w = _pen_step.lla_weights(coef_lla)
|
|
1104
|
+
|
|
1105
|
+
# SelectivePenalty wrapper handles intercept separately
|
|
1106
|
+
# (clips to [-15,15] then sets penalty gradient to 0).
|
|
1107
|
+
# Weights stay at p entries — no intercept padding needed.
|
|
1108
|
+
# lla_weights() already returns alpha-scaled derivative
|
|
1109
|
+
# weights (e.g. SCAD: alpha for |coef| <= alpha).
|
|
1110
|
+
# AdaptiveL1Penalty applies: alpha_inner * weight_j * |coef_j|,
|
|
1111
|
+
# so with alpha_inner=1 and weight=lla_w we get exactly
|
|
1112
|
+
# the LLA penalty: sum_j lla_w_j * |coef_j|.
|
|
1113
|
+
#
|
|
1114
|
+
inner_pen = AdaptiveL1Penalty(alpha=1.0)
|
|
1115
|
+
inner_pen._weights = lla_w
|
|
1116
|
+
|
|
1117
|
+
# Swap penalty (protected by try/finally)
|
|
1118
|
+
# Use copy to avoid thread-safety issues with shared instances
|
|
1119
|
+
import copy
|
|
1120
|
+
orig_penalty = copy.copy(self._penalty)
|
|
1121
|
+
self._penalty = inner_pen
|
|
1122
|
+
try:
|
|
1123
|
+
# Run inner FISTA with warm-start from previous LLA estimate
|
|
1124
|
+
# Use cached arrays to avoid repeated GPU transfers
|
|
1125
|
+
self._init_coef = coef_lla.copy()
|
|
1126
|
+
|
|
1127
|
+
if backend_name == "torch":
|
|
1128
|
+
self._fit_torch(X_cached, y_cached, sample_weight)
|
|
1129
|
+
elif backend_name == "cupy":
|
|
1130
|
+
self._fit_gpu(X_cached, y_cached, sample_weight)
|
|
1131
|
+
else:
|
|
1132
|
+
self._fit_cpu(X_cached, y_cached, sample_weight)
|
|
1133
|
+
|
|
1134
|
+
self._init_coef = None
|
|
1135
|
+
finally:
|
|
1136
|
+
# Restore original penalty even if inner fit raises
|
|
1137
|
+
self._penalty = orig_penalty
|
|
1138
|
+
|
|
1139
|
+
# LLA convergence
|
|
1140
|
+
coef_new = self.coef_.copy()
|
|
1141
|
+
delta = float(np.sum(np.abs(coef_new - coef_lla)))
|
|
1142
|
+
self._lla_n_iters_ = getattr(self, '_lla_n_iters_', 0) + 1
|
|
1143
|
+
|
|
1144
|
+
if delta < self._lla_tol:
|
|
1145
|
+
coef_lla = coef_new
|
|
1146
|
+
break
|
|
1147
|
+
|
|
1148
|
+
coef_lla = coef_new
|
|
1149
|
+
|
|
1150
|
+
# Store final results. For GLM+SCAD/MCP, _fit_cpu/_fit_gpu/_fit_torch
|
|
1151
|
+
# already set self.coef_ and self.intercept_. For squared_error+SCAD/MCP,
|
|
1152
|
+
# _irls_cd returned params but didn't set them on self.
|
|
1153
|
+
if self.coef_ is None and coef_lla is not None:
|
|
1154
|
+
self.coef_ = np.asarray(coef_lla[:X.shape[1]], dtype=float)
|
|
1155
|
+
if self._effective_intercept:
|
|
1156
|
+
X_np = np.asarray(X, dtype=float)
|
|
1157
|
+
y_np = np.asarray(y, dtype=float)
|
|
1158
|
+
if sample_weight is not None:
|
|
1159
|
+
sw_np = np.asarray(sample_weight, dtype=float).ravel()
|
|
1160
|
+
sw_sum = max(float(np.sum(sw_np)), 1e-15)
|
|
1161
|
+
X_wmean = np.sum(X_np * sw_np[:, None], axis=0) / sw_sum
|
|
1162
|
+
y_wmean = float(np.sum(y_np * sw_np)) / sw_sum
|
|
1163
|
+
self.intercept_ = float(y_wmean - X_wmean @ self.coef_)
|
|
1164
|
+
else:
|
|
1165
|
+
self.intercept_ = float(np.mean(y_np) - np.mean(X_np, axis=0) @ self.coef_)
|
|
1166
|
+
else:
|
|
1167
|
+
self.intercept_ = 0.0
|
|
1168
|
+
self._params = np.concatenate([[self.intercept_], self.coef_])
|
|
1169
|
+
self._df_resid = X.shape[0] - (X.shape[1] + (1 if self._effective_intercept else 0))
|
|
1170
|
+
finally:
|
|
1171
|
+
self.cpu_solver = saved_cpu_solver
|
|
1172
|
+
self._selected_solver = saved_selected_solver
|
|
1173
|
+
self.max_iter = saved_max_iter
|
|
1174
|
+
|