statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
"""Fused LLA+FISTA solver for SCAD/MCP over a continuation path.
|
|
2
|
+
|
|
3
|
+
Runs the entire continuation -> LLA -> FISTA loop in one tight function,
|
|
4
|
+
eliminating per-call overhead (backend detect, preprocess, Lipschitz
|
|
5
|
+
recompute, array allocation) that accumulates over 300+ fista_solver calls.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__all__ = ["fista_lla_path"]
|
|
9
|
+
|
|
10
|
+
import copy
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from statgpu.backends import _resolve_backend, _to_numpy
|
|
14
|
+
from statgpu.backends._utils import _to_float_scalar, xp_ones
|
|
15
|
+
from statgpu.backends._array_ops import (
|
|
16
|
+
_abs_sum_dev,
|
|
17
|
+
_clip_grad_on_device,
|
|
18
|
+
_copy_arr,
|
|
19
|
+
_norm2_dev,
|
|
20
|
+
_sync_scalars,
|
|
21
|
+
_zeros,
|
|
22
|
+
)
|
|
23
|
+
from statgpu.penalties._categories import NONSMOOTH as _NONSMOOTH_ALL
|
|
24
|
+
from statgpu.penalties._adaptive_l1 import AdaptiveL1Penalty
|
|
25
|
+
from ._constants import (
|
|
26
|
+
_DIVERGE_COEF_NORM_CAP,
|
|
27
|
+
_GRAD_CLIP_COEF_FACTOR,
|
|
28
|
+
_GRAD_CLIP_ABS_FLOOR,
|
|
29
|
+
_GRAD_CLIP_MAX,
|
|
30
|
+
)
|
|
31
|
+
from ._utils import (
|
|
32
|
+
_nesterov_momentum,
|
|
33
|
+
_validate_sample_weight,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# ---------------------------------------------------------------------------
|
|
37
|
+
# Fused proximal kernels for squared_error + AdaptiveL1 (SCAD/MCP via LLA)
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
# Pre-computes XtX, Xty to avoid redundant matmul; fuses element-wise ops;
|
|
40
|
+
# defers GPU->CPU syncs for convergence.
|
|
41
|
+
|
|
42
|
+
_SQERR_PROXIMAL_TORCH = None
|
|
43
|
+
_SQERR_PROXIMAL_CUPY = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _get_sqerr_proximal_torch():
|
|
47
|
+
global _SQERR_PROXIMAL_TORCH
|
|
48
|
+
if _SQERR_PROXIMAL_TORCH is None:
|
|
49
|
+
import torch
|
|
50
|
+
# torch.compile requires CUDA capability >= 7.0 (Triton).
|
|
51
|
+
# Fall back to JIT script for older GPUs (P100 = 6.0).
|
|
52
|
+
_cap = torch.cuda.get_device_capability()[0] if torch.cuda.is_available() else 0
|
|
53
|
+
if _cap >= 7:
|
|
54
|
+
try:
|
|
55
|
+
@torch.compile(mode='reduce-overhead', backend='inductor')
|
|
56
|
+
def _fused_update(y_current, grad, step, thresh, coef_old, beta):
|
|
57
|
+
w = y_current - step * grad
|
|
58
|
+
abs_w = w.abs()
|
|
59
|
+
sign_w = w.sign()
|
|
60
|
+
coef_new = sign_w * (abs_w - thresh).clamp(min=0.0)
|
|
61
|
+
y_k = coef_new + beta * (coef_new - coef_old)
|
|
62
|
+
return coef_new, y_k
|
|
63
|
+
_SQERR_PROXIMAL_TORCH = _fused_update
|
|
64
|
+
except (RuntimeError, TypeError):
|
|
65
|
+
pass
|
|
66
|
+
if _SQERR_PROXIMAL_TORCH is None:
|
|
67
|
+
def _fused_update_eager(y_current, grad, step, thresh, coef_old, beta):
|
|
68
|
+
w = y_current - step * grad
|
|
69
|
+
abs_w = w.abs()
|
|
70
|
+
sign_w = w.sign()
|
|
71
|
+
coef_new = sign_w * (abs_w - thresh).clamp(min=0.0)
|
|
72
|
+
y_k = coef_new + beta * (coef_new - coef_old)
|
|
73
|
+
return coef_new, y_k
|
|
74
|
+
_SQERR_PROXIMAL_TORCH = _fused_update_eager
|
|
75
|
+
return _SQERR_PROXIMAL_TORCH
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_sqerr_proximal_cupy():
|
|
79
|
+
global _SQERR_PROXIMAL_CUPY
|
|
80
|
+
if _SQERR_PROXIMAL_CUPY is None:
|
|
81
|
+
import cupy as cp
|
|
82
|
+
_SQERR_PROXIMAL_CUPY = cp.ElementwiseKernel(
|
|
83
|
+
'T y_current, T grad, T step, T thresh, T coef_old, T beta',
|
|
84
|
+
'T coef_new, T y_k',
|
|
85
|
+
'''
|
|
86
|
+
T w = y_current - step * grad;
|
|
87
|
+
T abs_w = abs(w);
|
|
88
|
+
T sign_w = (w > 0) ? 1 : ((w < 0) ? -1 : 0);
|
|
89
|
+
coef_new = (abs_w > thresh) ? sign_w * (abs_w - thresh) : 0;
|
|
90
|
+
y_k = coef_new + beta * (coef_new - coef_old);
|
|
91
|
+
''',
|
|
92
|
+
'sqerr_proximal_fused',
|
|
93
|
+
)
|
|
94
|
+
return _SQERR_PROXIMAL_CUPY
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# ---------------------------------------------------------------------------
|
|
98
|
+
# Main solver
|
|
99
|
+
# ---------------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def fista_lla_path(
|
|
103
|
+
loss,
|
|
104
|
+
scad_penalty,
|
|
105
|
+
X,
|
|
106
|
+
y,
|
|
107
|
+
alpha_path,
|
|
108
|
+
max_lla_per_step=6,
|
|
109
|
+
lla_tol=1e-6,
|
|
110
|
+
max_iter=1000,
|
|
111
|
+
tol=1e-4,
|
|
112
|
+
fit_intercept=True,
|
|
113
|
+
sample_weight=None,
|
|
114
|
+
lla_penalty_factory=None,
|
|
115
|
+
init_coef=None,
|
|
116
|
+
init_intercept=None,
|
|
117
|
+
return_path=False,
|
|
118
|
+
):
|
|
119
|
+
"""Fused LLA+FISTA solver for SCAD/MCP over a continuation path.
|
|
120
|
+
|
|
121
|
+
Runs the entire continuation -> LLA -> FISTA loop in one tight function,
|
|
122
|
+
eliminating per-call overhead (backend detect, preprocess, Lipschitz
|
|
123
|
+
recompute, array allocation) that accumulates over 300+ fista_solver calls.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
loss : GLMLoss
|
|
128
|
+
scad_penalty : SCADPenalty or MCPPenalty
|
|
129
|
+
Penalty object; its .alpha will be set along the path.
|
|
130
|
+
X, y : array (pre-centered if fit_intercept=True)
|
|
131
|
+
alpha_path : array of alpha values (descending, geomspace)
|
|
132
|
+
max_lla_per_step : int
|
|
133
|
+
lla_tol : float
|
|
134
|
+
max_iter : int or list[int]
|
|
135
|
+
FISTA iteration limit. If a list, one value per continuation step.
|
|
136
|
+
tol : float
|
|
137
|
+
fit_intercept : bool
|
|
138
|
+
sample_weight : array or None
|
|
139
|
+
init_coef : array or None
|
|
140
|
+
Warm-start coefficients (without intercept). If provided, they are
|
|
141
|
+
injected only at the final target-alpha continuation step.
|
|
142
|
+
init_intercept : float or None
|
|
143
|
+
Warm-start intercept value.
|
|
144
|
+
return_path : bool, default=False
|
|
145
|
+
When True, also return coefficients/intercepts after each continuation
|
|
146
|
+
alpha. The default keeps the historical 3-tuple return value.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
coef : array (p,)
|
|
151
|
+
intercept : float
|
|
152
|
+
total_iter : int
|
|
153
|
+
"""
|
|
154
|
+
backend = _resolve_backend("auto", X)
|
|
155
|
+
if backend == "torch":
|
|
156
|
+
import torch as xp
|
|
157
|
+
torch = xp
|
|
158
|
+
x_dtype = X.dtype if getattr(X, "is_floating_point", lambda: False)() else torch.float64
|
|
159
|
+
y_dtype = y.dtype if getattr(y, "is_floating_point", lambda: False)() else torch.float64
|
|
160
|
+
common_dtype = torch.promote_types(x_dtype, y_dtype)
|
|
161
|
+
X = X.to(dtype=common_dtype)
|
|
162
|
+
y = torch.as_tensor(y, device=X.device, dtype=common_dtype)
|
|
163
|
+
elif backend == "cupy":
|
|
164
|
+
import cupy as xp
|
|
165
|
+
else:
|
|
166
|
+
xp = np
|
|
167
|
+
X_proc, y_proc = loss.preprocess(X, y)
|
|
168
|
+
_is_quadratic = getattr(loss, '_is_quadratic', False)
|
|
169
|
+
_no_momentum = getattr(loss, '_skip_momentum', False)
|
|
170
|
+
_non_smooth_pen_lla = getattr(scad_penalty, 'name', '') in _NONSMOOTH_ALL
|
|
171
|
+
_momentum_beta_cap = getattr(loss, '_momentum_beta_cap', None)
|
|
172
|
+
_conservative_momentum_lla = (
|
|
173
|
+
_momentum_beta_cap is not None
|
|
174
|
+
or (getattr(loss, '_conservative_momentum_with_nonsmooth', False)
|
|
175
|
+
and _non_smooth_pen_lla)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
n_samples, n_features = X_proc.shape
|
|
179
|
+
_validate_sample_weight(sample_weight, n_samples)
|
|
180
|
+
|
|
181
|
+
# --- Intercept handling ---
|
|
182
|
+
# For squared_error (identity link): centering X, y is exact.
|
|
183
|
+
# For GLM losses (log/logit link): centering is WRONG -- it changes
|
|
184
|
+
# the objective. Instead, augment X with a ones column so the
|
|
185
|
+
# intercept is part of the coefficient vector.
|
|
186
|
+
_augment_intercept = fit_intercept and not _is_quadratic
|
|
187
|
+
if _augment_intercept:
|
|
188
|
+
# Augment X with a column of ones
|
|
189
|
+
ones_col = xp_ones((X.shape[0], 1), dtype=X.dtype, xp=xp, ref_arr=X)
|
|
190
|
+
X_c = xp.concatenate([X, ones_col], axis=1)
|
|
191
|
+
y_c = y
|
|
192
|
+
n_aug = n_features + 1
|
|
193
|
+
elif fit_intercept:
|
|
194
|
+
# squared_error: centering is exact for identity link
|
|
195
|
+
X_mean = xp.mean(X, axis=0)
|
|
196
|
+
y_mean = xp.mean(y)
|
|
197
|
+
X_c = X - X_mean
|
|
198
|
+
y_c = y - y_mean
|
|
199
|
+
n_aug = n_features
|
|
200
|
+
else:
|
|
201
|
+
X_c = X
|
|
202
|
+
y_c = y
|
|
203
|
+
n_aug = n_features
|
|
204
|
+
|
|
205
|
+
# Precompute Lipschitz using loss-specific method.
|
|
206
|
+
# Pass zero coef (global bound) -- not all losses handle coef=None.
|
|
207
|
+
_zero_coef_lla = _zeros(n_aug, backend, ref_tensor=X_c)
|
|
208
|
+
L_base = loss.lipschitz(X_c, _zero_coef_lla, y=y_c)
|
|
209
|
+
# Precompute XtX only for squared_error fast path (skip for GLM losses)
|
|
210
|
+
XtX = X_c.T @ X_c if _is_quadratic else None
|
|
211
|
+
if L_base <= 0:
|
|
212
|
+
L_base = 1.0
|
|
213
|
+
|
|
214
|
+
# Apply loss-specific Lipschitz safety factor (e.g. NB=2x, gamma=3x)
|
|
215
|
+
_lipschitz_safety = getattr(loss, '_lipschitz_safety', 1.0)
|
|
216
|
+
if _lipschitz_safety > 1.0:
|
|
217
|
+
L_base = L_base * _lipschitz_safety
|
|
218
|
+
|
|
219
|
+
# Y-scaling for exp-link families (Poisson, Gamma, etc.).
|
|
220
|
+
# At coef=0, mu~1, but near the optimum mu~y. The Hessian scales
|
|
221
|
+
# with mu, so L_base underestimates by up to max(y).
|
|
222
|
+
# Cap at 10x -- periodic Lipschitz recomputation corrects any remaining
|
|
223
|
+
# underestimate during the FISTA inner loop.
|
|
224
|
+
_skip_y_scaling = getattr(loss, '_lipschitz_uses_y', False)
|
|
225
|
+
_y_lipschitz_scale = 1.0
|
|
226
|
+
if not _is_quadratic and not _skip_y_scaling:
|
|
227
|
+
_y_arr = _to_numpy(y_c)
|
|
228
|
+
_y_abs = np.abs(_y_arr)
|
|
229
|
+
_y_mean = float(np.mean(_y_abs))
|
|
230
|
+
_y_max = float(np.max(_y_abs))
|
|
231
|
+
_y_lipschitz_scale = min(10.0, max(1.0, np.sqrt(_y_mean * _y_max)))
|
|
232
|
+
if _y_lipschitz_scale > 1.0:
|
|
233
|
+
L_base = L_base * _y_lipschitz_scale
|
|
234
|
+
|
|
235
|
+
def _zeros_coef():
|
|
236
|
+
return _zeros(n_aug, backend, ref_tensor=X_c)
|
|
237
|
+
|
|
238
|
+
def _warm_start_coef():
|
|
239
|
+
if init_coef is None:
|
|
240
|
+
return None
|
|
241
|
+
if backend == "torch":
|
|
242
|
+
import torch
|
|
243
|
+
_init = torch.as_tensor(init_coef, device=X_c.device, dtype=X_c.dtype)
|
|
244
|
+
if _augment_intercept and _init.shape[0] == n_features:
|
|
245
|
+
return torch.cat([
|
|
246
|
+
_init,
|
|
247
|
+
torch.tensor(
|
|
248
|
+
[0.0 if init_intercept is None else init_intercept],
|
|
249
|
+
device=X_c.device,
|
|
250
|
+
dtype=X_c.dtype,
|
|
251
|
+
),
|
|
252
|
+
])
|
|
253
|
+
return _init.clone()
|
|
254
|
+
if backend == "cupy":
|
|
255
|
+
import cupy as cp
|
|
256
|
+
_init = cp.asarray(init_coef, dtype=X_c.dtype)
|
|
257
|
+
if _augment_intercept and _init.shape[0] == n_features:
|
|
258
|
+
return cp.concatenate([
|
|
259
|
+
_init,
|
|
260
|
+
cp.array([0.0 if init_intercept is None else init_intercept], dtype=X_c.dtype),
|
|
261
|
+
])
|
|
262
|
+
return _init.copy()
|
|
263
|
+
_init = np.asarray(init_coef, dtype=np.float64)
|
|
264
|
+
if _augment_intercept and _init.shape[0] == n_features:
|
|
265
|
+
return np.concatenate([
|
|
266
|
+
_init,
|
|
267
|
+
[0.0 if init_intercept is None else float(init_intercept)],
|
|
268
|
+
])
|
|
269
|
+
return _init.copy()
|
|
270
|
+
|
|
271
|
+
# Keep the continuation path deterministic from zero. CV warm-starts are
|
|
272
|
+
# injected only at the target-alpha step, otherwise SCAD/MCP LLA weights can
|
|
273
|
+
# follow a different local trajectory for NB/Tweedie-like losses.
|
|
274
|
+
coef = _zeros_coef()
|
|
275
|
+
warm_coef = _warm_start_coef()
|
|
276
|
+
|
|
277
|
+
total_iter = 0
|
|
278
|
+
inner_pen = AdaptiveL1Penalty(alpha=1.0)
|
|
279
|
+
path_records = [] if return_path else None
|
|
280
|
+
|
|
281
|
+
def _split_current_coef(current_coef):
|
|
282
|
+
coef_all = np.asarray(_to_numpy(current_coef), dtype=np.float64).ravel()
|
|
283
|
+
if _augment_intercept:
|
|
284
|
+
return coef_all[:n_features].copy(), float(coef_all[n_features])
|
|
285
|
+
if fit_intercept:
|
|
286
|
+
X_mean_np = np.asarray(_to_numpy(X_mean), dtype=np.float64).ravel()
|
|
287
|
+
y_mean_np = float(_to_numpy(y_mean))
|
|
288
|
+
return coef_all.copy(), float(y_mean_np - X_mean_np @ coef_all)
|
|
289
|
+
return coef_all.copy(), 0.0
|
|
290
|
+
|
|
291
|
+
def _record_path_alpha(alpha_value):
|
|
292
|
+
if path_records is None:
|
|
293
|
+
return
|
|
294
|
+
coef_rec, intercept_rec = _split_current_coef(coef)
|
|
295
|
+
path_records.append({
|
|
296
|
+
"alpha": float(alpha_value),
|
|
297
|
+
"coef": coef_rec,
|
|
298
|
+
"intercept": float(intercept_rec),
|
|
299
|
+
"n_iter": int(total_iter),
|
|
300
|
+
})
|
|
301
|
+
|
|
302
|
+
# For squared_error + GPU: fully inlined fused loop.
|
|
303
|
+
# Uses torch.compile for torch, ElementwiseKernel for cupy.
|
|
304
|
+
# Must gate on sample_weight is None because the fused path uses
|
|
305
|
+
# unweighted Gram matrix (XtX, Xty) which is incorrect for weighted data.
|
|
306
|
+
if _is_quadratic and backend in ("torch", "cupy") and sample_weight is None:
|
|
307
|
+
Xty = X_c.T @ y_c
|
|
308
|
+
|
|
309
|
+
# Get fused proximal kernel
|
|
310
|
+
if backend == "torch":
|
|
311
|
+
_fused = _get_sqerr_proximal_torch()
|
|
312
|
+
coef_old = coef.clone()
|
|
313
|
+
y_k = coef.clone()
|
|
314
|
+
else:
|
|
315
|
+
_fused = _get_sqerr_proximal_cupy()
|
|
316
|
+
coef_old = coef.copy()
|
|
317
|
+
y_k = coef.copy()
|
|
318
|
+
|
|
319
|
+
step = 1.0 / L_base
|
|
320
|
+
t_k = 1.0
|
|
321
|
+
|
|
322
|
+
for _cont_i, cont_alpha in enumerate(alpha_path):
|
|
323
|
+
# Create a copy with the continuation alpha to avoid mutating
|
|
324
|
+
# the shared penalty object (thread-safety for future parallel CV).
|
|
325
|
+
_pen_step = copy.copy(scad_penalty)
|
|
326
|
+
_pen_step.alpha = float(cont_alpha)
|
|
327
|
+
_mi = max_iter[_cont_i] if isinstance(max_iter, (list, tuple)) else max_iter
|
|
328
|
+
if warm_coef is not None and _cont_i == len(alpha_path) - 1:
|
|
329
|
+
coef = _copy_arr(warm_coef)
|
|
330
|
+
for _lla_i in range(max_lla_per_step):
|
|
331
|
+
# lla_weights() is now backend-aware -- stays on device
|
|
332
|
+
lla_w = _pen_step.lla_weights(coef)
|
|
333
|
+
thresh = lla_w * step # stays on device
|
|
334
|
+
|
|
335
|
+
# Save coef for LLA convergence check (on device)
|
|
336
|
+
coef_before_lla = _copy_arr(coef)
|
|
337
|
+
|
|
338
|
+
# Reset momentum for new LLA step
|
|
339
|
+
t_k = 1.0
|
|
340
|
+
coef_old = _copy_arr(coef)
|
|
341
|
+
y_k = _copy_arr(coef)
|
|
342
|
+
|
|
343
|
+
# FISTA inner solve (inlined, fused proximal+momentum)
|
|
344
|
+
_conv_interval = 20 # check convergence every N iters (reduced GPU sync)
|
|
345
|
+
iteration = -1 # guard against _mi=0 causing UnboundLocalError
|
|
346
|
+
for iteration in range(_mi):
|
|
347
|
+
coef_old = _copy_arr(coef)
|
|
348
|
+
|
|
349
|
+
# Gradient: grad = (XtX @ y_k - Xty) / n
|
|
350
|
+
grad = (XtX @ y_k - Xty) / n_samples
|
|
351
|
+
|
|
352
|
+
# Clip gradients
|
|
353
|
+
if iteration % 10 == 0:
|
|
354
|
+
grad = _clip_grad_on_device(grad, coef_old, backend)
|
|
355
|
+
|
|
356
|
+
# Compute momentum beta BEFORE proximal so fused kernel does both
|
|
357
|
+
if _no_momentum:
|
|
358
|
+
beta_mom = 0.0
|
|
359
|
+
else:
|
|
360
|
+
beta_mom, t_k = _nesterov_momentum(t_k)
|
|
361
|
+
|
|
362
|
+
# Fused proximal + momentum in one kernel call. The gradient
|
|
363
|
+
# is evaluated at y_k, so y_k is the proximal center.
|
|
364
|
+
coef, y_k = _fused(y_k, grad, step, thresh, coef_old, beta_mom)
|
|
365
|
+
|
|
366
|
+
# Convergence check (device-side, minimal sync)
|
|
367
|
+
if iteration < 20 or iteration % _conv_interval == 0:
|
|
368
|
+
coef_diff_dev = _abs_sum_dev(coef - coef_old)
|
|
369
|
+
_cdf = _to_float_scalar(coef_diff_dev)
|
|
370
|
+
converged = _cdf < tol
|
|
371
|
+
diverged = (not np.isfinite(_cdf))
|
|
372
|
+
if converged:
|
|
373
|
+
break
|
|
374
|
+
if diverged:
|
|
375
|
+
coef = _copy_arr(coef_old)
|
|
376
|
+
break
|
|
377
|
+
|
|
378
|
+
total_iter += iteration + 1
|
|
379
|
+
|
|
380
|
+
# LLA convergence check (device-side, minimal sync)
|
|
381
|
+
delta_dev = _abs_sum_dev(coef - coef_before_lla)
|
|
382
|
+
if _to_float_scalar(delta_dev) < lla_tol:
|
|
383
|
+
break
|
|
384
|
+
_record_path_alpha(cont_alpha)
|
|
385
|
+
else:
|
|
386
|
+
# Pre-compute XtX and Xty for squared_error (avoids redundant matmuls).
|
|
387
|
+
# Must gate on sample_weight is None because XtX/Xty are unweighted.
|
|
388
|
+
_use_xtx = _is_quadratic and backend == "numpy" and sample_weight is None
|
|
389
|
+
if _use_xtx:
|
|
390
|
+
Xty = X_c.T @ y_c
|
|
391
|
+
|
|
392
|
+
for _cont_i, cont_alpha in enumerate(alpha_path):
|
|
393
|
+
# Create a copy with the continuation alpha to avoid mutating
|
|
394
|
+
# the shared penalty object (thread-safety for future parallel CV).
|
|
395
|
+
_pen_step = copy.copy(scad_penalty)
|
|
396
|
+
_pen_step.alpha = float(cont_alpha)
|
|
397
|
+
_mi = max_iter[_cont_i] if isinstance(max_iter, (list, tuple)) else max_iter
|
|
398
|
+
if warm_coef is not None and _cont_i == len(alpha_path) - 1:
|
|
399
|
+
coef = _copy_arr(warm_coef)
|
|
400
|
+
|
|
401
|
+
for _lla_i in range(max_lla_per_step):
|
|
402
|
+
# lla_weights() is now backend-aware -- stays on device
|
|
403
|
+
if _augment_intercept:
|
|
404
|
+
lla_w_feat = _pen_step.lla_weights(coef[:n_features])
|
|
405
|
+
# Append 0.0 for intercept on device
|
|
406
|
+
_zero_append = _zeros(1, backend, ref_tensor=coef)
|
|
407
|
+
lla_w = xp.concatenate([lla_w_feat, _zero_append])
|
|
408
|
+
else:
|
|
409
|
+
lla_w = _pen_step.lla_weights(coef)
|
|
410
|
+
if lla_penalty_factory is not None:
|
|
411
|
+
# lla_penalty_factory expects numpy; convert only if needed
|
|
412
|
+
lla_w_np = _to_numpy(lla_w) if type(lla_w).__module__ != "numpy" else lla_w
|
|
413
|
+
inner_pen = lla_penalty_factory(lla_w_np)
|
|
414
|
+
else:
|
|
415
|
+
inner_pen._weights = lla_w
|
|
416
|
+
|
|
417
|
+
# Save coef for LLA convergence check (on device)
|
|
418
|
+
coef_before_lla = _copy_arr(coef)
|
|
419
|
+
|
|
420
|
+
# --- FISTA inner solve (fixed-step, no backtracking) ---
|
|
421
|
+
y_k = _copy_arr(coef)
|
|
422
|
+
t_k = 1.0
|
|
423
|
+
L = L_base
|
|
424
|
+
|
|
425
|
+
# Get fused proximal+momentum kernel for GPU paths
|
|
426
|
+
if backend == "torch":
|
|
427
|
+
_fused_update = _get_sqerr_proximal_torch()
|
|
428
|
+
elif backend == "cupy":
|
|
429
|
+
_fused_update = _get_sqerr_proximal_cupy()
|
|
430
|
+
else:
|
|
431
|
+
_fused_update = None
|
|
432
|
+
step = 1.0 / L
|
|
433
|
+
|
|
434
|
+
# Pre-compute device-side tolerance for convergence check
|
|
435
|
+
if backend != "numpy":
|
|
436
|
+
_tol_dev = xp.asarray(tol)
|
|
437
|
+
|
|
438
|
+
# --- Async inner loop: skip backtracking, use fixed step ---
|
|
439
|
+
# For LLA, the Lipschitz constant L is pre-computed and stable.
|
|
440
|
+
# Backtracking is unnecessary — use fixed step 1/L.
|
|
441
|
+
# This eliminates per-iteration GPU→CPU syncs.
|
|
442
|
+
for iteration in range(_mi):
|
|
443
|
+
coef_old = _copy_arr(coef)
|
|
444
|
+
|
|
445
|
+
if _use_xtx:
|
|
446
|
+
grad = (XtX @ y_k - Xty) / n_samples
|
|
447
|
+
else:
|
|
448
|
+
if sample_weight is not None:
|
|
449
|
+
_, grad = loss.fused_value_and_gradient(
|
|
450
|
+
X_c, y_c, y_k, sample_weight=sample_weight,
|
|
451
|
+
)
|
|
452
|
+
else:
|
|
453
|
+
_, grad = loss.fused_value_and_gradient(X_c, y_c, y_k)
|
|
454
|
+
|
|
455
|
+
# Clip gradients (device-side, every 10 iterations)
|
|
456
|
+
if backend == "numpy" or iteration % 10 == 0:
|
|
457
|
+
_gn_dev = _norm2_dev(grad)
|
|
458
|
+
_gsum = _abs_sum_dev(coef_old) * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR
|
|
459
|
+
if backend == "torch":
|
|
460
|
+
_gmax_dev = xp.clamp(_gsum, min=_GRAD_CLIP_MAX)
|
|
461
|
+
else:
|
|
462
|
+
_gmax_dev = xp.maximum(_gsum, _GRAD_CLIP_MAX)
|
|
463
|
+
_gn_f, _gmax_f = _sync_scalars(_gn_dev, _gmax_dev, backend=backend)
|
|
464
|
+
if _gn_f > _gmax_f:
|
|
465
|
+
grad = grad * (_gmax_dev / _gn_dev)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
# Compute momentum beta before fused update
|
|
469
|
+
if _no_momentum:
|
|
470
|
+
beta_mom = 0.0
|
|
471
|
+
elif _conservative_momentum_lla:
|
|
472
|
+
beta_mom, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
|
|
473
|
+
else:
|
|
474
|
+
beta_mom, t_k = _nesterov_momentum(t_k)
|
|
475
|
+
|
|
476
|
+
# Fused proximal + momentum: single kernel launch on GPU
|
|
477
|
+
# Combines: w_tilde = y_k - step*grad
|
|
478
|
+
# coef = proximal(w_tilde, step) [weighted soft-threshold]
|
|
479
|
+
# y_k = coef + beta * (coef - coef_old)
|
|
480
|
+
# Reduces 3 kernel launches to 1.
|
|
481
|
+
if _fused_update is not None and backend != "numpy":
|
|
482
|
+
# Ensure thresh is on the correct device
|
|
483
|
+
_w = inner_pen._weights
|
|
484
|
+
if isinstance(_w, np.ndarray):
|
|
485
|
+
_w = xp.asarray(_w, dtype=coef.dtype)
|
|
486
|
+
thresh = _w * inner_pen.alpha * step
|
|
487
|
+
coef, y_k = _fused_update(y_k, grad, step, thresh, coef_old, beta_mom)
|
|
488
|
+
else:
|
|
489
|
+
w_tilde = y_k - step * grad
|
|
490
|
+
coef = inner_pen.proximal(w_tilde, step, backend=backend)
|
|
491
|
+
y_k = coef + beta_mom * (coef - coef_old)
|
|
492
|
+
|
|
493
|
+
# Convergence (device-side comparison, only D2H 1 bool)
|
|
494
|
+
if backend == "numpy" or iteration < 20 or iteration % 5 == 0:
|
|
495
|
+
_conv_dev = _abs_sum_dev(coef - coef_old)
|
|
496
|
+
if backend != "numpy":
|
|
497
|
+
if bool(_to_numpy(_conv_dev < _tol_dev)):
|
|
498
|
+
break
|
|
499
|
+
else:
|
|
500
|
+
if float(_to_numpy(_conv_dev)) < tol:
|
|
501
|
+
break
|
|
502
|
+
|
|
503
|
+
# Periodic Lipschitz recomputation -- corrects stale L
|
|
504
|
+
# as coef moves away from zero.
|
|
505
|
+
if not _is_quadratic and iteration > 0 and iteration % 20 == 0:
|
|
506
|
+
L_new = loss.lipschitz(X_c, coef, y=y_c)
|
|
507
|
+
if _y_lipschitz_scale > 1.0:
|
|
508
|
+
L_new = L_new * _y_lipschitz_scale
|
|
509
|
+
if L_new > L * 1.5 or L_new < L / 1.5:
|
|
510
|
+
L = max(L_new, L_base * 0.1)
|
|
511
|
+
step = 1.0 / L
|
|
512
|
+
|
|
513
|
+
total_iter += 1
|
|
514
|
+
# --- end FISTA ---
|
|
515
|
+
|
|
516
|
+
# LLA convergence (on device, single sync for scalar)
|
|
517
|
+
delta = float(_to_numpy(_abs_sum_dev(coef - coef_before_lla)))
|
|
518
|
+
if delta < lla_tol:
|
|
519
|
+
break
|
|
520
|
+
_record_path_alpha(cont_alpha)
|
|
521
|
+
|
|
522
|
+
# Extract coef and intercept
|
|
523
|
+
coef_np, intercept = _split_current_coef(coef)
|
|
524
|
+
|
|
525
|
+
if return_path:
|
|
526
|
+
if path_records:
|
|
527
|
+
path = {
|
|
528
|
+
"alpha": np.asarray([r["alpha"] for r in path_records], dtype=np.float64),
|
|
529
|
+
"coef": np.vstack([r["coef"] for r in path_records]).astype(np.float64, copy=False),
|
|
530
|
+
"intercept": np.asarray([r["intercept"] for r in path_records], dtype=np.float64),
|
|
531
|
+
"n_iter": np.asarray([r["n_iter"] for r in path_records], dtype=np.int64),
|
|
532
|
+
}
|
|
533
|
+
else:
|
|
534
|
+
path = {
|
|
535
|
+
"alpha": np.empty(0, dtype=np.float64),
|
|
536
|
+
"coef": np.empty((0, n_features), dtype=np.float64),
|
|
537
|
+
"intercept": np.empty(0, dtype=np.float64),
|
|
538
|
+
"n_iter": np.empty(0, dtype=np.int64),
|
|
539
|
+
}
|
|
540
|
+
return coef_np, intercept, total_iter, path
|
|
541
|
+
return coef_np, intercept, total_iter
|