statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
"""FISTA solver with backtracking line search.
|
|
2
|
+
|
|
3
|
+
minimize: loss(X, y, w) + penalty(w)
|
|
4
|
+
|
|
5
|
+
Supports numpy / cupy / torch backends via auto-detection.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
__all__ = ["fista_solver"]
|
|
11
|
+
|
|
12
|
+
import warnings
|
|
13
|
+
import numpy as np
|
|
14
|
+
from statgpu.backends import _resolve_backend, _to_numpy
|
|
15
|
+
from statgpu.backends._utils import _to_float_scalar, _get_xp
|
|
16
|
+
from statgpu.backends._array_ops import (
|
|
17
|
+
_abs_sum_dev,
|
|
18
|
+
_clip_grad_on_device,
|
|
19
|
+
_copy_arr,
|
|
20
|
+
_dot_dev,
|
|
21
|
+
_norm2_dev,
|
|
22
|
+
_sum_sq_dev,
|
|
23
|
+
_sync_scalars,
|
|
24
|
+
_zeros,
|
|
25
|
+
)
|
|
26
|
+
from ._convergence import ConvergenceWarning
|
|
27
|
+
from ._constants import (
|
|
28
|
+
_SLACK_TOLERANCE,
|
|
29
|
+
_DIVERGE_COEF_NORM_CAP,
|
|
30
|
+
_LIPSCHITZ_SAFETY_LOGISTIC_CV,
|
|
31
|
+
_GRAD_CLIP_COEF_FACTOR,
|
|
32
|
+
_GRAD_CLIP_ABS_FLOOR,
|
|
33
|
+
_GRAD_CLIP_MAX,
|
|
34
|
+
)
|
|
35
|
+
from ._utils import (
|
|
36
|
+
_validate_sample_weight,
|
|
37
|
+
_as_backend_vector,
|
|
38
|
+
_call_with_weight,
|
|
39
|
+
_nesterov_update,
|
|
40
|
+
_penalty_name,
|
|
41
|
+
_smooth_penalty_lipschitz,
|
|
42
|
+
_abs_mean_max,
|
|
43
|
+
_tracking_penalty_value,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def fista_solver(
|
|
48
|
+
loss: "GLMLoss",
|
|
49
|
+
penalty: "Penalty | None",
|
|
50
|
+
X,
|
|
51
|
+
y,
|
|
52
|
+
max_iter: int = 1000,
|
|
53
|
+
tol: float = 1e-4,
|
|
54
|
+
init_coef=None,
|
|
55
|
+
sample_weight=None,
|
|
56
|
+
lipschitz_L: float | None = None,
|
|
57
|
+
cv_mode: bool = False,
|
|
58
|
+
) -> tuple:
|
|
59
|
+
"""General FISTA solver with backtracking line search.
|
|
60
|
+
|
|
61
|
+
Supports numpy / cupy / torch backends via auto-detection of X.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
loss : GLMLoss
|
|
66
|
+
GLM loss function with gradient(), lipschitz(), preprocess(), value().
|
|
67
|
+
penalty : Penalty
|
|
68
|
+
Penalty with proximal().
|
|
69
|
+
X : array
|
|
70
|
+
Design matrix (numpy/cupy/torch).
|
|
71
|
+
y : array
|
|
72
|
+
Target (numpy/cupy/torch).
|
|
73
|
+
max_iter : int
|
|
74
|
+
Maximum iterations.
|
|
75
|
+
tol : float
|
|
76
|
+
Convergence tolerance.
|
|
77
|
+
init_coef : array, optional
|
|
78
|
+
Initial coefficient vector.
|
|
79
|
+
sample_weight : array, optional
|
|
80
|
+
Per-sample weights. Non-uniform weights are currently rejected in this
|
|
81
|
+
solver path to avoid silently running an incorrect unweighted update.
|
|
82
|
+
cv_mode : bool, default=False
|
|
83
|
+
Private CV fast path: keeps the same update rule but checks objective
|
|
84
|
+
and convergence less often on GPU non-smooth GLM paths.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
coef : array
|
|
89
|
+
Fitted coefficients (same backend as X).
|
|
90
|
+
n_iter : int
|
|
91
|
+
Number of iterations.
|
|
92
|
+
"""
|
|
93
|
+
backend = _resolve_backend("auto", X)
|
|
94
|
+
X_proc, y_proc = loss.preprocess(X, y)
|
|
95
|
+
_is_quadratic = getattr(loss, '_is_quadratic', False)
|
|
96
|
+
# Momentum control via loss class attributes:
|
|
97
|
+
# _momentum_beta_cap: if set, cap Nesterov beta at this value
|
|
98
|
+
# _skip_momentum: if True, disable momentum entirely
|
|
99
|
+
# Conservative momentum (cap beta at 0.5) for exp-link families and
|
|
100
|
+
# for logistic/gamma with non-smooth penalties. Logistic/gamma with
|
|
101
|
+
# smooth penalties (none, l2) benefit from full Nesterov acceleration.
|
|
102
|
+
_momentum_beta_cap = getattr(loss, '_momentum_beta_cap', None)
|
|
103
|
+
_skip_momentum = getattr(loss, '_skip_momentum', False)
|
|
104
|
+
|
|
105
|
+
n_features = X_proc.shape[1]
|
|
106
|
+
if init_coef is not None:
|
|
107
|
+
coef = _as_backend_vector(init_coef, backend, X)
|
|
108
|
+
else:
|
|
109
|
+
coef = _zeros(n_features, backend, ref_tensor=X)
|
|
110
|
+
|
|
111
|
+
y_k = _copy_arr(coef)
|
|
112
|
+
t_k = 1.0
|
|
113
|
+
|
|
114
|
+
# Divergence detection: track best objective for recovery
|
|
115
|
+
_obj_best_fista = float('inf')
|
|
116
|
+
_coef_best_fista = None
|
|
117
|
+
|
|
118
|
+
# Initial Lipschitz: default to zero (safe for exp-link warm starts),
|
|
119
|
+
# but allow losses to request evaluation at the provided init to avoid
|
|
120
|
+
# degenerate curvature from eta=0 clipping.
|
|
121
|
+
_cached_XtWX_weighted = None # populated in Lipschitz block, used in GPU loop
|
|
122
|
+
if lipschitz_L is not None and lipschitz_L > 0:
|
|
123
|
+
L = lipschitz_L
|
|
124
|
+
else:
|
|
125
|
+
if getattr(loss, '_lipschitz_at_init', False):
|
|
126
|
+
_lip_coef = _copy_arr(coef)
|
|
127
|
+
else:
|
|
128
|
+
_lip_coef = _zeros(n_features, backend, ref_tensor=X)
|
|
129
|
+
if sample_weight is not None:
|
|
130
|
+
# Weighted Lipschitz: eigenvalue of X' diag(w) X / sum(w)
|
|
131
|
+
_xp_mod = _get_xp(backend)
|
|
132
|
+
# Ensure sample_weight is on same backend as X_proc
|
|
133
|
+
_sw_np = _to_numpy(sample_weight)
|
|
134
|
+
_sw = _xp_mod.asarray(_sw_np, dtype=X_proc.dtype)
|
|
135
|
+
sw_sum = _to_float_scalar(_xp_mod.sum(_sw))
|
|
136
|
+
sw_col = _sw[:, None] if _sw.ndim == 1 else _sw
|
|
137
|
+
XtWX = X_proc.T @ (X_proc * sw_col) / sw_sum
|
|
138
|
+
L = _to_float_scalar(_xp_mod.max(_xp_mod.diag(XtWX))) # conservative bound
|
|
139
|
+
if L <= 0:
|
|
140
|
+
L = 1.0
|
|
141
|
+
# Cache for periodic recomputation in the loop (X and weights are constant)
|
|
142
|
+
_cached_XtWX_weighted = XtWX
|
|
143
|
+
else:
|
|
144
|
+
L = loss.lipschitz(X_proc, _lip_coef, y=y_proc)
|
|
145
|
+
_cached_XtWX_weighted = None
|
|
146
|
+
if L <= 0:
|
|
147
|
+
L = 1.0
|
|
148
|
+
# Add smooth penalty Lipschitz contribution (e.g. l2 penalty gradient
|
|
149
|
+
# alpha*coef has Lipschitz constant alpha). Without this, the step
|
|
150
|
+
# size 1/L is too large, causing oscillation near the optimum.
|
|
151
|
+
_smooth_lip = _smooth_penalty_lipschitz(penalty)
|
|
152
|
+
if _smooth_lip > 0:
|
|
153
|
+
L = L + _smooth_lip
|
|
154
|
+
# For GLM losses with exp link (Poisson, etc.), mu at coef=0
|
|
155
|
+
# is ~1, but mu near the optimum ≈ y. Scale Lipschitz up by a
|
|
156
|
+
# geometric-mean factor to avoid oversized first steps that cause
|
|
157
|
+
# divergence on non-smooth penalties (scad, mcp, etc.).
|
|
158
|
+
# Logistic now uses iterate-dependent Lipschitz, so y-scaling applies.
|
|
159
|
+
# Gamma's expected Fisher Hessian X'X/n underestimates
|
|
160
|
+
# true curvature by ~mean(y), so y-scaling IS needed.
|
|
161
|
+
_skip_y_scaling = getattr(loss, '_lipschitz_uses_y', False)
|
|
162
|
+
_y_scale = 1.0 # default; overridden below for families that need it
|
|
163
|
+
if not _is_quadratic and not _skip_y_scaling:
|
|
164
|
+
_y_mean, _y_max = _abs_mean_max(y_proc, backend)
|
|
165
|
+
_y_scale = max(1.0, _y_mean, np.sqrt(_y_mean * _y_max))
|
|
166
|
+
if _y_scale > 1.0:
|
|
167
|
+
L = L * _y_scale
|
|
168
|
+
|
|
169
|
+
# Loss-specific Lipschitz safety factors (from loss class attributes)
|
|
170
|
+
_lip_safety = getattr(loss, '_lipschitz_safety', 1.0)
|
|
171
|
+
if _lip_safety > 1.0:
|
|
172
|
+
L = L * _lip_safety
|
|
173
|
+
# Additional safety for CV mode (from loss class attribute)
|
|
174
|
+
_lip_safety_cv = getattr(loss, '_lipschitz_safety_cv', _LIPSCHITZ_SAFETY_LOGISTIC_CV if cv_mode else 1.0)
|
|
175
|
+
if cv_mode and _lip_safety_cv > 1.0:
|
|
176
|
+
L = L * _lip_safety_cv
|
|
177
|
+
# Async GPU loop: skip backtracking, deferred checks.
|
|
178
|
+
# For non-smooth penalties (l1, elasticnet, scad, mcp, adaptive, group):
|
|
179
|
+
# - Quadratic losses (squared_error): Lipschitz is exact, fixed step is optimal
|
|
180
|
+
# - GLM losses: use 3x safety factor on Lipschitz, no backtracking
|
|
181
|
+
# Smooth penalties (l2, none) need backtracking for GLM losses.
|
|
182
|
+
n_samples = X_proc.shape[0]
|
|
183
|
+
_pen_name_lower = _penalty_name(penalty)
|
|
184
|
+
_non_smooth = _pen_name_lower not in ("none", "null", "l2", "")
|
|
185
|
+
_gpu_excluded = getattr(loss, '_gpu_loop_excluded', False) and not cv_mode
|
|
186
|
+
# Async GPU loop: skip backtracking, use fixed step size.
|
|
187
|
+
# For squared_error + non-smooth penalties, Lipschitz is exact → no backtracking needed.
|
|
188
|
+
# For GLM losses, only enabled in CV mode (backtracking needed for safety).
|
|
189
|
+
_use_gpu_loop = (
|
|
190
|
+
backend in ("torch", "cupy")
|
|
191
|
+
and _non_smooth
|
|
192
|
+
and (cv_mode or _is_quadratic)
|
|
193
|
+
and not _gpu_excluded
|
|
194
|
+
)
|
|
195
|
+
_is_gpu = backend in ("torch", "cupy")
|
|
196
|
+
_conv_interval = 3
|
|
197
|
+
_div_interval = 5
|
|
198
|
+
_lip_interval = 5
|
|
199
|
+
if _use_gpu_loop:
|
|
200
|
+
_conv_interval = 10
|
|
201
|
+
_div_interval = 25
|
|
202
|
+
_lip_interval = 25
|
|
203
|
+
_validate_sample_weight(sample_weight, X_proc.shape[0])
|
|
204
|
+
|
|
205
|
+
# Gram matrix optimization for squared_error on async GPU path only.
|
|
206
|
+
# Precompute X'X/n and X'y/n to avoid redundant X@coef per iteration.
|
|
207
|
+
_use_xtx = _is_quadratic and sample_weight is None and _use_gpu_loop
|
|
208
|
+
if _use_xtx:
|
|
209
|
+
_xp_mod = _get_xp(backend)
|
|
210
|
+
XtX = X_proc.T @ X_proc / n_samples
|
|
211
|
+
Xty = X_proc.T @ y_proc / n_samples
|
|
212
|
+
else:
|
|
213
|
+
XtX = None
|
|
214
|
+
Xty = None
|
|
215
|
+
|
|
216
|
+
iteration = -1 # default if max_iter=0
|
|
217
|
+
|
|
218
|
+
for iteration in range(max_iter):
|
|
219
|
+
coef_old = _copy_arr(coef)
|
|
220
|
+
|
|
221
|
+
# Compute gradient
|
|
222
|
+
if _use_xtx and XtX is not None:
|
|
223
|
+
# Gram matrix path: single matmul instead of X@coef + X.T@resid
|
|
224
|
+
# XtX = X'X/n, Xty = X'y/n, so grad = XtX @ w - Xty = X'(Xw-y)/n
|
|
225
|
+
grad = XtX @ y_k - Xty
|
|
226
|
+
q_yk_dev = loss.value(X_proc, y_proc, y_k)
|
|
227
|
+
elif sample_weight is not None:
|
|
228
|
+
q_yk_dev, grad = loss.fused_value_and_gradient(
|
|
229
|
+
X_proc, y_proc, y_k, sample_weight=sample_weight
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
q_yk_dev, grad = loss.fused_value_and_gradient(X_proc, y_proc, y_k)
|
|
233
|
+
|
|
234
|
+
if _use_gpu_loop:
|
|
235
|
+
# -- GPU async path: all ops stay on device --
|
|
236
|
+
grad = _clip_grad_on_device(grad, coef_old, backend)
|
|
237
|
+
|
|
238
|
+
step = 1.0 / L
|
|
239
|
+
|
|
240
|
+
# Single proximal step -- no backtracking (L is conservative enough)
|
|
241
|
+
w_tilde = y_k - step * grad
|
|
242
|
+
coef = penalty.proximal(w_tilde, step, backend=backend)
|
|
243
|
+
|
|
244
|
+
# ALL safety checks deferred -- no per-iteration GPU->CPU sync.
|
|
245
|
+
# Finiteness + divergence + objective tracking batched together.
|
|
246
|
+
if iteration > 0 and (iteration < 20 or iteration % _div_interval == 0):
|
|
247
|
+
_obj_dev = loss.value(X_proc, y_proc, coef)
|
|
248
|
+
# Single D2H transfer: extract float, then check finiteness.
|
|
249
|
+
_obj_val_f = float(_to_numpy(_obj_dev))
|
|
250
|
+
_all_finite = np.isfinite(_obj_val_f)
|
|
251
|
+
if not _all_finite:
|
|
252
|
+
if _coef_best_fista is not None:
|
|
253
|
+
coef = _copy_arr(_coef_best_fista)
|
|
254
|
+
else:
|
|
255
|
+
coef = _zeros(n_features, backend, ref_tensor=X_proc)
|
|
256
|
+
y_k = _copy_arr(coef)
|
|
257
|
+
t_k = 1.0
|
|
258
|
+
L = L * 2.0
|
|
259
|
+
continue
|
|
260
|
+
# Track best objective (reuse _obj_val_f from finiteness check above)
|
|
261
|
+
_obj_val_f += _tracking_penalty_value(penalty, coef)
|
|
262
|
+
if _obj_val_f < _obj_best_fista:
|
|
263
|
+
_obj_best_fista = _obj_val_f
|
|
264
|
+
_coef_best_fista = _copy_arr(coef)
|
|
265
|
+
# Periodic Lipschitz recomputation (piggyback on same sync)
|
|
266
|
+
# Skip for quadratic losses -- Lipschitz is constant (spectral norm of X^T X).
|
|
267
|
+
# Interval matches CPU path for trajectory consistency.
|
|
268
|
+
if not _is_quadratic and iteration % _lip_interval == 0:
|
|
269
|
+
if sample_weight is not None and _cached_XtWX_weighted is not None:
|
|
270
|
+
# Use cached weighted Gram matrix (X and weights are constant)
|
|
271
|
+
_xp_lip = _get_xp(backend)
|
|
272
|
+
L_new = _to_float_scalar(_xp_lip.max(_xp_lip.diag(_cached_XtWX_weighted)))
|
|
273
|
+
else:
|
|
274
|
+
L_new = loss.lipschitz(X_proc, coef, y=y_proc)
|
|
275
|
+
if L_new > 0:
|
|
276
|
+
# Re-apply y-scaling (Lipschitz at current coef may not
|
|
277
|
+
# capture the y-dependent curvature scaling applied at init)
|
|
278
|
+
if _y_scale > 1.0:
|
|
279
|
+
L_new = L_new * _y_scale
|
|
280
|
+
_safety = getattr(loss, '_lipschitz_safety', 1.0)
|
|
281
|
+
L_new *= _safety
|
|
282
|
+
if _smooth_lip > 0:
|
|
283
|
+
L_new = L_new + _smooth_lip
|
|
284
|
+
if L_new > L:
|
|
285
|
+
L = L_new
|
|
286
|
+
else:
|
|
287
|
+
L = max(L * 0.8, L_new)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
else:
|
|
291
|
+
# -- CPU/GPU path with backtracking (smooth penalties) --
|
|
292
|
+
# Use identical sync-based clipping for both CPU and GPU.
|
|
293
|
+
# (Backtracking already syncs every iteration for slack check,
|
|
294
|
+
# so on-device clipping has no performance benefit here.)
|
|
295
|
+
_gn_f, _coef_abs_f = _sync_scalars(
|
|
296
|
+
_norm2_dev(grad), _abs_sum_dev(coef_old), backend=backend)
|
|
297
|
+
_gmax = max(_coef_abs_f * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
|
|
298
|
+
if _gn_f > _gmax:
|
|
299
|
+
grad = grad * (_gmax / _gn_f)
|
|
300
|
+
|
|
301
|
+
step = 1.0 / L
|
|
302
|
+
_q_new_dev_last = None
|
|
303
|
+
for _bt in range(20):
|
|
304
|
+
w_tilde = y_k - step * grad
|
|
305
|
+
coef_new = penalty.proximal(w_tilde, step, backend=backend)
|
|
306
|
+
|
|
307
|
+
diff = coef_new - y_k
|
|
308
|
+
if sample_weight is not None:
|
|
309
|
+
q_new_dev, _ = loss.fused_value_and_gradient(
|
|
310
|
+
X_proc, y_proc, coef_new, sample_weight=sample_weight
|
|
311
|
+
)
|
|
312
|
+
else:
|
|
313
|
+
q_new_dev = loss.value(X_proc, y_proc, coef_new)
|
|
314
|
+
_q_new_dev_last = q_new_dev
|
|
315
|
+
bound_dev = q_yk_dev + _dot_dev(grad, diff) + 0.5 * L * _sum_sq_dev(diff)
|
|
316
|
+
slack_dev = bound_dev + _SLACK_TOLERANCE - q_new_dev
|
|
317
|
+
_armijo_ok = _to_float_scalar(slack_dev) >= 0
|
|
318
|
+
if _armijo_ok:
|
|
319
|
+
break
|
|
320
|
+
L *= 1.5
|
|
321
|
+
step = 1.0 / L
|
|
322
|
+
|
|
323
|
+
coef = coef_new
|
|
324
|
+
|
|
325
|
+
# Finiteness check
|
|
326
|
+
if not _is_quadratic:
|
|
327
|
+
_coef_norm_dev = _norm2_dev(coef)
|
|
328
|
+
_finite_ok = np.isfinite(float(_coef_norm_dev))
|
|
329
|
+
if not _finite_ok:
|
|
330
|
+
if _coef_best_fista is not None:
|
|
331
|
+
coef = _copy_arr(_coef_best_fista)
|
|
332
|
+
y_k = _copy_arr(coef)
|
|
333
|
+
t_k = 1.0
|
|
334
|
+
L = L * 2.0
|
|
335
|
+
continue
|
|
336
|
+
|
|
337
|
+
# Divergence detection
|
|
338
|
+
if not _is_quadratic and iteration > 0:
|
|
339
|
+
_need_norm_check = (iteration > 10)
|
|
340
|
+
if _q_new_dev_last is not None:
|
|
341
|
+
_obj_dev = _q_new_dev_last
|
|
342
|
+
_q_new_dev_last = None
|
|
343
|
+
else:
|
|
344
|
+
if sample_weight is not None:
|
|
345
|
+
_obj_dev, _ = loss.fused_value_and_gradient(
|
|
346
|
+
X_proc, y_proc, coef, sample_weight=sample_weight
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
_obj_dev = loss.value(X_proc, y_proc, coef)
|
|
350
|
+
# Batched sync: objective + coef norm in one transfer
|
|
351
|
+
if _need_norm_check:
|
|
352
|
+
_obj_val_f, _coef_norm_f = _sync_scalars(
|
|
353
|
+
_obj_dev, _norm2_dev(coef), backend=backend
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
_obj_val_f = float(_to_numpy(_obj_dev))
|
|
357
|
+
_coef_norm_f = 0.0
|
|
358
|
+
_obj_val_f += _tracking_penalty_value(penalty, coef)
|
|
359
|
+
_diverged_f = False
|
|
360
|
+
if not np.isfinite(_obj_val_f):
|
|
361
|
+
_diverged_f = True
|
|
362
|
+
elif _obj_best_fista > 1e-8:
|
|
363
|
+
_diverged_f = _obj_val_f > _obj_best_fista * 10.0 + 1e-8
|
|
364
|
+
else:
|
|
365
|
+
_diverged_f = _obj_val_f > _obj_best_fista + max(abs(_obj_best_fista) * 10.0, 1.0)
|
|
366
|
+
if not _diverged_f and _need_norm_check:
|
|
367
|
+
if _coef_norm_f > _DIVERGE_COEF_NORM_CAP:
|
|
368
|
+
_diverged_f = True
|
|
369
|
+
if _diverged_f:
|
|
370
|
+
if _coef_best_fista is not None:
|
|
371
|
+
coef = _copy_arr(_coef_best_fista)
|
|
372
|
+
else:
|
|
373
|
+
coef = _zeros(n_features, backend, ref_tensor=X_proc)
|
|
374
|
+
y_k = _copy_arr(coef)
|
|
375
|
+
t_k = 1.0
|
|
376
|
+
L = L * 2.0
|
|
377
|
+
continue
|
|
378
|
+
elif _obj_val_f < _obj_best_fista:
|
|
379
|
+
_obj_best_fista = _obj_val_f
|
|
380
|
+
_coef_best_fista = _copy_arr(coef)
|
|
381
|
+
|
|
382
|
+
# Periodic Lipschitz recomputation
|
|
383
|
+
# Skip if coefficients haven't changed much (Lipschitz is stable)
|
|
384
|
+
if not _is_quadratic and iteration > 0 and iteration % 5 == 0:
|
|
385
|
+
# Batch both norms into a single GPU->CPU transfer
|
|
386
|
+
_coef_change, _coef_norm = _sync_scalars(
|
|
387
|
+
_norm2_dev(coef - coef_old), _norm2_dev(coef), backend=backend)
|
|
388
|
+
_relative_change = _coef_change / max(_coef_norm, 1e-10)
|
|
389
|
+
if _relative_change > 1e-3: # Only recompute if coefficients changed significantly
|
|
390
|
+
L_new = _call_with_weight(loss.lipschitz, X_proc, coef, y=y_proc, sample_weight=sample_weight)
|
|
391
|
+
# Safety factors from loss class
|
|
392
|
+
_lip_safety_recomp = getattr(loss, '_lipschitz_safety', 1.0)
|
|
393
|
+
if _lip_safety_recomp > 1.0:
|
|
394
|
+
L_new = L_new * _lip_safety_recomp
|
|
395
|
+
if _smooth_lip > 0:
|
|
396
|
+
L_new = L_new + _smooth_lip
|
|
397
|
+
if L_new > L:
|
|
398
|
+
L = L_new
|
|
399
|
+
else:
|
|
400
|
+
L = max(L * 0.8, L_new)
|
|
401
|
+
|
|
402
|
+
# Momentum update -- all backends
|
|
403
|
+
if _skip_momentum:
|
|
404
|
+
# No momentum (e.g. inverse_gaussian): just copy coef
|
|
405
|
+
y_k = _copy_arr(coef)
|
|
406
|
+
elif _momentum_beta_cap is not None:
|
|
407
|
+
# Conservative momentum with capped beta
|
|
408
|
+
y_k, t_k = _nesterov_update(coef, coef_old, t_k, beta_cap=_momentum_beta_cap)
|
|
409
|
+
else:
|
|
410
|
+
y_k, t_k = _nesterov_update(coef, coef_old, t_k)
|
|
411
|
+
|
|
412
|
+
# Convergence check -- deferred for GPU, every iteration for CPU
|
|
413
|
+
if _is_gpu:
|
|
414
|
+
if iteration < 20 or iteration % _conv_interval == 0:
|
|
415
|
+
_conv_dev = _abs_sum_dev(coef - coef_old)
|
|
416
|
+
if _to_float_scalar(_conv_dev) < tol:
|
|
417
|
+
break
|
|
418
|
+
else:
|
|
419
|
+
_conv_dev = _abs_sum_dev(coef - coef_old)
|
|
420
|
+
if float(_conv_dev) < tol:
|
|
421
|
+
break
|
|
422
|
+
|
|
423
|
+
# Return best iterate if available
|
|
424
|
+
if _coef_best_fista is not None:
|
|
425
|
+
coef = _copy_arr(_coef_best_fista)
|
|
426
|
+
|
|
427
|
+
n_iter = iteration + 1
|
|
428
|
+
if n_iter >= max_iter:
|
|
429
|
+
warnings.warn(
|
|
430
|
+
f"fista_solver did not converge within {max_iter} iterations "
|
|
431
|
+
f"(loss={getattr(loss, 'name', '?')}, penalty={getattr(penalty, 'name', '?')}). "
|
|
432
|
+
f"Consider increasing max_iter or using a different solver (newton, lbfgs, irls).",
|
|
433
|
+
ConvergenceWarning,
|
|
434
|
+
stacklevel=2,
|
|
435
|
+
)
|
|
436
|
+
return coef, n_iter
|