statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
"""FISTA with Barzilai-Borwein step sizes and adaptive restart.
|
|
2
|
+
|
|
3
|
+
Uses alternating BB1/BB2 steps (Barzilai & Borwein 1988) that adapt to
|
|
4
|
+
local curvature, eliminating the backtracking line search while preserving
|
|
5
|
+
sparsity. BB1 = <dw,dw>/<dw,dg> (long step), BB2 = <dw,dg>/<dg,dg>
|
|
6
|
+
(short step). Adaptive restart (O'Donoghue & Candes 2015) resets
|
|
7
|
+
momentum when it opposes the descent direction.
|
|
8
|
+
|
|
9
|
+
Supports numpy / cupy / torch backends via auto-detection of X.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
__all__ = ["fista_bb_solver"]
|
|
15
|
+
|
|
16
|
+
import warnings
|
|
17
|
+
import numpy as np
|
|
18
|
+
from statgpu.backends import _resolve_backend, _to_numpy
|
|
19
|
+
from statgpu.backends._utils import _to_float_scalar
|
|
20
|
+
from statgpu.backends._array_ops import (
|
|
21
|
+
_abs_sum_dev, _clip_grad_on_device, _copy_arr, _dot_dev,
|
|
22
|
+
_norm2_dev, _sync_scalars, _zeros,
|
|
23
|
+
)
|
|
24
|
+
from statgpu.penalties._categories import BB_DISABLED as _BB_DISABLED
|
|
25
|
+
from ._convergence import ConvergenceWarning
|
|
26
|
+
from ._constants import (
|
|
27
|
+
_DIVERGE_COEF_NORM_CAP,
|
|
28
|
+
_BB_RESTART_DOT_TOL,
|
|
29
|
+
_DIVERGE_OBJ_RATIO,
|
|
30
|
+
_DIVERGE_OBJ_ABS,
|
|
31
|
+
_GRAD_CLIP_COEF_FACTOR,
|
|
32
|
+
_GRAD_CLIP_ABS_FLOOR,
|
|
33
|
+
_GRAD_CLIP_MAX,
|
|
34
|
+
)
|
|
35
|
+
from ._fista import fista_solver
|
|
36
|
+
from ._utils import (
|
|
37
|
+
_validate_sample_weight,
|
|
38
|
+
_as_backend_vector,
|
|
39
|
+
_call_with_weight,
|
|
40
|
+
_nesterov_update,
|
|
41
|
+
_penalty_name,
|
|
42
|
+
_smooth_penalty_lipschitz,
|
|
43
|
+
_tracking_penalty_value,
|
|
44
|
+
_abs_mean_max,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def fista_bb_solver(
|
|
49
|
+
loss,
|
|
50
|
+
penalty: "Penalty | None",
|
|
51
|
+
X,
|
|
52
|
+
y,
|
|
53
|
+
max_iter: int = 1000,
|
|
54
|
+
tol: float = 1e-4,
|
|
55
|
+
init_coef=None,
|
|
56
|
+
sample_weight=None,
|
|
57
|
+
use_restart: bool = True,
|
|
58
|
+
step_max_factor: float = 1e3,
|
|
59
|
+
step_min_factor: float = 1e-3,
|
|
60
|
+
bb_burn_in: int = 20,
|
|
61
|
+
cv_mode: bool = False,
|
|
62
|
+
lipschitz_L: float | None = None,
|
|
63
|
+
) -> tuple:
|
|
64
|
+
"""FISTA with Barzilai-Borwein step sizes and adaptive restart.
|
|
65
|
+
|
|
66
|
+
Uses alternating BB1/BB2 steps (Barzilai & Borwein 1988) that adapt to
|
|
67
|
+
local curvature, eliminating the backtracking line search while preserving
|
|
68
|
+
sparsity. BB1 = <dw,dw>/<dw,dg> (long step), BB2 = <dw,dg>/<dg,dg>
|
|
69
|
+
(short step). Adaptive restart (O'Donoghue & Candes 2015) resets
|
|
70
|
+
momentum when it opposes the descent direction.
|
|
71
|
+
|
|
72
|
+
Supports numpy / cupy / torch backends via auto-detection of X.
|
|
73
|
+
"""
|
|
74
|
+
backend = _resolve_backend("auto", X)
|
|
75
|
+
_is_gpu = backend in ("torch", "cupy")
|
|
76
|
+
X_proc, y_proc = loss.preprocess(X, y)
|
|
77
|
+
n_features = X_proc.shape[1]
|
|
78
|
+
_pen_name = _penalty_name(penalty)
|
|
79
|
+
|
|
80
|
+
# Smooth logistic objectives are better handled by the Armijo-backed FISTA
|
|
81
|
+
# path. This keeps explicit fista_bb numerically aligned across CPU/CuPy/
|
|
82
|
+
# Torch for logistic+none/l2 Section A checks.
|
|
83
|
+
if getattr(loss, '_prefer_fista_over_bb', False) and _pen_name in ("l2", "none", "null", ""):
|
|
84
|
+
return fista_solver(
|
|
85
|
+
loss,
|
|
86
|
+
penalty,
|
|
87
|
+
X,
|
|
88
|
+
y,
|
|
89
|
+
max_iter=max_iter,
|
|
90
|
+
tol=tol,
|
|
91
|
+
init_coef=init_coef,
|
|
92
|
+
sample_weight=sample_weight,
|
|
93
|
+
cv_mode=cv_mode,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# --- Initialize coefficients ---
|
|
97
|
+
if init_coef is not None:
|
|
98
|
+
coef = _as_backend_vector(init_coef, backend, X)
|
|
99
|
+
else:
|
|
100
|
+
coef = _zeros(n_features, backend, ref_tensor=X)
|
|
101
|
+
|
|
102
|
+
y_k = _copy_arr(coef)
|
|
103
|
+
t_k = 1.0
|
|
104
|
+
|
|
105
|
+
# Divergence detection: track best objective for recovery
|
|
106
|
+
_obj_best = float('inf')
|
|
107
|
+
_coef_best = None
|
|
108
|
+
_diverge_count = 0
|
|
109
|
+
|
|
110
|
+
_bb_use_long = True # alternate BB1 / BB2
|
|
111
|
+
dot_dw_dg = 0.0 # BB step numerator (initialized for bb_burn_in=0)
|
|
112
|
+
dot_dw_dw = 1.0 # BB step denominator
|
|
113
|
+
_div_check_interval = 25 if cv_mode and _is_gpu else 5
|
|
114
|
+
_lip_check_interval = 25 if cv_mode and _is_gpu else 5
|
|
115
|
+
_conv_check_interval = 10 if cv_mode and _is_gpu else 3
|
|
116
|
+
# For quadratic losses (squared_error) the gradient is linear in coef,
|
|
117
|
+
# so dg = H @ dw and BB1 = BB2 = 1 / Rayleigh_quotient(H, dw). The BB
|
|
118
|
+
# step gives zero adaptation and the algorithm degenerates to ISTA
|
|
119
|
+
# (O(1/k) convergence), too slow to reach the true sparse solution
|
|
120
|
+
# within max_iter. Use standard FISTA (fixed Lipschitz step + Nesterov
|
|
121
|
+
# momentum, O(1/k^2)) instead.
|
|
122
|
+
_is_quadratic = getattr(loss, '_is_quadratic', False)
|
|
123
|
+
|
|
124
|
+
# BB steps estimate local curvature from smooth-gradient differences.
|
|
125
|
+
# For non-smooth penalties the proximal operator introduces a
|
|
126
|
+
# discontinuity that makes the gradient differences noisy.
|
|
127
|
+
#
|
|
128
|
+
# On quadratic losses (squared_error) BB adds nothing — BB1 = BB2 =
|
|
129
|
+
# 1/R_H and the method degenerates to ISTA (O(1/k)). _is_quadratic
|
|
130
|
+
# already disables BB above.
|
|
131
|
+
#
|
|
132
|
+
# For GLM losses with convex non-smooth penalties (L1, elasticnet,
|
|
133
|
+
# adaptive_l1) the subgradient is bounded and BB differences are valid
|
|
134
|
+
# after a burn-in that lets the iterates stabilise. This gives 2-3x
|
|
135
|
+
# faster convergence for logistic+L1, poisson+L1, etc.
|
|
136
|
+
#
|
|
137
|
+
# For non-convex non-smooth penalties (SCAD, MCP, group_*) the
|
|
138
|
+
# subgradient can change abruptly (reweighting, folding points),
|
|
139
|
+
# amplifying noise through the non-linear link and causing catastrophic
|
|
140
|
+
# divergence. Disable BB entirely for these.
|
|
141
|
+
_pen_name = getattr(penalty, "name", _pen_name).lower() if hasattr(getattr(penalty, "name", _pen_name), 'lower') else _pen_name
|
|
142
|
+
if _pen_name in _BB_DISABLED:
|
|
143
|
+
bb_burn_in = max_iter + 1 # never switch to BB
|
|
144
|
+
elif _pen_name in {"l1", "elasticnet", "en", "adaptive_l1", "adaptive_lasso"}:
|
|
145
|
+
bb_burn_in = max(bb_burn_in, 50) # longer burn-in for non-smooth
|
|
146
|
+
|
|
147
|
+
# Initial Lipschitz at zero (safe for all losses). Computing L at
|
|
148
|
+
# init_coef can produce enormous values for exp-link families (mu =
|
|
149
|
+
# exp(X@coef) explodes for warm-start coefs from OLS).
|
|
150
|
+
_zero_coef_bb = _zeros(n_features, backend, ref_tensor=X)
|
|
151
|
+
_cached_lipschitz_L = None
|
|
152
|
+
if lipschitz_L is not None:
|
|
153
|
+
try:
|
|
154
|
+
_cached_lipschitz_L = float(_to_numpy(lipschitz_L))
|
|
155
|
+
except (ValueError, TypeError):
|
|
156
|
+
_cached_lipschitz_L = None
|
|
157
|
+
if _cached_lipschitz_L is not None and _cached_lipschitz_L > 0:
|
|
158
|
+
L = _cached_lipschitz_L
|
|
159
|
+
else:
|
|
160
|
+
_cached_lipschitz_L = None
|
|
161
|
+
L = _call_with_weight(loss.lipschitz, X_proc, _zero_coef_bb, y=y_proc, sample_weight=sample_weight)
|
|
162
|
+
if L <= 0:
|
|
163
|
+
L = 1.0
|
|
164
|
+
# For GLM losses with exp link (Poisson, etc.), mu at coef=0
|
|
165
|
+
# is ~1, but mu near the optimum ~ y. The Hessian X'@diag(mu)@X
|
|
166
|
+
# scales linearly with mu, so Lipschitz at init can underestimate the
|
|
167
|
+
# true curvature by orders of magnitude (e.g. max(y)=2865 vs init mu=1).
|
|
168
|
+
# Use geometric-mean heuristic: robust against extreme outliers while
|
|
169
|
+
# still scaling up enough to avoid oversized first steps.
|
|
170
|
+
# Logistic: BB step handles adaptation, y-scaling causes divergence.
|
|
171
|
+
# Gamma's expected Fisher Hessian X'X/n underestimates
|
|
172
|
+
# true curvature by ~mean(y), so y-scaling IS needed.
|
|
173
|
+
_skip_y_scaling_bb = getattr(loss, '_lipschitz_uses_y', False)
|
|
174
|
+
_y_scale = 1.0 # default; overridden below for families that need it
|
|
175
|
+
if not _is_quadratic and not _skip_y_scaling_bb:
|
|
176
|
+
_y_mean, _y_max = _abs_mean_max(y_proc, backend)
|
|
177
|
+
_y_scale = max(1.0, _y_mean, np.sqrt(_y_mean * _y_max))
|
|
178
|
+
if _y_scale > 1.0:
|
|
179
|
+
L = L * _y_scale
|
|
180
|
+
# Inverse Gaussian: gradient scales as 1/mu^3, causing extreme
|
|
181
|
+
# sensitivity to step size. Use a much more conservative Lipschitz
|
|
182
|
+
# to prevent catastrophic divergence.
|
|
183
|
+
_invgauss_like = getattr(loss, '_inverse_gaussian', False)
|
|
184
|
+
_tweedie_like = getattr(loss, '_tweedie', False)
|
|
185
|
+
_lip_safety_bb = getattr(loss, '_lipschitz_safety', 1.0)
|
|
186
|
+
if _lip_safety_bb > 1.0:
|
|
187
|
+
L = L * _lip_safety_bb
|
|
188
|
+
# Add smooth penalty Lipschitz contribution (e.g. l2 gradient alpha*coef
|
|
189
|
+
# has Lipschitz alpha). Without this the step 1/L is too large.
|
|
190
|
+
_smooth_lip_bb = _smooth_penalty_lipschitz(penalty)
|
|
191
|
+
if _smooth_lip_bb > 0:
|
|
192
|
+
L = L + _smooth_lip_bb
|
|
193
|
+
step_L = 1.0 / L
|
|
194
|
+
step_k = step_L
|
|
195
|
+
step_max = step_L * step_max_factor
|
|
196
|
+
step_min = step_L * step_min_factor
|
|
197
|
+
_validate_sample_weight(sample_weight, X_proc.shape[0])
|
|
198
|
+
|
|
199
|
+
# Gradient at initial point for first BB difference
|
|
200
|
+
grad_old = _call_with_weight(loss.gradient, X_proc, y_proc, coef, sample_weight=sample_weight)
|
|
201
|
+
# Initialize dg for BB step selection (used before first assignment in loop)
|
|
202
|
+
dg = _zeros(n_features, backend, ref_tensor=X_proc)
|
|
203
|
+
iteration = -1 # default if max_iter=0
|
|
204
|
+
|
|
205
|
+
# Loop-invariant constants for momentum/BB decisions
|
|
206
|
+
_poisson_like = getattr(loss, '_poisson_like', False)
|
|
207
|
+
_gamma_like = getattr(loss, '_gamma_like', False)
|
|
208
|
+
|
|
209
|
+
# --- Pre-compute loop-invariant burn-in and momentum parameters ---
|
|
210
|
+
# These depend only on loss/penalty type, not on iterates.
|
|
211
|
+
if _invgauss_like:
|
|
212
|
+
bb_burn_in = max_iter + 1 # never switch to BB
|
|
213
|
+
elif _tweedie_like:
|
|
214
|
+
bb_burn_in = max(200, max_iter // 2)
|
|
215
|
+
elif _gamma_like:
|
|
216
|
+
bb_burn_in = max(50, max_iter // 8)
|
|
217
|
+
|
|
218
|
+
_momentum_disabled = getattr(loss, '_momentum_disabled', False)
|
|
219
|
+
if _momentum_disabled:
|
|
220
|
+
_momentum_burn_in = max_iter + 1 # never use momentum
|
|
221
|
+
elif _tweedie_like:
|
|
222
|
+
_momentum_burn_in = max(100, max_iter // 4)
|
|
223
|
+
elif _gamma_like:
|
|
224
|
+
_momentum_burn_in = max(30, max_iter // 10)
|
|
225
|
+
else:
|
|
226
|
+
_momentum_burn_in = 0 # momentum from the start
|
|
227
|
+
|
|
228
|
+
# Conservative momentum for specific loss+penalty combos
|
|
229
|
+
_momentum_beta_cap = getattr(loss, '_momentum_beta_cap', None)
|
|
230
|
+
if _momentum_beta_cap is not None and _poisson_like and not _invgauss_like:
|
|
231
|
+
_pen_name_bb = getattr(penalty, 'name', '')
|
|
232
|
+
if _pen_name_bb in ("l2", "none", "", None):
|
|
233
|
+
_momentum_burn_in = min(100, max_iter)
|
|
234
|
+
if _tweedie_like or _gamma_like:
|
|
235
|
+
if _momentum_beta_cap is None:
|
|
236
|
+
_momentum_beta_cap = 0.2
|
|
237
|
+
|
|
238
|
+
for iteration in range(max_iter):
|
|
239
|
+
coef_old = _copy_arr(coef)
|
|
240
|
+
|
|
241
|
+
# Gradient at extrapolated point
|
|
242
|
+
grad = _call_with_weight(loss.gradient, X_proc, y_proc, y_k, sample_weight=sample_weight)
|
|
243
|
+
|
|
244
|
+
# Clip extreme gradients -- every iteration, all backends.
|
|
245
|
+
# Skip for inverse_gaussian: 1/mu^3 gradient scaling produces large but
|
|
246
|
+
# valid gradients; clipping prevents convergence to the true optimum.
|
|
247
|
+
# Use identical sync-based clipping for both CPU and GPU to ensure
|
|
248
|
+
# consistent trajectories (backtracking already syncs for non-quadratic).
|
|
249
|
+
if not _invgauss_like:
|
|
250
|
+
if cv_mode and _is_gpu:
|
|
251
|
+
grad = _clip_grad_on_device(grad, coef_old, backend)
|
|
252
|
+
else:
|
|
253
|
+
_gn_f, _coef_abs_f = _sync_scalars(
|
|
254
|
+
_norm2_dev(grad), _abs_sum_dev(coef_old), backend=backend)
|
|
255
|
+
_gmax = max(_coef_abs_f * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
|
|
256
|
+
if _gn_f > _gmax:
|
|
257
|
+
grad = grad * (_gmax / _gn_f)
|
|
258
|
+
|
|
259
|
+
# --- Divergence detection ---
|
|
260
|
+
# Full objective check every 5 iterations (GPU optimization: reduces
|
|
261
|
+
# expensive loss.value() calls). Coefficient norm check every iteration
|
|
262
|
+
# (cheap) catches catastrophic explosion early.
|
|
263
|
+
# Batch obj + coef-norm into a single sync when both are needed.
|
|
264
|
+
_do_full_div_check = (
|
|
265
|
+
iteration % _div_check_interval == 0 or iteration <= 5
|
|
266
|
+
)
|
|
267
|
+
# GPU: defer ALL divergence checks to every 5 iterations (no per-iter sync)
|
|
268
|
+
_do_div_check = (not _is_quadratic and iteration > 0 and
|
|
269
|
+
(not _is_gpu or _do_full_div_check))
|
|
270
|
+
if _do_div_check:
|
|
271
|
+
_diverged = False
|
|
272
|
+
# Coef norm divergence check (works for both CPU and GPU)
|
|
273
|
+
if iteration > 10 and not _diverged:
|
|
274
|
+
_coef_norm_dev = _norm2_dev(coef)
|
|
275
|
+
if _to_float_scalar(_coef_norm_dev) > _DIVERGE_COEF_NORM_CAP:
|
|
276
|
+
_diverged = True
|
|
277
|
+
# Full objective check every 5 iterations
|
|
278
|
+
if not _diverged:
|
|
279
|
+
_obj_val = float(_to_numpy(_call_with_weight(loss.value, X_proc, y_proc, coef, sample_weight=sample_weight)))
|
|
280
|
+
_pen_val = _tracking_penalty_value(penalty, coef)
|
|
281
|
+
_obj_total = _obj_val + _pen_val
|
|
282
|
+
if not np.isfinite(_obj_total):
|
|
283
|
+
_diverged = True
|
|
284
|
+
elif not np.isfinite(_obj_best):
|
|
285
|
+
# _obj_best is inf/-inf (first valid iter or degenerate loss):
|
|
286
|
+
# skip ratio-based check, rely on norm check above.
|
|
287
|
+
pass
|
|
288
|
+
elif _obj_best > 1e-8:
|
|
289
|
+
_diverge_threshold = _obj_best * 10.0 + 1e-8
|
|
290
|
+
if _invgauss_like or _tweedie_like:
|
|
291
|
+
_diverge_threshold = _obj_best * _DIVERGE_OBJ_RATIO + _DIVERGE_OBJ_ABS
|
|
292
|
+
_diverged = _obj_total > _diverge_threshold
|
|
293
|
+
else:
|
|
294
|
+
_diverge_threshold = _obj_best + max(abs(_obj_best) * 10.0, 1.0)
|
|
295
|
+
if _invgauss_like or _tweedie_like:
|
|
296
|
+
_diverge_threshold = _obj_best + max(abs(_obj_best) * _DIVERGE_OBJ_RATIO, _DIVERGE_OBJ_ABS)
|
|
297
|
+
_diverged = _obj_total > _diverge_threshold
|
|
298
|
+
if _diverged:
|
|
299
|
+
# Diverged: reset to best known iterate (or zeros) and halve step
|
|
300
|
+
_diverge_count += 1
|
|
301
|
+
if _coef_best is not None:
|
|
302
|
+
coef = _copy_arr(_coef_best)
|
|
303
|
+
else:
|
|
304
|
+
# No valid iterate yet -- reset to zeros
|
|
305
|
+
coef = _zeros(n_features, backend, ref_tensor=X_proc)
|
|
306
|
+
y_k = _copy_arr(coef)
|
|
307
|
+
t_k = 1.0
|
|
308
|
+
grad_old = _call_with_weight(loss.gradient, X_proc, y_proc, coef, sample_weight=sample_weight)
|
|
309
|
+
# Halve step size bounds
|
|
310
|
+
step_L = step_L * 0.5
|
|
311
|
+
step_k = step_L
|
|
312
|
+
step_max = step_max * 0.5
|
|
313
|
+
step_min = step_min * 0.5
|
|
314
|
+
L = L * 2.0
|
|
315
|
+
# Reset BB state
|
|
316
|
+
dot_dw_dg = 0.0
|
|
317
|
+
dot_dw_dw = 1.0
|
|
318
|
+
continue
|
|
319
|
+
elif _obj_total < _obj_best:
|
|
320
|
+
_obj_best = _obj_total
|
|
321
|
+
_coef_best = _copy_arr(coef)
|
|
322
|
+
|
|
323
|
+
# --- Step size selection ---
|
|
324
|
+
if _is_quadratic or iteration < bb_burn_in:
|
|
325
|
+
# Quadratic loss or burn-in phase: use fixed Lipschitz step.
|
|
326
|
+
# During burn-in for GLM losses, BB steps are delayed because
|
|
327
|
+
# early gradient differences (dw, dg) are dominated by the
|
|
328
|
+
# coef trajectory from zero toward the optimum rather than by
|
|
329
|
+
# local curvature; using BB too early amplifies oscillations.
|
|
330
|
+
step_k = step_L
|
|
331
|
+
# Recompute Lipschitz periodically during burn-in since mu
|
|
332
|
+
# (and therefore the Hessian scale) changes rapidly.
|
|
333
|
+
if (
|
|
334
|
+
not _is_quadratic
|
|
335
|
+
and iteration > 0
|
|
336
|
+
and iteration % _lip_check_interval == 0
|
|
337
|
+
):
|
|
338
|
+
# Use global Lipschitz (coef=zero) during burn-in to prevent
|
|
339
|
+
# iterate-dependent Lipschitz from shrinking too fast.
|
|
340
|
+
# BB steps handle adaptation after burn-in.
|
|
341
|
+
# Pass zero coef -- not all losses handle coef=None.
|
|
342
|
+
if _cached_lipschitz_L is not None:
|
|
343
|
+
L_new = _cached_lipschitz_L
|
|
344
|
+
else:
|
|
345
|
+
L_new = loss.lipschitz(X_proc, _zero_coef_bb, y=y_proc)
|
|
346
|
+
if L_new > 0:
|
|
347
|
+
# Re-apply y-scaling and per-family safety factor
|
|
348
|
+
if _y_scale > 1.0:
|
|
349
|
+
L_new = L_new * _y_scale
|
|
350
|
+
_lip_safety_bt = getattr(loss, '_lipschitz_safety', 1.0)
|
|
351
|
+
if _lip_safety_bt > 1.0:
|
|
352
|
+
L_new = L_new * _lip_safety_bt
|
|
353
|
+
# Allow L to move toward L_new: full increase, gradual decrease
|
|
354
|
+
if L_new > L:
|
|
355
|
+
L = L_new
|
|
356
|
+
else:
|
|
357
|
+
L = max(L * 0.8, L_new)
|
|
358
|
+
step_L = 1.0 / L
|
|
359
|
+
step_k = step_L
|
|
360
|
+
step_max = step_L * step_max_factor
|
|
361
|
+
step_min = step_L * step_min_factor
|
|
362
|
+
else:
|
|
363
|
+
# Nonlinear GLM loss, post-burn-in: use BB step when valid,
|
|
364
|
+
# fall back to Lipschitz step otherwise.
|
|
365
|
+
if dot_dw_dg > _BB_RESTART_DOT_TOL:
|
|
366
|
+
if _bb_use_long:
|
|
367
|
+
step_k = dot_dw_dw / dot_dw_dg # BB1: long
|
|
368
|
+
else:
|
|
369
|
+
dot_dg_dg = float(_to_numpy(_dot_dev(dg, dg)))
|
|
370
|
+
step_k = dot_dw_dg / max(dot_dg_dg, 1e-14) # BB2: short
|
|
371
|
+
_bb_use_long = not _bb_use_long
|
|
372
|
+
# Tweedie: cap BB step more aggressively to prevent overshoot
|
|
373
|
+
if _tweedie_like:
|
|
374
|
+
step_k = min(step_k, step_L * 2.0)
|
|
375
|
+
step_k = min(max(step_k, step_min), step_max)
|
|
376
|
+
# else: keep previous step_k (step_L or last valid BB step)
|
|
377
|
+
|
|
378
|
+
# Gradient step + proximal
|
|
379
|
+
w_tilde = y_k - step_k * grad
|
|
380
|
+
coef_new = penalty.proximal(w_tilde, step_k, backend=backend)
|
|
381
|
+
coef = coef_new
|
|
382
|
+
|
|
383
|
+
# Safeguarded backtracking for GLM losses:
|
|
384
|
+
# After proximal, verify the objective didn't explode. If it did,
|
|
385
|
+
# halve step and recompute. This catches cases where the BB step
|
|
386
|
+
# or Lipschitz estimate was too optimistic for the new coef region.
|
|
387
|
+
# Interval-based: full objective check every 5 iterations (expensive
|
|
388
|
+
# loss.value() call), cheap norm check every iteration.
|
|
389
|
+
_last_coef_norm_f = None
|
|
390
|
+
if not _is_quadratic:
|
|
391
|
+
_steep_loss = getattr(loss, '_steep_loss', False)
|
|
392
|
+
# Interval-based: only run expensive objective check every 5 iters
|
|
393
|
+
# (divergence detection above also checks every 5 iters)
|
|
394
|
+
_do_bt_check = (iteration % 5 == 0 or iteration <= 5)
|
|
395
|
+
if _do_bt_check:
|
|
396
|
+
for _bt in range(15):
|
|
397
|
+
# Batch obj + coef-norm into a single sync.
|
|
398
|
+
_new_obj, _new_norm = _sync_scalars(
|
|
399
|
+
loss.value(X_proc, y_proc, coef), _norm2_dev(coef), backend=backend)
|
|
400
|
+
_new_pen = _tracking_penalty_value(penalty, coef)
|
|
401
|
+
_new_total = _new_obj + _new_pen
|
|
402
|
+
# Accept if: finite, reasonable norm, and objective not exploded.
|
|
403
|
+
# Use relative threshold (10x initial objective) instead of
|
|
404
|
+
# absolute 1e6 -- NB/Tweedie with large counts can have
|
|
405
|
+
# legitimate loss > 1e6.
|
|
406
|
+
_obj_cap = max(_obj_best * 10.0, 1e6) if np.isfinite(_obj_best) else 1e6
|
|
407
|
+
if _steep_loss:
|
|
408
|
+
_obj_acceptable = (np.isfinite(_new_total) and _new_norm < _DIVERGE_COEF_NORM_CAP and
|
|
409
|
+
_new_total < _obj_cap)
|
|
410
|
+
else:
|
|
411
|
+
# For logistic/gamma/poisson: accept if finite, reasonable
|
|
412
|
+
# norm, and objective not significantly worse than best known.
|
|
413
|
+
_obj_acceptable = (np.isfinite(_new_total) and _new_norm < _DIVERGE_COEF_NORM_CAP and
|
|
414
|
+
_new_total < max(_obj_best * 1.5 + 1.0, 1e3))
|
|
415
|
+
if _obj_acceptable:
|
|
416
|
+
_last_coef_norm_f = _new_norm
|
|
417
|
+
break
|
|
418
|
+
# Step too large -- halve and retry
|
|
419
|
+
step_k = step_k * 0.5
|
|
420
|
+
L = L * 2.0
|
|
421
|
+
w_tilde = y_k - step_k * grad
|
|
422
|
+
coef = penalty.proximal(w_tilde, step_k, backend=backend)
|
|
423
|
+
_last_coef_norm_f = None
|
|
424
|
+
|
|
425
|
+
# Finiteness check: if coef is non-finite after proximal, reset.
|
|
426
|
+
# Reuse the norm already synchronized by safeguarded backtracking.
|
|
427
|
+
if not _is_quadratic:
|
|
428
|
+
if _last_coef_norm_f is not None:
|
|
429
|
+
_finite_ok2 = np.isfinite(_last_coef_norm_f)
|
|
430
|
+
else:
|
|
431
|
+
_coef_norm_dev2 = _norm2_dev(coef)
|
|
432
|
+
_finite_ok2 = np.isfinite(_to_float_scalar(_coef_norm_dev2))
|
|
433
|
+
if not _finite_ok2:
|
|
434
|
+
_diverge_count += 1
|
|
435
|
+
if _coef_best is not None:
|
|
436
|
+
coef = _copy_arr(_coef_best)
|
|
437
|
+
y_k = _copy_arr(coef)
|
|
438
|
+
t_k = 1.0
|
|
439
|
+
grad_old = _call_with_weight(loss.gradient, X_proc, y_proc, coef, sample_weight=sample_weight)
|
|
440
|
+
step_L = step_L * 0.5
|
|
441
|
+
step_k = step_L
|
|
442
|
+
step_max = step_max * 0.5
|
|
443
|
+
step_min = step_min * 0.5
|
|
444
|
+
L = L * 2.0
|
|
445
|
+
dot_dw_dg = 0.0
|
|
446
|
+
dot_dw_dw = 1.0
|
|
447
|
+
continue
|
|
448
|
+
|
|
449
|
+
# --- Store BB step info for next iteration (non-quadratic only) ---
|
|
450
|
+
# Use accepted iterate (coef) not pre-backtracking (coef_new)
|
|
451
|
+
if not _is_quadratic:
|
|
452
|
+
grad_new = _call_with_weight(loss.gradient, X_proc, y_proc, coef, sample_weight=sample_weight)
|
|
453
|
+
|
|
454
|
+
dw = coef - coef_old
|
|
455
|
+
dg = grad_new - grad_old
|
|
456
|
+
# Batch two dot products into a single GPU->CPU sync.
|
|
457
|
+
dot_dw_dw, dot_dw_dg = _sync_scalars(
|
|
458
|
+
_dot_dev(dw, dw), _dot_dev(dw, dg), backend=backend)
|
|
459
|
+
grad_old = grad_new
|
|
460
|
+
|
|
461
|
+
# --- Nesterov momentum with adaptive restart ---
|
|
462
|
+
# bb_burn_in, _momentum_burn_in, _momentum_beta_cap are loop-invariant
|
|
463
|
+
# and computed once before the loop.
|
|
464
|
+
if iteration < _momentum_burn_in:
|
|
465
|
+
t_k = 1.0
|
|
466
|
+
beta = 0.0
|
|
467
|
+
y_k = _copy_arr(coef) # next gradient at current point, not extrapolated
|
|
468
|
+
elif _momentum_beta_cap is not None:
|
|
469
|
+
# Conservative momentum: fixed small beta to avoid explosion
|
|
470
|
+
beta = _momentum_beta_cap
|
|
471
|
+
y_k = coef + beta * (coef - coef_old)
|
|
472
|
+
t_k = 1.0
|
|
473
|
+
else:
|
|
474
|
+
y_k, t_new = _nesterov_update(coef, coef_old, t_k)
|
|
475
|
+
beta = (t_k - 1.0) / t_new
|
|
476
|
+
|
|
477
|
+
if use_restart and iteration > 0:
|
|
478
|
+
# GPU-side comparison, only sync bool.
|
|
479
|
+
# Use `coef` (always current) not `coef_new` (stale after reset).
|
|
480
|
+
_mc_dev = _dot_dev(y_k - coef, coef - coef_old)
|
|
481
|
+
if _to_float_scalar(_mc_dev) > 0:
|
|
482
|
+
t_k = 1.0
|
|
483
|
+
t_new = 1.0
|
|
484
|
+
beta = 0.0
|
|
485
|
+
y_k = coef + beta * (coef - coef_old)
|
|
486
|
+
|
|
487
|
+
t_k = t_new
|
|
488
|
+
|
|
489
|
+
# --- Convergence check -- deferred for GPU, every iteration for CPU. ---
|
|
490
|
+
if _is_gpu:
|
|
491
|
+
if iteration < 20 or iteration % _conv_check_interval == 0:
|
|
492
|
+
_conv_dev2 = _abs_sum_dev(coef - coef_old)
|
|
493
|
+
if _to_float_scalar(_conv_dev2) < tol:
|
|
494
|
+
break
|
|
495
|
+
else:
|
|
496
|
+
_conv_dev2 = _abs_sum_dev(coef - coef_old)
|
|
497
|
+
if _to_float_scalar(_conv_dev2) < tol:
|
|
498
|
+
break
|
|
499
|
+
|
|
500
|
+
# Return best iterate if divergence was detected
|
|
501
|
+
if _diverge_count > 0 and _coef_best is not None:
|
|
502
|
+
coef = _copy_arr(_coef_best)
|
|
503
|
+
|
|
504
|
+
n_iter = iteration + 1
|
|
505
|
+
if n_iter >= max_iter:
|
|
506
|
+
warnings.warn(
|
|
507
|
+
f"fista_bb_solver did not converge within {max_iter} iterations "
|
|
508
|
+
f"(loss={getattr(loss, 'name', '?')}, penalty={getattr(penalty, 'name', '?')}). "
|
|
509
|
+
f"Consider increasing max_iter or using a different solver (newton, lbfgs, irls).",
|
|
510
|
+
ConvergenceWarning,
|
|
511
|
+
stacklevel=2,
|
|
512
|
+
)
|
|
513
|
+
return coef, n_iter
|